mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-05-18 17:05:21 +08:00
refactor: 移除antlr4减小包体积&ai助手优化
This commit is contained in:
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@@ -1 +0,0 @@
|
||||
java -jar antlr-4.13.1-complete.jar -Dlanguage=Go -package parser -visitor *.g4
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,59 +1,17 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/db/dbm/sqlparser/base"
|
||||
mysqlparser "mayfly-go/internal/db/dbm/sqlparser/mysql/antlr4"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"mayfly-go/pkg/gox"
|
||||
"mayfly-go/pkg/logx"
|
||||
|
||||
"github.com/antlr4-go/antlr/v4"
|
||||
)
|
||||
|
||||
func GetMysqlParserTree(baseLine int, statement string) (antlr.ParseTree, *antlr.CommonTokenStream, error) {
|
||||
lexer := mysqlparser.NewMySqlLexer(antlr.NewInputStream(statement))
|
||||
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
|
||||
parser := mysqlparser.NewMySqlParser(stream)
|
||||
|
||||
lexerErrorListener := &base.ParseErrorListener{
|
||||
BaseLine: baseLine,
|
||||
}
|
||||
lexer.RemoveErrorListeners()
|
||||
lexer.AddErrorListener(lexerErrorListener)
|
||||
|
||||
parserErrorListener := &base.ParseErrorListener{
|
||||
BaseLine: baseLine,
|
||||
}
|
||||
parser.RemoveErrorListeners()
|
||||
parser.AddErrorListener(parserErrorListener)
|
||||
parser.BuildParseTrees = true
|
||||
|
||||
tree := parser.Root()
|
||||
|
||||
if lexerErrorListener.Err != nil {
|
||||
return nil, nil, lexerErrorListener.Err
|
||||
}
|
||||
|
||||
if parserErrorListener.Err != nil {
|
||||
return nil, nil, parserErrorListener.Err
|
||||
}
|
||||
|
||||
return tree, stream, nil
|
||||
}
|
||||
|
||||
type MysqlParser struct {
|
||||
}
|
||||
|
||||
func (*MysqlParser) Parse(stmt string) (stmts []sqlstmt.Stmt, err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
logx.ErrorTrace("mysql sql parser err: ", e)
|
||||
err = e.(error)
|
||||
}
|
||||
}()
|
||||
tree, _, err := GetMysqlParserTree(1, stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tree.Accept(new(MysqlVisitor)).([]sqlstmt.Stmt), nil
|
||||
func (*MysqlParser) Parse(stmt string) (sqlstmt.Stmt, error) {
|
||||
defer gox.Recover(func(e error) {
|
||||
logx.ErrorTrace("mysql sql parser err: ", e)
|
||||
})
|
||||
return NewParser(stmt).Parse()
|
||||
}
|
||||
|
||||
@@ -1,122 +0,0 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParserSimpleSelect(t *testing.T) {
|
||||
parser := new(MysqlParser)
|
||||
|
||||
// sql := "select sum(t.age), t.`id` tid, t1.id2, t1.* from T_DB t left join t_db_ins as t1 on t.id = t1.id2 where t.id = 1 AND t1.status=0 and t.id2='9' and t.name in ('name2', 'name3') order by t.id desc limit 0, 100"
|
||||
sql := "SELECT t.* FROM `t_sys_resource` t WHERE t.`id` > 0"
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserUnionSelect(t *testing.T) {
|
||||
parser := new(MysqlParser)
|
||||
|
||||
sql := "(select sum(t.age), t.id tid, t1.id2, t1.* from T_DB t join t_db_ins as t1 on t.id = t1.id2 where t.id = 1 AND t1.status=0 and t.id2='9' and t.name in ('name2', 'name3') order by t.id desc limit 0, 100) union all (select * from t_db2) limit 10"
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserSingleUpdate(t *testing.T) {
|
||||
parser := new(MysqlParser)
|
||||
|
||||
sql := `UPDATE t_sys_msg t
|
||||
SET
|
||||
t.recipient_id = 13,
|
||||
t.creator = 'admin4'
|
||||
WHERE
|
||||
t.id = 1;`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserInsert(t *testing.T) {
|
||||
parser := new(MysqlParser)
|
||||
|
||||
sql := `INSERT INTO
|
||||
mayfly_go.t_sys_msg (
|
||||
type,
|
||||
msg,
|
||||
recipient_id,
|
||||
creator_id,
|
||||
create_time,
|
||||
is_deleted
|
||||
)
|
||||
VALUES
|
||||
(1, 'hahaha', 2, 1, '2024-08-26 15:36:27', 0);`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserSql(t *testing.T) {
|
||||
parser := new(MysqlParser)
|
||||
|
||||
// sql := `INSERT INTO
|
||||
// mayfly_go.t_sys_msg (
|
||||
// type,
|
||||
// msg,
|
||||
// recipient_id,
|
||||
// creator_id,
|
||||
// create_time,
|
||||
// is_deleted
|
||||
// )
|
||||
// VALUES
|
||||
// (1, 'hahaha', 2, 1, '2024-08-26 15:36:27', 0);`
|
||||
|
||||
sql := `UPDATE t_sys_msg
|
||||
SET
|
||||
recipient_id = 13,
|
||||
creator = 'admin4'
|
||||
WHERE
|
||||
id = 1;`
|
||||
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
switch stmt := stmts[0].(type) {
|
||||
case *sqlstmt.InsertStmt:
|
||||
t.Log("insert")
|
||||
t.Log(stmt.TableName.Identifier.Value)
|
||||
case *sqlstmt.UpdateStmt:
|
||||
t.Log("update")
|
||||
case *sqlstmt.DeleteStmt:
|
||||
t.Log("delete")
|
||||
case *sqlstmt.SelectStmt:
|
||||
t.Log("select")
|
||||
default:
|
||||
t.Log("other")
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
|
||||
func TestParserDelete(t *testing.T) {
|
||||
parser := new(MysqlParser)
|
||||
|
||||
sql := `DELETE FROM t_sys_log
|
||||
WHERE
|
||||
id IN (59);`
|
||||
stmts, err := parser.Parse(sql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log(stmts)
|
||||
}
|
||||
767
server/internal/db/dbm/sqlparser/mysql/parser.go
Normal file
767
server/internal/db/dbm/sqlparser/mysql/parser.go
Normal file
@@ -0,0 +1,767 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/base"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/tokenizer"
|
||||
)
|
||||
|
||||
// Parser MySQL 方言 SQL 解析器
|
||||
type Parser struct {
|
||||
*base.Lexer
|
||||
}
|
||||
|
||||
// NewParser 创建 MySQL 解析器
|
||||
func NewParser(sql string) *Parser {
|
||||
return &Parser{
|
||||
Lexer: base.NewLexer(sql, tokenizer.DialectConfig{
|
||||
BacktickAsIdentifier: true,
|
||||
HashLineComment: true,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
// Parse 解析单条 SQL 语句
|
||||
func (p *Parser) Parse() (sqlstmt.Stmt, error) {
|
||||
p.SkipSemicolons()
|
||||
if p.Current().IsEOF() {
|
||||
return nil, nil
|
||||
}
|
||||
stmt := p.parseStatement()
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseStatement() sqlstmt.Stmt {
|
||||
tok := p.Current()
|
||||
switch {
|
||||
case tok.IsKeyword("SELECT") || tok.Value == "(":
|
||||
return p.parseSelect()
|
||||
case tok.IsKeyword("INSERT"):
|
||||
return p.parseInsert()
|
||||
case tok.IsKeyword("UPDATE"):
|
||||
return p.parseUpdate()
|
||||
case tok.IsKeyword("DELETE"):
|
||||
return p.parseDelete()
|
||||
case tok.IsKeyword("CREATE"):
|
||||
return p.parseCreate()
|
||||
case tok.IsKeyword("DROP"):
|
||||
return p.parseDrop()
|
||||
case tok.IsKeyword("ALTER"):
|
||||
return p.parseAlter()
|
||||
case tok.IsKeyword("WITH"):
|
||||
return p.parseWith()
|
||||
case tok.IsKeyword("SHOW"):
|
||||
return p.parseShow()
|
||||
case tok.IsKeyword("TRUNCATE"):
|
||||
return p.parseGenericDdl()
|
||||
default:
|
||||
return p.parseGenericStmt()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- SELECT 解析 ----------
|
||||
|
||||
func (p *Parser) parseSelect() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
var selectStmt *sqlstmt.SelectStmt
|
||||
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
selectStmt = p.parseSelectBody()
|
||||
} else {
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.OtherStmt{Base: sqlstmt.Base{Text: p.TextFrom(start)}}
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
} else if p.Current().IsKeyword("SELECT") {
|
||||
selectStmt = p.parseSelectBody()
|
||||
} else {
|
||||
return &sqlstmt.OtherStmt{Base: sqlstmt.Base{Text: p.TextFrom(start)}}
|
||||
}
|
||||
|
||||
if selectStmt == nil {
|
||||
return &sqlstmt.OtherStmt{Base: sqlstmt.Base{Text: p.TextFrom(start)}}
|
||||
}
|
||||
|
||||
// UNION
|
||||
for p.Current().IsKeyword("UNION") {
|
||||
p.Consume()
|
||||
isAll := false
|
||||
if p.Current().IsKeyword("ALL") {
|
||||
p.Consume()
|
||||
isAll = true
|
||||
} else if p.Current().IsKeyword("DISTINCT") {
|
||||
p.Consume()
|
||||
}
|
||||
unionStart := p.Pos
|
||||
var unionSelect *sqlstmt.SelectStmt
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("SELECT") {
|
||||
unionSelect = p.parseSelectBody()
|
||||
}
|
||||
p.ExpectValue(")")
|
||||
} else if p.Current().IsKeyword("SELECT") {
|
||||
unionSelect = p.parseSelectBody()
|
||||
}
|
||||
if unionSelect != nil {
|
||||
// 如果 unionSelect 有 LIMIT 或 ORDER BY,移动到外层 selectStmt(UNION 的 LIMIT/ORDER BY 属于整个语句)
|
||||
if unionSelect.Limit != nil {
|
||||
selectStmt.Limit = unionSelect.Limit
|
||||
unionSelect.Limit = nil
|
||||
}
|
||||
if len(unionSelect.OrderBy) > 0 {
|
||||
selectStmt.OrderBy = unionSelect.OrderBy
|
||||
unionSelect.OrderBy = nil
|
||||
}
|
||||
unionSelect.Text = p.TextFrom(unionStart)
|
||||
}
|
||||
selectStmt.Unions = append(selectStmt.Unions, sqlstmt.UnionClause{
|
||||
Select: unionSelect,
|
||||
All: isAll,
|
||||
Text: p.TextFrom(unionStart),
|
||||
})
|
||||
}
|
||||
|
||||
// UNION 之后可能有 ORDER BY
|
||||
if p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
p.parseOrderBy(selectStmt)
|
||||
}
|
||||
|
||||
// LIMIT (UNION 之后的 LIMIT 覆盖子查询内部的 LIMIT)
|
||||
// 注意:只在确实解析到新 LIMIT 时才覆盖,避免 nil 覆盖有效值
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
if limit := p.parseLimit(); limit != nil {
|
||||
selectStmt.Limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
selectStmt.Text = p.TextFrom(start)
|
||||
return selectStmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseSelectBody() *sqlstmt.SelectStmt {
|
||||
if !p.Current().IsKeyword("SELECT") {
|
||||
return nil
|
||||
}
|
||||
start := p.Pos
|
||||
p.Consume()
|
||||
|
||||
distinct := false
|
||||
if p.Current().IsKeyword("DISTINCT") {
|
||||
p.Consume()
|
||||
distinct = true
|
||||
} else if p.Current().IsKeyword("ALL") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
items := p.parseSelectItems()
|
||||
|
||||
stmt := &sqlstmt.SelectStmt{
|
||||
Distinct: distinct,
|
||||
Items: items,
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("FROM") {
|
||||
p.Consume()
|
||||
p.parseFromClause(stmt)
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
stmt.Where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("GROUP") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
p.SkipGroupByExpr()
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("HAVING") {
|
||||
p.Consume()
|
||||
p.SkipExpr()
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
p.parseOrderBy(stmt)
|
||||
}
|
||||
|
||||
// LIMIT (子查询或简单 SELECT 的 LIMIT)
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
stmt.Limit = p.parseLimit()
|
||||
}
|
||||
|
||||
stmt.Text = p.TextFrom(start)
|
||||
return stmt
|
||||
}
|
||||
|
||||
func (p *Parser) parseSelectItems() []sqlstmt.SelectItem {
|
||||
items := make([]sqlstmt.SelectItem, 0)
|
||||
for !p.Current().IsEOF() {
|
||||
tok := p.Current()
|
||||
if tok.Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.IsSelectClauseEnd() {
|
||||
break
|
||||
}
|
||||
if tok.Value == "*" {
|
||||
p.Consume()
|
||||
items = append(items, sqlstmt.SelectItem{
|
||||
Kind: sqlstmt.SelectItemStar,
|
||||
Text: "*",
|
||||
})
|
||||
} else if tok.Type == tokenizer.TokenIdentifier && p.Peek(1).Value == "." && p.Peek(2).Value == "*" {
|
||||
tableAlias := p.Unquote(p.Consume().Value)
|
||||
p.Consume()
|
||||
p.Consume()
|
||||
items = append(items, sqlstmt.SelectItem{
|
||||
Kind: sqlstmt.SelectItemStar,
|
||||
Text: tableAlias + ".*",
|
||||
TableAlias: tableAlias,
|
||||
})
|
||||
} else {
|
||||
colStart := p.Pos
|
||||
p.skipSelectElement()
|
||||
colText := base.TrimTrailingComma(p.TextFromExclusive(colStart))
|
||||
colName, alias := p.ExtractColumnAndAlias(colText)
|
||||
kind := sqlstmt.SelectItemColumn
|
||||
if strings.Contains(colText, "(") {
|
||||
kind = sqlstmt.SelectItemExpr
|
||||
}
|
||||
items = append(items, sqlstmt.SelectItem{
|
||||
Kind: kind,
|
||||
Text: colText,
|
||||
Alias: alias,
|
||||
ColumnName: colName,
|
||||
})
|
||||
}
|
||||
}
|
||||
return items
|
||||
}
|
||||
|
||||
func (p *Parser) skipSelectElement() {
|
||||
for !p.Current().IsEOF() {
|
||||
tok := p.Current()
|
||||
if tok.Value == "," || p.IsSelectClauseEnd() {
|
||||
break
|
||||
}
|
||||
if tok.Value == "(" {
|
||||
p.SkipParentheses()
|
||||
continue
|
||||
}
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- FROM / TableRef 解析 ----------
|
||||
|
||||
func (p *Parser) parseFromClause(stmt *sqlstmt.SelectStmt) {
|
||||
for !p.Current().IsEOF() {
|
||||
if p.IsFromClauseEnd() {
|
||||
break
|
||||
}
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().IsKeyword("JOIN") || p.IsJoinStart() {
|
||||
join := p.parseJoinClause()
|
||||
if join != nil {
|
||||
stmt.Joins = append(stmt.Joins, *join)
|
||||
}
|
||||
continue
|
||||
}
|
||||
tableRef := p.parseTableRef()
|
||||
if tableRef.Name != "" {
|
||||
stmt.From = append(stmt.From, tableRef)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseTableRef() sqlstmt.TableRef {
|
||||
if p.Current().Value == "(" {
|
||||
start := p.Pos
|
||||
p.SkipParentheses()
|
||||
alias := ""
|
||||
if p.Current().IsKeyword("AS") {
|
||||
p.Consume()
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
alias = p.Unquote(p.Consume().Value)
|
||||
}
|
||||
return sqlstmt.TableRef{
|
||||
Name: p.TextFrom(start),
|
||||
Alias: alias,
|
||||
}
|
||||
}
|
||||
|
||||
if p.Current().Type != tokenizer.TokenIdentifier && p.Current().Type != tokenizer.TokenString {
|
||||
return sqlstmt.TableRef{}
|
||||
}
|
||||
|
||||
ref := sqlstmt.TableRef{}
|
||||
part1 := p.Consume().Value
|
||||
|
||||
if p.Current().Value == "." {
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenIdentifier || p.Current().Type == tokenizer.TokenString {
|
||||
part2 := p.Consume().Value
|
||||
ref.Schema = p.Unquote(part1)
|
||||
ref.Name = p.Unquote(part2)
|
||||
} else {
|
||||
ref.Name = p.Unquote(part1)
|
||||
}
|
||||
} else {
|
||||
ref.Name = p.Unquote(part1)
|
||||
}
|
||||
|
||||
if p.Current().IsKeyword("AS") {
|
||||
p.Consume()
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
ref.Alias = p.Unquote(p.Consume().Value)
|
||||
}
|
||||
|
||||
return ref
|
||||
}
|
||||
|
||||
func (p *Parser) parseJoinClause() *sqlstmt.JoinClause {
|
||||
start := p.Pos
|
||||
joinType := sqlstmt.JoinKindInner
|
||||
if p.Current().IsKeyword("LEFT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindLeft
|
||||
} else if p.Current().IsKeyword("RIGHT") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindRight
|
||||
} else if p.Current().IsKeyword("FULL") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("OUTER") {
|
||||
p.Consume()
|
||||
}
|
||||
joinType = sqlstmt.JoinKindFull
|
||||
} else if p.Current().IsKeyword("CROSS") {
|
||||
p.Consume()
|
||||
joinType = sqlstmt.JoinKindCross
|
||||
} else if p.Current().IsKeyword("NATURAL") {
|
||||
p.Consume()
|
||||
joinType = sqlstmt.JoinKindNatural
|
||||
} else if p.Current().IsKeyword("INNER") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
if !p.Current().IsKeyword("JOIN") {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
p.Consume()
|
||||
|
||||
tableRef := p.parseTableRef()
|
||||
if tableRef.Name == "" {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
|
||||
var onExpr *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("ON") {
|
||||
p.Consume()
|
||||
onStart := p.Pos
|
||||
p.SkipExpr()
|
||||
onExpr = &sqlstmt.Expr{Text: p.TextFromExclusive(onStart)}
|
||||
} else if p.Current().IsKeyword("USING") {
|
||||
p.Consume()
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
}
|
||||
}
|
||||
|
||||
return &sqlstmt.JoinClause{
|
||||
Kind: joinType,
|
||||
Table: tableRef,
|
||||
On: onExpr,
|
||||
Text: p.TextFrom(start),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseOrderBy(stmt *sqlstmt.SelectStmt) {
|
||||
for !p.Current().IsEOF() {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.IsExprEnd() || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
itemStart := p.Pos
|
||||
p.skipOrderByItem()
|
||||
itemText := base.TrimTrailingComma(p.TextFromExclusive(itemStart))
|
||||
desc := false
|
||||
upper := strings.ToUpper(itemText)
|
||||
if strings.HasSuffix(upper, " DESC") {
|
||||
desc = true
|
||||
itemText = strings.TrimSpace(itemText[:len(itemText)-5])
|
||||
} else if strings.HasSuffix(upper, " ASC") {
|
||||
itemText = strings.TrimSpace(itemText[:len(itemText)-4])
|
||||
}
|
||||
stmt.OrderBy = append(stmt.OrderBy, sqlstmt.OrderByItem{
|
||||
Text: itemText,
|
||||
Desc: desc,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) skipOrderByItem() {
|
||||
for !p.Current().IsEOF() {
|
||||
if p.Current().Value == "," {
|
||||
break
|
||||
}
|
||||
if p.IsExprEnd() || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
continue
|
||||
}
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- LIMIT 解析 ----------
|
||||
|
||||
func (p *Parser) parseLimit() *sqlstmt.Limit {
|
||||
if !p.Current().IsKeyword("LIMIT") {
|
||||
return nil
|
||||
}
|
||||
start := p.Pos
|
||||
p.Consume()
|
||||
rowCount := 0
|
||||
offset := 0
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
rowCount = p.ParseInt(p.Consume().Value)
|
||||
}
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
offset = rowCount
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
rowCount = p.ParseInt(p.Consume().Value)
|
||||
}
|
||||
}
|
||||
if p.Current().IsKeyword("OFFSET") {
|
||||
p.Consume()
|
||||
if p.Current().Type == tokenizer.TokenNumber {
|
||||
offset = p.ParseInt(p.Consume().Value)
|
||||
}
|
||||
}
|
||||
return &sqlstmt.Limit{
|
||||
Text: p.TextFrom(start),
|
||||
Count: rowCount,
|
||||
Offset: offset,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- INSERT 解析 ----------
|
||||
|
||||
func (p *Parser) parseInsert() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("INTO") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
tableRef := p.parseTableRef()
|
||||
columns := []string{}
|
||||
|
||||
if p.Current().Value == "(" {
|
||||
p.Consume()
|
||||
for !p.Current().IsEOF() && p.Current().Value != ")" {
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().Type == tokenizer.TokenIdentifier {
|
||||
columns = append(columns, p.Unquote(p.Consume().Value))
|
||||
} else {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
if p.Current().Value == ")" {
|
||||
p.Consume()
|
||||
}
|
||||
}
|
||||
|
||||
valuesStart := p.Pos
|
||||
if p.Current().IsKeyword("VALUES") {
|
||||
p.Consume()
|
||||
for !p.Current().IsEOF() && !p.IsExprEnd() {
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
continue
|
||||
}
|
||||
p.Consume()
|
||||
}
|
||||
} else if p.Current().IsKeyword("SELECT") {
|
||||
p.parseSelect()
|
||||
}
|
||||
|
||||
return &sqlstmt.InsertStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Table: tableRef,
|
||||
Columns: columns,
|
||||
Values: p.TextFrom(valuesStart),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- UPDATE 解析 ----------
|
||||
|
||||
func (p *Parser) parseUpdate() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume()
|
||||
|
||||
tables := make([]sqlstmt.TableRef, 0)
|
||||
for !p.Current().IsEOF() {
|
||||
if p.Current().IsKeyword("SET") || p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().IsKeyword("JOIN") || p.IsJoinStart() {
|
||||
p.parseJoinClause()
|
||||
continue
|
||||
}
|
||||
tableRef := p.parseTableRef()
|
||||
if tableRef.Name != "" {
|
||||
tables = append(tables, tableRef)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assignments := make([]sqlstmt.Assignment, 0)
|
||||
if p.Current().IsKeyword("SET") {
|
||||
p.Consume()
|
||||
for !p.Current().IsEOF() {
|
||||
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
assign := p.parseAssignment()
|
||||
if assign != nil {
|
||||
assignments = append(assignments, *assign)
|
||||
}
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var where *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
// MySQL UPDATE 支持 ORDER BY, LIMIT
|
||||
if p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
p.SkipOrderByExpr()
|
||||
}
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
p.parseLimit()
|
||||
}
|
||||
|
||||
return &sqlstmt.UpdateStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Tables: tables,
|
||||
Set: assignments,
|
||||
Where: where,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseAssignment() *sqlstmt.Assignment {
|
||||
start := p.Pos
|
||||
colText := ""
|
||||
for !p.Current().IsEOF() && p.Current().Value != "=" && p.Current().Value != "," &&
|
||||
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
|
||||
colText += p.Consume().Value
|
||||
}
|
||||
if p.Current().Value != "=" {
|
||||
p.Pos = start
|
||||
return nil
|
||||
}
|
||||
p.Consume()
|
||||
|
||||
valStart := p.Pos
|
||||
for !p.Current().IsEOF() && p.Current().Value != "," &&
|
||||
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
|
||||
if p.Current().Value == "(" {
|
||||
p.SkipParentheses()
|
||||
continue
|
||||
}
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
return &sqlstmt.Assignment{
|
||||
Column: p.Unquote(strings.TrimSpace(colText)),
|
||||
Value: &sqlstmt.Expr{Text: p.TextFromExclusive(valStart)},
|
||||
Text: p.TextFrom(start),
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- DELETE 解析 ----------
|
||||
|
||||
func (p *Parser) parseDelete() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("FROM") {
|
||||
p.Consume()
|
||||
}
|
||||
|
||||
tables := make([]sqlstmt.TableRef, 0)
|
||||
for !p.Current().IsEOF() {
|
||||
if p.Current().IsKeyword("WHERE") || p.Current().IsKeyword("USING") ||
|
||||
p.Current().IsKeyword("ORDER") || p.Current().IsKeyword("LIMIT") ||
|
||||
p.Current().Value == ";" {
|
||||
break
|
||||
}
|
||||
if p.Current().Value == "," {
|
||||
p.Consume()
|
||||
continue
|
||||
}
|
||||
tableRef := p.parseTableRef()
|
||||
if tableRef.Name != "" {
|
||||
tables = append(tables, tableRef)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
var where *sqlstmt.Expr
|
||||
if p.Current().IsKeyword("WHERE") {
|
||||
p.Consume()
|
||||
whereStart := p.Pos
|
||||
p.SkipExpr()
|
||||
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
|
||||
}
|
||||
|
||||
// MySQL DELETE 支持 ORDER BY, LIMIT
|
||||
if p.Current().IsKeyword("ORDER") {
|
||||
p.Consume()
|
||||
if p.Current().IsKeyword("BY") {
|
||||
p.Consume()
|
||||
}
|
||||
p.SkipOrderByExpr()
|
||||
}
|
||||
if p.Current().IsKeyword("LIMIT") {
|
||||
p.parseLimit()
|
||||
}
|
||||
|
||||
return &sqlstmt.DeleteStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
Tables: tables,
|
||||
Where: where,
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- DDL 解析 ----------
|
||||
|
||||
func (p *Parser) parseCreate() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "CREATE",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseDrop() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "DROP",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseAlter() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "ALTER",
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Parser) parseGenericDdl() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.DdlStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
DdlKind: "DDL",
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- SHOW 解析(MySQL 特有)----------
|
||||
|
||||
func (p *Parser) parseShow() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.SelectStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- WITH 解析 ----------
|
||||
|
||||
func (p *Parser) parseWith() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.WithStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 通用语句解析 ----------
|
||||
|
||||
func (p *Parser) parseGenericStmt() sqlstmt.Stmt {
|
||||
start := p.Pos
|
||||
p.SkipToNextStatement()
|
||||
return &sqlstmt.OtherStmt{
|
||||
Base: sqlstmt.Base{Text: p.TextFrom(start)},
|
||||
}
|
||||
}
|
||||
211
server/internal/db/dbm/sqlparser/mysql/parser_dml_test.go
Normal file
211
server/internal/db/dbm/sqlparser/mysql/parser_dml_test.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== INSERT 测试 ==========
|
||||
|
||||
func TestInsertBasic(t *testing.T) {
|
||||
sql := "INSERT INTO users (name, age) VALUES ('John', 30)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt, ok := stmt.(*sqlstmt.InsertStmt)
|
||||
if !ok {
|
||||
t.Fatal("expected InsertStmt")
|
||||
}
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, insertStmt.GetText())
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertMultipleRows(t *testing.T) {
|
||||
sql := "INSERT INTO users (name, age) VALUES ('John', 30), ('Jane', 25), ('Bob', 35)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
insertStmt := stmt.(*sqlstmt.InsertStmt)
|
||||
if insertStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UPDATE 测试 ==========
|
||||
|
||||
func TestUpdateBasic(t *testing.T) {
|
||||
sql := "UPDATE users SET name = 'John' WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
|
||||
if updateStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
|
||||
}
|
||||
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateMultipleColumns(t *testing.T) {
|
||||
sql := "UPDATE users SET name = 'John', age = 30, email = 'john@example.com' WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateWithOrderByLimit(t *testing.T) {
|
||||
sql := "UPDATE users SET status = 0 WHERE status = 1 ORDER BY id LIMIT 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
updateStmt := stmt.(*sqlstmt.UpdateStmt)
|
||||
|
||||
if updateStmt.Where == nil || updateStmt.Where.Text != "status = 1" {
|
||||
t.Errorf("expected WHERE='status = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== DELETE 测试 ==========
|
||||
|
||||
func TestDeleteBasic(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE id = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
|
||||
if deleteStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, deleteStmt.GetText())
|
||||
}
|
||||
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
if deleteStmt.Where == nil || deleteStmt.Where.Text != "id = 1" {
|
||||
t.Errorf("expected WHERE='id = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteMultipleTables(t *testing.T) {
|
||||
sql := "DELETE t1, t2 FROM users t1 JOIN orders t2 ON t1.id = t2.user_id WHERE t1.status = 0"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
|
||||
if len(deleteStmt.Tables) < 1 {
|
||||
t.Errorf("expected at least 1 table")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteWithOrderByLimit(t *testing.T) {
|
||||
sql := "DELETE FROM users WHERE status = 0 ORDER BY id LIMIT 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
|
||||
|
||||
if deleteStmt.Where == nil || deleteStmt.Where.Text != "status = 0" {
|
||||
t.Errorf("expected WHERE='status = 0'")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== DDL 测试 ==========
|
||||
|
||||
func TestDDLCreate(t *testing.T) {
|
||||
sql := "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50), age INT)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
|
||||
if ddlStmt.DdlKind != "CREATE" {
|
||||
t.Errorf("expected DdlKind='CREATE', got '%s'", ddlStmt.DdlKind)
|
||||
}
|
||||
if ddlStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDDLDrop(t *testing.T) {
|
||||
sql := "DROP TABLE users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
|
||||
if ddlStmt.DdlKind != "DROP" {
|
||||
t.Errorf("expected DdlKind='DROP', got '%s'", ddlStmt.DdlKind)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDDLAlter(t *testing.T) {
|
||||
sql := "ALTER TABLE users ADD COLUMN email VARCHAR(255)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
ddlStmt := stmt.(*sqlstmt.DdlStmt)
|
||||
|
||||
if ddlStmt.DdlKind != "ALTER" {
|
||||
t.Errorf("expected DdlKind='ALTER', got '%s'", ddlStmt.DdlKind)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDDLTruncate(t *testing.T) {
|
||||
sql := "TRUNCATE TABLE users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
// TRUNCATE 可能返回 DdlStmt 或其他类型
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
t.Logf("TRUNCATE stmt type: %T", stmt)
|
||||
t.Logf("TRUNCATE text: %s", stmt.GetText())
|
||||
}
|
||||
322
server/internal/db/dbm/sqlparser/mysql/parser_pagination_test.go
Normal file
322
server/internal/db/dbm/sqlparser/mysql/parser_pagination_test.go
Normal file
@@ -0,0 +1,322 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== 简单分页测试 ==========
|
||||
|
||||
func TestMysqlPaginationSimple(t *testing.T) {
|
||||
sql := "SELECT * FROM users LIMIT 0, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
|
||||
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
|
||||
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlPaginationWithOffset(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users LIMIT 20, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlPaginationKeywordOffset(t *testing.T) {
|
||||
// MySQL 8.0+ 支持 LIMIT count OFFSET offset
|
||||
sql := "SELECT * FROM products LIMIT 15 OFFSET 30"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
|
||||
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
|
||||
|
||||
if selectStmt.Limit.Offset != 30 {
|
||||
t.Errorf("expected offset=30, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 15 {
|
||||
t.Errorf("expected count=15, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlPaginationOnlyLimit(t *testing.T) {
|
||||
sql := "SELECT * FROM orders LIMIT 50"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 50 {
|
||||
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 复杂分页测试 ==========
|
||||
|
||||
func TestMysqlPaginationWithWhere(t *testing.T) {
|
||||
sql := "SELECT id, name, email FROM users WHERE status = 1 AND age > 18 LIMIT 10, 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlPaginationWithOrderBy(t *testing.T) {
|
||||
sql := "SELECT * FROM users ORDER BY created_at DESC, id ASC LIMIT 0, 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.OrderBy) != 2 {
|
||||
t.Fatalf("expected 2 ORDER BY clauses")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlPaginationWithWhereOrderBy(t *testing.T) {
|
||||
sql := "SELECT id, name FROM users WHERE status = 'active' ORDER BY score DESC LIMIT 50, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 'active'" {
|
||||
t.Errorf("expected WHERE='status = 'active''")
|
||||
}
|
||||
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "score" {
|
||||
t.Errorf("expected ORDER BY score")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 50 {
|
||||
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlPaginationWithJoin(t *testing.T) {
|
||||
sql := "SELECT u.id, u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1 ORDER BY o.created_at DESC LIMIT 100, 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatal("expected 1 JOIN")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 100 {
|
||||
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UNION 分页测试 ==========
|
||||
|
||||
func TestMysqlPaginationWithUnion(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 0, 50"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatal("expected 1 UNION")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause for UNION")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 0 {
|
||||
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 50 {
|
||||
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlPaginationUnionAll(t *testing.T) {
|
||||
sql := "SELECT name FROM products UNION ALL SELECT name FROM services LIMIT 10, 30"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 1 {
|
||||
t.Fatal("expected 1 UNION ALL")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 30 {
|
||||
t.Errorf("expected count=30, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 子查询分页测试 ==========
|
||||
|
||||
func TestMysqlPaginationWithSubquery(t *testing.T) {
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1 ORDER BY id) AS tmp LIMIT 20, 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatal("expected 1 FROM table")
|
||||
}
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 20 {
|
||||
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMysqlNestedPagination(t *testing.T) {
|
||||
// 嵌套分页查询
|
||||
sql := "SELECT * FROM (SELECT * FROM users LIMIT 0, 100) AS tmp LIMIT 10, 5"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected outer LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected outer offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 5 {
|
||||
t.Errorf("expected outer count=5, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== 大数据量分页测试 ==========
|
||||
|
||||
func TestMysqlLargeOffsetPagination(t *testing.T) {
|
||||
sql := "SELECT * FROM logs ORDER BY id LIMIT 1000000, 100"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT clause")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 1000000 {
|
||||
t.Errorf("expected offset=1000000, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 100 {
|
||||
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
345
server/internal/db/dbm/sqlparser/mysql/parser_select_test.go
Normal file
345
server/internal/db/dbm/sqlparser/mysql/parser_select_test.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
// ========== 基础 SELECT 测试 ==========
|
||||
|
||||
func TestSelectBasic(t *testing.T) {
|
||||
sql := "-- 测试查询 sql \n SELECT `id`, name FROM users WHERE status = 1 ORDER BY id DESC LIMIT 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
|
||||
}
|
||||
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
if selectStmt.Items[0].Text != "id" {
|
||||
t.Errorf("expected item[0]='id', got '%s'", selectStmt.Items[0].Text)
|
||||
}
|
||||
if selectStmt.Items[1].Text != "name" {
|
||||
t.Errorf("expected item[1]='name', got '%s'", selectStmt.Items[1].Text)
|
||||
}
|
||||
|
||||
if len(selectStmt.From) != 1 || selectStmt.From[0].Name != "users" {
|
||||
t.Errorf("expected table='users'")
|
||||
}
|
||||
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 1" {
|
||||
t.Errorf("expected WHERE='status = 1'")
|
||||
}
|
||||
|
||||
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "id" || !selectStmt.OrderBy[0].Desc {
|
||||
t.Errorf("expected ORDER BY id DESC")
|
||||
}
|
||||
|
||||
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 10" || selectStmt.Limit.Count != 10 {
|
||||
t.Errorf("expected LIMIT 10")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectDistinct(t *testing.T) {
|
||||
sql := "SELECT DISTINCT id, name FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if !selectStmt.Distinct {
|
||||
t.Error("expected Distinct=true")
|
||||
}
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectStar(t *testing.T) {
|
||||
sql := "SELECT * FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Items) != 1 || selectStmt.Items[0].Text != "*" {
|
||||
t.Errorf("expected SELECT *")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectTableStar(t *testing.T) {
|
||||
sql := "SELECT u.*, o.amount FROM users u, orders o"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items")
|
||||
}
|
||||
if selectStmt.Items[0].Text != "u.*" {
|
||||
t.Errorf("expected item[0]='u.*', got '%s'", selectStmt.Items[0].Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectWithAlias(t *testing.T) {
|
||||
sql := "SELECT id AS user_id, name AS user_name FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Items) != 2 {
|
||||
t.Fatalf("expected 2 items")
|
||||
}
|
||||
if selectStmt.Items[0].Alias != "user_id" {
|
||||
t.Errorf("expected alias='user_id', got '%s'", selectStmt.Items[0].Alias)
|
||||
}
|
||||
if selectStmt.Items[1].Alias != "user_name" {
|
||||
t.Errorf("expected alias='user_name', got '%s'", selectStmt.Items[1].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectMultipleTables(t *testing.T) {
|
||||
sql := "SELECT * FROM users, orders, products"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.From) != 3 {
|
||||
t.Fatalf("expected 3 tables, got %d", len(selectStmt.From))
|
||||
}
|
||||
if selectStmt.From[0].Name != "users" || selectStmt.From[1].Name != "orders" || selectStmt.From[2].Name != "products" {
|
||||
t.Errorf("expected tables: users, orders, products")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== JOIN 测试 ==========
|
||||
|
||||
func TestSelectWithJoin(t *testing.T) {
|
||||
sql := "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join")
|
||||
}
|
||||
if selectStmt.Joins[0].Table.Name != "orders" {
|
||||
t.Errorf("expected join table='orders'")
|
||||
}
|
||||
if selectStmt.Joins[0].On == nil || selectStmt.Joins[0].On.Text != "u.id = o.user_id" {
|
||||
t.Errorf("expected ON='u.id = o.user_id'")
|
||||
}
|
||||
if selectStmt.Where == nil || selectStmt.Where.Text != "u.status = 1" {
|
||||
t.Errorf("expected WHERE='u.status = 1'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectInnerJoin(t *testing.T) {
|
||||
sql := "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectRightJoin(t *testing.T) {
|
||||
sql := "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectCrossJoin(t *testing.T) {
|
||||
sql := "SELECT * FROM users CROSS JOIN roles"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Joins) != 1 {
|
||||
t.Fatalf("expected 1 join")
|
||||
}
|
||||
}
|
||||
|
||||
// ========== WHERE 测试 ==========
|
||||
|
||||
func TestSelectComplexWhere(t *testing.T) {
|
||||
sql := "SELECT * FROM users WHERE status = 1 AND (age > 18 OR role = 'admin') ORDER BY name"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE")
|
||||
}
|
||||
expectedWhere := "status = 1 AND (age > 18 OR role = 'admin')"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== LIMIT/OFFSET 测试 ==========
|
||||
|
||||
func TestSelectLimitOffset(t *testing.T) {
|
||||
sql := "SELECT * FROM users LIMIT 10, 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if selectStmt.Limit == nil {
|
||||
t.Fatal("expected LIMIT")
|
||||
}
|
||||
if selectStmt.Limit.Text != "LIMIT 10, 20" {
|
||||
t.Errorf("expected LIMIT text='LIMIT 10, 20'")
|
||||
}
|
||||
if selectStmt.Limit.Offset != 10 {
|
||||
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
|
||||
}
|
||||
if selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
|
||||
}
|
||||
}
|
||||
|
||||
// ========== UNION 测试 ==========
|
||||
|
||||
func TestUnionSelect(t *testing.T) {
|
||||
sql := "SELECT 1 UNION SELECT 2 UNION ALL SELECT 3"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 2 {
|
||||
t.Fatalf("expected 2 unions, got %d", len(selectStmt.Unions))
|
||||
}
|
||||
if selectStmt.Unions[0].All {
|
||||
t.Error("expected first union not ALL")
|
||||
}
|
||||
if !selectStmt.Unions[1].All {
|
||||
t.Error("expected second union ALL")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnionWithOrderByLimit(t *testing.T) {
|
||||
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 20"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "1" || !selectStmt.OrderBy[0].Desc {
|
||||
t.Errorf("expected ORDER BY 1 DESC")
|
||||
}
|
||||
if selectStmt.Limit == nil || selectStmt.Limit.Count != 20 {
|
||||
t.Errorf("expected LIMIT 20")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnionMultiple(t *testing.T) {
|
||||
sql := "SELECT id FROM t1 UNION SELECT id FROM t2 UNION ALL SELECT id FROM t3 UNION SELECT id FROM t4"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
if len(selectStmt.Unions) != 3 {
|
||||
t.Fatalf("expected 3 unions, got %d", len(selectStmt.Unions))
|
||||
}
|
||||
}
|
||||
|
||||
// ========== SHOW 测试 ==========
|
||||
|
||||
func TestShow(t *testing.T) {
|
||||
sql := "SHOW VARIABLES LIKE 'max_connections'"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
if stmt.GetText() != sql {
|
||||
t.Errorf("expected text='%s'", sql)
|
||||
}
|
||||
}
|
||||
|
||||
func TestShowTables(t *testing.T) {
|
||||
sql := "SHOW TABLES"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
}
|
||||
292
server/internal/db/dbm/sqlparser/mysql/parser_subquery_test.go
Normal file
292
server/internal/db/dbm/sqlparser/mysql/parser_subquery_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
)
|
||||
|
||||
func TestSubqueryInFrom(t *testing.T) {
|
||||
// FROM 子句中的子查询
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1) AS u WHERE u.id > 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
if stmt == nil {
|
||||
t.Fatalf("expected stmt not nil")
|
||||
}
|
||||
|
||||
selectStmt, ok := stmt.(*sqlstmt.SelectStmt)
|
||||
if !ok {
|
||||
t.Fatal("expected SelectStmt")
|
||||
}
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("FROM 表数量: %d", len(selectStmt.From))
|
||||
|
||||
// 验证 FROM 子句
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 table in FROM, got %d", len(selectStmt.From))
|
||||
}
|
||||
|
||||
// FROM 应该是子查询
|
||||
fromText := selectStmt.From[0].Name
|
||||
t.Logf("FROM[0] Name: %s", fromText)
|
||||
t.Logf("FROM[0] Alias: %s", selectStmt.From[0].Alias)
|
||||
|
||||
if selectStmt.From[0].Alias != "u" {
|
||||
t.Errorf("expected alias='u', got '%s'", selectStmt.From[0].Alias)
|
||||
}
|
||||
|
||||
// 验证外层 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected outer WHERE clause")
|
||||
}
|
||||
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
|
||||
if selectStmt.Where.Text != "u.id > 10" {
|
||||
t.Errorf("expected outer WHERE text='u.id > 10', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubqueryInWhere(t *testing.T) {
|
||||
// WHERE 子句中的子查询(IN)
|
||||
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE amount > 100)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
// WHERE 应该包含整个 IN 子句
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "id IN (SELECT user_id FROM orders WHERE amount > 100)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubqueryInSelect(t *testing.T) {
|
||||
// SELECT 列中的子查询
|
||||
sql := "SELECT id, name, (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) AS order_count FROM users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
|
||||
|
||||
for i, item := range selectStmt.Items {
|
||||
t.Logf(" Item[%d]: Text='%s', ColumnName='%s', Alias='%s'", i, item.Text, item.ColumnName, item.Alias)
|
||||
}
|
||||
|
||||
// 应该有 3 个 SELECT 项
|
||||
if len(selectStmt.Items) != 3 {
|
||||
t.Fatalf("expected 3 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
|
||||
// 第三个项应该是子查询
|
||||
if selectStmt.Items[2].Alias != "order_count" {
|
||||
t.Errorf("expected item[2] alias='order_count', got '%s'", selectStmt.Items[2].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedSubquery(t *testing.T) {
|
||||
// 嵌套子查询
|
||||
sql := "SELECT * FROM (SELECT * FROM (SELECT id FROM users WHERE status = 1) AS inner_u) AS outer_u WHERE outer_u.id > 5"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
// 验证外层 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected outer WHERE")
|
||||
}
|
||||
if selectStmt.Where.Text != "outer_u.id > 5" {
|
||||
t.Errorf("expected WHERE text='outer_u.id > 5', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubqueryWithUnion(t *testing.T) {
|
||||
// 子查询包含 UNION
|
||||
sql := "SELECT * FROM (SELECT id FROM users UNION SELECT id FROM admins) AS all_users WHERE all_users.id > 10"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("FROM 数量: %d", len(selectStmt.From))
|
||||
|
||||
// 验证外层 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
if selectStmt.Where.Text != "all_users.id > 10" {
|
||||
t.Errorf("expected WHERE text='all_users.id > 10', got '%s'", selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorrelatedSubquery(t *testing.T) {
|
||||
// 相关子查询
|
||||
sql := "SELECT u.id, u.name FROM users u WHERE u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
// WHERE 应该包含相关子查询
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubqueryWithExists(t *testing.T) {
|
||||
// EXISTS 子查询
|
||||
sql := "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleSubqueries(t *testing.T) {
|
||||
// 多个子查询
|
||||
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
expectedWhere := "id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
|
||||
if selectStmt.Where.Text != expectedWhere {
|
||||
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSubqueryWithLimit(t *testing.T) {
|
||||
// 子查询包含 LIMIT
|
||||
sql := "SELECT * FROM (SELECT id, name FROM users ORDER BY created_at DESC LIMIT 10) AS top_users"
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("FROM 数量: %d", len(selectStmt.From))
|
||||
|
||||
if len(selectStmt.From) != 1 {
|
||||
t.Fatalf("expected 1 table, got %d", len(selectStmt.From))
|
||||
}
|
||||
|
||||
if selectStmt.From[0].Alias != "top_users" {
|
||||
t.Errorf("expected alias='top_users', got '%s'", selectStmt.From[0].Alias)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComplexSubquery(t *testing.T) {
|
||||
// 复杂子查询场景
|
||||
sql := `SELECT
|
||||
u.id,
|
||||
u.name,
|
||||
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count,
|
||||
(SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id AND o.status = 'completed') AS total_amount
|
||||
FROM users u
|
||||
WHERE u.status = 1
|
||||
AND u.id IN (SELECT user_id FROM user_groups WHERE group_id = 5)
|
||||
ORDER BY u.name`
|
||||
|
||||
parser := NewParser(sql)
|
||||
stmt, err := parser.Parse()
|
||||
if err != nil {
|
||||
t.Fatalf("parse error: %v", err)
|
||||
}
|
||||
|
||||
selectStmt := stmt.(*sqlstmt.SelectStmt)
|
||||
|
||||
t.Logf("完整文本: %s", selectStmt.GetText())
|
||||
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
|
||||
|
||||
// 验证 SELECT 项
|
||||
if len(selectStmt.Items) != 4 {
|
||||
t.Fatalf("expected 4 items, got %d", len(selectStmt.Items))
|
||||
}
|
||||
|
||||
// 验证各个项
|
||||
t.Logf("Item[0]: %s", selectStmt.Items[0].Text)
|
||||
t.Logf("Item[1]: %s", selectStmt.Items[1].Text)
|
||||
t.Logf("Item[2]: %s (alias: %s)", selectStmt.Items[2].Text, selectStmt.Items[2].Alias)
|
||||
t.Logf("Item[3]: %s (alias: %s)", selectStmt.Items[3].Text, selectStmt.Items[3].Alias)
|
||||
|
||||
// 验证 WHERE
|
||||
if selectStmt.Where == nil {
|
||||
t.Fatal("expected WHERE clause")
|
||||
}
|
||||
t.Logf("WHERE: %s", selectStmt.Where.Text)
|
||||
|
||||
// 验证 ORDER BY
|
||||
if len(selectStmt.OrderBy) != 1 {
|
||||
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
|
||||
}
|
||||
if selectStmt.OrderBy[0].Text != "u.name" {
|
||||
t.Errorf("expected order by text='u.name', got '%s'", selectStmt.OrderBy[0].Text)
|
||||
}
|
||||
}
|
||||
@@ -1,984 +0,0 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
mysqlparser "mayfly-go/internal/db/dbm/sqlparser/mysql/antlr4"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
type MysqlVisitor struct {
|
||||
*mysqlparser.BaseMySqlParserVisitor
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitRoot(ctx *mysqlparser.RootContext) interface{} {
|
||||
stms := ctx.SqlStatements()
|
||||
if stms != nil {
|
||||
return stms.Accept(v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSqlStatements(ctx *mysqlparser.SqlStatementsContext) interface{} {
|
||||
allSqlStatement := ctx.AllSqlStatement()
|
||||
stmts := make([]sqlstmt.Stmt, 0)
|
||||
for _, sqlStatement := range allSqlStatement {
|
||||
stmts = append(stmts, sqlStatement.Accept(v).(sqlstmt.Stmt))
|
||||
}
|
||||
return stmts
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSqlStatement(ctx *mysqlparser.SqlStatementContext) interface{} {
|
||||
if c := ctx.DmlStatement(); c != nil {
|
||||
return ctx.DmlStatement().Accept(v)
|
||||
}
|
||||
if c := ctx.DdlStatement(); c != nil {
|
||||
return ctx.DdlStatement().Accept(v)
|
||||
}
|
||||
if c := ctx.AdministrationStatement(); c != nil {
|
||||
return c.Accept(v)
|
||||
}
|
||||
if c := ctx.UtilityStatement(); c != nil {
|
||||
return c.Accept(v)
|
||||
}
|
||||
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitEmptyStatement_(ctx *mysqlparser.EmptyStatement_Context) interface{} {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitDdlStatement(ctx *mysqlparser.DdlStatementContext) interface{} {
|
||||
ddlStmt := &sqlstmt.DdlStmt{}
|
||||
ddlStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ddlStmt
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitDmlStatement(ctx *mysqlparser.DmlStatementContext) interface{} {
|
||||
if ssc := ctx.SelectStatement(); ssc != nil {
|
||||
return ssc.Accept(v)
|
||||
}
|
||||
if withStmt := ctx.WithStatement(); withStmt != nil {
|
||||
return withStmt.Accept(v)
|
||||
}
|
||||
if usc := ctx.UpdateStatement(); usc != nil {
|
||||
return usc.Accept(v)
|
||||
}
|
||||
if dsc := ctx.DeleteStatement(); dsc != nil {
|
||||
return dsc.Accept(v)
|
||||
}
|
||||
if isc := ctx.InsertStatement(); isc != nil {
|
||||
return isc.Accept(v)
|
||||
}
|
||||
|
||||
dmlStmt := sqlstmt.DmlStmt{}
|
||||
dmlStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return dmlStmt
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitAdministrationStatement(ctx *mysqlparser.AdministrationStatementContext) interface{} {
|
||||
if ssc := ctx.ShowStatement(); ssc != nil {
|
||||
return ssc.Accept(v)
|
||||
}
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitUtilityStatement(ctx *mysqlparser.UtilityStatementContext) interface{} {
|
||||
if c := ctx.SimpleDescribeStatement(); c != nil {
|
||||
return c.Accept(v)
|
||||
}
|
||||
if c := ctx.FullDescribeStatement(); c != nil {
|
||||
return c.Accept(v)
|
||||
}
|
||||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitWithStatement(ctx *mysqlparser.WithStatementContext) interface{} {
|
||||
ort := new(sqlstmt.WithStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSimpleSelect(ctx *mysqlparser.SimpleSelectContext) interface{} {
|
||||
sss := new(sqlstmt.SimpleSelectStmt)
|
||||
sss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
sss.QuerySpecification = ctx.QuerySpecification().Accept(v).(*sqlstmt.QuerySpecification)
|
||||
return sss
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitUnionSelect(ctx *mysqlparser.UnionSelectContext) interface{} {
|
||||
uss := new(sqlstmt.UnionSelectStmt)
|
||||
uss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if lc := ctx.LimitClause(); lc != nil {
|
||||
uss.Limit = lc.Accept(v).(*sqlstmt.Limit)
|
||||
}
|
||||
|
||||
if ausc := ctx.AllUnionStatement(); ausc != nil {
|
||||
unionStmts := make([]*sqlstmt.UnionStmt, 0)
|
||||
for _, usc := range ausc {
|
||||
unionStmts = append(unionStmts, usc.Accept(v).(*sqlstmt.UnionStmt))
|
||||
}
|
||||
uss.UnionStmts = unionStmts
|
||||
}
|
||||
|
||||
if qsc := ctx.QuerySpecification(); qsc != nil {
|
||||
uss.QuerySpecification = qsc.Accept(v).(*sqlstmt.QuerySpecification)
|
||||
}
|
||||
if qscn := ctx.QuerySpecificationNointo(); qscn != nil {
|
||||
uss.QuerySpecification = qscn.Accept(v).(*sqlstmt.QuerySpecification)
|
||||
}
|
||||
if qec := ctx.QueryExpression(); qec != nil {
|
||||
uss.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
|
||||
}
|
||||
if qenc := ctx.QueryExpressionNointo(); qenc != nil {
|
||||
uss.QueryExpr = qenc.Accept(v).(*sqlstmt.QueryExpr)
|
||||
}
|
||||
|
||||
if ui := ctx.UNION(); ui != nil {
|
||||
uss.UnionType = ui.GetText()
|
||||
}
|
||||
|
||||
return uss
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitParenthesisSelect(ctx *mysqlparser.ParenthesisSelectContext) interface{} {
|
||||
ps := new(sqlstmt.ParenthesisSelect)
|
||||
ps.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if qec := ctx.QueryExpression(); qec != nil {
|
||||
ps.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
|
||||
}
|
||||
|
||||
return ps
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitUnionParenthesisSelect(ctx *mysqlparser.UnionParenthesisSelectContext) interface{} {
|
||||
ss := new(sqlstmt.SelectStmt)
|
||||
ss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ss
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitWithLateralStatement(ctx *mysqlparser.WithLateralStatementContext) interface{} {
|
||||
ss := new(sqlstmt.SelectStmt)
|
||||
ss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ss
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitUnionStatement(ctx *mysqlparser.UnionStatementContext) interface{} {
|
||||
us := new(sqlstmt.UnionStmt)
|
||||
us.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if qs := ctx.QuerySpecificationNointo(); qs != nil {
|
||||
us.QuerySpecification = qs.Accept(v).(*sqlstmt.QuerySpecification)
|
||||
}
|
||||
if qec := ctx.QueryExpressionNointo(); qec != nil {
|
||||
us.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
|
||||
}
|
||||
|
||||
if ui := ctx.UNION(); ui != nil {
|
||||
us.UnionType = ui.GetText()
|
||||
}
|
||||
|
||||
return us
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitQuerySpecification(ctx *mysqlparser.QuerySpecificationContext) interface{} {
|
||||
qs := new(sqlstmt.QuerySpecification)
|
||||
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
qs.SelectElements = ctx.SelectElements().Accept(v).(*sqlstmt.SelectElements)
|
||||
|
||||
if fromClause := ctx.FromClause(); fromClause != nil {
|
||||
where := fromClause.GetWhereExpr()
|
||||
if where != nil {
|
||||
qs.Where = v.GetExpr(where)
|
||||
}
|
||||
|
||||
tableSourcesCtx := fromClause.TableSources()
|
||||
if tableSourcesCtx != nil {
|
||||
tss := new(sqlstmt.TableSources)
|
||||
tss.Node = sqlstmt.NewNode(tableSourcesCtx.GetParser(), tableSourcesCtx)
|
||||
tss.TableSources = tableSourcesCtx.Accept(v).([]sqlstmt.ITableSource)
|
||||
qs.From = tss
|
||||
}
|
||||
}
|
||||
|
||||
if limitClause := ctx.LimitClause(); limitClause != nil {
|
||||
qs.Limit = limitClause.Accept(v).(*sqlstmt.Limit)
|
||||
}
|
||||
|
||||
return qs
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitQuerySpecificationNointo(ctx *mysqlparser.QuerySpecificationNointoContext) interface{} {
|
||||
qs := new(sqlstmt.QuerySpecification)
|
||||
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
qs.SelectElements = ctx.SelectElements().Accept(v).(*sqlstmt.SelectElements)
|
||||
|
||||
if fromClause := ctx.FromClause(); fromClause != nil {
|
||||
where := fromClause.GetWhereExpr()
|
||||
if where != nil {
|
||||
qs.Where = v.GetExpr(where)
|
||||
}
|
||||
|
||||
tableSourcesCtx := fromClause.TableSources()
|
||||
if tableSourcesCtx != nil {
|
||||
tss := new(sqlstmt.TableSources)
|
||||
tss.Node = sqlstmt.NewNode(tableSourcesCtx.GetParser(), tableSourcesCtx)
|
||||
tss.TableSources = tableSourcesCtx.Accept(v).([]sqlstmt.ITableSource)
|
||||
qs.From = tss
|
||||
}
|
||||
}
|
||||
|
||||
if limitClause := ctx.LimitClause(); limitClause != nil {
|
||||
qs.Limit = limitClause.Accept(v).(*sqlstmt.Limit)
|
||||
}
|
||||
|
||||
return qs
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitQueryExpression(ctx *mysqlparser.QueryExpressionContext) interface{} {
|
||||
qe := new(sqlstmt.QueryExpr)
|
||||
qe.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if qec := ctx.QueryExpression(); qec != nil {
|
||||
qe.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
|
||||
}
|
||||
|
||||
if qsc := ctx.QuerySpecification(); qsc != nil {
|
||||
qe.QuerySpecification = qsc.Accept(v).(*sqlstmt.QuerySpecification)
|
||||
}
|
||||
|
||||
return qe
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitQueryExpressionNointo(ctx *mysqlparser.QueryExpressionNointoContext) interface{} {
|
||||
qe := new(sqlstmt.QueryExpr)
|
||||
qe.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if qec := ctx.QueryExpressionNointo(); qec != nil {
|
||||
qe.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
|
||||
}
|
||||
|
||||
if qsc := ctx.QuerySpecificationNointo(); qsc != nil {
|
||||
qe.QuerySpecification = qsc.Accept(v).(*sqlstmt.QuerySpecification)
|
||||
}
|
||||
|
||||
return qe
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSelectElements(ctx *mysqlparser.SelectElementsContext) interface{} {
|
||||
ses := new(sqlstmt.SelectElements)
|
||||
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if ctx.STAR() != nil {
|
||||
ses.Star = ctx.STAR().GetText()
|
||||
}
|
||||
|
||||
eles := make([]sqlstmt.ISelectElement, 0)
|
||||
ase := ctx.AllSelectElement()
|
||||
for _, selectElement := range ase {
|
||||
eles = append(eles, selectElement.Accept(v).(sqlstmt.ISelectElement))
|
||||
}
|
||||
ses.Elements = eles
|
||||
|
||||
return ses
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSelectStarElement(ctx *mysqlparser.SelectStarElementContext) interface{} {
|
||||
sse := new(sqlstmt.SelectStarElement)
|
||||
sse.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
sse.FullId = ctx.FullId().GetText()
|
||||
return sse
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSelectColumnElement(ctx *mysqlparser.SelectColumnElementContext) interface{} {
|
||||
sce := new(sqlstmt.SelectColumnElement)
|
||||
sce.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
sce.ColumnName = ctx.FullColumnName().Accept(v).(*sqlstmt.ColumnName)
|
||||
if uid := ctx.Uid(); uid != nil {
|
||||
sce.Alias = uid.GetText()
|
||||
}
|
||||
return sce
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSelectFunctionElement(ctx *mysqlparser.SelectFunctionElementContext) interface{} {
|
||||
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return node
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSelectExpressionElement(ctx *mysqlparser.SelectExpressionElementContext) interface{} {
|
||||
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return node
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitTableSources(ctx *mysqlparser.TableSourcesContext) interface{} {
|
||||
tableSourcesCtx := ctx.AllTableSource()
|
||||
tableSources := make([]sqlstmt.ITableSource, 0)
|
||||
for _, tableSourceCtx := range tableSourcesCtx {
|
||||
tableSources = append(tableSources, tableSourceCtx.Accept(v).(sqlstmt.ITableSource))
|
||||
}
|
||||
return tableSources
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitTableSourceBase(ctx *mysqlparser.TableSourceBaseContext) interface{} {
|
||||
tsb := new(sqlstmt.TableSourceBase)
|
||||
tsb.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
tsb.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
|
||||
tsb.JoinParts = v.GetJoinParts(ctx.AllJoinPart())
|
||||
return tsb
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitAtomTableItem(ctx *mysqlparser.AtomTableItemContext) interface{} {
|
||||
tableSourceItem := new(sqlstmt.AtomTableItem)
|
||||
tableSourceItem.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
tableSourceItem.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
|
||||
|
||||
if alias := ctx.GetAlias(); alias != nil {
|
||||
tableSourceItem.Alias = alias.GetText()
|
||||
}
|
||||
|
||||
return tableSourceItem
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSubqueryTableItem(ctx *mysqlparser.SubqueryTableItemContext) interface{} {
|
||||
sti := new(sqlstmt.SubqueryTableItem)
|
||||
sti.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
// 解析子查询
|
||||
if ss := ctx.SelectStatement(); ss != nil {
|
||||
sti.SubQuery = ss.Accept(v).(sqlstmt.ISelectStmt)
|
||||
}
|
||||
|
||||
// 获取别名
|
||||
if alias := ctx.GetAlias(); alias != nil {
|
||||
sti.Alias = alias.GetText()
|
||||
} else if uid := ctx.Uid(); uid != nil {
|
||||
sti.Alias = uid.GetText()
|
||||
}
|
||||
|
||||
return sti
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitInnerJoin(ctx *mysqlparser.InnerJoinContext) interface{} {
|
||||
ij := new(sqlstmt.InnerJoin)
|
||||
ij.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
ij.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
|
||||
return ij
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitStraightJoin(ctx *mysqlparser.StraightJoinContext) interface{} {
|
||||
jp := new(sqlstmt.JoinPart)
|
||||
jp.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
jp.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
|
||||
return jp
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitOuterJoin(ctx *mysqlparser.OuterJoinContext) interface{} {
|
||||
oj := new(sqlstmt.OuterJoin)
|
||||
oj.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
oj.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
|
||||
return oj
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitNaturalJoin(ctx *mysqlparser.NaturalJoinContext) interface{} {
|
||||
nj := new(sqlstmt.NaturalJoin)
|
||||
nj.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
nj.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
|
||||
return nj
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitJoinSpec(ctx *mysqlparser.JoinSpecContext) interface{} {
|
||||
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return node
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitIsExpression(ctx *mysqlparser.IsExpressionContext) interface{} {
|
||||
e := new(sqlstmt.Expr)
|
||||
e.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return e
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitNotExpression(ctx *mysqlparser.NotExpressionContext) interface{} {
|
||||
e := new(sqlstmt.Expr)
|
||||
e.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return e
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitLogicalExpression(ctx *mysqlparser.LogicalExpressionContext) interface{} {
|
||||
le := new(sqlstmt.ExprLogical)
|
||||
le.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
le.Operator = ctx.LogicalOperator().GetText()
|
||||
le.Exprs = v.GetExprs(ctx.AllExpression())
|
||||
return le
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitPredicateExpression(ctx *mysqlparser.PredicateExpressionContext) interface{} {
|
||||
ep := new(sqlstmt.ExprPredicate)
|
||||
ep.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
ep.Predicate = ctx.Predicate().Accept(v).(sqlstmt.IPredicate)
|
||||
return ep
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSoundsLikePredicate(ctx *mysqlparser.SoundsLikePredicateContext) interface{} {
|
||||
e := new(sqlstmt.Expr)
|
||||
e.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return e
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitExpressionAtomPredicate(ctx *mysqlparser.ExpressionAtomPredicateContext) interface{} {
|
||||
pea := new(sqlstmt.PredicateExprAtom)
|
||||
pea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
pea.ExprAtom = ctx.ExpressionAtom().Accept(v).(sqlstmt.IExprAtom)
|
||||
return pea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitLogicalOperator(ctx *mysqlparser.LogicalOperatorContext) interface{} {
|
||||
return ctx.GetText()
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitBinaryComparisonPredicate(ctx *mysqlparser.BinaryComparisonPredicateContext) interface{} {
|
||||
bcp := new(sqlstmt.PredicateBinaryComparison)
|
||||
bcp.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
bcp.Left = ctx.GetLeft().Accept(v).(sqlstmt.IPredicate)
|
||||
bcp.Right = ctx.GetRight().Accept(v).(sqlstmt.IPredicate)
|
||||
bcp.ComparisonOperator = ctx.ComparisonOperator().Accept(v).(string)
|
||||
return bcp
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitInPredicate(ctx *mysqlparser.InPredicateContext) interface{} {
|
||||
inPredicate := new(sqlstmt.PredicateIn)
|
||||
inPredicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
if pc := ctx.Predicate(); pc != nil {
|
||||
inPredicate.InPredicate = pc.Accept(v).(sqlstmt.IPredicate)
|
||||
}
|
||||
if ssc := ctx.SelectStatement(); ssc != nil {
|
||||
inPredicate.SelectStmt = ssc.Accept(v).(sqlstmt.ISelectStmt)
|
||||
}
|
||||
if ec := ctx.Expressions(); ec != nil {
|
||||
inPredicate.Exprs = v.GetExprs(ec.AllExpression())
|
||||
}
|
||||
|
||||
return inPredicate
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitBetweenPredicate(ctx *mysqlparser.BetweenPredicateContext) interface{} {
|
||||
predicate := new(sqlstmt.PredicateLike)
|
||||
predicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return predicate
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitIsNullPredicate(ctx *mysqlparser.IsNullPredicateContext) interface{} {
|
||||
predicate := new(sqlstmt.PredicateLike)
|
||||
predicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return predicate
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitLikePredicate(ctx *mysqlparser.LikePredicateContext) interface{} {
|
||||
likePredicate := new(sqlstmt.PredicateLike)
|
||||
likePredicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return likePredicate
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitRegexpPredicate(ctx *mysqlparser.RegexpPredicateContext) interface{} {
|
||||
predicate := new(sqlstmt.PredicateLike)
|
||||
predicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return predicate
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitUnaryExpressionAtom(ctx *mysqlparser.UnaryExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitCollateExpressionAtom(ctx *mysqlparser.CollateExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitVariableAssignExpressionAtom(ctx *mysqlparser.VariableAssignExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitMysqlVariableExpressionAtom(ctx *mysqlparser.MysqlVariableExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitNestedExpressionAtom(ctx *mysqlparser.NestedExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitNestedRowExpressionAtom(ctx *mysqlparser.NestedRowExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitMathExpressionAtom(ctx *mysqlparser.MathExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitExistsExpressionAtom(ctx *mysqlparser.ExistsExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitIntervalExpressionAtom(ctx *mysqlparser.IntervalExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitJsonExpressionAtom(ctx *mysqlparser.JsonExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSubqueryExpressionAtom(ctx *mysqlparser.SubqueryExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitConstantExpressionAtom(ctx *mysqlparser.ConstantExpressionAtomContext) interface{} {
|
||||
constExprAtom := new(sqlstmt.ExprAtomConstant)
|
||||
constExprAtom.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
constExprAtom.Constant = ctx.Constant().Accept(v).(*sqlstmt.Constant)
|
||||
return constExprAtom
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitFunctionCallExpressionAtom(ctx *mysqlparser.FunctionCallExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitBinaryExpressionAtom(ctx *mysqlparser.BinaryExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitFullColumnNameExpressionAtom(ctx *mysqlparser.FullColumnNameExpressionAtomContext) interface{} {
|
||||
eacn := new(sqlstmt.ExprAtomColumnName)
|
||||
eacn.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
eacn.ColumnName = ctx.FullColumnName().Accept(v).(*sqlstmt.ColumnName)
|
||||
return eacn
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitBitExpressionAtom(ctx *mysqlparser.BitExpressionAtomContext) interface{} {
|
||||
ea := new(sqlstmt.ExprAtom)
|
||||
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ea
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitComparisonOperator(ctx *mysqlparser.ComparisonOperatorContext) interface{} {
|
||||
return ctx.GetText()
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitTableName(ctx *mysqlparser.TableNameContext) interface{} {
|
||||
tableName := new(sqlstmt.TableName)
|
||||
tableName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
fullId := ctx.FullId().Accept(v).(*sqlstmt.FullId)
|
||||
|
||||
if uids := fullId.Uids; len(uids) == 1 {
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(uids[0])
|
||||
} else {
|
||||
tableName.Owner = uids[0]
|
||||
tableName.Identifier = sqlstmt.NewIdentifierValue(uids[1])
|
||||
}
|
||||
|
||||
return tableName
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitFullId(ctx *mysqlparser.FullIdContext) interface{} {
|
||||
fid := new(sqlstmt.FullId)
|
||||
fid.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
uids := make([]string, 0)
|
||||
for _, uid := range ctx.AllUid() {
|
||||
uids = append(uids, uid.GetText())
|
||||
}
|
||||
|
||||
if did := ctx.DOT_ID(); did != nil {
|
||||
uids = append(uids, strings.TrimPrefix(did.GetText(), "."))
|
||||
}
|
||||
|
||||
fid.Uids = uids
|
||||
|
||||
return fid
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitRoleName(ctx *mysqlparser.RoleNameContext) interface{} {
|
||||
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return node
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitFullColumnName(ctx *mysqlparser.FullColumnNameContext) interface{} {
|
||||
fullColumnName := new(sqlstmt.ColumnName)
|
||||
fullColumnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
adis := ctx.AllDottedId()
|
||||
// 不存在.则直接取标识符
|
||||
if len(adis) == 0 {
|
||||
fullColumnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Uid().GetText())
|
||||
} else {
|
||||
fullColumnName.Owner = ctx.Uid().GetText()
|
||||
fullColumnName.Identifier = sqlstmt.NewIdentifierValue(adis[0].GetText())
|
||||
}
|
||||
|
||||
return fullColumnName
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitIndexColumnName(ctx *mysqlparser.IndexColumnNameContext) interface{} {
|
||||
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return node
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitConstant(ctx *mysqlparser.ConstantContext) interface{} {
|
||||
constant := new(sqlstmt.Constant)
|
||||
constant.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
constant.Value = ctx.GetText()
|
||||
return constant
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitLimitClause(ctx *mysqlparser.LimitClauseContext) interface{} {
|
||||
limit := new(sqlstmt.Limit)
|
||||
limit.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
if lc := ctx.GetLimit(); lc != nil {
|
||||
limit.RowCount = cast.ToInt(lc.GetText())
|
||||
}
|
||||
if oc := ctx.GetOffset(); oc != nil {
|
||||
limit.Offset = cast.ToInt(oc.GetText())
|
||||
}
|
||||
|
||||
return limit
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitInsertStatement(ctx *mysqlparser.InsertStatementContext) interface{} {
|
||||
is := new(sqlstmt.InsertStmt)
|
||||
is.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
is.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
|
||||
return is
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitUpdateStatement(ctx *mysqlparser.UpdateStatementContext) interface{} {
|
||||
if sus := ctx.SingleUpdateStatement(); sus != nil {
|
||||
return sus.Accept(v)
|
||||
}
|
||||
if mus := ctx.MultipleUpdateStatement(); mus != nil {
|
||||
return mus.Accept(v)
|
||||
}
|
||||
|
||||
us := new(sqlstmt.UpdateStmt)
|
||||
us.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return us
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSingleUpdateStatement(ctx *mysqlparser.SingleUpdateStatementContext) interface{} {
|
||||
sus := new(sqlstmt.UpdateStmt)
|
||||
sus.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
tss := new(sqlstmt.TableSources)
|
||||
tss.Node = sqlstmt.NewNode(ctx.TableName().GetParser(), ctx.TableName())
|
||||
atomTable := new(sqlstmt.AtomTableItem)
|
||||
atomTable.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
|
||||
if uid := ctx.Uid(); uid != nil {
|
||||
atomTable.Alias = uid.GetText()
|
||||
}
|
||||
|
||||
tableSourceBase := new(sqlstmt.TableSourceBase)
|
||||
tableSourceBase.Node = tss.Node
|
||||
tableSourceBase.TableSourceItem = atomTable
|
||||
tss.TableSources = []sqlstmt.ITableSource{tableSourceBase}
|
||||
|
||||
sus.TableSources = tss
|
||||
|
||||
if aucs := ctx.AllUpdatedElement(); aucs != nil {
|
||||
ues := make([]*sqlstmt.UpdatedElement, 0)
|
||||
for _, auc := range aucs {
|
||||
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
|
||||
}
|
||||
sus.UpdatedElements = ues
|
||||
}
|
||||
|
||||
if ec := ctx.Expression(); ec != nil {
|
||||
sus.Where = v.GetExpr(ec)
|
||||
}
|
||||
|
||||
return sus
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitMultipleUpdateStatement(ctx *mysqlparser.MultipleUpdateStatementContext) interface{} {
|
||||
mus := new(sqlstmt.UpdateStmt)
|
||||
mus.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if tssc := ctx.TableSources(); tssc != nil {
|
||||
tss := new(sqlstmt.TableSources)
|
||||
tss.Node = sqlstmt.NewNode(tssc.GetParser(), tssc)
|
||||
tss.TableSources = tssc.Accept(v).([]sqlstmt.ITableSource)
|
||||
mus.TableSources = tss
|
||||
}
|
||||
|
||||
if aucs := ctx.AllUpdatedElement(); aucs != nil {
|
||||
ues := make([]*sqlstmt.UpdatedElement, 0)
|
||||
for _, auc := range aucs {
|
||||
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
|
||||
}
|
||||
mus.UpdatedElements = ues
|
||||
}
|
||||
|
||||
if ec := ctx.Expression(); ec != nil {
|
||||
mus.Where = v.GetExpr(ec)
|
||||
}
|
||||
|
||||
return mus
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitUpdatedElement(ctx *mysqlparser.UpdatedElementContext) interface{} {
|
||||
ue := new(sqlstmt.UpdatedElement)
|
||||
ue.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
ue.ColumnName = ctx.FullColumnName().Accept(v).(*sqlstmt.ColumnName)
|
||||
ue.Value = v.GetExpr(ctx.Expression())
|
||||
return ue
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitDeleteStatement(ctx *mysqlparser.DeleteStatementContext) interface{} {
|
||||
if sus := ctx.SingleDeleteStatement(); sus != nil {
|
||||
return sus.Accept(v)
|
||||
}
|
||||
if mus := ctx.MultipleDeleteStatement(); mus != nil {
|
||||
return mus.Accept(v)
|
||||
}
|
||||
|
||||
ds := new(sqlstmt.DeleteStmt)
|
||||
ds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ds
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSingleDeleteStatement(ctx *mysqlparser.SingleDeleteStatementContext) interface{} {
|
||||
ds := new(sqlstmt.DeleteStmt)
|
||||
ds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
tss := new(sqlstmt.TableSources)
|
||||
tss.Node = sqlstmt.NewNode(ctx.TableName().GetParser(), ctx.TableName())
|
||||
atomTable := new(sqlstmt.AtomTableItem)
|
||||
atomTable.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
|
||||
if uid := ctx.Uid(); uid != nil {
|
||||
atomTable.Alias = uid.GetText()
|
||||
}
|
||||
|
||||
tableSourceBase := new(sqlstmt.TableSourceBase)
|
||||
tableSourceBase.Node = tss.Node
|
||||
tableSourceBase.TableSourceItem = atomTable
|
||||
tss.TableSources = []sqlstmt.ITableSource{tableSourceBase}
|
||||
|
||||
ds.TableSources = tss
|
||||
|
||||
if ec := ctx.Expression(); ec != nil {
|
||||
ds.Where = v.GetExpr(ec)
|
||||
}
|
||||
|
||||
return ds
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitMultipleDeleteStatement(ctx *mysqlparser.MultipleDeleteStatementContext) interface{} {
|
||||
ds := new(sqlstmt.DeleteStmt)
|
||||
ds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
|
||||
if tssc := ctx.TableSources(); tssc != nil {
|
||||
tss := new(sqlstmt.TableSources)
|
||||
tss.Node = sqlstmt.NewNode(tssc.GetParser(), tssc)
|
||||
tss.TableSources = tssc.Accept(v).([]sqlstmt.ITableSource)
|
||||
ds.TableSources = tss
|
||||
}
|
||||
|
||||
if ec := ctx.Expression(); ec != nil {
|
||||
ds.Where = v.GetExpr(ec)
|
||||
}
|
||||
|
||||
return ds
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitSimpleDescribeStatement(ctx *mysqlparser.SimpleDescribeStatementContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitFullDescribeStatement(ctx *mysqlparser.FullDescribeStatementContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowMasterLogs(ctx *mysqlparser.ShowMasterLogsContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowLogEvents(ctx *mysqlparser.ShowLogEventsContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowObjectFilter(ctx *mysqlparser.ShowObjectFilterContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowColumns(ctx *mysqlparser.ShowColumnsContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowCreateDb(ctx *mysqlparser.ShowCreateDbContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowCreateFullIdObject(ctx *mysqlparser.ShowCreateFullIdObjectContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowCreateUser(ctx *mysqlparser.ShowCreateUserContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowEngine(ctx *mysqlparser.ShowEngineContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowGlobalInfo(ctx *mysqlparser.ShowGlobalInfoContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowErrors(ctx *mysqlparser.ShowErrorsContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowCountErrors(ctx *mysqlparser.ShowCountErrorsContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowSchemaFilter(ctx *mysqlparser.ShowSchemaFilterContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowRoutine(ctx *mysqlparser.ShowRoutineContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowGrants(ctx *mysqlparser.ShowGrantsContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowIndexes(ctx *mysqlparser.ShowIndexesContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowOpenTables(ctx *mysqlparser.ShowOpenTablesContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowProfile(ctx *mysqlparser.ShowProfileContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitShowSlaveStatus(ctx *mysqlparser.ShowSlaveStatusContext) interface{} {
|
||||
ort := new(sqlstmt.OtherReadStmt)
|
||||
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return ort
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) VisitCreateDatabase(ctx *mysqlparser.CreateDatabaseContext) interface{} {
|
||||
cds := new(sqlstmt.CreateDatabase)
|
||||
cds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||||
return cds
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) GetTableSourceItem(ctx mysqlparser.ITableSourceItemContext) sqlstmt.ITableSourceItem {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
return ctx.Accept(v).(sqlstmt.ITableSourceItem)
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) GetExpr(ctx mysqlparser.IExpressionContext) sqlstmt.IExpr {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return ctx.Accept(v).(sqlstmt.IExpr)
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) GetExprs(ctxs []mysqlparser.IExpressionContext) []sqlstmt.IExpr {
|
||||
if ctxs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
exprs := make([]sqlstmt.IExpr, 0)
|
||||
for _, exprCtx := range ctxs {
|
||||
exprs = append(exprs, exprCtx.Accept(v).(sqlstmt.IExpr))
|
||||
}
|
||||
return exprs
|
||||
}
|
||||
|
||||
func (v *MysqlVisitor) GetJoinParts(ctxs []mysqlparser.IJoinPartContext) []sqlstmt.IJoinPart {
|
||||
if ctxs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
joinPorts := make([]sqlstmt.IJoinPart, 0)
|
||||
for _, joinPartCtx := range ctxs {
|
||||
joinPorts = append(joinPorts, joinPartCtx.Accept(v).(sqlstmt.IJoinPart))
|
||||
}
|
||||
return joinPorts
|
||||
}
|
||||
Reference in New Issue
Block a user