mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-05-21 18:35:19 +08:00
refactor: 移除antlr4减小包体积&ai助手优化
This commit is contained in:
18
server/internal/db/dbm/sqlparser/dm/dm.go
Normal file
18
server/internal/db/dbm/sqlparser/dm/dm.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"mayfly-go/pkg/logx"
|
||||
)
|
||||
|
||||
type DmParser struct {
|
||||
}
|
||||
|
||||
func (*DmParser) Parse(stmt string) (sqlstmt.Stmt, error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
logx.ErrorTrace("dm sql parser err: ", e)
|
||||
}
|
||||
}()
|
||||
return NewParser(stmt).Parse()
|
||||
}
|
||||
691
server/internal/db/dbm/sqlparser/dm/parser.go
Normal file
691
server/internal/db/dbm/sqlparser/dm/parser.go
Normal file
@@ -0,0 +1,691 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/base"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/tokenizer"
|
||||
)
|
||||
|
||||
// Parser 达梦方言 SQL 解析器
|
||||
type Parser struct {
|
||||
*base.Lexer
|
||||
}
|
||||
|
||||
// NewParser 创建达梦解析器
|
||||
func NewParser(sql string) *Parser {
|
||||
return &Parser{
|
||||
Lexer: base.NewLexer(sql, tokenizer.DialectConfig{
|
||||
DoubleQuoteAsIdentifier: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Parse 解析单条 SQL 语句
|
||||
func (p *Parser) Parse() (sqlstmt.Stmt, error) {
|
||||
p.SkipSemicolons()
|
||||
if p.Current().IsEOF() {
|
||||
return nil, nil
|
||||
}
|
||||
stmt := p.parseStatement()
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseStatement() sqlstmt.Stmt {
|
||||
tok := p.Current()
|
||||
switch {
|
||||
case tok.IsKeyword("SELECT") || tok.Value == "(":
|
||||
return p.parseSelect()
|
||||
case tok.IsKeyword("INSERT"):
|
||||
return p.parseInsert()
|
||||
case tok.IsKeyword("UPDATE"):
|
||||
return p.parseUpdate()
|
||||
case tok.IsKeyword("DELETE"):
|
||||
return p.parseDelete()
|
||||
case tok.IsKeyword("CREATE"):
|
||||
return p.parseCreate()
|
||||
case tok.IsKeyword("DROP"):
|
||||
return p.parseDrop()
|
||||
case tok.IsKeyword("ALTER"):
|
||||
return p.parseAlter()
|
||||
case tok.IsKeyword("WITH"):
|
||||
return p.parseWith()
|
||||
case tok.IsKeyword("TRUNCATE"):
|
||||
return p.parseGenericDdl()
|
||||
default:
|
||||
return p.parseGenericStmt()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- SELECT 解析 ----------
|
||||
|
||||
func (p *Parser) parseSelect() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
var selectStmt *sqlstmt.SelectStmt
|
||||
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
selectStmt = p.parseSelectBody()
|
||||
} else {
|
||||
p.SkipParentheses()
|
||||
selectStmt = &sqlstmt.SelectStmt{}
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
} else {
|
||||
selectStmt = p.parseSelectBody()
|
||||
}
|
||||
|
||||
// UNION 解析
|
||||
selectStmt = p.parseUnions(selectStmt)
|
||||
|
||||
// ORDER BY
|
||||
if len(selectStmt.OrderBy) == 0 && p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
selectStmt.OrderBy = p.parseOrderBy()
|
||||
}
|
||||
|
||||
// LIMIT(达梦支持 LIMIT offset, count 或 LIMIT count OFFSET offset)
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
limit := &sqlstmt.Limit{}
|
||||
limitStart := p.Pos
|
||||
p.Consume()
|
||||
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
firstNum := p.Consume().Value
|
||||
firstVal, _ := strconv.Atoi(firstNum)
|
||||
|
||||
// 检查是 LIMIT offset, count 还是 LIMIT count OFFSET offset
|
||||
if p.Current().Value == "," {
|
||||
// LIMIT offset, count 格式
|
||||
limit.Offset = firstVal
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
countStr := p.Consume().Value
|
||||
limit.Count, _ = strconv.Atoi(countStr)
|
||||
}
|
||||
} else if p.Current().IsKeyword("OFFSET") {
|
||||
// LIMIT count OFFSET offset 格式
|
||||
limit.Count = firstVal
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
offsetStr := p.Consume().Value
|
||||
limit.Offset, _ = strconv.Atoi(offsetStr)
|
||||
}
|
||||
} else {
|
||||
// 只有 LIMIT count
|
||||
limit.Count = firstVal
|
||||
limit.Offset = 0
|
||||
}
|
||||
}
|
||||
limit.Text = p.TextFrom(limitStart)
|
||||
selectStmt.Limit = limit
|
||||
}
|
||||
|
||||
// FOR UPDATE
|
||||
if p.Current().IsKeyword("FOR") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("UPDATE") {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
// 更新完整文本
|
||||
selectStmt.Base = sqlstmt.Base{Text: p.TextFrom(start)}
|
||||
return selectStmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseSelectBody() *sqlstmt.SelectStmt {
|
||||
start := p.Pos
|
||||
|
||||
distinct := false
|
||||
// 支持 SELECT TOP n(达梦特性)
|
||||
topCount := ""
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("TOP") {
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
topCount = p.Consume().Value
|
||||
}
|
||||
} else if p.Current().IsKeyword("DISTINCT") {
|
||||
p.Consume()
|
||||
distinct = true
|
||||
} else if p.Current().IsKeyword("ALL") {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
// SELECT 项
|
||||
var items []sqlstmt.SelectItem
|
||||
for !p.Current().IsEOF() && !p.IsSelectClauseEnd() {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
itemStart := p.Pos
|
||||
p.SkipExpr()
|
||||
text := base.TrimTrailingComma(p.TextFromExclusive(itemStart))
|
||||
col, alias := p.ExtractColumnAndAlias(text)
|
||||
items = append(items, sqlstmt.SelectItem{
|
||||
Text: text,
|
||||
ColumnName: col,
|
||||
Alias: alias,
|
||||
})
|
||||
}
|
||||
|
||||
selectStmt := &sqlstmt.SelectStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Distinct: distinct,
|
||||
Items: items,
|
||||
}
|
||||
|
||||
// FROM
|
||||
if p.Current().IsKeyword("FROM") {
|
||||
p.Consume()
|
||||
selectStmt.From = p.parseFromClause()
|
||||
}
|
||||
|
||||
// JOIN
|
||||
for p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
|
||||
if join := p.parseJoinClause(); join != nil {
|
||||
selectStmt.Joins = append(selectStmt.Joins, *join)
|
||||
}
|
||||
}
|
||||
|
||||
// WHERE
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
selectStmt.Where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
// GROUP BY
|
||||
if p.Current().IsKeyword("GROUP") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
p.SkipGroupByExpr()
|
||||
}
|
||||
|
||||
// HAVING
|
||||
if p.Current().IsKeyword("HAVING") {
|
||||
p.Consume()
|
||||
p.SkipExpr()
|
||||
}
|
||||
|
||||
// ORDER BY
|
||||
if p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
selectStmt.OrderBy = p.parseOrderBy()
|
||||
}
|
||||
|
||||
// LIMIT 在 SELECT 子句中(如果有)
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
limit := &sqlstmt.Limit{}
|
||||
limitStart := p.Pos
|
||||
p.Consume()
|
||||
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
firstNum := p.Consume().Value
|
||||
firstVal, _ := strconv.Atoi(firstNum)
|
||||
|
||||
if p.Current().Value == "," {
|
||||
limit.Offset = firstVal
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
countStr := p.Consume().Value
|
||||
limit.Count, _ = strconv.Atoi(countStr)
|
||||
}
|
||||
} else if p.Current().IsKeyword("OFFSET") {
|
||||
limit.Count = firstVal
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
offsetStr := p.Consume().Value
|
||||
limit.Offset, _ = strconv.Atoi(offsetStr)
|
||||
}
|
||||
} else {
|
||||
limit.Count = firstVal
|
||||
limit.Offset = 0
|
||||
}
|
||||
}
|
||||
limit.Text = p.TextFrom(limitStart)
|
||||
selectStmt.Limit = limit
|
||||
}
|
||||
|
||||
_ = topCount // 暂时未使用
|
||||
return selectStmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseUnions(selectStmt *sqlstmt.SelectStmt) *sqlstmt.SelectStmt {
|
||||
for p.Current().IsKeyword("UNION") {
|
||||
p.Consume()
|
||||
all := false
|
||||
if p.Current().IsKeyword("ALL") {
|
||||
p.Consume()
|
||||
all = true
|
||||
} else if p.Current().IsKeyword("DISTINCT") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
var nextSelect *sqlstmt.SelectStmt
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
nextSelect = p.parseSelectBody()
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
} else if p.Current().IsKeyword("SELECT") {
|
||||
nextSelect = p.parseSelectBody()
|
||||
}
|
||||
|
||||
if nextSelect != nil {
|
||||
// 提取最后一个 unionSelect 的 ORDER BY 和 LIMIT 到外层
|
||||
if len(nextSelect.OrderBy) > 0 {
|
||||
selectStmt.OrderBy = nextSelect.OrderBy
|
||||
nextSelect.OrderBy = nil
|
||||
}
|
||||
if nextSelect.Limit != nil {
|
||||
selectStmt.Limit = nextSelect.Limit
|
||||
nextSelect.Limit = nil
|
||||
}
|
||||
selectStmt.Unions = append(selectStmt.Unions, sqlstmt.UnionClause{
|
||||
Select: nextSelect,
|
||||
All: all,
|
||||
})
|
||||
}
|
||||
}
|
||||
return selectStmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseOrderBy() []sqlstmt.OrderByItem {
|
||||
var items []sqlstmt.OrderByItem
|
||||
for !p.Current().IsEOF() && !p.IsExprEnd() {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
start := p.Pos
|
||||
p.SkipExpr()
|
||||
text := p.TextFromExclusive(start)
|
||||
|
||||
desc := false
|
||||
upper := strings.ToUpper(text)
|
||||
if strings.HasSuffix(upper, " DESC") {
|
||||
desc = true
|
||||
text = strings.TrimSpace(text[:len(text)-5])
|
||||
} else if strings.HasSuffix(upper, " ASC") {
|
||||
text = strings.TrimSpace(text[:len(text)-4])
|
||||
}
|
||||
|
||||
items = append(items, sqlstmt.OrderByItem{
|
||||
Text: text,
|
||||
Desc: desc,
|
||||
})
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
// ---------- FROM 解析 ----------
|
||||
|
||||
func (p *Parser) parseFromClause() []sqlstmt.TableRef {
|
||||
var tables []sqlstmt.TableRef
|
||||
for !p.Current().IsEOF() && !p.IsFromClauseEnd() {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.IsJoinStart() || p.Current().IsKeyword("JOIN") {
|
||||
break
|
||||
}
|
||||
ref := p.parseTableRef()
|
||||
if ref.Name != "" {
|
||||
tables = append(tables, ref)
|
||||
}
|
||||
}
|
||||
return tables
|
||||
}
|
||||
|
||||
func (p *Parser) parseTableRef() sqlstmt.TableRef {
|
||||
start := p.Pos
|
||||
|
||||
// 子查询
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
p.parseSelectBody()
|
||||
} else {
|
||||
p.SkipParentheses()
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
|
||||
var alias string
|
||||
if p.Current().IsKeyword("AS") {
|
||||
p.Consume()
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
alias = p.Unquote(p.Consume().Value)
|
||||
}
|
||||
return sqlstmt.TableRef{
|
||||
Name: p.TextFrom(start),
|
||||
Alias: alias,
|
||||
}
|
||||
}
|
||||
|
||||
if p.Current().Type != tokenizer.TokenIdentifier && p.Current().Type != tokenizer.TokenString {
|
||||
return sqlstmt.TableRef{}
|
||||
}
|
||||
|
||||
ref := sqlstmt.TableRef{}
|
||||
part1 := p.Consume().Value
|
||||
|
||||
if p.Current().Value == "." {
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenIdentifier || p.Current().Type == tokenizer.TokenString {
|
||||
part2 := p.Consume().Value
|
||||
ref.Schema = p.Unquote(part1)
|
||||
ref.Name = p.Unquote(part2)
|
||||
} else {
|
||||
ref.Name = p.Unquote(part1)
|
||||
}
|
||||
} else {
|
||||
ref.Name = p.Unquote(part1)
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("AS") {
|
||||
p.Consume()
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
ref.Alias = p.Unquote(p.Consume().Value)
|
||||
}
|
||||
|
||||
return ref
|
||||
}
|
||||
|
||||
func (p *Parser) parseJoinClause() *sqlstmt.JoinClause {
|
||||
start := p.Pos
|
||||
joinType := sqlstmt.JoinKindInner
|
||||
if p.Current().IsKeyword("LEFT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindLeft
|
||||
} else if p.Current().IsKeyword("RIGHT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindRight
|
||||
} else if p.Current().IsKeyword("FULL") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindFull
|
||||
} else if p.Current().IsKeyword("CROSS") {
|
||||
p.Consume()
|
||||
joinType = sqlstmt.JoinKindCross
|
||||
} else if p.Current().IsKeyword("INNER") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
if !p.Current().IsKeyword("JOIN") {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
p.Consume()
|
||||
|
||||
tableRef := p.parseTableRef()
|
||||
if tableRef.Name == "" {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
|
||||
var onExpr *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("ON") {
|
||||
p.Consume()
|
||||
onStart := p.Pos
|
||||
p.SkipExpr()
|
||||
onExpr = &sqlstmt.Expr{Text: p.TextFromExclusive(onStart)}
|
||||
} else if p.Current().IsKeyword("USING") {
|
||||
p.Consume()
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
}
|
||||
}
|
||||
|
||||
return &sqlstmt.JoinClause{
|
||||
Kind: joinType,
|
||||
Table: tableRef,
|
||||
On: onExpr,
|
||||
Text: p.TextFrom(start),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- INSERT 解析 ----------
|
||||
|
||||
func (p *Parser) parseInsert() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume() // INSERT
|
||||
|
||||
if p.Current().IsKeyword("INTO") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
tableRef := p.parseTableRef()
|
||||
|
||||
// 解析列名列表
|
||||
columns := []string{}
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
for !p.Current().IsEOF() && p.Current().Value != ")" {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
columns = append(columns, p.Unquote(p.Consume().Value))
|
||||
} else {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
if p.Current().Value == ")" {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("VALUES") {
|
||||
p.Consume()
|
||||
for p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
} else if p.Current().IsKeyword("SELECT") {
|
||||
p.parseSelect()
|
||||
}
|
||||
|
||||
return &sqlstmt.InsertStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Table: tableRef,
|
||||
Columns: columns,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- UPDATE 解析 ----------
|
||||
|
||||
func (p *Parser) parseUpdate() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume() // UPDATE
|
||||
|
||||
tableRef := p.parseTableRef()
|
||||
tables := []sqlstmt.TableRef{tableRef}
|
||||
|
||||
// SET - 解析字段赋值
|
||||
assignments := make([]sqlstmt.Assignment, 0)
|
||||
if p.Current().IsKeyword("SET") {
|
||||
p.Consume()
|
||||
for !p.Current().IsEOF() {
|
||||
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
assign := p.parseAssignment()
|
||||
if assign != nil {
|
||||
assignments = append(assignments, *assign)
|
||||
}
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var where *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
return &sqlstmt.UpdateStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Tables: tables,
|
||||
Set: assignments,
|
||||
Where: where,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- DELETE 解析 ----------
|
||||
|
||||
func (p *Parser) parseDelete() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume() // DELETE
|
||||
|
||||
if p.Current().IsKeyword("FROM") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
// 使用 parseTableRef 正确解析表名
|
||||
tableRef := p.parseTableRef()
|
||||
tables := []sqlstmt.TableRef{tableRef}
|
||||
|
||||
// WHERE
|
||||
var where *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
return &sqlstmt.DeleteStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Tables: tables,
|
||||
Where: where,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- DDL 解析 ----------
|
||||
|
||||
func (p *Parser) parseCreate() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "CREATE",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseDrop() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "DROP",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseAlter() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "ALTER",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseGenericDdl() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "DDL",
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- WITH 解析 ----------
|
||||
|
||||
func (p *Parser) parseWith() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.WithStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 通用语句解析 ----------
|
||||
|
||||
func (p *Parser) parseGenericStmt() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.OtherStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
}
|
||||
}
|
||||
|
||||
// parseAssignment 解析 SET 字段赋值:column = value
|
||||
func (p *Parser) parseAssignment() *sqlstmt.Assignment {
|
||||
start := p.Pos
|
||||
colText := ""
|
||||
for !p.Current().IsEOF() && p.Current().Value != "=" && p.Current().Value != "," &&
|
||||
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
|
||||
colText += p.Consume().Value
|
||||
}
|
||||
if p.Current().Value != "=" {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
p.Consume() // =
|
||||
|
||||
valStart := p.Pos
|
||||
for !p.Current().IsEOF() && p.Current().Value != "," &&
|
||||
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
continue
|
||||
}
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
return &sqlstmt.Assignment{
|
||||
Column: p.Unquote(strings.TrimSpace(colText)),
|
||||
Value: &sqlstmt.Expr{Text: p.TextFromExclusive(valStart)},
|
||||
Text: p.TextFromExclusive(start),
|
||||
}
|
||||
}
|
||||
569
server/internal/db/dbm/sqlparser/dm/parser_dml_test.go
Normal file
569
server/internal/db/dbm/sqlparser/dm/parser_dml_test.go
Normal file
@@ -0,0 +1,569 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== INSERT 测试 ==========
|
||||
|
||||
func TestDmInsertBasic(t *testing.T) {
|
||||
sql := "INSERT INTO users (name, age) VALUES ('John', 30)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt, ok := stmt.(*sqlstmt.InsertStmt)
|
||||
if !ok {
|
||||
t.Fatal("expected InsertStmt")
|
||||
}
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, insertStmt.GetText())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmInsertMultipleRows(t *testing.T) {
|
||||
sql := "INSERT INTO users (name, age) VALUES ('John', 30), ('Jane', 25), ('Bob', 35)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UPDATE 测试 ==========
|
||||
|
||||
func TestDmUpdateBasic(t *testing.T) {
|
||||
sql := "UPDATE users SET name = 'John' WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
|
||||
if updateStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
|
||||
}
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmUpdateMultipleColumns(t *testing.T) {
|
||||
sql := "UPDATE users SET name = 'John', age = 30, email = 'john@example.com' WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmUpdateReturning(t *testing.T) {
|
||||
sql := "UPDATE users SET status = 0 WHERE id = 1 RETURNING id, name, status"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== DELETE 测试 ==========
|
||||
|
||||
func TestDmDeleteBasic(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
|
||||
if deleteStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, deleteStmt.GetText())
|
||||
}
|
||||
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
if deleteStmt.Where == nil || deleteStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmDeleteReturning(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE status = 0 RETURNING id, name"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil || deleteStmt.Where.Text != "status = 0" {
|
||||
t.Errorf("expected WHERE='status = 0'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 复杂 DML 测试 ==========
|
||||
|
||||
func TestDmComplexUpdateWithSubquery(t *testing.T) {
|
||||
sql := "UPDATE users SET total_orders = (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmComplexDeleteWithSubquery(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE id NOT IN (SELECT DISTINCT user_id FROM orders)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 双引号标识符测试 ==========
|
||||
|
||||
func TestDmDoubleQuoteDML(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "name" = 'John' WHERE "id" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
// 注意:DM 解析器对于双引号标识符可能保留引号
|
||||
if len(updateStmt.Tables) != 1 {
|
||||
t.Fatal("expected 1 table")
|
||||
}
|
||||
t.Logf("UPDATE table: %s", updateStmt.Tables[0].Name)
|
||||
t.Logf("UPDATE WHERE: %+v", updateStmt.Where)
|
||||
}
|
||||
|
||||
// ========== 复杂 INSERT 测试 ==========
|
||||
|
||||
func TestDmInsertWithDoubleQuotes(t *testing.T) {
|
||||
sql := `INSERT INTO "users" ("name", "age", "email") VALUES ('John', 30, 'john@example.com')`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmInsertWithSpecialChars(t *testing.T) {
|
||||
sql := `INSERT INTO "logs" ("message", "level") VALUES ('Error: connection failed!', 'ERROR')`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT text: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmInsertMultipleRowsComplex(t *testing.T) {
|
||||
// 多行插入包含特殊字符
|
||||
sql := `INSERT INTO "users" ("name", "email") VALUES ('John', 'john@example.com'), ('Jane', 'jane@test.com'), ('Bob', 'bob@demo.com')`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("Multiple INSERT: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmInsertReturningComplex(t *testing.T) {
|
||||
sql := `INSERT INTO "users" ("name", "email") VALUES ('John', 'john@example.com') RETURNING "id", "name", "created_at"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT RETURNING: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
// ========== 复杂 UPDATE 测试 ==========
|
||||
|
||||
func TestDmUpdateWithDoubleQuotes(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "name" = 'John', "age" = 30 WHERE "id" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 {
|
||||
t.Fatal("expected 1 table")
|
||||
}
|
||||
t.Logf("UPDATE table: %s", updateStmt.Tables[0].Name)
|
||||
}
|
||||
|
||||
func TestDmUpdateWithComplexWhere(t *testing.T) {
|
||||
sql := `UPDATE "orders" SET "status" = 'cancelled' WHERE "status" = 'pending' AND "created_at" < '2024-01-01' AND ("amount" < 100 OR "user_id" IS NULL)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("Complex WHERE: %s", updateStmt.Where.Text)
|
||||
}
|
||||
|
||||
func TestDmUpdateWithFunctions(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "updated_at" = NOW(), "login_count" = "login_count" + 1 WHERE "id" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
t.Logf("UPDATE with functions: %s", updateStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmUpdateReturningComplex(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "status" = 0 WHERE "id" = 1 RETURNING "id", "name", "status"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != `"id" = 1` {
|
||||
t.Errorf("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmUpdateWithSubquery(t *testing.T) {
|
||||
sql := `UPDATE "users" SET "total" = (SELECT SUM("amount") FROM "orders" WHERE "user_id" = "users"."id") WHERE "status" = 'active'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 复杂 DELETE 测试 ==========
|
||||
|
||||
func TestDmDeleteWithDoubleQuotes(t *testing.T) {
|
||||
sql := `DELETE FROM "users" WHERE "id" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmDeleteWithComplexWhere(t *testing.T) {
|
||||
sql := `DELETE FROM "logs" WHERE "created_at" < '2024-01-01' AND ("level" = 'DEBUG' OR "level" = 'INFO')`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("Complex DELETE WHERE: %s", deleteStmt.Where.Text)
|
||||
}
|
||||
|
||||
func TestDmDeleteReturningComplex(t *testing.T) {
|
||||
sql := `DELETE FROM "users" WHERE "status" = 0 RETURNING "id", "name"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil || deleteStmt.Where.Text != `"status" = 0` {
|
||||
t.Errorf("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmDeleteWithSubquery(t *testing.T) {
|
||||
sql := `DELETE FROM "users" WHERE "id" NOT IN (SELECT DISTINCT "user_id" FROM "orders")`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 复杂 DDL 测试 ==========
|
||||
|
||||
func TestDmDDLCreateTableWithQuotes(t *testing.T) {
|
||||
sql := `CREATE TABLE "users" ("id" INT PRIMARY KEY, "name" VARCHAR(50) NOT NULL, "email" VARCHAR(255) UNIQUE)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
if ddlStmt.DdlKind != "CREATE" {
|
||||
t.Errorf("expected DdlKind='CREATE'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmDDLCreateTableWithComments(t *testing.T) {
|
||||
sql := `CREATE TABLE "orders" ("id" INT COMMENT '订单ID', "amount" DECIMAL(10,2) COMMENT '金额') COMMENT='订单表'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
t.Logf("CREATE TABLE with comments: %s", ddlStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmDDLAlterTableAddColumn(t *testing.T) {
|
||||
sql := `ALTER TABLE "users" ADD COLUMN "email" VARCHAR(255)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
if ddlStmt.DdlKind != "ALTER" {
|
||||
t.Errorf("expected DdlKind='ALTER'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmDDLDropIfExists(t *testing.T) {
|
||||
sql := `DROP TABLE IF EXISTS "users"`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
if ddlStmt.DdlKind != "DROP" {
|
||||
t.Errorf("expected DdlKind='DROP'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 注释风格测试 ==========
|
||||
|
||||
func TestDmDMLWithSingleLineComment(t *testing.T) {
|
||||
// DM 单行注释 --
|
||||
sql := "-- 更新用户信息\nUPDATE \"users\" SET \"name\" = 'John' WHERE \"id\" = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if len(updateStmt.Tables) != 1 {
|
||||
t.Fatal("expected 1 table")
|
||||
}
|
||||
t.Logf("UPDATE with -- comment: %s", updateStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmDMLWithMultiLineComment(t *testing.T) {
|
||||
// 多行注释 /* */
|
||||
sql := "/* 删除日志 */ DELETE FROM \"logs\" WHERE \"created_at\" < '2024-01-01'"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if deleteStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("DELETE with /* */ comment: %s", deleteStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmDMLWithInlineComment(t *testing.T) {
|
||||
// 行内注释
|
||||
sql := `SELECT "id", "name" /* 用户名 */ FROM "users" WHERE "status" = 1`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items")
|
||||
}
|
||||
t.Logf("SELECT with inline comment: %s", selectStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmDMLWithMultipleComments(t *testing.T) {
|
||||
// 多个注释
|
||||
sql := `-- 查询订单
|
||||
/* 只查询已支付的 */
|
||||
SELECT "id", "amount" FROM "orders"
|
||||
WHERE "status" = 'paid' AND "created_at" > '2024-01-01'
|
||||
ORDER BY "created_at" DESC`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
t.Logf("Multiple comments: %s", selectStmt.GetText())
|
||||
}
|
||||
|
||||
func TestDmInsertWithComment(t *testing.T) {
|
||||
sql := "-- 插入数据\nINSERT INTO \"users\" (\"name\", \"email\") VALUES ('John', 'john@example.com')"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
t.Logf("INSERT with comment: %s", insertStmt.GetText())
|
||||
}
|
||||
|
||||
// ========== Schema.Table 测试 ==========
|
||||
|
||||
func TestDmInsertWithSchema(t *testing.T) {
|
||||
sql := `INSERT INTO "TEST"."users" ("name", "age") VALUES ('John', 30)`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.Table.Schema != "TEST" {
|
||||
t.Errorf("expected schema='TEST', got '%s'", insertStmt.Table.Schema)
|
||||
}
|
||||
if insertStmt.Table.Name != "users" {
|
||||
t.Errorf("expected table='users', got '%s'", insertStmt.Table.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmUpdateWithSchema(t *testing.T) {
|
||||
sql := `UPDATE "TEST"."t_db" SET "name" = 'test' WHERE "id" = 5`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
|
||||
t.Logf("UPDATE text: %s", updateStmt.GetText())
|
||||
t.Logf("UPDATE tables: %+v", updateStmt.Tables)
|
||||
t.Logf("UPDATE WHERE: %+v", updateStmt.Where)
|
||||
|
||||
if len(updateStmt.Tables) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(updateStmt.Tables))
|
||||
}
|
||||
if updateStmt.Tables[0].Schema != "TEST" {
|
||||
t.Errorf("expected schema='TEST', got '%s'", updateStmt.Tables[0].Schema)
|
||||
}
|
||||
if updateStmt.Tables[0].Name != "t_db" {
|
||||
t.Errorf("expected table='t_db', got '%s'", updateStmt.Tables[0].Name)
|
||||
}
|
||||
if updateStmt.Where == nil {
|
||||
t.Errorf("expected WHERE")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmDeleteWithSchema(t *testing.T) {
|
||||
sql := `DELETE FROM "TEST"."logs" WHERE "created_at" < '2024-01-01'`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
if len(deleteStmt.Tables) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(deleteStmt.Tables))
|
||||
}
|
||||
if deleteStmt.Tables[0].Schema != "TEST" {
|
||||
t.Errorf("expected schema='TEST', got '%s'", deleteStmt.Tables[0].Schema)
|
||||
}
|
||||
if deleteStmt.Tables[0].Name != "logs" {
|
||||
t.Errorf("expected table='logs', got '%s'", deleteStmt.Tables[0].Name)
|
||||
}
|
||||
if deleteStmt.Where == nil {
|
||||
t.Errorf("expected WHERE")
|
||||
}
|
||||
}
|
||||
472
server/internal/db/dbm/sqlparser/dm/parser_pagination_test.go
Normal file
472
server/internal/db/dbm/sqlparser/dm/parser_pagination_test.go
Normal file
@@ -0,0 +1,472 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== 简单分页测试 ==========
|
||||
|
||||
func TestDmPaginationSimple(t *testing.T) {
|
||||
sql := "SELECT * FROM users LIMIT 0, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
|
||||
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
|
||||
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationWithOffset(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users LIMIT 20, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationKeywordOffset(t *testing.T) {
|
||||
// DM 也支持 LIMIT count OFFSET offset
|
||||
sql := "SELECT * FROM products LIMIT 15 OFFSET 30"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
|
||||
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
|
||||
|
||||
if selectStmt.Limit.Offset != 30 {
|
||||
t.Errorf("expected offset=30, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 15 {
|
||||
t.Errorf("expected count=15, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationOnlyLimit(t *testing.T) {
|
||||
sql := "SELECT * FROM orders LIMIT 50"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 50 {
|
||||
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== SELECT TOP 分页测试 ==========
|
||||
|
||||
func TestDmPaginationTopN(t *testing.T) {
|
||||
// DM 支持 SELECT TOP n
|
||||
sql := "SELECT TOP 10 id, name FROM users ORDER BY id"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items")
|
||||
}
|
||||
t.Logf("SELECT TOP parsed successfully")
|
||||
}
|
||||
|
||||
func TestDmPaginationTopWithWhere(t *testing.T) {
|
||||
sql := "SELECT TOP 20 id, name FROM users WHERE status = 1 ORDER BY created_at DESC"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
t.Logf("SELECT TOP with WHERE parsed successfully")
|
||||
}
|
||||
|
||||
// ========== 复杂分页测试 ==========
|
||||
|
||||
func TestDmPaginationWithWhere(t *testing.T) {
|
||||
sql := "SELECT id, name, email FROM users WHERE status = 1 AND age > 18 LIMIT 10, 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationWithOrderBy(t *testing.T) {
|
||||
sql := "SELECT * FROM users ORDER BY created_at DESC, id ASC LIMIT 0, 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.OrderBy) != 2 {
|
||||
t.Fatalf("expected 2 ORDER BY clauses")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationWithWhereOrderBy(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users WHERE status = 'active' ORDER BY score DESC LIMIT 50, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 'active'" {
|
||||
t.Errorf("expected WHERE='status = 'active''")
|
||||
}
|
||||
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "score" {
|
||||
t.Errorf("expected ORDER BY score")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 50 {
|
||||
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationWithJoin(t *testing.T) {
|
||||
sql := "SELECT u.id, u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1 ORDER BY o.created_at DESC LIMIT 100, 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatal("expected 1 JOIN")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 100 {
|
||||
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationWithGroupBy(t *testing.T) {
|
||||
sql := "SELECT user_id, COUNT(*) as order_count FROM orders GROUP BY user_id ORDER BY order_count DESC LIMIT 0, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UNION 分页测试 ==========
|
||||
|
||||
func TestDmPaginationWithUnion(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 0, 50"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatal("expected 1 UNION")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause for UNION")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 50 {
|
||||
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationUnionAll(t *testing.T) {
|
||||
sql := "SELECT name FROM products UNION ALL SELECT name FROM services LIMIT 10, 30"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatal("expected 1 UNION ALL")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 30 {
|
||||
t.Errorf("expected count=30, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmMultipleUnionsPagination(t *testing.T) {
|
||||
sql := "SELECT id FROM t1 UNION SELECT id FROM t2 UNION SELECT id FROM t3 ORDER BY 1 LIMIT 50, 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 2 {
|
||||
t.Fatalf("expected 2 UNIONs, got %d", len(selectStmt.Unions))
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 50 {
|
||||
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 子查询分页测试 ==========
|
||||
|
||||
func TestDmPaginationWithSubquery(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1 ORDER BY id LIMIT 0, 100) AS tmp LIMIT 20, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatal("expected 1 FROM table")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected outer LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected outer offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected outer count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmNestedPagination(t *testing.T) {
|
||||
// 嵌套分页查询
|
||||
sql := "SELECT * FROM (SELECT * FROM users LIMIT 0, 100) AS tmp LIMIT 10, 5"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected outer LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected outer offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 5 {
|
||||
t.Errorf("expected outer count=5, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmPaginationWithExists(t *testing.T) {
|
||||
sql := "SELECT * FROM users u WHERE EXISTS (SELECT 1 FROM orders o WHERE o.user_id = u.id LIMIT 1) LIMIT 20, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 大数据量分页测试 ==========
|
||||
|
||||
func TestDmLargeOffsetPagination(t *testing.T) {
|
||||
sql := "SELECT * FROM logs ORDER BY id LIMIT 1000000, 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 1000000 {
|
||||
t.Errorf("expected offset=1000000, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== FOR UPDATE 分页测试 ==========
|
||||
|
||||
func TestDmPaginationForUpdate(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE status = 0 ORDER BY id LIMIT 10 OFFSET 0 FOR UPDATE"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 双引号标识符分页测试 ==========
|
||||
|
||||
func TestDmPaginationDoubleQuote(t *testing.T) {
|
||||
sql := `SELECT "id", "name" FROM "users" WHERE "status" = 1 ORDER BY "id" DESC LIMIT 20, 10`
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
276
server/internal/db/dbm/sqlparser/dm/parser_select_test.go
Normal file
276
server/internal/db/dbm/sqlparser/dm/parser_select_test.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== SELECT 测试 ==========
|
||||
|
||||
func TestDmSelectBasic(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users WHERE status = 1 ORDER BY id DESC LIMIT 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 1" {
|
||||
t.Errorf("expected WHERE='status = 1'")
|
||||
}
|
||||
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "id" || !selectStmt.OrderBy[0].Desc {
|
||||
t.Errorf("expected ORDER BY id DESC")
|
||||
}
|
||||
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 10" || selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected LIMIT 10")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSelectTop(t *testing.T) {
|
||||
sql := "SELECT TOP 10 id, name FROM users ORDER BY id"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSelectDistinct(t *testing.T) {
|
||||
sql := "SELECT DISTINCT department_id FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if !selectStmt.Distinct {
|
||||
t.Error("expected Distinct=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSelectStar(t *testing.T) {
|
||||
sql := "SELECT * FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Items) != 1 || selectStmt.Items[0].Text != "*" {
|
||||
t.Errorf("expected SELECT *")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSelectWithJoin(t *testing.T) {
|
||||
sql := "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Joins) != 1 || selectStmt.Joins[0].Table.Name != "orders" {
|
||||
t.Errorf("expected JOIN orders")
|
||||
}
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "u.status = 1" {
|
||||
t.Errorf("expected WHERE='u.status = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSelectMultipleJoins(t *testing.T) {
|
||||
sql := "SELECT u.id, o.amount, p.name FROM users u LEFT JOIN orders o ON u.id = o.user_id INNER JOIN products p ON o.product_id = p.id WHERE u.status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Joins) != 2 {
|
||||
t.Fatalf("expected 2 joins")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSelectComplexWhere(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE status = 1 AND (age > 18 OR role = 'admin') ORDER BY name"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
expectedWhere := "status = 1 AND (age > 18 OR role = 'admin')"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE='%s'", expectedWhere)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== LIMIT/OFFSET 测试 ==========
|
||||
|
||||
func TestDmLimitOffset(t *testing.T) {
|
||||
sql := "SELECT id FROM users ORDER BY id LIMIT 20, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 20, 10" {
|
||||
t.Errorf("expected LIMIT text='LIMIT 20, 10'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmLimitOnly(t *testing.T) {
|
||||
sql := "SELECT * FROM users LIMIT 5"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 5" || selectStmt.Limit.Count != 5 {
|
||||
t.Errorf("expected LIMIT 5")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UNION 测试 ==========
|
||||
|
||||
func TestDmUnion(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatalf("expected 1 union")
|
||||
}
|
||||
if len(selectStmt.OrderBy) != 1 {
|
||||
t.Fatalf("expected ORDER BY")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmUnionWithOrderBy(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatalf("expected 1 union")
|
||||
}
|
||||
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "1" || !selectStmt.OrderBy[0].Desc {
|
||||
t.Errorf("expected ORDER BY 1 DESC")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmUnionWithLimit(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION ALL SELECT id FROM admins LIMIT 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 20" || selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected LIMIT 20")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== FOR UPDATE 测试 ==========
|
||||
|
||||
func TestDmForUpdate(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users WHERE id = 1 FOR UPDATE"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== WITH (CTE) 测试 ==========
|
||||
|
||||
func TestDmWithClause(t *testing.T) {
|
||||
sql := "WITH temp_users AS (SELECT * FROM users WHERE status = 1) SELECT * FROM temp_users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmMultipleWith(t *testing.T) {
|
||||
sql := `WITH
|
||||
active_users AS (SELECT * FROM users WHERE status = 1),
|
||||
recent_orders AS (SELECT * FROM orders WHERE created_at > NOW())
|
||||
SELECT * FROM active_users u JOIN recent_orders o ON u.id = o.user_id`
|
||||
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== DDL 测试 ==========
|
||||
|
||||
func TestDmDDL(t *testing.T) {
|
||||
sqls := []string{
|
||||
"CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50))",
|
||||
"DROP TABLE users",
|
||||
"ALTER TABLE users ADD COLUMN email VARCHAR(255)",
|
||||
}
|
||||
|
||||
for _, sql := range sqls {
|
||||
t.Run(sql[:10], func(t *testing.T) {
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
297
server/internal/db/dbm/sqlparser/dm/parser_subquery_test.go
Normal file
297
server/internal/db/dbm/sqlparser/dm/parser_subquery_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== 子查询测试 ==========
|
||||
|
||||
func TestDmSubqueryInFrom(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1) AS u WHERE u.id > 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("FROM 表数量: %d", len(selectStmt.From))
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 table in FROM, got %d", len(selectStmt.From))
|
||||
}
|
||||
|
||||
t.Logf("FROM[0] Name: %s", selectStmt.From[0].Name)
|
||||
t.Logf("FROM[0] Alias: %s", selectStmt.From[0].Alias)
|
||||
|
||||
if selectStmt.From[0].Alias != "u" {
|
||||
t.Errorf("expected alias='u', got '%s'", selectStmt.From[0].Alias)
|
||||
}
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected outer WHERE clause")
|
||||
}
|
||||
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
|
||||
if selectStmt.Where.Text != "u.id > 10" {
|
||||
t.Errorf("expected outer WHERE text='u.id > 10', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSubqueryInWhere(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE amount > 100)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "id IN (SELECT user_id FROM orders WHERE amount > 100)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSubqueryInSelect(t *testing.T) {
|
||||
sql := "SELECT id, name, (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) AS order_count FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
|
||||
|
||||
for i, item := range selectStmt.Items {
|
||||
t.Logf(" Item[%d]: Text='%s', ColumnName='%s', Alias='%s'", i, item.Text, item.ColumnName, item.Alias)
|
||||
}
|
||||
|
||||
if len(selectStmt.Items) != 3 {
|
||||
t.Fatalf("expected 3 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
|
||||
if selectStmt.Items[2].Alias != "order_count" {
|
||||
t.Errorf("expected item[2] alias='order_count', got '%s'", selectStmt.Items[2].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmNestedSubquery(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT * FROM (SELECT id FROM users WHERE status = 1) AS inner_u) AS outer_u WHERE outer_u.id > 5"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected outer WHERE")
|
||||
}
|
||||
if selectStmt.Where.Text != "outer_u.id > 5" {
|
||||
t.Errorf("expected WHERE text='outer_u.id > 5', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmCorrelatedSubquery(t *testing.T) {
|
||||
sql := "SELECT u.id, u.name FROM users u WHERE u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSubqueryWithExists(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmMultipleSubqueries(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmSubqueryWithLimit(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users ORDER BY created_at DESC LIMIT 10) AS top_users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("FROM 数量: %d", len(selectStmt.From))
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(selectStmt.From))
|
||||
}
|
||||
|
||||
if selectStmt.From[0].Alias != "top_users" {
|
||||
t.Errorf("expected alias='top_users', got '%s'", selectStmt.From[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmComplexSubquery(t *testing.T) {
|
||||
sql := `SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count,
|
||||
(SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id AND o.status = 'completed') AS total_amount
|
||||
FROM users u
|
||||
WHERE u.status = 1
|
||||
AND u.id IN (SELECT user_id FROM user_groups WHERE group_id = 5)
|
||||
ORDER BY u.name`
|
||||
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
|
||||
|
||||
if len(selectStmt.Items) != 4 {
|
||||
t.Fatalf("expected 4 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
|
||||
t.Logf("Item[0]: %s", selectStmt.Items[0].Text)
|
||||
t.Logf("Item[1]: %s", selectStmt.Items[1].Text)
|
||||
t.Logf("Item[2]: %s (alias: %s)", selectStmt.Items[2].Text, selectStmt.Items[2].Alias)
|
||||
t.Logf("Item[3]: %s (alias: %s)", selectStmt.Items[3].Text, selectStmt.Items[3].Alias)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if len(selectStmt.OrderBy) != 1 {
|
||||
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
|
||||
}
|
||||
if selectStmt.OrderBy[0].Text != "u.name" {
|
||||
t.Errorf("expected order by text='u.name', got '%s'", selectStmt.OrderBy[0].Text)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== JOIN 测试 ==========
|
||||
|
||||
func TestDmMultipleJoins(t *testing.T) {
|
||||
sql := "SELECT u.id, o.amount, p.name FROM users u LEFT JOIN orders o ON u.id = o.user_id INNER JOIN products p ON o.product_id = p.id WHERE u.status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("JOIN 数量: %d", len(selectStmt.Joins))
|
||||
|
||||
if len(selectStmt.Joins) != 2 {
|
||||
t.Fatalf("expected 2 joins, got %d", len(selectStmt.Joins))
|
||||
}
|
||||
|
||||
if selectStmt.Joins[0].Table.Name != "orders" {
|
||||
t.Errorf("expected first join table='orders', got '%s'", selectStmt.Joins[0].Table.Name)
|
||||
}
|
||||
if selectStmt.Joins[1].Table.Name != "products" {
|
||||
t.Errorf("expected second join table='products', got '%s'", selectStmt.Joins[1].Table.Name)
|
||||
}
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Where.Text != "u.status = 1" {
|
||||
t.Errorf("expected WHERE text='u.status = 1', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmRightJoin(t *testing.T) {
|
||||
sql := "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDmCrossJoin(t *testing.T) {
|
||||
sql := "SELECT * FROM users CROSS JOIN roles"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join, got %d", len(selectStmt.Joins))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user