mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-05-19 01:15:40 +08:00
847 lines
24 KiB
Go
847 lines
24 KiB
Go
package pgsql
|
||
|
||
import (
|
||
"testing"
|
||
|
||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||
)
|
||
|
||
// ========== INSERT 测试 ==========
|
||
|
||
func TestPgInsertBasic(t *testing.T) {
|
||
sql := "INSERT INTO users (name, age) VALUES ('John', 30)"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt, ok := stmt.(*sqlstmt.InsertStmt)
|
||
if !ok {
|
||
t.Fatal("expected InsertStmt")
|
||
}
|
||
|
||
// 验证完整文本
|
||
if insertStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s', got '%s'", sql, insertStmt.GetText())
|
||
}
|
||
|
||
// 验证表名
|
||
if insertStmt.Table.Name != "users" {
|
||
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
|
||
}
|
||
|
||
// 验证列名
|
||
if len(insertStmt.Columns) != 2 {
|
||
t.Fatalf("expected 2 columns, got %d", len(insertStmt.Columns))
|
||
}
|
||
if insertStmt.Columns[0] != "name" {
|
||
t.Errorf("expected Columns[0]='name', got '%s'", insertStmt.Columns[0])
|
||
}
|
||
if insertStmt.Columns[1] != "age" {
|
||
t.Errorf("expected Columns[1]='age', got '%s'", insertStmt.Columns[1])
|
||
}
|
||
}
|
||
|
||
func TestPgInsertMultipleRows(t *testing.T) {
|
||
sql := "INSERT INTO users (name, age) VALUES ('John', 30), ('Jane', 25), ('Bob', 35)"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
if insertStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s'", sql)
|
||
}
|
||
}
|
||
|
||
func TestPgInsertReturning(t *testing.T) {
|
||
sql := "INSERT INTO users (name, age) VALUES ('John', 30) RETURNING id, name"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
if insertStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s'", sql)
|
||
}
|
||
}
|
||
|
||
func TestPgInsertFromSelect(t *testing.T) {
|
||
sql := "INSERT INTO users_backup SELECT * FROM users WHERE status = 1"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
if insertStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s'", sql)
|
||
}
|
||
}
|
||
|
||
// ========== UPDATE 测试 ==========
|
||
|
||
func TestPgUpdateBasic(t *testing.T) {
|
||
sql := "UPDATE users SET name = 'John' WHERE id = 1"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
|
||
// 验证完整文本
|
||
if updateStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
|
||
}
|
||
|
||
// 验证表名
|
||
if len(updateStmt.Tables) != 1 {
|
||
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
|
||
}
|
||
if updateStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||
}
|
||
|
||
// 验证 SET 字段
|
||
if len(updateStmt.Set) != 1 {
|
||
t.Fatalf("expected 1 assignment, got %d", len(updateStmt.Set))
|
||
}
|
||
if updateStmt.Set[0].Column != "name" {
|
||
t.Errorf("expected Set[0].Column='name', got '%s'", updateStmt.Set[0].Column)
|
||
}
|
||
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'John'" {
|
||
t.Errorf("expected Set[0].Value=''John'', got '%s'", updateStmt.Set[0].Value.Text)
|
||
}
|
||
|
||
// 验证 WHERE
|
||
if updateStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
if updateStmt.Where.Text != "id = 1" {
|
||
t.Errorf("expected WHERE='id = 1', got '%s'", updateStmt.Where.Text)
|
||
}
|
||
}
|
||
|
||
func TestPgUpdateMultipleColumns(t *testing.T) {
|
||
sql := "UPDATE users SET name = 'John', age = 30, email = 'john@example.com' WHERE id = 1"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
|
||
// 验证完整文本
|
||
if updateStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
|
||
}
|
||
|
||
// 验证表名
|
||
if len(updateStmt.Tables) != 1 {
|
||
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
|
||
}
|
||
if updateStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||
}
|
||
|
||
// 验证 SET 字段(3个赋值)
|
||
if len(updateStmt.Set) != 3 {
|
||
t.Fatalf("expected 3 assignments, got %d", len(updateStmt.Set))
|
||
}
|
||
|
||
// 验证第一个赋值:name = 'John'
|
||
if updateStmt.Set[0].Column != "name" {
|
||
t.Errorf("expected Set[0].Column='name', got '%s'", updateStmt.Set[0].Column)
|
||
}
|
||
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'John'" {
|
||
t.Errorf("expected Set[0].Value=''John'', got '%s'", updateStmt.Set[0].Value.Text)
|
||
}
|
||
if updateStmt.Set[0].Text != "name = 'John'" {
|
||
t.Errorf("expected Set[0].Text=\"name = 'John'\", got '%s'", updateStmt.Set[0].Text)
|
||
}
|
||
|
||
// 验证第二个赋值:age = 30
|
||
if updateStmt.Set[1].Column != "age" {
|
||
t.Errorf("expected Set[1].Column='age', got '%s'", updateStmt.Set[1].Column)
|
||
}
|
||
if updateStmt.Set[1].Value == nil || updateStmt.Set[1].Value.Text != "30" {
|
||
t.Errorf("expected Set[1].Value='30', got '%s'", updateStmt.Set[1].Value.Text)
|
||
}
|
||
if updateStmt.Set[1].Text != "age = 30" {
|
||
t.Errorf("expected Set[1].Text='age = 30', got '%s'", updateStmt.Set[1].Text)
|
||
}
|
||
|
||
// 验证第三个赋值:email = 'john@example.com'
|
||
if updateStmt.Set[2].Column != "email" {
|
||
t.Errorf("expected Set[2].Column='email', got '%s'", updateStmt.Set[2].Column)
|
||
}
|
||
if updateStmt.Set[2].Value == nil || updateStmt.Set[2].Value.Text != "'john@example.com'" {
|
||
t.Errorf("expected Set[2].Value=''john@example.com'', got '%s'", updateStmt.Set[2].Value.Text)
|
||
}
|
||
if updateStmt.Set[2].Text != "email = 'john@example.com'" {
|
||
t.Errorf("expected Set[2].Text=\"email = 'john@example.com'\", got '%s'", updateStmt.Set[2].Text)
|
||
}
|
||
|
||
// 验证 WHERE
|
||
if updateStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
if updateStmt.Where.Text != "id = 1" {
|
||
t.Errorf("expected WHERE='id = 1', got '%s'", updateStmt.Where.Text)
|
||
}
|
||
}
|
||
|
||
func TestPgUpdateFromJoin(t *testing.T) {
|
||
// PostgreSQL UPDATE FROM 语法
|
||
sql := "UPDATE orders SET status = 'shipped' FROM users WHERE orders.user_id = users.id AND users.status = 1"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
t.Logf("UPDATE text: %s", updateStmt.GetText())
|
||
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
|
||
|
||
// 主表名应该正确解析
|
||
if len(updateStmt.Tables) < 1 {
|
||
t.Fatal("expected at least 1 table")
|
||
}
|
||
// 注意:当前解析器对于 UPDATE FROM 语法支持有限,可能表名包含别名
|
||
// 但至少验证能成功解析而不报错
|
||
t.Logf("Successfully parsed UPDATE FROM statement")
|
||
}
|
||
|
||
// ========== DELETE 测试 ==========
|
||
|
||
func TestPgDeleteBasic(t *testing.T) {
|
||
sql := "DELETE FROM users WHERE id = 1"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
|
||
// 验证完整文本
|
||
if deleteStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s', got '%s'", sql, deleteStmt.GetText())
|
||
}
|
||
|
||
// 验证表名
|
||
if len(deleteStmt.Tables) != 1 {
|
||
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
|
||
}
|
||
if deleteStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users', got '%s'", deleteStmt.Tables[0].Name)
|
||
}
|
||
|
||
// 验证 WHERE
|
||
if deleteStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
if deleteStmt.Where.Text != "id = 1" {
|
||
t.Errorf("expected WHERE='id = 1', got '%s'", deleteStmt.Where.Text)
|
||
}
|
||
}
|
||
|
||
func TestPgDeleteUsing(t *testing.T) {
|
||
sql := "DELETE FROM orders o USING users u WHERE o.user_id = u.id AND u.status = 0"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if len(deleteStmt.Tables) < 1 || deleteStmt.Tables[0].Name != "orders" {
|
||
t.Errorf("expected table='orders'")
|
||
}
|
||
}
|
||
|
||
// ========== Schema.Table 测试 ==========
|
||
|
||
func TestPgInsertWithSchema(t *testing.T) {
|
||
sql := `INSERT INTO "public"."users" ("name", "age") VALUES ('John', 30)`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
if insertStmt.Table.Schema != "public" {
|
||
t.Errorf("expected schema='public', got '%s'", insertStmt.Table.Schema)
|
||
}
|
||
if insertStmt.Table.Name != "users" {
|
||
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
|
||
}
|
||
}
|
||
|
||
func TestPgUpdateWithSchema(t *testing.T) {
|
||
sql := `UPDATE "public"."t_db" SET "name" = 'fsdfds3' WHERE "id" = 5`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
|
||
t.Logf("UPDATE text: %s", updateStmt.GetText())
|
||
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
|
||
t.Logf("UPDATE SET: %+v", updateStmt.Set)
|
||
t.Logf("UPDATE WHERE: %+v", updateStmt.Where)
|
||
|
||
if len(updateStmt.Tables) != 1 {
|
||
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
|
||
}
|
||
if updateStmt.Tables[0].Schema != "public" {
|
||
t.Errorf("expected schema='public', got '%s'", updateStmt.Tables[0].Schema)
|
||
}
|
||
if updateStmt.Tables[0].Name != "t_db" {
|
||
t.Errorf("expected table='t_db', got '%s'", updateStmt.Tables[0].Name)
|
||
}
|
||
// 验证 SET 字段
|
||
if len(updateStmt.Set) != 1 {
|
||
t.Fatalf("expected 1 assignment, got %d", len(updateStmt.Set))
|
||
}
|
||
if updateStmt.Set[0].Column != "name" {
|
||
t.Errorf("expected column='name', got '%s'", updateStmt.Set[0].Column)
|
||
}
|
||
if updateStmt.Set[0].Value == nil || updateStmt.Set[0].Value.Text != "'fsdfds3'" {
|
||
t.Errorf("expected value=''fsdfds3'', got '%s'", updateStmt.Set[0].Value.Text)
|
||
}
|
||
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 5` {
|
||
t.Errorf("expected WHERE='\"id\" = 5', got '%s'", updateStmt.Where.Text)
|
||
}
|
||
}
|
||
|
||
func TestPgDeleteWithSchema(t *testing.T) {
|
||
sql := `DELETE FROM "public"."logs" WHERE "created_at" < '2024-01-01'`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if len(deleteStmt.Tables) != 1 {
|
||
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
|
||
}
|
||
if deleteStmt.Tables[0].Schema != "public" {
|
||
t.Errorf("expected schema='public', got '%s'", deleteStmt.Tables[0].Schema)
|
||
}
|
||
if deleteStmt.Tables[0].Name != "logs" {
|
||
t.Errorf("expected table='logs', got '%s'", deleteStmt.Tables[0].Name)
|
||
}
|
||
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"created_at" < '2024-01-01'` {
|
||
t.Errorf("expected WHERE text")
|
||
}
|
||
}
|
||
|
||
// ========== DDL 测试 ==========
|
||
|
||
func TestPgDDLCreate(t *testing.T) {
|
||
sql := "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50), age INT)"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||
|
||
if ddlStmt.DdlKind != "CREATE" {
|
||
t.Errorf("expected DdlKind='CREATE', got '%s'", ddlStmt.DdlKind)
|
||
}
|
||
if ddlStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s'", sql)
|
||
}
|
||
}
|
||
|
||
func TestPgDDLDrop(t *testing.T) {
|
||
sql := "DROP TABLE users"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||
|
||
if ddlStmt.DdlKind != "DROP" {
|
||
t.Errorf("expected DdlKind='DROP', got '%s'", ddlStmt.DdlKind)
|
||
}
|
||
}
|
||
|
||
func TestPgDDLAlter(t *testing.T) {
|
||
sql := "ALTER TABLE users ADD COLUMN email VARCHAR(255)"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||
|
||
if ddlStmt.DdlKind != "ALTER" {
|
||
t.Errorf("expected DdlKind='ALTER', got '%s'", ddlStmt.DdlKind)
|
||
}
|
||
}
|
||
|
||
func TestPgDDLTruncate(t *testing.T) {
|
||
sql := "TRUNCATE TABLE users"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
if stmt == nil {
|
||
t.Fatalf("expected stmt not nil")
|
||
}
|
||
t.Logf("TRUNCATE stmt type: %T", stmt)
|
||
t.Logf("TRUNCATE text: %s", stmt.GetText())
|
||
}
|
||
|
||
// ========== 复杂 DML 测试 ==========
|
||
|
||
func TestPgComplexUpdateWithSubquery(t *testing.T) {
|
||
sql := "UPDATE users SET total_orders = (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users'")
|
||
}
|
||
}
|
||
|
||
func TestPgComplexDeleteWithSubquery(t *testing.T) {
|
||
sql := "DELETE FROM users WHERE id NOT IN (SELECT DISTINCT user_id FROM orders)"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if deleteStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
}
|
||
|
||
func TestPgInsertOnConflict(t *testing.T) {
|
||
sql := "INSERT INTO users (id, name) VALUES (1, 'John') ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
t.Logf("Actual text: %s", insertStmt.GetText())
|
||
if insertStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s'", sql)
|
||
}
|
||
}
|
||
|
||
// ========== 复杂 INSERT 测试 ==========
|
||
|
||
func TestPgInsertWithDoubleQuotes(t *testing.T) {
|
||
// PostgreSQL 双引号标识符
|
||
sql := `INSERT INTO "users" ("name", "age", "email") VALUES ('John', 30, 'john@example.com')`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
if insertStmt.GetText() != sql {
|
||
t.Errorf("expected text='%s'", sql)
|
||
}
|
||
}
|
||
|
||
func TestPgInsertWithSpecialChars(t *testing.T) {
|
||
// 包含特殊字符
|
||
sql := `INSERT INTO "logs" ("message", "level") VALUES ('Error: connection failed!', 'ERROR')`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
t.Logf("INSERT text: %s", insertStmt.GetText())
|
||
}
|
||
|
||
func TestPgInsertReturningMultiple(t *testing.T) {
|
||
// RETURNING 多个字段
|
||
sql := `INSERT INTO "users" ("name", "email") VALUES ('John', 'john@example.com') RETURNING "id", "name", "created_at"`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
t.Logf("INSERT RETURNING: %s", insertStmt.GetText())
|
||
}
|
||
|
||
func TestPgInsertFromSelectComplex(t *testing.T) {
|
||
// INSERT FROM SELECT 复杂查询
|
||
sql := `INSERT INTO "users_backup" SELECT * FROM "users" WHERE "status" = 1 AND "created_at" > '2024-01-01'`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
t.Logf("INSERT FROM SELECT: %s", insertStmt.GetText())
|
||
}
|
||
|
||
func TestPgInsertOnConflictDoNothing(t *testing.T) {
|
||
// ON CONFLICT DO NOTHING
|
||
sql := `INSERT INTO "users" ("id", "name") VALUES (1, 'John') ON CONFLICT ("id") DO NOTHING`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
t.Logf("INSERT ON CONFLICT DO NOTHING: %s", insertStmt.GetText())
|
||
}
|
||
|
||
// ========== 复杂 UPDATE 测试 ==========
|
||
|
||
func TestPgUpdateWithDoubleQuotes(t *testing.T) {
|
||
sql := `UPDATE "users" SET "name" = 'John', "age" = 30 WHERE "id" = 1`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||
}
|
||
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 1` {
|
||
t.Errorf("expected WHERE")
|
||
}
|
||
}
|
||
|
||
func TestPgUpdateWithComplexWhere(t *testing.T) {
|
||
sql := `UPDATE "orders" SET "status" = 'cancelled' WHERE "status" = 'pending' AND "created_at" < '2024-01-01' AND ("amount" < 100 OR "user_id" IS NULL)`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
if updateStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
t.Logf("Complex WHERE: %s", updateStmt.Where.Text)
|
||
}
|
||
|
||
func TestPgUpdateWithFunctions(t *testing.T) {
|
||
sql := `UPDATE "users" SET "updated_at" = NOW(), "login_count" = "login_count" + 1 WHERE "id" = 1`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
t.Logf("UPDATE with functions: %s", updateStmt.GetText())
|
||
}
|
||
|
||
func TestPgUpdateReturningComplex(t *testing.T) {
|
||
sql := `UPDATE "users" SET "status" = 0 WHERE "status" = 1 RETURNING "id", "name", "old_status"`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
if updateStmt.Where == nil || updateStmt.Where.Text != `"status" = 1` {
|
||
t.Errorf("expected WHERE")
|
||
}
|
||
}
|
||
|
||
func TestPgUpdateWithSubquery(t *testing.T) {
|
||
sql := `UPDATE "users" SET "total" = (SELECT SUM("amount") FROM "orders" WHERE "user_id" = "users"."id") WHERE "status" = 'active'`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users'")
|
||
}
|
||
}
|
||
|
||
func TestPgUpdateFromComplex(t *testing.T) {
|
||
// PostgreSQL UPDATE FROM 复杂场景
|
||
sql := `UPDATE "orders" o SET "status" = 'shipped' FROM "users" u WHERE o."user_id" = u."id" AND u."status" = 1 AND o."created_at" > '2024-01-01'`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
t.Logf("UPDATE FROM: %s", updateStmt.GetText())
|
||
}
|
||
|
||
// ========== 复杂 DELETE 测试 ==========
|
||
|
||
func TestPgDeleteWithDoubleQuotes(t *testing.T) {
|
||
sql := `DELETE FROM "users" WHERE "id" = 1`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users'")
|
||
}
|
||
}
|
||
|
||
func TestPgDeleteWithComplexWhere(t *testing.T) {
|
||
sql := `DELETE FROM "logs" WHERE "created_at" < '2024-01-01' AND ("level" = 'DEBUG' OR "level" = 'INFO')`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if deleteStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
t.Logf("Complex DELETE WHERE: %s", deleteStmt.Where.Text)
|
||
}
|
||
|
||
func TestPgDeleteReturningComplex(t *testing.T) {
|
||
sql := `DELETE FROM "users" WHERE "status" = 0 RETURNING "id", "name", "deleted_at"`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"status" = 0` {
|
||
t.Errorf("expected WHERE")
|
||
}
|
||
}
|
||
|
||
func TestPgDeleteWithSubquery(t *testing.T) {
|
||
sql := `DELETE FROM "users" WHERE "id" NOT IN (SELECT DISTINCT "user_id" FROM "orders")`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if deleteStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
}
|
||
|
||
func TestPgDeleteUsingComplex(t *testing.T) {
|
||
// PostgreSQL DELETE USING 复杂场景
|
||
sql := `DELETE FROM "orders" o USING "users" u WHERE o."user_id" = u."id" AND u."status" = 0 AND o."created_at" < '2024-01-01'`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
t.Logf("DELETE USING: %s", deleteStmt.GetText())
|
||
}
|
||
|
||
// ========== 复杂 DDL 测试 ==========
|
||
|
||
func TestPgDDLCreateTableWithQuotes(t *testing.T) {
|
||
sql := `CREATE TABLE "users" ("id" INT PRIMARY KEY, "name" VARCHAR(50) NOT NULL, "email" VARCHAR(255) UNIQUE)`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||
if ddlStmt.DdlKind != "CREATE" {
|
||
t.Errorf("expected DdlKind='CREATE'")
|
||
}
|
||
}
|
||
|
||
func TestPgDDLCreateTableWithSerial(t *testing.T) {
|
||
sql := `CREATE TABLE "orders" ("id" SERIAL PRIMARY KEY, "amount" DECIMAL(10,2), "created_at" TIMESTAMP DEFAULT NOW())`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||
t.Logf("CREATE TABLE with SERIAL: %s", ddlStmt.GetText())
|
||
}
|
||
|
||
func TestPgDDLAlterTableAddColumn(t *testing.T) {
|
||
sql := `ALTER TABLE "users" ADD COLUMN "email" VARCHAR(255)`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||
if ddlStmt.DdlKind != "ALTER" {
|
||
t.Errorf("expected DdlKind='ALTER'")
|
||
}
|
||
}
|
||
|
||
func TestPgDDLDropIfExists(t *testing.T) {
|
||
sql := `DROP TABLE IF EXISTS "users"`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||
if ddlStmt.DdlKind != "DROP" {
|
||
t.Errorf("expected DdlKind='DROP'")
|
||
}
|
||
}
|
||
|
||
// ========== 注释风格测试 ==========
|
||
|
||
func TestPgDMLWithSingleLineComment(t *testing.T) {
|
||
// PostgreSQL 单行注释 --
|
||
sql := "-- 更新用户信息\nUPDATE \"users\" SET \"name\" = 'John' WHERE \"id\" = 1"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||
t.Errorf("expected table='users'")
|
||
}
|
||
t.Logf("UPDATE with -- comment: %s", updateStmt.GetText())
|
||
}
|
||
|
||
func TestPgDMLWithMultiLineComment(t *testing.T) {
|
||
// 多行注释 /* */
|
||
sql := "/* 删除过期订单 */ DELETE FROM \"orders\" WHERE \"created_at\" < '2024-01-01'"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||
if deleteStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
t.Logf("DELETE with /* */ comment: %s", deleteStmt.GetText())
|
||
}
|
||
|
||
func TestPgDMLWithInlineComment(t *testing.T) {
|
||
// 行内注释
|
||
sql := `SELECT "id", "name" /* 用户名 */ FROM "users" WHERE "status" = 1`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||
if len(selectStmt.Items) != 2 {
|
||
t.Fatalf("expected 2 items")
|
||
}
|
||
t.Logf("SELECT with inline comment: %s", selectStmt.GetText())
|
||
}
|
||
|
||
func TestPgDMLWithMultipleComments(t *testing.T) {
|
||
// 多个注释
|
||
sql := `-- 查询活跃用户
|
||
/* 只查询最近注册的 */
|
||
SELECT "id", "name" FROM "users"
|
||
WHERE "status" = 1 AND "created_at" > '2024-01-01'
|
||
ORDER BY "created_at" DESC`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||
if selectStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
t.Logf("Multiple comments: %s", selectStmt.GetText())
|
||
}
|
||
|
||
func TestPgInsertWithComment(t *testing.T) {
|
||
sql := "-- 插入新用户\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('John', 'john@example.com')"
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||
t.Logf("INSERT with comment: %s", insertStmt.GetText())
|
||
}
|
||
|
||
func TestPgUpdateWithComment(t *testing.T) {
|
||
sql := `/* 批量更新状态 */
|
||
UPDATE "orders" SET "status" = 'cancelled'
|
||
WHERE "status" = 'pending' -- 只更新待处理的订单
|
||
AND "created_at" < '2024-01-01'`
|
||
parser := NewParser(sql)
|
||
stmt, err := parser.Parse()
|
||
if err != nil {
|
||
t.Fatalf("parse error: %v", err)
|
||
}
|
||
|
||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||
if updateStmt.Where == nil {
|
||
t.Fatal("expected WHERE")
|
||
}
|
||
t.Logf("UPDATE with multiple comments: %s", updateStmt.GetText())
|
||
}
|