refactor: 移除antlr4减小包体积&ai助手优化

This commit is contained in:
meilin.huang
2026-05-08 20:45:13 +08:00
parent 3768cef62d
commit f23b243fc5
154 changed files with 13054 additions and 396804 deletions

View File

@@ -0,0 +1,18 @@
package dm
import (
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"mayfly-go/pkg/logx"
)
type DmParser struct {
}
func (*DmParser) Parse(stmt string) (sqlstmt.Stmt, error) {
defer func() {
if e := recover(); e != nil {
logx.ErrorTrace("dm sql parser err: ", e)
}
}()
return NewParser(stmt).Parse()
}

View File

@@ -0,0 +1,691 @@
package dm
import (
"strconv"
"strings"
"mayfly-go/internal/db/dbm/sqlparser/base"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"mayfly-go/internal/db/dbm/sqlparser/tokenizer"
)
// Parser 达梦方言 SQL 解析器
type Parser struct {
*base.Lexer
}
// NewParser 创建达梦解析器
func NewParser(sql string) *Parser {
return &Parser{
Lexer: base.NewLexer(sql, tokenizer.DialectConfig{
DoubleQuoteAsIdentifier: true,
}),
}
}
// Parse 解析单条 SQL 语句
func (p *Parser) Parse() (sqlstmt.Stmt, error) {
p.SkipSemicolons()
if p.Current().IsEOF() {
return nil, nil
}
stmt := p.parseStatement()
return stmt, nil
}
func (p *Parser) parseStatement() sqlstmt.Stmt {
tok := p.Current()
switch {
case tok.IsKeyword("SELECT") || tok.Value == "(":
return p.parseSelect()
case tok.IsKeyword("INSERT"):
return p.parseInsert()
case tok.IsKeyword("UPDATE"):
return p.parseUpdate()
case tok.IsKeyword("DELETE"):
return p.parseDelete()
case tok.IsKeyword("CREATE"):
return p.parseCreate()
case tok.IsKeyword("DROP"):
return p.parseDrop()
case tok.IsKeyword("ALTER"):
return p.parseAlter()
case tok.IsKeyword("WITH"):
return p.parseWith()
case tok.IsKeyword("TRUNCATE"):
return p.parseGenericDdl()
default:
return p.parseGenericStmt()
}
}
// ---------- SELECT 解析 ----------
func (p *Parser) parseSelect() sqlstmt.Stmt {
start := p.Pos
var selectStmt *sqlstmt.SelectStmt
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
selectStmt = p.parseSelectBody()
} else {
p.SkipParentheses()
selectStmt = &sqlstmt.SelectStmt{}
}
p.ExpectValue(")")
} else {
selectStmt = p.parseSelectBody()
}
// UNION 解析
selectStmt = p.parseUnions(selectStmt)
// ORDER BY
if len(selectStmt.OrderBy) == 0 && p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
selectStmt.OrderBy = p.parseOrderBy()
}
// LIMIT达梦支持 LIMIT offset, count 或 LIMIT count OFFSET offset
if p.Current().IsKeyword("LIMIT") {
limit := &sqlstmt.Limit{}
limitStart := p.Pos
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
firstNum := p.Consume().Value
firstVal, _ := strconv.Atoi(firstNum)
// 检查是 LIMIT offset, count 还是 LIMIT count OFFSET offset
if p.Current().Value == "," {
// LIMIT offset, count 格式
limit.Offset = firstVal
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
countStr := p.Consume().Value
limit.Count, _ = strconv.Atoi(countStr)
}
} else if p.Current().IsKeyword("OFFSET") {
// LIMIT count OFFSET offset 格式
limit.Count = firstVal
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
offsetStr := p.Consume().Value
limit.Offset, _ = strconv.Atoi(offsetStr)
}
} else {
// 只有 LIMIT count
limit.Count = firstVal
limit.Offset = 0
}
}
limit.Text = p.TextFrom(limitStart)
selectStmt.Limit = limit
}
// FOR UPDATE
if p.Current().IsKeyword("FOR") {
p.Consume()
if p.Current().IsKeyword("UPDATE") {
p.Consume()
}
}
// 更新完整文本
selectStmt.Base = sqlstmt.Base{Text: p.TextFrom(start)}
return selectStmt
}
func (p *Parser) parseSelectBody() *sqlstmt.SelectStmt {
start := p.Pos
distinct := false
// 支持 SELECT TOP n达梦特性
topCount := ""
if p.Current().IsKeyword("SELECT") {
p.Consume()
if p.Current().IsKeyword("TOP") {
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
topCount = p.Consume().Value
}
} else if p.Current().IsKeyword("DISTINCT") {
p.Consume()
distinct = true
} else if p.Current().IsKeyword("ALL") {
p.Consume()
}
}
// SELECT 项
var items []sqlstmt.SelectItem
for !p.Current().IsEOF() && !p.IsSelectClauseEnd() {
if p.Current().Value == "," {
p.Consume()
continue
}
itemStart := p.Pos
p.SkipExpr()
text := base.TrimTrailingComma(p.TextFromExclusive(itemStart))
col, alias := p.ExtractColumnAndAlias(text)
items = append(items, sqlstmt.SelectItem{
Text: text,
ColumnName: col,
Alias: alias,
})
}
selectStmt := &sqlstmt.SelectStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Distinct: distinct,
Items: items,
}
// FROM
if p.Current().IsKeyword("FROM") {
p.Consume()
selectStmt.From = p.parseFromClause()
}
// JOIN
for p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
if join := p.parseJoinClause(); join != nil {
selectStmt.Joins = append(selectStmt.Joins, *join)
}
}
// WHERE
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
selectStmt.Where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
// GROUP BY
if p.Current().IsKeyword("GROUP") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
p.SkipGroupByExpr()
}
// HAVING
if p.Current().IsKeyword("HAVING") {
p.Consume()
p.SkipExpr()
}
// ORDER BY
if p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
selectStmt.OrderBy = p.parseOrderBy()
}
// LIMIT 在 SELECT 子句中(如果有)
if p.Current().IsKeyword("LIMIT") {
limit := &sqlstmt.Limit{}
limitStart := p.Pos
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
firstNum := p.Consume().Value
firstVal, _ := strconv.Atoi(firstNum)
if p.Current().Value == "," {
limit.Offset = firstVal
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
countStr := p.Consume().Value
limit.Count, _ = strconv.Atoi(countStr)
}
} else if p.Current().IsKeyword("OFFSET") {
limit.Count = firstVal
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
offsetStr := p.Consume().Value
limit.Offset, _ = strconv.Atoi(offsetStr)
}
} else {
limit.Count = firstVal
limit.Offset = 0
}
}
limit.Text = p.TextFrom(limitStart)
selectStmt.Limit = limit
}
_ = topCount // 暂时未使用
return selectStmt
}
func (p *Parser) parseUnions(selectStmt *sqlstmt.SelectStmt) *sqlstmt.SelectStmt {
for p.Current().IsKeyword("UNION") {
p.Consume()
all := false
if p.Current().IsKeyword("ALL") {
p.Consume()
all = true
} else if p.Current().IsKeyword("DISTINCT") {
p.Consume()
}
var nextSelect *sqlstmt.SelectStmt
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
nextSelect = p.parseSelectBody()
}
p.ExpectValue(")")
} else if p.Current().IsKeyword("SELECT") {
nextSelect = p.parseSelectBody()
}
if nextSelect != nil {
// 提取最后一个 unionSelect 的 ORDER BY 和 LIMIT 到外层
if len(nextSelect.OrderBy) > 0 {
selectStmt.OrderBy = nextSelect.OrderBy
nextSelect.OrderBy = nil
}
if nextSelect.Limit != nil {
selectStmt.Limit = nextSelect.Limit
nextSelect.Limit = nil
}
selectStmt.Unions = append(selectStmt.Unions, sqlstmt.UnionClause{
Select: nextSelect,
All: all,
})
}
}
return selectStmt
}
func (p *Parser) parseOrderBy() []sqlstmt.OrderByItem {
var items []sqlstmt.OrderByItem
for !p.Current().IsEOF() && !p.IsExprEnd() {
if p.Current().Value == "," {
p.Consume()
continue
}
start := p.Pos
p.SkipExpr()
text := p.TextFromExclusive(start)
desc := false
upper := strings.ToUpper(text)
if strings.HasSuffix(upper, " DESC") {
desc = true
text = strings.TrimSpace(text[:len(text)-5])
} else if strings.HasSuffix(upper, " ASC") {
text = strings.TrimSpace(text[:len(text)-4])
}
items = append(items, sqlstmt.OrderByItem{
Text: text,
Desc: desc,
})
}
return items
}
// ---------- FROM 解析 ----------
func (p *Parser) parseFromClause() []sqlstmt.TableRef {
var tables []sqlstmt.TableRef
for !p.Current().IsEOF() && !p.IsFromClauseEnd() {
if p.Current().Value == "," {
p.Consume()
continue
}
if p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
break
}
ref := p.parseTableRef()
if ref.Name != "" {
tables = append(tables, ref)
}
}
return tables
}
func (p *Parser) parseTableRef() sqlstmt.TableRef {
start := p.Pos
// 子查询
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
p.parseSelectBody()
} else {
p.SkipParentheses()
}
p.ExpectValue(")")
var alias string
if p.Current().IsKeyword("AS") {
p.Consume()
}
if p.Current().Type == tokenizer.TokenIdentifier {
alias = p.Unquote(p.Consume().Value)
}
return sqlstmt.TableRef{
Name: p.TextFrom(start),
Alias: alias,
}
}
if p.Current().Type != tokenizer.TokenIdentifier && p.Current().Type != tokenizer.TokenString {
return sqlstmt.TableRef{}
}
ref := sqlstmt.TableRef{}
part1 := p.Consume().Value
if p.Current().Value == "." {
p.Consume()
if p.Current().Type == tokenizer.TokenIdentifier || p.Current().Type == tokenizer.TokenString {
part2 := p.Consume().Value
ref.Schema = p.Unquote(part1)
ref.Name = p.Unquote(part2)
} else {
ref.Name = p.Unquote(part1)
}
} else {
ref.Name = p.Unquote(part1)
}
if p.Current().IsKeyword("AS") {
p.Consume()
}
if p.Current().Type == tokenizer.TokenIdentifier {
ref.Alias = p.Unquote(p.Consume().Value)
}
return ref
}
func (p *Parser) parseJoinClause() *sqlstmt.JoinClause {
start := p.Pos
joinType := sqlstmt.JoinKindInner
if p.Current().IsKeyword("LEFT") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindLeft
} else if p.Current().IsKeyword("RIGHT") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindRight
} else if p.Current().IsKeyword("FULL") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindFull
} else if p.Current().IsKeyword("CROSS") {
p.Consume()
joinType = sqlstmt.JoinKindCross
} else if p.Current().IsKeyword("INNER") {
p.Consume()
}
if !p.Current().IsKeyword("JOIN") {
p.Pos = start
return nil
}
p.Consume()
tableRef := p.parseTableRef()
if tableRef.Name == "" {
p.Pos = start
return nil
}
var onExpr *sqlstmt.Expr
if p.Current().IsKeyword("ON") {
p.Consume()
onStart := p.Pos
p.SkipExpr()
onExpr = &sqlstmt.Expr{Text: p.TextFromExclusive(onStart)}
} else if p.Current().IsKeyword("USING") {
p.Consume()
if p.Current().Value == "(" {
p.SkipParentheses()
}
}
return &sqlstmt.JoinClause{
Kind: joinType,
Table: tableRef,
On: onExpr,
Text: p.TextFrom(start),
}
}
// ---------- INSERT 解析 ----------
func (p *Parser) parseInsert() sqlstmt.Stmt {
start := p.Pos
p.Consume() // INSERT
if p.Current().IsKeyword("INTO") {
p.Consume()
}
tableRef := p.parseTableRef()
// 解析列名列表
columns := []string{}
if p.Current().Value == "(" {
p.Consume()
for !p.Current().IsEOF() && p.Current().Value != ")" {
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().Type == tokenizer.TokenIdentifier {
columns = append(columns, p.Unquote(p.Consume().Value))
} else {
p.Consume()
}
}
if p.Current().Value == ")" {
p.Consume()
}
}
if p.Current().IsKeyword("VALUES") {
p.Consume()
for p.Current().Value == "(" {
p.SkipParentheses()
if p.Current().Value == "," {
p.Consume()
}
}
} else if p.Current().IsKeyword("SELECT") {
p.parseSelect()
}
return &sqlstmt.InsertStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Table: tableRef,
Columns: columns,
}
}
// ---------- UPDATE 解析 ----------
func (p *Parser) parseUpdate() sqlstmt.Stmt {
start := p.Pos
p.Consume() // UPDATE
tableRef := p.parseTableRef()
tables := []sqlstmt.TableRef{tableRef}
// SET - 解析字段赋值
assignments := make([]sqlstmt.Assignment, 0)
if p.Current().IsKeyword("SET") {
p.Consume()
for !p.Current().IsEOF() {
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
break
}
assign := p.parseAssignment()
if assign != nil {
assignments = append(assignments, *assign)
}
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
break
}
}
}
var where *sqlstmt.Expr
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
return &sqlstmt.UpdateStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Tables: tables,
Set: assignments,
Where: where,
}
}
// ---------- DELETE 解析 ----------
func (p *Parser) parseDelete() sqlstmt.Stmt {
start := p.Pos
p.Consume() // DELETE
if p.Current().IsKeyword("FROM") {
p.Consume()
}
// 使用 parseTableRef 正确解析表名
tableRef := p.parseTableRef()
tables := []sqlstmt.TableRef{tableRef}
// WHERE
var where *sqlstmt.Expr
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
return &sqlstmt.DeleteStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Tables: tables,
Where: where,
}
}
// ---------- DDL 解析 ----------
func (p *Parser) parseCreate() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "CREATE",
}
}
func (p *Parser) parseDrop() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "DROP",
}
}
func (p *Parser) parseAlter() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "ALTER",
}
}
func (p *Parser) parseGenericDdl() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "DDL",
}
}
// ---------- WITH 解析 ----------
func (p *Parser) parseWith() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.WithStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
}
}
// ---------- 通用语句解析 ----------
func (p *Parser) parseGenericStmt() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.OtherStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
}
}
// parseAssignment 解析 SET 字段赋值column = value
func (p *Parser) parseAssignment() *sqlstmt.Assignment {
start := p.Pos
colText := ""
for !p.Current().IsEOF() && p.Current().Value != "=" && p.Current().Value != "," &&
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
colText += p.Consume().Value
}
if p.Current().Value != "=" {
p.Pos = start
return nil
}
p.Consume() // =
valStart := p.Pos
for !p.Current().IsEOF() && p.Current().Value != "," &&
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
if p.Current().Value == "(" {
p.SkipParentheses()
continue
}
p.Consume()
}
return &sqlstmt.Assignment{
Column: p.Unquote(strings.TrimSpace(colText)),
Value: &sqlstmt.Expr{Text: p.TextFromExclusive(valStart)},
Text: p.TextFromExclusive(start),
}
}

View File

@@ -0,0 +1,569 @@
package dm
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== INSERT 测试 ==========
func TestDmInsertBasic(t *testing.T) {
sql := "INSERT INTO users (name, age) VALUES ('John', 30)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt, ok := stmt.(*sqlstmt.InsertStmt)
if !ok {
t.Fatal("expected InsertStmt")
}
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, insertStmt.GetText())
}
}
func TestDmInsertMultipleRows(t *testing.T) {
sql := "INSERT INTO users (name, age) VALUES ('John', 30), ('Jane', 25), ('Bob', 35)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
// ========== UPDATE 测试 ==========
func TestDmUpdateBasic(t *testing.T) {
sql := "UPDATE users SET name = 'John' WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
}
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
}
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
func TestDmUpdateMultipleColumns(t *testing.T) {
sql := "UPDATE users SET name = 'John', age = 30, email = 'john@example.com' WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
func TestDmUpdateReturning(t *testing.T) {
sql := "UPDATE users SET status = 0 WHERE id = 1 RETURNING id, name, status"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
// ========== DELETE 测试 ==========
func TestDmDeleteBasic(t *testing.T) {
sql := "DELETE FROM users WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, deleteStmt.GetText())
}
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
if deleteStmt.Where == nil || deleteStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
func TestDmDeleteReturning(t *testing.T) {
sql := "DELETE FROM users WHERE status = 0 RETURNING id, name"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil || deleteStmt.Where.Text != "status = 0" {
t.Errorf("expected WHERE='status = 0'")
}
}
// ========== 复杂 DML 测试 ==========
func TestDmComplexUpdateWithSubquery(t *testing.T) {
sql := "UPDATE users SET total_orders = (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
}
}
func TestDmComplexDeleteWithSubquery(t *testing.T) {
sql := "DELETE FROM users WHERE id NOT IN (SELECT DISTINCT user_id FROM orders)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
}
// ========== 双引号标识符测试 ==========
func TestDmDoubleQuoteDML(t *testing.T) {
sql := `UPDATE "users" SET "name" = 'John' WHERE "id" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
// 注意DM 解析器对于双引号标识符可能保留引号
if len(updateStmt.Tables) != 1 {
t.Fatal("expected 1 table")
}
t.Logf("UPDATE table: %s", updateStmt.Tables[0].Name)
t.Logf("UPDATE WHERE: %+v", updateStmt.Where)
}
// ========== 复杂 INSERT 测试 ==========
func TestDmInsertWithDoubleQuotes(t *testing.T) {
sql := `INSERT INTO "users" ("name", "age", "email") VALUES ('John', 30, 'john@example.com')`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestDmInsertWithSpecialChars(t *testing.T) {
sql := `INSERT INTO "logs" ("message", "level") VALUES ('Error: connection failed!', 'ERROR')`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT text: %s", insertStmt.GetText())
}
func TestDmInsertMultipleRowsComplex(t *testing.T) {
// 多行插入包含特殊字符
sql := `INSERT INTO "users" ("name", "email") VALUES ('John', 'john@example.com'), ('Jane', 'jane@test.com'), ('Bob', 'bob@demo.com')`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("Multiple INSERT: %s", insertStmt.GetText())
}
func TestDmInsertReturningComplex(t *testing.T) {
sql := `INSERT INTO "users" ("name", "email") VALUES ('John', 'john@example.com') RETURNING "id", "name", "created_at"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT RETURNING: %s", insertStmt.GetText())
}
// ========== 复杂 UPDATE 测试 ==========
func TestDmUpdateWithDoubleQuotes(t *testing.T) {
sql := `UPDATE "users" SET "name" = 'John', "age" = 30 WHERE "id" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 {
t.Fatal("expected 1 table")
}
t.Logf("UPDATE table: %s", updateStmt.Tables[0].Name)
}
func TestDmUpdateWithComplexWhere(t *testing.T) {
sql := `UPDATE "orders" SET "status" = 'cancelled' WHERE "status" = 'pending' AND "created_at" < '2024-01-01' AND ("amount" < 100 OR "user_id" IS NULL)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("Complex WHERE: %s", updateStmt.Where.Text)
}
func TestDmUpdateWithFunctions(t *testing.T) {
sql := `UPDATE "users" SET "updated_at" = NOW(), "login_count" = "login_count" + 1 WHERE "id" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
t.Logf("UPDATE with functions: %s", updateStmt.GetText())
}
func TestDmUpdateReturningComplex(t *testing.T) {
sql := `UPDATE "users" SET "status" = 0 WHERE "id" = 1 RETURNING "id", "name", "status"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 1` {
t.Errorf("expected WHERE")
}
}
func TestDmUpdateWithSubquery(t *testing.T) {
sql := `UPDATE "users" SET "total" = (SELECT SUM("amount") FROM "orders" WHERE "user_id" = "users"."id") WHERE "status" = 'active'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
}
}
// ========== 复杂 DELETE 测试 ==========
func TestDmDeleteWithDoubleQuotes(t *testing.T) {
sql := `DELETE FROM "users" WHERE "id" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
}
func TestDmDeleteWithComplexWhere(t *testing.T) {
sql := `DELETE FROM "logs" WHERE "created_at" < '2024-01-01' AND ("level" = 'DEBUG' OR "level" = 'INFO')`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("Complex DELETE WHERE: %s", deleteStmt.Where.Text)
}
func TestDmDeleteReturningComplex(t *testing.T) {
sql := `DELETE FROM "users" WHERE "status" = 0 RETURNING "id", "name"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"status" = 0` {
t.Errorf("expected WHERE")
}
}
func TestDmDeleteWithSubquery(t *testing.T) {
sql := `DELETE FROM "users" WHERE "id" NOT IN (SELECT DISTINCT "user_id" FROM "orders")`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
}
// ========== 复杂 DDL 测试 ==========
func TestDmDDLCreateTableWithQuotes(t *testing.T) {
sql := `CREATE TABLE "users" ("id" INT PRIMARY KEY, "name" VARCHAR(50) NOT NULL, "email" VARCHAR(255) UNIQUE)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "CREATE" {
t.Errorf("expected DdlKind='CREATE'")
}
}
func TestDmDDLCreateTableWithComments(t *testing.T) {
sql := `CREATE TABLE "orders" ("id" INT COMMENT '订单ID', "amount" DECIMAL(10,2) COMMENT '金额') COMMENT='订单表'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
t.Logf("CREATE TABLE with comments: %s", ddlStmt.GetText())
}
func TestDmDDLAlterTableAddColumn(t *testing.T) {
sql := `ALTER TABLE "users" ADD COLUMN "email" VARCHAR(255)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "ALTER" {
t.Errorf("expected DdlKind='ALTER'")
}
}
func TestDmDDLDropIfExists(t *testing.T) {
sql := `DROP TABLE IF EXISTS "users"`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "DROP" {
t.Errorf("expected DdlKind='DROP'")
}
}
// ========== 注释风格测试 ==========
func TestDmDMLWithSingleLineComment(t *testing.T) {
// DM 单行注释 --
sql := "-- 更新用户信息\nUPDATE \"users\" SET \"name\" = 'John' WHERE \"id\" = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if len(updateStmt.Tables) != 1 {
t.Fatal("expected 1 table")
}
t.Logf("UPDATE with -- comment: %s", updateStmt.GetText())
}
func TestDmDMLWithMultiLineComment(t *testing.T) {
// 多行注释 /* */
sql := "/* 删除日志 */ DELETE FROM \"logs\" WHERE \"created_at\" < '2024-01-01'"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("DELETE with /* */ comment: %s", deleteStmt.GetText())
}
func TestDmDMLWithInlineComment(t *testing.T) {
// 行内注释
sql := `SELECT "id", "name" /* 用户名 */ FROM "users" WHERE "status" = 1`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items")
}
t.Logf("SELECT with inline comment: %s", selectStmt.GetText())
}
func TestDmDMLWithMultipleComments(t *testing.T) {
// 多个注释
sql := `-- 查询订单
/* 只查询已支付的 */
SELECT "id", "amount" FROM "orders"
WHERE "status" = 'paid' AND "created_at" > '2024-01-01'
ORDER BY "created_at" DESC`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
t.Logf("Multiple comments: %s", selectStmt.GetText())
}
func TestDmInsertWithComment(t *testing.T) {
sql := "-- 插入数据\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('John', 'john@example.com')"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
t.Logf("INSERT with comment: %s", insertStmt.GetText())
}
// ========== Schema.Table 测试 ==========
func TestDmInsertWithSchema(t *testing.T) {
sql := `INSERT INTO "TEST"."users" ("name", "age") VALUES ('John', 30)`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.Table.Schema != "TEST" {
t.Errorf("expected schema='TEST', got '%s'", insertStmt.Table.Schema)
}
if insertStmt.Table.Name != "users" {
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
}
}
func TestDmUpdateWithSchema(t *testing.T) {
sql := `UPDATE "TEST"."t_db" SET "name" = 'test' WHERE "id" = 5`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
t.Logf("UPDATE text: %s", updateStmt.GetText())
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
t.Logf("UPDATE WHERE: %+v", updateStmt.Where)
if len(updateStmt.Tables) != 1 {
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
}
if updateStmt.Tables[0].Schema != "TEST" {
t.Errorf("expected schema='TEST', got '%s'", updateStmt.Tables[0].Schema)
}
if updateStmt.Tables[0].Name != "t_db" {
t.Errorf("expected table='t_db', got '%s'", updateStmt.Tables[0].Name)
}
if updateStmt.Where == nil {
t.Errorf("expected WHERE")
}
}
func TestDmDeleteWithSchema(t *testing.T) {
sql := `DELETE FROM "TEST"."logs" WHERE "created_at" < '2024-01-01'`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if len(deleteStmt.Tables) != 1 {
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
}
if deleteStmt.Tables[0].Schema != "TEST" {
t.Errorf("expected schema='TEST', got '%s'", deleteStmt.Tables[0].Schema)
}
if deleteStmt.Tables[0].Name != "logs" {
t.Errorf("expected table='logs', got '%s'", deleteStmt.Tables[0].Name)
}
if deleteStmt.Where == nil {
t.Errorf("expected WHERE")
}
}

View File

@@ -0,0 +1,472 @@
package dm
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== 简单分页测试 ==========
func TestDmPaginationSimple(t *testing.T) {
sql := "SELECT * FROM users LIMIT 0, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationWithOffset(t *testing.T) {
sql := "SELECT id, name FROM users LIMIT 20, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationKeywordOffset(t *testing.T) {
// DM 也支持 LIMIT count OFFSET offset
sql := "SELECT * FROM products LIMIT 15 OFFSET 30"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
if selectStmt.Limit.Offset != 30 {
t.Errorf("expected offset=30, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 15 {
t.Errorf("expected count=15, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationOnlyLimit(t *testing.T) {
sql := "SELECT * FROM orders LIMIT 50"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 50 {
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
}
}
// ========== SELECT TOP 分页测试 ==========
func TestDmPaginationTopN(t *testing.T) {
// DM 支持 SELECT TOP n
sql := "SELECT TOP 10 id, name FROM users ORDER BY id"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items")
}
t.Logf("SELECT TOP parsed successfully")
}
func TestDmPaginationTopWithWhere(t *testing.T) {
sql := "SELECT TOP 20 id, name FROM users WHERE status = 1 ORDER BY created_at DESC"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
t.Logf("SELECT TOP with WHERE parsed successfully")
}
// ========== 复杂分页测试 ==========
func TestDmPaginationWithWhere(t *testing.T) {
sql := "SELECT id, name, email FROM users WHERE status = 1 AND age > 18 LIMIT 10, 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationWithOrderBy(t *testing.T) {
sql := "SELECT * FROM users ORDER BY created_at DESC, id ASC LIMIT 0, 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.OrderBy) != 2 {
t.Fatalf("expected 2 ORDER BY clauses")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationWithWhereOrderBy(t *testing.T) {
sql := "SELECT id, name FROM users WHERE status = 'active' ORDER BY score DESC LIMIT 50, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 'active'" {
t.Errorf("expected WHERE='status = 'active''")
}
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "score" {
t.Errorf("expected ORDER BY score")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 50 {
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationWithJoin(t *testing.T) {
sql := "SELECT u.id, u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1 ORDER BY o.created_at DESC LIMIT 100, 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatal("expected 1 JOIN")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 100 {
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationWithGroupBy(t *testing.T) {
sql := "SELECT user_id, COUNT(*) as order_count FROM orders GROUP BY user_id ORDER BY order_count DESC LIMIT 0, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
// ========== UNION 分页测试 ==========
func TestDmPaginationWithUnion(t *testing.T) {
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 0, 50"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatal("expected 1 UNION")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause for UNION")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 50 {
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationUnionAll(t *testing.T) {
sql := "SELECT name FROM products UNION ALL SELECT name FROM services LIMIT 10, 30"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatal("expected 1 UNION ALL")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 30 {
t.Errorf("expected count=30, got %d", selectStmt.Limit.Count)
}
}
func TestDmMultipleUnionsPagination(t *testing.T) {
sql := "SELECT id FROM t1 UNION SELECT id FROM t2 UNION SELECT id FROM t3 ORDER BY 1 LIMIT 50, 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 2 {
t.Fatalf("expected 2 UNIONs, got %d", len(selectStmt.Unions))
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 50 {
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}
// ========== 子查询分页测试 ==========
func TestDmPaginationWithSubquery(t *testing.T) {
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1 ORDER BY id LIMIT 0, 100) AS tmp LIMIT 20, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.From) != 1 {
t.Fatal("expected 1 FROM table")
}
if selectStmt.Limit == nil {
t.Fatal("expected outer LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected outer offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected outer count=10, got %d", selectStmt.Limit.Count)
}
}
func TestDmNestedPagination(t *testing.T) {
// 嵌套分页查询
sql := "SELECT * FROM (SELECT * FROM users LIMIT 0, 100) AS tmp LIMIT 10, 5"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected outer LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected outer offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 5 {
t.Errorf("expected outer count=5, got %d", selectStmt.Limit.Count)
}
}
func TestDmPaginationWithExists(t *testing.T) {
sql := "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id LIMIT 1) LIMIT 20, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
// ========== 大数据量分页测试 ==========
func TestDmLargeOffsetPagination(t *testing.T) {
sql := "SELECT * FROM logs ORDER BY id LIMIT 1000000, 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 1000000 {
t.Errorf("expected offset=1000000, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}
// ========== FOR UPDATE 分页测试 ==========
func TestDmPaginationForUpdate(t *testing.T) {
sql := "SELECT * FROM users WHERE status = 0 ORDER BY id LIMIT 10 OFFSET 0 FOR UPDATE"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
// ========== 双引号标识符分页测试 ==========
func TestDmPaginationDoubleQuote(t *testing.T) {
sql := `SELECT "id", "name" FROM "users" WHERE "status" = 1 ORDER BY "id" DESC LIMIT 20, 10`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}

View File

@@ -0,0 +1,276 @@
package dm
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== SELECT 测试 ==========
func TestDmSelectBasic(t *testing.T) {
sql := "SELECT id, name FROM users WHERE status = 1 ORDER BY id DESC LIMIT 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 1" {
t.Errorf("expected WHERE='status = 1'")
}
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "id" || !selectStmt.OrderBy[0].Desc {
t.Errorf("expected ORDER BY id DESC")
}
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 10" || selectStmt.Limit.Count != 10 {
t.Errorf("expected LIMIT 10")
}
}
func TestDmSelectTop(t *testing.T) {
sql := "SELECT TOP 10 id, name FROM users ORDER BY id"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items")
}
}
func TestDmSelectDistinct(t *testing.T) {
sql := "SELECT DISTINCT department_id FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if !selectStmt.Distinct {
t.Error("expected Distinct=true")
}
}
func TestDmSelectStar(t *testing.T) {
sql := "SELECT * FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 1 || selectStmt.Items[0].Text != "*" {
t.Errorf("expected SELECT *")
}
}
func TestDmSelectWithJoin(t *testing.T) {
sql := "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 || selectStmt.Joins[0].Table.Name != "orders" {
t.Errorf("expected JOIN orders")
}
if selectStmt.Where == nil || selectStmt.Where.Text != "u.status = 1" {
t.Errorf("expected WHERE='u.status = 1'")
}
}
func TestDmSelectMultipleJoins(t *testing.T) {
sql := "SELECT u.id, o.amount, p.name FROM users u LEFT JOIN orders o ON u.id = o.user_id INNER JOIN products p ON o.product_id = p.id WHERE u.status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 2 {
t.Fatalf("expected 2 joins")
}
}
func TestDmSelectComplexWhere(t *testing.T) {
sql := "SELECT * FROM users WHERE status = 1 AND (age > 18 OR role = 'admin') ORDER BY name"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
expectedWhere := "status = 1 AND (age > 18 OR role = 'admin')"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE='%s'", expectedWhere)
}
}
// ========== LIMIT/OFFSET 测试 ==========
func TestDmLimitOffset(t *testing.T) {
sql := "SELECT id FROM users ORDER BY id LIMIT 20, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 20, 10" {
t.Errorf("expected LIMIT text='LIMIT 20, 10'")
}
}
func TestDmLimitOnly(t *testing.T) {
sql := "SELECT * FROM users LIMIT 5"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 5" || selectStmt.Limit.Count != 5 {
t.Errorf("expected LIMIT 5")
}
}
// ========== UNION 测试 ==========
func TestDmUnion(t *testing.T) {
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatalf("expected 1 union")
}
if len(selectStmt.OrderBy) != 1 {
t.Fatalf("expected ORDER BY")
}
}
func TestDmUnionWithOrderBy(t *testing.T) {
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatalf("expected 1 union")
}
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "1" || !selectStmt.OrderBy[0].Desc {
t.Errorf("expected ORDER BY 1 DESC")
}
}
func TestDmUnionWithLimit(t *testing.T) {
sql := "SELECT id FROM users UNION ALL SELECT id FROM admins LIMIT 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 20" || selectStmt.Limit.Count != 20 {
t.Errorf("expected LIMIT 20")
}
}
// ========== FOR UPDATE 测试 ==========
func TestDmForUpdate(t *testing.T) {
sql := "SELECT id, name FROM users WHERE id = 1 FOR UPDATE"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil || selectStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
// ========== WITH (CTE) 测试 ==========
func TestDmWithClause(t *testing.T) {
sql := "WITH temp_users AS (SELECT * FROM users WHERE status = 1) SELECT * FROM temp_users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
}
func TestDmMultipleWith(t *testing.T) {
sql := `WITH
active_users AS (SELECT * FROM users WHERE status = 1),
recent_orders AS (SELECT * FROM orders WHERE created_at > NOW())
SELECT * FROM active_users u JOIN recent_orders o ON u.id = o.user_id`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
}
// ========== DDL 测试 ==========
func TestDmDDL(t *testing.T) {
sqls := []string{
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50))",
"DROP TABLE users",
"ALTER TABLE users ADD COLUMN email VARCHAR(255)",
}
for _, sql := range sqls {
t.Run(sql[:10], func(t *testing.T) {
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
})
}
}

View File

@@ -0,0 +1,297 @@
package dm
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== 子查询测试 ==========
func TestDmSubqueryInFrom(t *testing.T) {
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1) AS u WHERE u.id > 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("FROM 表数量: %d", len(selectStmt.From))
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 table in FROM, got %d", len(selectStmt.From))
}
t.Logf("FROM[0] Name: %s", selectStmt.From[0].Name)
t.Logf("FROM[0] Alias: %s", selectStmt.From[0].Alias)
if selectStmt.From[0].Alias != "u" {
t.Errorf("expected alias='u', got '%s'", selectStmt.From[0].Alias)
}
if selectStmt.Where == nil {
t.Fatal("expected outer WHERE clause")
}
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where.Text != "u.id > 10" {
t.Errorf("expected outer WHERE text='u.id > 10', got '%s'", selectStmt.Where.Text)
}
}
func TestDmSubqueryInWhere(t *testing.T) {
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE amount > 100)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "id IN (SELECT user_id FROM orders WHERE amount > 100)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestDmSubqueryInSelect(t *testing.T) {
sql := "SELECT id, name, (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) AS order_count FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
for i, item := range selectStmt.Items {
t.Logf(" Item[%d]: Text='%s', ColumnName='%s', Alias='%s'", i, item.Text, item.ColumnName, item.Alias)
}
if len(selectStmt.Items) != 3 {
t.Fatalf("expected 3 items, got %d", len(selectStmt.Items))
}
if selectStmt.Items[2].Alias != "order_count" {
t.Errorf("expected item[2] alias='order_count', got '%s'", selectStmt.Items[2].Alias)
}
}
func TestDmNestedSubquery(t *testing.T) {
sql := "SELECT * FROM (SELECT * FROM (SELECT id FROM users WHERE status = 1) AS inner_u) AS outer_u WHERE outer_u.id > 5"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected outer WHERE")
}
if selectStmt.Where.Text != "outer_u.id > 5" {
t.Errorf("expected WHERE text='outer_u.id > 5', got '%s'", selectStmt.Where.Text)
}
}
func TestDmCorrelatedSubquery(t *testing.T) {
sql := "SELECT u.id, u.name FROM users u WHERE u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestDmSubqueryWithExists(t *testing.T) {
sql := "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestDmMultipleSubqueries(t *testing.T) {
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestDmSubqueryWithLimit(t *testing.T) {
sql := "SELECT * FROM (SELECT id, name FROM users ORDER BY created_at DESC LIMIT 10) AS top_users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("FROM 数量: %d", len(selectStmt.From))
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Alias != "top_users" {
t.Errorf("expected alias='top_users', got '%s'", selectStmt.From[0].Alias)
}
}
func TestDmComplexSubquery(t *testing.T) {
sql := `SELECT
u.id,
u.name,
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count,
(SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id AND o.status = 'completed') AS total_amount
FROM users u
WHERE u.status = 1
AND u.id IN (SELECT user_id FROM user_groups WHERE group_id = 5)
ORDER BY u.name`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
if len(selectStmt.Items) != 4 {
t.Fatalf("expected 4 items, got %d", len(selectStmt.Items))
}
t.Logf("Item[0]: %s", selectStmt.Items[0].Text)
t.Logf("Item[1]: %s", selectStmt.Items[1].Text)
t.Logf("Item[2]: %s (alias: %s)", selectStmt.Items[2].Text, selectStmt.Items[2].Alias)
t.Logf("Item[3]: %s (alias: %s)", selectStmt.Items[3].Text, selectStmt.Items[3].Alias)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
t.Logf("WHERE: %s", selectStmt.Where.Text)
if len(selectStmt.OrderBy) != 1 {
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
}
if selectStmt.OrderBy[0].Text != "u.name" {
t.Errorf("expected order by text='u.name', got '%s'", selectStmt.OrderBy[0].Text)
}
}
// ========== JOIN 测试 ==========
func TestDmMultipleJoins(t *testing.T) {
sql := "SELECT u.id, o.amount, p.name FROM users u LEFT JOIN orders o ON u.id = o.user_id INNER JOIN products p ON o.product_id = p.id WHERE u.status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("JOIN 数量: %d", len(selectStmt.Joins))
if len(selectStmt.Joins) != 2 {
t.Fatalf("expected 2 joins, got %d", len(selectStmt.Joins))
}
if selectStmt.Joins[0].Table.Name != "orders" {
t.Errorf("expected first join table='orders', got '%s'", selectStmt.Joins[0].Table.Name)
}
if selectStmt.Joins[1].Table.Name != "products" {
t.Errorf("expected second join table='products', got '%s'", selectStmt.Joins[1].Table.Name)
}
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Where.Text != "u.status = 1" {
t.Errorf("expected WHERE text='u.status = 1', got '%s'", selectStmt.Where.Text)
}
}
func TestDmRightJoin(t *testing.T) {
sql := "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
}
}
func TestDmCrossJoin(t *testing.T) {
sql := "SELECT * FROM users CROSS JOIN roles"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
}
}