mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-05-20 01:45:48 +08:00
413 lines
11 KiB
Go
413 lines
11 KiB
Go
|
|
package pgsql
|
|||
|
|
|
|||
|
|
import (
|
|||
|
|
"strings"
|
|||
|
|
"testing"
|
|||
|
|
|
|||
|
|
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
// ========== SELECT 测试 ==========
|
|||
|
|
|
|||
|
|
func TestPgSelectBasic(t *testing.T) {
|
|||
|
|
sql := "SELECT id, name FROM users WHERE status = 1 ORDER BY id DESC LIMIT 10"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
|
|||
|
|
// 验证完整文本
|
|||
|
|
if selectStmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 SELECT 字段
|
|||
|
|
if len(selectStmt.Items) != 2 {
|
|||
|
|
t.Fatalf("expected 2 items, got %d", len(selectStmt.Items))
|
|||
|
|
}
|
|||
|
|
if selectStmt.Items[0].Text != "id" {
|
|||
|
|
t.Errorf("expected Items[0]='id', got '%s'", selectStmt.Items[0].Text)
|
|||
|
|
}
|
|||
|
|
if selectStmt.Items[1].Text != "name" {
|
|||
|
|
t.Errorf("expected Items[1]='name', got '%s'", selectStmt.Items[1].Text)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 FROM 表名
|
|||
|
|
if len(selectStmt.From) != 1 {
|
|||
|
|
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
|||
|
|
}
|
|||
|
|
if selectStmt.From[0].Name != "users" {
|
|||
|
|
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 WHERE
|
|||
|
|
if selectStmt.Where == nil {
|
|||
|
|
t.Fatal("expected WHERE")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Where.Text != "status = 1" {
|
|||
|
|
t.Errorf("expected WHERE='status = 1', got '%s'", selectStmt.Where.Text)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 ORDER BY
|
|||
|
|
if len(selectStmt.OrderBy) != 1 {
|
|||
|
|
t.Fatalf("expected 1 OrderBy, got %d", len(selectStmt.OrderBy))
|
|||
|
|
}
|
|||
|
|
if selectStmt.OrderBy[0].Text != "id" {
|
|||
|
|
t.Errorf("expected OrderBy[0].Text='id', got '%s'", selectStmt.OrderBy[0].Text)
|
|||
|
|
}
|
|||
|
|
if !selectStmt.OrderBy[0].Desc {
|
|||
|
|
t.Error("expected OrderBy[0].Desc=true")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 LIMIT
|
|||
|
|
if selectStmt.Limit == nil {
|
|||
|
|
t.Fatal("expected LIMIT")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Limit.Text != "LIMIT 10" {
|
|||
|
|
t.Errorf("expected LIMIT.Text='LIMIT 10', got '%s'", selectStmt.Limit.Text)
|
|||
|
|
}
|
|||
|
|
if selectStmt.Limit.Count != 10 {
|
|||
|
|
t.Errorf("expected LIMIT.Count=10, got %d", selectStmt.Limit.Count)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgSelectDistinct(t *testing.T) {
|
|||
|
|
sql := "SELECT DISTINCT id, name FROM users"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
if !selectStmt.Distinct {
|
|||
|
|
t.Error("expected Distinct=true")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgSelectStar(t *testing.T) {
|
|||
|
|
sql := "SELECT * FROM users"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
|
|||
|
|
// 验证完整文本
|
|||
|
|
if selectStmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 SELECT *
|
|||
|
|
if len(selectStmt.Items) != 1 {
|
|||
|
|
t.Fatalf("expected 1 item, got %d", len(selectStmt.Items))
|
|||
|
|
}
|
|||
|
|
if selectStmt.Items[0].Text != "*" {
|
|||
|
|
t.Errorf("expected Items[0]='*', got '%s'", selectStmt.Items[0].Text)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 FROM
|
|||
|
|
if len(selectStmt.From) != 1 {
|
|||
|
|
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
|||
|
|
}
|
|||
|
|
if selectStmt.From[0].Name != "users" {
|
|||
|
|
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgSelectWithJoin(t *testing.T) {
|
|||
|
|
sql := "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
|
|||
|
|
// 验证完整文本
|
|||
|
|
if selectStmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证主表
|
|||
|
|
if len(selectStmt.From) != 1 {
|
|||
|
|
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
|||
|
|
}
|
|||
|
|
if selectStmt.From[0].Name != "users" {
|
|||
|
|
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
|||
|
|
}
|
|||
|
|
if selectStmt.From[0].Alias != "u" {
|
|||
|
|
t.Errorf("expected From[0].Alias='u', got '%s'", selectStmt.From[0].Alias)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 JOIN
|
|||
|
|
if len(selectStmt.Joins) != 1 {
|
|||
|
|
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
|
|||
|
|
}
|
|||
|
|
if selectStmt.Joins[0].Table.Name != "orders" {
|
|||
|
|
t.Errorf("expected Join[0].Table.Name='orders', got '%s'", selectStmt.Joins[0].Table.Name)
|
|||
|
|
}
|
|||
|
|
if selectStmt.Joins[0].Table.Alias != "o" {
|
|||
|
|
t.Errorf("expected Join[0].Table.Alias='o', got '%s'", selectStmt.Joins[0].Table.Alias)
|
|||
|
|
}
|
|||
|
|
if selectStmt.Joins[0].Kind != sqlstmt.JoinKindLeft {
|
|||
|
|
t.Errorf("expected Join[0].Kind=JoinKindLeft, got '%d'", selectStmt.Joins[0].Kind)
|
|||
|
|
}
|
|||
|
|
if selectStmt.Joins[0].On == nil {
|
|||
|
|
t.Fatal("expected Join[0].On")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Joins[0].On.Text != "u.id = o.user_id" {
|
|||
|
|
t.Errorf("expected Join[0].On.Text='u.id = o.user_id', got '%s'", selectStmt.Joins[0].On.Text)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 WHERE
|
|||
|
|
if selectStmt.Where == nil {
|
|||
|
|
t.Fatal("expected WHERE")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Where.Text != "u.status = 1" {
|
|||
|
|
t.Errorf("expected WHERE='u.status = 1', got '%s'", selectStmt.Where.Text)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgSelectComplexWhere(t *testing.T) {
|
|||
|
|
sql := "SELECT * FROM users WHERE status = 1 AND (age > 18 OR role = 'admin') ORDER BY name"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
|
|||
|
|
// 验证完整文本
|
|||
|
|
if selectStmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 FROM
|
|||
|
|
if len(selectStmt.From) != 1 {
|
|||
|
|
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
|||
|
|
}
|
|||
|
|
if selectStmt.From[0].Name != "users" {
|
|||
|
|
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 WHERE
|
|||
|
|
if selectStmt.Where == nil {
|
|||
|
|
t.Fatal("expected WHERE")
|
|||
|
|
}
|
|||
|
|
expectedWhere := "status = 1 AND (age > 18 OR role = 'admin')"
|
|||
|
|
if selectStmt.Where.Text != expectedWhere {
|
|||
|
|
t.Errorf("expected WHERE='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 ORDER BY
|
|||
|
|
if len(selectStmt.OrderBy) != 1 {
|
|||
|
|
t.Fatalf("expected 1 OrderBy, got %d", len(selectStmt.OrderBy))
|
|||
|
|
}
|
|||
|
|
if selectStmt.OrderBy[0].Text != "name" {
|
|||
|
|
t.Errorf("expected OrderBy[0].Text='name', got '%s'", selectStmt.OrderBy[0].Text)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgOffsetLimit(t *testing.T) {
|
|||
|
|
sql := "SELECT * FROM users OFFSET 10 LIMIT 20"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
|
|||
|
|
// 验证完整文本
|
|||
|
|
if selectStmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 LIMIT/OFFSET
|
|||
|
|
if selectStmt.Limit == nil {
|
|||
|
|
t.Fatal("expected LIMIT")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Limit.Text != "OFFSET 10 LIMIT 20" {
|
|||
|
|
t.Errorf("expected LIMIT.Text='OFFSET 10 LIMIT 20', got '%s'", selectStmt.Limit.Text)
|
|||
|
|
}
|
|||
|
|
if selectStmt.Limit.Count != 20 {
|
|||
|
|
t.Errorf("expected LIMIT.Count=20, got %d", selectStmt.Limit.Count)
|
|||
|
|
}
|
|||
|
|
if selectStmt.Limit.Offset != 10 {
|
|||
|
|
t.Errorf("expected LIMIT.Offset=10, got %d", selectStmt.Limit.Offset)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgLimitAll(t *testing.T) {
|
|||
|
|
sql := "SELECT * FROM users LIMIT ALL"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT ALL" {
|
|||
|
|
t.Errorf("expected LIMIT text='LIMIT ALL'")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgUnion(t *testing.T) {
|
|||
|
|
sql := "SELECT 1 UNION SELECT 2 UNION ALL SELECT 3"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
|
|||
|
|
// 验证完整文本
|
|||
|
|
if selectStmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 UNION 数量
|
|||
|
|
if len(selectStmt.Unions) != 2 {
|
|||
|
|
t.Fatalf("expected 2 unions, got %d", len(selectStmt.Unions))
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证第一个 UNION(DISTINCT)
|
|||
|
|
if selectStmt.Unions[0].All {
|
|||
|
|
t.Error("expected Unions[0].All=false (DISTINCT)")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Unions[0].Select == nil {
|
|||
|
|
t.Fatal("expected Unions[0].Select")
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证第二个 UNION(ALL)
|
|||
|
|
if !selectStmt.Unions[1].All {
|
|||
|
|
t.Error("expected Unions[1].All=true")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Unions[1].Select == nil {
|
|||
|
|
t.Fatal("expected Unions[1].Select")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ========== FOR UPDATE 测试 ==========
|
|||
|
|
|
|||
|
|
func TestPgForUpdate(t *testing.T) {
|
|||
|
|
sql := "SELECT id FROM users WHERE id = 1 FOR UPDATE"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
|
|||
|
|
// 验证完整文本
|
|||
|
|
if selectStmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 FROM
|
|||
|
|
if len(selectStmt.From) != 1 {
|
|||
|
|
t.Fatalf("expected 1 from table, got %d", len(selectStmt.From))
|
|||
|
|
}
|
|||
|
|
if selectStmt.From[0].Name != "users" {
|
|||
|
|
t.Errorf("expected From[0].Name='users', got '%s'", selectStmt.From[0].Name)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 验证 WHERE
|
|||
|
|
if selectStmt.Where == nil {
|
|||
|
|
t.Fatal("expected WHERE")
|
|||
|
|
}
|
|||
|
|
if selectStmt.Where.Text != "id = 1" {
|
|||
|
|
t.Errorf("expected WHERE='id = 1', got '%s'", selectStmt.Where.Text)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// 注意:FOR UPDATE 标记可能在 Base.Text 中体现
|
|||
|
|
// 只要完整文本包含 FOR UPDATE 即可
|
|||
|
|
if !strings.Contains(selectStmt.GetText(), "FOR UPDATE") {
|
|||
|
|
t.Error("expected text to contain 'FOR UPDATE'")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgForUpdateSkipLocked(t *testing.T) {
|
|||
|
|
sql := "SELECT id FROM users WHERE status = 1 FOR UPDATE SKIP LOCKED"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
|||
|
|
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 1" {
|
|||
|
|
t.Errorf("expected WHERE='status = 1'")
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ========== RETURNING 测试 ==========
|
|||
|
|
|
|||
|
|
func TestPgUpdateReturning(t *testing.T) {
|
|||
|
|
sql := "UPDATE users SET name = 'John' WHERE id = 1 RETURNING id, name"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if stmt == nil {
|
|||
|
|
t.Fatalf("expected stmt not nil")
|
|||
|
|
}
|
|||
|
|
if stmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s'", sql)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
func TestPgDeleteReturning(t *testing.T) {
|
|||
|
|
sql := "DELETE FROM users WHERE id = 1 RETURNING *"
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if stmt == nil {
|
|||
|
|
t.Fatalf("expected stmt not nil")
|
|||
|
|
}
|
|||
|
|
if stmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s'", sql)
|
|||
|
|
}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
// ========== DDL 测试 ==========
|
|||
|
|
|
|||
|
|
func TestPgDDL(t *testing.T) {
|
|||
|
|
sqls := []string{
|
|||
|
|
"CREATE TABLE users (id INT PRIMARY KEY)",
|
|||
|
|
"DROP TABLE users",
|
|||
|
|
"ALTER TABLE users ADD COLUMN email VARCHAR(255)",
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
for _, sql := range sqls {
|
|||
|
|
t.Run(sql[:10], func(t *testing.T) {
|
|||
|
|
parser := NewParser(sql)
|
|||
|
|
stmt, err := parser.Parse()
|
|||
|
|
if err != nil {
|
|||
|
|
t.Fatalf("parse error: %v", err)
|
|||
|
|
}
|
|||
|
|
if stmt == nil {
|
|||
|
|
t.Fatalf("expected stmt not nil")
|
|||
|
|
}
|
|||
|
|
if stmt.GetText() != sql {
|
|||
|
|
t.Errorf("expected text='%s'", sql)
|
|||
|
|
}
|
|||
|
|
})
|
|||
|
|
}
|
|||
|
|
}
|