refactor: 移除antlr4减小包体积&ai助手优化

This commit is contained in:
meilin.huang
2026-05-08 20:45:13 +08:00
parent 3768cef62d
commit f23b243fc5
154 changed files with 13054 additions and 396804 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -1 +0,0 @@
java -jar antlr-4.13.1-complete.jar -Dlanguage=Go -package parser -visitor *.g4

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

View File

@@ -1,59 +1,17 @@
package mysql
import (
"mayfly-go/internal/db/dbm/sqlparser/base"
mysqlparser "mayfly-go/internal/db/dbm/sqlparser/mysql/antlr4"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"mayfly-go/pkg/gox"
"mayfly-go/pkg/logx"
"github.com/antlr4-go/antlr/v4"
)
func GetMysqlParserTree(baseLine int, statement string) (antlr.ParseTree, *antlr.CommonTokenStream, error) {
lexer := mysqlparser.NewMySqlLexer(antlr.NewInputStream(statement))
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
parser := mysqlparser.NewMySqlParser(stream)
lexerErrorListener := &base.ParseErrorListener{
BaseLine: baseLine,
}
lexer.RemoveErrorListeners()
lexer.AddErrorListener(lexerErrorListener)
parserErrorListener := &base.ParseErrorListener{
BaseLine: baseLine,
}
parser.RemoveErrorListeners()
parser.AddErrorListener(parserErrorListener)
parser.BuildParseTrees = true
tree := parser.Root()
if lexerErrorListener.Err != nil {
return nil, nil, lexerErrorListener.Err
}
if parserErrorListener.Err != nil {
return nil, nil, parserErrorListener.Err
}
return tree, stream, nil
}
type MysqlParser struct {
}
func (*MysqlParser) Parse(stmt string) (stmts []sqlstmt.Stmt, err error) {
defer func() {
if e := recover(); e != nil {
logx.ErrorTrace("mysql sql parser err: ", e)
err = e.(error)
}
}()
tree, _, err := GetMysqlParserTree(1, stmt)
if err != nil {
return nil, err
}
return tree.Accept(new(MysqlVisitor)).([]sqlstmt.Stmt), nil
func (*MysqlParser) Parse(stmt string) (sqlstmt.Stmt, error) {
defer gox.Recover(func(e error) {
logx.ErrorTrace("mysql sql parser err: ", e)
})
return NewParser(stmt).Parse()
}

View File

@@ -1,122 +0,0 @@
package mysql
import (
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"testing"
)
func TestParserSimpleSelect(t *testing.T) {
parser := new(MysqlParser)
// sql := "select sum(t.age), t.`id` tid, t1.id2, t1.* from T_DB t left join t_db_ins as t1 on t.id = t1.id2 where t.id = 1 AND t1.status=0 and t.id2='9' and t.name in ('name2', 'name3') order by t.id desc limit 0, 100"
sql := "SELECT t.* FROM `t_sys_resource` t WHERE t.`id` > 0"
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}
func TestParserUnionSelect(t *testing.T) {
parser := new(MysqlParser)
sql := "(select sum(t.age), t.id tid, t1.id2, t1.* from T_DB t join t_db_ins as t1 on t.id = t1.id2 where t.id = 1 AND t1.status=0 and t.id2='9' and t.name in ('name2', 'name3') order by t.id desc limit 0, 100) union all (select * from t_db2) limit 10"
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}
func TestParserSingleUpdate(t *testing.T) {
parser := new(MysqlParser)
sql := `UPDATE t_sys_msg t
SET
t.recipient_id = 13,
t.creator = 'admin4'
WHERE
t.id = 1;`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}
func TestParserInsert(t *testing.T) {
parser := new(MysqlParser)
sql := `INSERT INTO
mayfly_go.t_sys_msg (
type,
msg,
recipient_id,
creator_id,
create_time,
is_deleted
)
VALUES
(1, 'hahaha', 2, 1, '2024-08-26 15:36:27', 0);`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}
func TestParserSql(t *testing.T) {
parser := new(MysqlParser)
// sql := `INSERT INTO
// mayfly_go.t_sys_msg (
// type,
// msg,
// recipient_id,
// creator_id,
// create_time,
// is_deleted
// )
// VALUES
// (1, 'hahaha', 2, 1, '2024-08-26 15:36:27', 0);`
sql := `UPDATE t_sys_msg
SET
recipient_id = 13,
creator = 'admin4'
WHERE
id = 1;`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
switch stmt := stmts[0].(type) {
case *sqlstmt.InsertStmt:
t.Log("insert")
t.Log(stmt.TableName.Identifier.Value)
case *sqlstmt.UpdateStmt:
t.Log("update")
case *sqlstmt.DeleteStmt:
t.Log("delete")
case *sqlstmt.SelectStmt:
t.Log("select")
default:
t.Log("other")
}
t.Log(stmts)
}
func TestParserDelete(t *testing.T) {
parser := new(MysqlParser)
sql := `DELETE FROM t_sys_log
WHERE
id IN (59);`
stmts, err := parser.Parse(sql)
if err != nil {
t.Fatal(err)
}
t.Log(stmts)
}

View File

@@ -0,0 +1,767 @@
package mysql
import (
"strings"
"mayfly-go/internal/db/dbm/sqlparser/base"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"mayfly-go/internal/db/dbm/sqlparser/tokenizer"
)
// Parser MySQL 方言 SQL 解析器
type Parser struct {
*base.Lexer
}
// NewParser 创建 MySQL 解析器
func NewParser(sql string) *Parser {
return &Parser{
Lexer: base.NewLexer(sql, tokenizer.DialectConfig{
BacktickAsIdentifier: true,
HashLineComment: true,
}),
}
}
// Parse 解析单条 SQL 语句
func (p *Parser) Parse() (sqlstmt.Stmt, error) {
p.SkipSemicolons()
if p.Current().IsEOF() {
return nil, nil
}
stmt := p.parseStatement()
return stmt, nil
}
func (p *Parser) parseStatement() sqlstmt.Stmt {
tok := p.Current()
switch {
case tok.IsKeyword("SELECT") || tok.Value == "(":
return p.parseSelect()
case tok.IsKeyword("INSERT"):
return p.parseInsert()
case tok.IsKeyword("UPDATE"):
return p.parseUpdate()
case tok.IsKeyword("DELETE"):
return p.parseDelete()
case tok.IsKeyword("CREATE"):
return p.parseCreate()
case tok.IsKeyword("DROP"):
return p.parseDrop()
case tok.IsKeyword("ALTER"):
return p.parseAlter()
case tok.IsKeyword("WITH"):
return p.parseWith()
case tok.IsKeyword("SHOW"):
return p.parseShow()
case tok.IsKeyword("TRUNCATE"):
return p.parseGenericDdl()
default:
return p.parseGenericStmt()
}
}
// ---------- SELECT 解析 ----------
func (p *Parser) parseSelect() sqlstmt.Stmt {
start := p.Pos
var selectStmt *sqlstmt.SelectStmt
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
selectStmt = p.parseSelectBody()
} else {
p.SkipToNextStatement()
return &sqlstmt.OtherStmt{Base: sqlstmt.Base{Text: p.TextFrom(start)}}
}
p.ExpectValue(")")
} else if p.Current().IsKeyword("SELECT") {
selectStmt = p.parseSelectBody()
} else {
return &sqlstmt.OtherStmt{Base: sqlstmt.Base{Text: p.TextFrom(start)}}
}
if selectStmt == nil {
return &sqlstmt.OtherStmt{Base: sqlstmt.Base{Text: p.TextFrom(start)}}
}
// UNION
for p.Current().IsKeyword("UNION") {
p.Consume()
isAll := false
if p.Current().IsKeyword("ALL") {
p.Consume()
isAll = true
} else if p.Current().IsKeyword("DISTINCT") {
p.Consume()
}
unionStart := p.Pos
var unionSelect *sqlstmt.SelectStmt
if p.Current().Value == "(" {
p.Consume()
if p.Current().IsKeyword("SELECT") {
unionSelect = p.parseSelectBody()
}
p.ExpectValue(")")
} else if p.Current().IsKeyword("SELECT") {
unionSelect = p.parseSelectBody()
}
if unionSelect != nil {
// 如果 unionSelect 有 LIMIT 或 ORDER BY移动到外层 selectStmtUNION 的 LIMIT/ORDER BY 属于整个语句)
if unionSelect.Limit != nil {
selectStmt.Limit = unionSelect.Limit
unionSelect.Limit = nil
}
if len(unionSelect.OrderBy) > 0 {
selectStmt.OrderBy = unionSelect.OrderBy
unionSelect.OrderBy = nil
}
unionSelect.Text = p.TextFrom(unionStart)
}
selectStmt.Unions = append(selectStmt.Unions, sqlstmt.UnionClause{
Select: unionSelect,
All: isAll,
Text: p.TextFrom(unionStart),
})
}
// UNION 之后可能有 ORDER BY
if p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
p.parseOrderBy(selectStmt)
}
// LIMIT (UNION 之后的 LIMIT 覆盖子查询内部的 LIMIT)
// 注意:只在确实解析到新 LIMIT 时才覆盖,避免 nil 覆盖有效值
if p.Current().IsKeyword("LIMIT") {
if limit := p.parseLimit(); limit != nil {
selectStmt.Limit = limit
}
}
selectStmt.Text = p.TextFrom(start)
return selectStmt
}
func (p *Parser) parseSelectBody() *sqlstmt.SelectStmt {
if !p.Current().IsKeyword("SELECT") {
return nil
}
start := p.Pos
p.Consume()
distinct := false
if p.Current().IsKeyword("DISTINCT") {
p.Consume()
distinct = true
} else if p.Current().IsKeyword("ALL") {
p.Consume()
}
items := p.parseSelectItems()
stmt := &sqlstmt.SelectStmt{
Distinct: distinct,
Items: items,
}
if p.Current().IsKeyword("FROM") {
p.Consume()
p.parseFromClause(stmt)
}
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
stmt.Where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
if p.Current().IsKeyword("GROUP") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
p.SkipGroupByExpr()
}
if p.Current().IsKeyword("HAVING") {
p.Consume()
p.SkipExpr()
}
if p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
p.parseOrderBy(stmt)
}
// LIMIT (子查询或简单 SELECT 的 LIMIT)
if p.Current().IsKeyword("LIMIT") {
stmt.Limit = p.parseLimit()
}
stmt.Text = p.TextFrom(start)
return stmt
}
func (p *Parser) parseSelectItems() []sqlstmt.SelectItem {
items := make([]sqlstmt.SelectItem, 0)
for !p.Current().IsEOF() {
tok := p.Current()
if tok.Value == "," {
p.Consume()
continue
}
if p.IsSelectClauseEnd() {
break
}
if tok.Value == "*" {
p.Consume()
items = append(items, sqlstmt.SelectItem{
Kind: sqlstmt.SelectItemStar,
Text: "*",
})
} else if tok.Type == tokenizer.TokenIdentifier && p.Peek(1).Value == "." && p.Peek(2).Value == "*" {
tableAlias := p.Unquote(p.Consume().Value)
p.Consume()
p.Consume()
items = append(items, sqlstmt.SelectItem{
Kind: sqlstmt.SelectItemStar,
Text: tableAlias + ".*",
TableAlias: tableAlias,
})
} else {
colStart := p.Pos
p.skipSelectElement()
colText := base.TrimTrailingComma(p.TextFromExclusive(colStart))
colName, alias := p.ExtractColumnAndAlias(colText)
kind := sqlstmt.SelectItemColumn
if strings.Contains(colText, "(") {
kind = sqlstmt.SelectItemExpr
}
items = append(items, sqlstmt.SelectItem{
Kind: kind,
Text: colText,
Alias: alias,
ColumnName: colName,
})
}
}
return items
}
func (p *Parser) skipSelectElement() {
for !p.Current().IsEOF() {
tok := p.Current()
if tok.Value == "," || p.IsSelectClauseEnd() {
break
}
if tok.Value == "(" {
p.SkipParentheses()
continue
}
p.Consume()
}
}
// ---------- FROM / TableRef 解析 ----------
func (p *Parser) parseFromClause(stmt *sqlstmt.SelectStmt) {
for !p.Current().IsEOF() {
if p.IsFromClauseEnd() {
break
}
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().IsKeyword("JOIN") || p.IsJoinStart() {
join := p.parseJoinClause()
if join != nil {
stmt.Joins = append(stmt.Joins, *join)
}
continue
}
tableRef := p.parseTableRef()
if tableRef.Name != "" {
stmt.From = append(stmt.From, tableRef)
} else {
break
}
}
}
func (p *Parser) parseTableRef() sqlstmt.TableRef {
if p.Current().Value == "(" {
start := p.Pos
p.SkipParentheses()
alias := ""
if p.Current().IsKeyword("AS") {
p.Consume()
}
if p.Current().Type == tokenizer.TokenIdentifier {
alias = p.Unquote(p.Consume().Value)
}
return sqlstmt.TableRef{
Name: p.TextFrom(start),
Alias: alias,
}
}
if p.Current().Type != tokenizer.TokenIdentifier && p.Current().Type != tokenizer.TokenString {
return sqlstmt.TableRef{}
}
ref := sqlstmt.TableRef{}
part1 := p.Consume().Value
if p.Current().Value == "." {
p.Consume()
if p.Current().Type == tokenizer.TokenIdentifier || p.Current().Type == tokenizer.TokenString {
part2 := p.Consume().Value
ref.Schema = p.Unquote(part1)
ref.Name = p.Unquote(part2)
} else {
ref.Name = p.Unquote(part1)
}
} else {
ref.Name = p.Unquote(part1)
}
if p.Current().IsKeyword("AS") {
p.Consume()
}
if p.Current().Type == tokenizer.TokenIdentifier {
ref.Alias = p.Unquote(p.Consume().Value)
}
return ref
}
func (p *Parser) parseJoinClause() *sqlstmt.JoinClause {
start := p.Pos
joinType := sqlstmt.JoinKindInner
if p.Current().IsKeyword("LEFT") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindLeft
} else if p.Current().IsKeyword("RIGHT") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindRight
} else if p.Current().IsKeyword("FULL") {
p.Consume()
if p.Current().IsKeyword("OUTER") {
p.Consume()
}
joinType = sqlstmt.JoinKindFull
} else if p.Current().IsKeyword("CROSS") {
p.Consume()
joinType = sqlstmt.JoinKindCross
} else if p.Current().IsKeyword("NATURAL") {
p.Consume()
joinType = sqlstmt.JoinKindNatural
} else if p.Current().IsKeyword("INNER") {
p.Consume()
}
if !p.Current().IsKeyword("JOIN") {
p.Pos = start
return nil
}
p.Consume()
tableRef := p.parseTableRef()
if tableRef.Name == "" {
p.Pos = start
return nil
}
var onExpr *sqlstmt.Expr
if p.Current().IsKeyword("ON") {
p.Consume()
onStart := p.Pos
p.SkipExpr()
onExpr = &sqlstmt.Expr{Text: p.TextFromExclusive(onStart)}
} else if p.Current().IsKeyword("USING") {
p.Consume()
if p.Current().Value == "(" {
p.SkipParentheses()
}
}
return &sqlstmt.JoinClause{
Kind: joinType,
Table: tableRef,
On: onExpr,
Text: p.TextFrom(start),
}
}
func (p *Parser) parseOrderBy(stmt *sqlstmt.SelectStmt) {
for !p.Current().IsEOF() {
if p.Current().Value == "," {
p.Consume()
continue
}
if p.IsExprEnd() || p.Current().Value == ";" {
break
}
itemStart := p.Pos
p.skipOrderByItem()
itemText := base.TrimTrailingComma(p.TextFromExclusive(itemStart))
desc := false
upper := strings.ToUpper(itemText)
if strings.HasSuffix(upper, " DESC") {
desc = true
itemText = strings.TrimSpace(itemText[:len(itemText)-5])
} else if strings.HasSuffix(upper, " ASC") {
itemText = strings.TrimSpace(itemText[:len(itemText)-4])
}
stmt.OrderBy = append(stmt.OrderBy, sqlstmt.OrderByItem{
Text: itemText,
Desc: desc,
})
}
}
func (p *Parser) skipOrderByItem() {
for !p.Current().IsEOF() {
if p.Current().Value == "," {
break
}
if p.IsExprEnd() || p.Current().Value == ";" {
break
}
if p.Current().Value == "(" {
p.SkipParentheses()
continue
}
p.Consume()
}
}
// ---------- LIMIT 解析 ----------
func (p *Parser) parseLimit() *sqlstmt.Limit {
if !p.Current().IsKeyword("LIMIT") {
return nil
}
start := p.Pos
p.Consume()
rowCount := 0
offset := 0
if p.Current().Type == tokenizer.TokenNumber {
rowCount = p.ParseInt(p.Consume().Value)
}
if p.Current().Value == "," {
p.Consume()
offset = rowCount
if p.Current().Type == tokenizer.TokenNumber {
rowCount = p.ParseInt(p.Consume().Value)
}
}
if p.Current().IsKeyword("OFFSET") {
p.Consume()
if p.Current().Type == tokenizer.TokenNumber {
offset = p.ParseInt(p.Consume().Value)
}
}
return &sqlstmt.Limit{
Text: p.TextFrom(start),
Count: rowCount,
Offset: offset,
}
}
// ---------- INSERT 解析 ----------
func (p *Parser) parseInsert() sqlstmt.Stmt {
start := p.Pos
p.Consume()
if p.Current().IsKeyword("INTO") {
p.Consume()
}
tableRef := p.parseTableRef()
columns := []string{}
if p.Current().Value == "(" {
p.Consume()
for !p.Current().IsEOF() && p.Current().Value != ")" {
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().Type == tokenizer.TokenIdentifier {
columns = append(columns, p.Unquote(p.Consume().Value))
} else {
p.Consume()
}
}
if p.Current().Value == ")" {
p.Consume()
}
}
valuesStart := p.Pos
if p.Current().IsKeyword("VALUES") {
p.Consume()
for !p.Current().IsEOF() && !p.IsExprEnd() {
if p.Current().Value == "(" {
p.SkipParentheses()
continue
}
p.Consume()
}
} else if p.Current().IsKeyword("SELECT") {
p.parseSelect()
}
return &sqlstmt.InsertStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Table: tableRef,
Columns: columns,
Values: p.TextFrom(valuesStart),
}
}
// ---------- UPDATE 解析 ----------
func (p *Parser) parseUpdate() sqlstmt.Stmt {
start := p.Pos
p.Consume()
tables := make([]sqlstmt.TableRef, 0)
for !p.Current().IsEOF() {
if p.Current().IsKeyword("SET") || p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
break
}
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().IsKeyword("JOIN") || p.IsJoinStart() {
p.parseJoinClause()
continue
}
tableRef := p.parseTableRef()
if tableRef.Name != "" {
tables = append(tables, tableRef)
} else {
break
}
}
assignments := make([]sqlstmt.Assignment, 0)
if p.Current().IsKeyword("SET") {
p.Consume()
for !p.Current().IsEOF() {
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
break
}
assign := p.parseAssignment()
if assign != nil {
assignments = append(assignments, *assign)
}
if p.Current().Value == "," {
p.Consume()
continue
}
if p.Current().IsKeyword("WHERE") || p.Current().Value == ";" {
break
}
}
}
var where *sqlstmt.Expr
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
// MySQL UPDATE 支持 ORDER BY, LIMIT
if p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
p.SkipOrderByExpr()
}
if p.Current().IsKeyword("LIMIT") {
p.parseLimit()
}
return &sqlstmt.UpdateStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Tables: tables,
Set: assignments,
Where: where,
}
}
func (p *Parser) parseAssignment() *sqlstmt.Assignment {
start := p.Pos
colText := ""
for !p.Current().IsEOF() && p.Current().Value != "=" && p.Current().Value != "," &&
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
colText += p.Consume().Value
}
if p.Current().Value != "=" {
p.Pos = start
return nil
}
p.Consume()
valStart := p.Pos
for !p.Current().IsEOF() && p.Current().Value != "," &&
!p.Current().IsKeyword("WHERE") && p.Current().Value != ";" {
if p.Current().Value == "(" {
p.SkipParentheses()
continue
}
p.Consume()
}
return &sqlstmt.Assignment{
Column: p.Unquote(strings.TrimSpace(colText)),
Value: &sqlstmt.Expr{Text: p.TextFromExclusive(valStart)},
Text: p.TextFrom(start),
}
}
// ---------- DELETE 解析 ----------
func (p *Parser) parseDelete() sqlstmt.Stmt {
start := p.Pos
p.Consume()
if p.Current().IsKeyword("FROM") {
p.Consume()
}
tables := make([]sqlstmt.TableRef, 0)
for !p.Current().IsEOF() {
if p.Current().IsKeyword("WHERE") || p.Current().IsKeyword("USING") ||
p.Current().IsKeyword("ORDER") || p.Current().IsKeyword("LIMIT") ||
p.Current().Value == ";" {
break
}
if p.Current().Value == "," {
p.Consume()
continue
}
tableRef := p.parseTableRef()
if tableRef.Name != "" {
tables = append(tables, tableRef)
} else {
break
}
}
var where *sqlstmt.Expr
if p.Current().IsKeyword("WHERE") {
p.Consume()
whereStart := p.Pos
p.SkipExpr()
where = &sqlstmt.Expr{Text: p.TextFromExclusive(whereStart)}
}
// MySQL DELETE 支持 ORDER BY, LIMIT
if p.Current().IsKeyword("ORDER") {
p.Consume()
if p.Current().IsKeyword("BY") {
p.Consume()
}
p.SkipOrderByExpr()
}
if p.Current().IsKeyword("LIMIT") {
p.parseLimit()
}
return &sqlstmt.DeleteStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
Tables: tables,
Where: where,
}
}
// ---------- DDL 解析 ----------
func (p *Parser) parseCreate() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "CREATE",
}
}
func (p *Parser) parseDrop() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "DROP",
}
}
func (p *Parser) parseAlter() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "ALTER",
}
}
func (p *Parser) parseGenericDdl() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.DdlStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
DdlKind: "DDL",
}
}
// ---------- SHOW 解析MySQL 特有)----------
func (p *Parser) parseShow() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.SelectStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
}
}
// ---------- WITH 解析 ----------
func (p *Parser) parseWith() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.WithStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
}
}
// ---------- 通用语句解析 ----------
func (p *Parser) parseGenericStmt() sqlstmt.Stmt {
start := p.Pos
p.SkipToNextStatement()
return &sqlstmt.OtherStmt{
Base: sqlstmt.Base{Text: p.TextFrom(start)},
}
}

View File

@@ -0,0 +1,211 @@
package mysql
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== INSERT 测试 ==========
func TestInsertBasic(t *testing.T) {
sql := "INSERT INTO users (name, age) VALUES ('John', 30)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt, ok := stmt.(*sqlstmt.InsertStmt)
if !ok {
t.Fatal("expected InsertStmt")
}
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, insertStmt.GetText())
}
}
func TestInsertMultipleRows(t *testing.T) {
sql := "INSERT INTO users (name, age) VALUES ('John', 30), ('Jane', 25), ('Bob', 35)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
insertStmt := stmt.(*sqlstmt.InsertStmt)
if insertStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
// ========== UPDATE 测试 ==========
func TestUpdateBasic(t *testing.T) {
sql := "UPDATE users SET name = 'John' WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, updateStmt.GetText())
}
if len(updateStmt.Tables) != 1 || updateStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
func TestUpdateMultipleColumns(t *testing.T) {
sql := "UPDATE users SET name = 'John', age = 30, email = 'john@example.com' WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil || updateStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
func TestUpdateWithOrderByLimit(t *testing.T) {
sql := "UPDATE users SET status = 0 WHERE status = 1 ORDER BY id LIMIT 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
updateStmt := stmt.(*sqlstmt.UpdateStmt)
if updateStmt.Where == nil || updateStmt.Where.Text != "status = 1" {
t.Errorf("expected WHERE='status = 1'")
}
}
// ========== DELETE 测试 ==========
func TestDeleteBasic(t *testing.T) {
sql := "DELETE FROM users WHERE id = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, deleteStmt.GetText())
}
if len(deleteStmt.Tables) != 1 || deleteStmt.Tables[0].Name != "users" {
t.Errorf("expected table='users'")
}
if deleteStmt.Where == nil || deleteStmt.Where.Text != "id = 1" {
t.Errorf("expected WHERE='id = 1'")
}
}
func TestDeleteMultipleTables(t *testing.T) {
sql := "DELETE t1, t2 FROM users t1 JOIN orders t2 ON t1.id = t2.user_id WHERE t1.status = 0"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if len(deleteStmt.Tables) < 1 {
t.Errorf("expected at least 1 table")
}
}
func TestDeleteWithOrderByLimit(t *testing.T) {
sql := "DELETE FROM users WHERE status = 0 ORDER BY id LIMIT 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
deleteStmt := stmt.(*sqlstmt.DeleteStmt)
if deleteStmt.Where == nil || deleteStmt.Where.Text != "status = 0" {
t.Errorf("expected WHERE='status = 0'")
}
}
// ========== DDL 测试 ==========
func TestDDLCreate(t *testing.T) {
sql := "CREATE TABLE users (id INT PRIMARY KEY, name VARCHAR(50), age INT)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "CREATE" {
t.Errorf("expected DdlKind='CREATE', got '%s'", ddlStmt.DdlKind)
}
if ddlStmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestDDLDrop(t *testing.T) {
sql := "DROP TABLE users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "DROP" {
t.Errorf("expected DdlKind='DROP', got '%s'", ddlStmt.DdlKind)
}
}
func TestDDLAlter(t *testing.T) {
sql := "ALTER TABLE users ADD COLUMN email VARCHAR(255)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
ddlStmt := stmt.(*sqlstmt.DdlStmt)
if ddlStmt.DdlKind != "ALTER" {
t.Errorf("expected DdlKind='ALTER', got '%s'", ddlStmt.DdlKind)
}
}
func TestDDLTruncate(t *testing.T) {
sql := "TRUNCATE TABLE users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
// TRUNCATE 可能返回 DdlStmt 或其他类型
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
t.Logf("TRUNCATE stmt type: %T", stmt)
t.Logf("TRUNCATE text: %s", stmt.GetText())
}

View File

@@ -0,0 +1,322 @@
package mysql
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== 简单分页测试 ==========
func TestMysqlPaginationSimple(t *testing.T) {
sql := "SELECT * FROM users LIMIT 0, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlPaginationWithOffset(t *testing.T) {
sql := "SELECT id, name FROM users LIMIT 20, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlPaginationKeywordOffset(t *testing.T) {
// MySQL 8.0+ 支持 LIMIT count OFFSET offset
sql := "SELECT * FROM products LIMIT 15 OFFSET 30"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
t.Logf("LIMIT text: %s", selectStmt.Limit.Text)
t.Logf("Offset: %d, Count: %d", selectStmt.Limit.Offset, selectStmt.Limit.Count)
if selectStmt.Limit.Offset != 30 {
t.Errorf("expected offset=30, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 15 {
t.Errorf("expected count=15, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlPaginationOnlyLimit(t *testing.T) {
sql := "SELECT * FROM orders LIMIT 50"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 50 {
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
}
}
// ========== 复杂分页测试 ==========
func TestMysqlPaginationWithWhere(t *testing.T) {
sql := "SELECT id, name, email FROM users WHERE status = 1 AND age > 18 LIMIT 10, 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlPaginationWithOrderBy(t *testing.T) {
sql := "SELECT * FROM users ORDER BY created_at DESC, id ASC LIMIT 0, 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.OrderBy) != 2 {
t.Fatalf("expected 2 ORDER BY clauses")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlPaginationWithWhereOrderBy(t *testing.T) {
sql := "SELECT id, name FROM users WHERE status = 'active' ORDER BY score DESC LIMIT 50, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 'active'" {
t.Errorf("expected WHERE='status = 'active''")
}
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "score" {
t.Errorf("expected ORDER BY score")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 50 {
t.Errorf("expected offset=50, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlPaginationWithJoin(t *testing.T) {
sql := "SELECT u.id, u.name, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1 ORDER BY o.created_at DESC LIMIT 100, 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatal("expected 1 JOIN")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 100 {
t.Errorf("expected offset=100, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
// ========== UNION 分页测试 ==========
func TestMysqlPaginationWithUnion(t *testing.T) {
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 0, 50"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatal("expected 1 UNION")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause for UNION")
}
if selectStmt.Limit.Offset != 0 {
t.Errorf("expected offset=0, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 50 {
t.Errorf("expected count=50, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlPaginationUnionAll(t *testing.T) {
sql := "SELECT name FROM products UNION ALL SELECT name FROM services LIMIT 10, 30"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 1 {
t.Fatal("expected 1 UNION ALL")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 30 {
t.Errorf("expected count=30, got %d", selectStmt.Limit.Count)
}
}
// ========== 子查询分页测试 ==========
func TestMysqlPaginationWithSubquery(t *testing.T) {
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1 ORDER BY id) AS tmp LIMIT 20, 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.From) != 1 {
t.Fatal("expected 1 FROM table")
}
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 20 {
t.Errorf("expected offset=20, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 10 {
t.Errorf("expected count=10, got %d", selectStmt.Limit.Count)
}
}
func TestMysqlNestedPagination(t *testing.T) {
// 嵌套分页查询
sql := "SELECT * FROM (SELECT * FROM users LIMIT 0, 100) AS tmp LIMIT 10, 5"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected outer LIMIT clause")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected outer offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 5 {
t.Errorf("expected outer count=5, got %d", selectStmt.Limit.Count)
}
}
// ========== 大数据量分页测试 ==========
func TestMysqlLargeOffsetPagination(t *testing.T) {
sql := "SELECT * FROM logs ORDER BY id LIMIT 1000000, 100"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT clause")
}
if selectStmt.Limit.Offset != 1000000 {
t.Errorf("expected offset=1000000, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 100 {
t.Errorf("expected count=100, got %d", selectStmt.Limit.Count)
}
}

View File

@@ -0,0 +1,345 @@
package mysql
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
// ========== 基础 SELECT 测试 ==========
func TestSelectBasic(t *testing.T) {
sql := "-- 测试查询 sql \n SELECT `id`, name FROM users WHERE status = 1 ORDER BY id DESC LIMIT 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.GetText() != sql {
t.Errorf("expected text='%s', got '%s'", sql, selectStmt.GetText())
}
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items, got %d", len(selectStmt.Items))
}
if selectStmt.Items[0].Text != "id" {
t.Errorf("expected item[0]='id', got '%s'", selectStmt.Items[0].Text)
}
if selectStmt.Items[1].Text != "name" {
t.Errorf("expected item[1]='name', got '%s'", selectStmt.Items[1].Text)
}
if len(selectStmt.From) != 1 || selectStmt.From[0].Name != "users" {
t.Errorf("expected table='users'")
}
if selectStmt.Where == nil || selectStmt.Where.Text != "status = 1" {
t.Errorf("expected WHERE='status = 1'")
}
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "id" || !selectStmt.OrderBy[0].Desc {
t.Errorf("expected ORDER BY id DESC")
}
if selectStmt.Limit == nil || selectStmt.Limit.Text != "LIMIT 10" || selectStmt.Limit.Count != 10 {
t.Errorf("expected LIMIT 10")
}
}
func TestSelectDistinct(t *testing.T) {
sql := "SELECT DISTINCT id, name FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if !selectStmt.Distinct {
t.Error("expected Distinct=true")
}
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items")
}
}
func TestSelectStar(t *testing.T) {
sql := "SELECT * FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 1 || selectStmt.Items[0].Text != "*" {
t.Errorf("expected SELECT *")
}
}
func TestSelectTableStar(t *testing.T) {
sql := "SELECT u.*, o.amount FROM users u, orders o"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items")
}
if selectStmt.Items[0].Text != "u.*" {
t.Errorf("expected item[0]='u.*', got '%s'", selectStmt.Items[0].Text)
}
}
func TestSelectWithAlias(t *testing.T) {
sql := "SELECT id AS user_id, name AS user_name FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Items) != 2 {
t.Fatalf("expected 2 items")
}
if selectStmt.Items[0].Alias != "user_id" {
t.Errorf("expected alias='user_id', got '%s'", selectStmt.Items[0].Alias)
}
if selectStmt.Items[1].Alias != "user_name" {
t.Errorf("expected alias='user_name', got '%s'", selectStmt.Items[1].Alias)
}
}
func TestSelectMultipleTables(t *testing.T) {
sql := "SELECT * FROM users, orders, products"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.From) != 3 {
t.Fatalf("expected 3 tables, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Name != "users" || selectStmt.From[1].Name != "orders" || selectStmt.From[2].Name != "products" {
t.Errorf("expected tables: users, orders, products")
}
}
// ========== JOIN 测试 ==========
func TestSelectWithJoin(t *testing.T) {
sql := "SELECT u.id, o.amount FROM users u LEFT JOIN orders o ON u.id = o.user_id WHERE u.status = 1"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join")
}
if selectStmt.Joins[0].Table.Name != "orders" {
t.Errorf("expected join table='orders'")
}
if selectStmt.Joins[0].On == nil || selectStmt.Joins[0].On.Text != "u.id = o.user_id" {
t.Errorf("expected ON='u.id = o.user_id'")
}
if selectStmt.Where == nil || selectStmt.Where.Text != "u.status = 1" {
t.Errorf("expected WHERE='u.status = 1'")
}
}
func TestSelectInnerJoin(t *testing.T) {
sql := "SELECT * FROM users u INNER JOIN orders o ON u.id = o.user_id"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join")
}
}
func TestSelectRightJoin(t *testing.T) {
sql := "SELECT * FROM users u RIGHT JOIN orders o ON u.id = o.user_id"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join")
}
}
func TestSelectCrossJoin(t *testing.T) {
sql := "SELECT * FROM users CROSS JOIN roles"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Joins) != 1 {
t.Fatalf("expected 1 join")
}
}
// ========== WHERE 测试 ==========
func TestSelectComplexWhere(t *testing.T) {
sql := "SELECT * FROM users WHERE status = 1 AND (age > 18 OR role = 'admin') ORDER BY name"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Where == nil {
t.Fatal("expected WHERE")
}
expectedWhere := "status = 1 AND (age > 18 OR role = 'admin')"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
// ========== LIMIT/OFFSET 测试 ==========
func TestSelectLimitOffset(t *testing.T) {
sql := "SELECT * FROM users LIMIT 10, 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if selectStmt.Limit == nil {
t.Fatal("expected LIMIT")
}
if selectStmt.Limit.Text != "LIMIT 10, 20" {
t.Errorf("expected LIMIT text='LIMIT 10, 20'")
}
if selectStmt.Limit.Offset != 10 {
t.Errorf("expected offset=10, got %d", selectStmt.Limit.Offset)
}
if selectStmt.Limit.Count != 20 {
t.Errorf("expected count=20, got %d", selectStmt.Limit.Count)
}
}
// ========== UNION 测试 ==========
func TestUnionSelect(t *testing.T) {
sql := "SELECT 1 UNION SELECT 2 UNION ALL SELECT 3"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 2 {
t.Fatalf("expected 2 unions, got %d", len(selectStmt.Unions))
}
if selectStmt.Unions[0].All {
t.Error("expected first union not ALL")
}
if !selectStmt.Unions[1].All {
t.Error("expected second union ALL")
}
}
func TestUnionWithOrderByLimit(t *testing.T) {
sql := "SELECT id FROM users UNION SELECT id FROM admins ORDER BY 1 DESC LIMIT 20"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.OrderBy) != 1 || selectStmt.OrderBy[0].Text != "1" || !selectStmt.OrderBy[0].Desc {
t.Errorf("expected ORDER BY 1 DESC")
}
if selectStmt.Limit == nil || selectStmt.Limit.Count != 20 {
t.Errorf("expected LIMIT 20")
}
}
func TestUnionMultiple(t *testing.T) {
sql := "SELECT id FROM t1 UNION SELECT id FROM t2 UNION ALL SELECT id FROM t3 UNION SELECT id FROM t4"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
if len(selectStmt.Unions) != 3 {
t.Fatalf("expected 3 unions, got %d", len(selectStmt.Unions))
}
}
// ========== SHOW 测试 ==========
func TestShow(t *testing.T) {
sql := "SHOW VARIABLES LIKE 'max_connections'"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
if stmt.GetText() != sql {
t.Errorf("expected text='%s'", sql)
}
}
func TestShowTables(t *testing.T) {
sql := "SHOW TABLES"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
}

View File

@@ -0,0 +1,292 @@
package mysql
import (
"testing"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
)
func TestSubqueryInFrom(t *testing.T) {
// FROM 子句中的子查询
sql := "SELECT * FROM (SELECT id, name FROM users WHERE status = 1) AS u WHERE u.id > 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
if stmt == nil {
t.Fatalf("expected stmt not nil")
}
selectStmt, ok := stmt.(*sqlstmt.SelectStmt)
if !ok {
t.Fatal("expected SelectStmt")
}
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("FROM 表数量: %d", len(selectStmt.From))
// 验证 FROM 子句
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 table in FROM, got %d", len(selectStmt.From))
}
// FROM 应该是子查询
fromText := selectStmt.From[0].Name
t.Logf("FROM[0] Name: %s", fromText)
t.Logf("FROM[0] Alias: %s", selectStmt.From[0].Alias)
if selectStmt.From[0].Alias != "u" {
t.Errorf("expected alias='u', got '%s'", selectStmt.From[0].Alias)
}
// 验证外层 WHERE
if selectStmt.Where == nil {
t.Fatal("expected outer WHERE clause")
}
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where.Text != "u.id > 10" {
t.Errorf("expected outer WHERE text='u.id > 10', got '%s'", selectStmt.Where.Text)
}
}
func TestSubqueryInWhere(t *testing.T) {
// WHERE 子句中的子查询IN
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders WHERE amount > 100)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
// WHERE 应该包含整个 IN 子句
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "id IN (SELECT user_id FROM orders WHERE amount > 100)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestSubqueryInSelect(t *testing.T) {
// SELECT 列中的子查询
sql := "SELECT id, name, (SELECT COUNT(*) FROM orders WHERE orders.user_id = users.id) AS order_count FROM users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
for i, item := range selectStmt.Items {
t.Logf(" Item[%d]: Text='%s', ColumnName='%s', Alias='%s'", i, item.Text, item.ColumnName, item.Alias)
}
// 应该有 3 个 SELECT 项
if len(selectStmt.Items) != 3 {
t.Fatalf("expected 3 items, got %d", len(selectStmt.Items))
}
// 第三个项应该是子查询
if selectStmt.Items[2].Alias != "order_count" {
t.Errorf("expected item[2] alias='order_count', got '%s'", selectStmt.Items[2].Alias)
}
}
func TestNestedSubquery(t *testing.T) {
// 嵌套子查询
sql := "SELECT * FROM (SELECT * FROM (SELECT id FROM users WHERE status = 1) AS inner_u) AS outer_u WHERE outer_u.id > 5"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("外层 WHERE: %s", selectStmt.Where.Text)
// 验证外层 WHERE
if selectStmt.Where == nil {
t.Fatal("expected outer WHERE")
}
if selectStmt.Where.Text != "outer_u.id > 5" {
t.Errorf("expected WHERE text='outer_u.id > 5', got '%s'", selectStmt.Where.Text)
}
}
func TestSubqueryWithUnion(t *testing.T) {
// 子查询包含 UNION
sql := "SELECT * FROM (SELECT id FROM users UNION SELECT id FROM admins) AS all_users WHERE all_users.id > 10"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("FROM 数量: %d", len(selectStmt.From))
// 验证外层 WHERE
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
if selectStmt.Where.Text != "all_users.id > 10" {
t.Errorf("expected WHERE text='all_users.id > 10', got '%s'", selectStmt.Where.Text)
}
}
func TestCorrelatedSubquery(t *testing.T) {
// 相关子查询
sql := "SELECT u.id, u.name FROM users u WHERE u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
// WHERE 应该包含相关子查询
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "u.id IN (SELECT o.user_id FROM orders o WHERE o.user_id = u.id AND o.amount > 50)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestSubqueryWithExists(t *testing.T) {
// EXISTS 子查询
sql := "SELECT * FROM users WHERE EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "EXISTS (SELECT 1 FROM orders WHERE orders.user_id = users.id)"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestMultipleSubqueries(t *testing.T) {
// 多个子查询
sql := "SELECT * FROM users WHERE id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("WHERE: %s", selectStmt.Where.Text)
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
expectedWhere := "id IN (SELECT user_id FROM orders) AND department_id IN (SELECT id FROM departments WHERE name = 'IT')"
if selectStmt.Where.Text != expectedWhere {
t.Errorf("expected WHERE text='%s', got '%s'", expectedWhere, selectStmt.Where.Text)
}
}
func TestSubqueryWithLimit(t *testing.T) {
// 子查询包含 LIMIT
sql := "SELECT * FROM (SELECT id, name FROM users ORDER BY created_at DESC LIMIT 10) AS top_users"
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("FROM 数量: %d", len(selectStmt.From))
if len(selectStmt.From) != 1 {
t.Fatalf("expected 1 table, got %d", len(selectStmt.From))
}
if selectStmt.From[0].Alias != "top_users" {
t.Errorf("expected alias='top_users', got '%s'", selectStmt.From[0].Alias)
}
}
func TestComplexSubquery(t *testing.T) {
// 复杂子查询场景
sql := `SELECT
u.id,
u.name,
(SELECT COUNT(*) FROM orders o WHERE o.user_id = u.id) AS order_count,
(SELECT SUM(amount) FROM orders o WHERE o.user_id = u.id AND o.status = 'completed') AS total_amount
FROM users u
WHERE u.status = 1
AND u.id IN (SELECT user_id FROM user_groups WHERE group_id = 5)
ORDER BY u.name`
parser := NewParser(sql)
stmt, err := parser.Parse()
if err != nil {
t.Fatalf("parse error: %v", err)
}
selectStmt := stmt.(*sqlstmt.SelectStmt)
t.Logf("完整文本: %s", selectStmt.GetText())
t.Logf("SELECT 项数量: %d", len(selectStmt.Items))
// 验证 SELECT 项
if len(selectStmt.Items) != 4 {
t.Fatalf("expected 4 items, got %d", len(selectStmt.Items))
}
// 验证各个项
t.Logf("Item[0]: %s", selectStmt.Items[0].Text)
t.Logf("Item[1]: %s", selectStmt.Items[1].Text)
t.Logf("Item[2]: %s (alias: %s)", selectStmt.Items[2].Text, selectStmt.Items[2].Alias)
t.Logf("Item[3]: %s (alias: %s)", selectStmt.Items[3].Text, selectStmt.Items[3].Alias)
// 验证 WHERE
if selectStmt.Where == nil {
t.Fatal("expected WHERE clause")
}
t.Logf("WHERE: %s", selectStmt.Where.Text)
// 验证 ORDER BY
if len(selectStmt.OrderBy) != 1 {
t.Fatalf("expected 1 order by, got %d", len(selectStmt.OrderBy))
}
if selectStmt.OrderBy[0].Text != "u.name" {
t.Errorf("expected order by text='u.name', got '%s'", selectStmt.OrderBy[0].Text)
}
}

View File

@@ -1,984 +0,0 @@
package mysql
import (
"strings"
mysqlparser "mayfly-go/internal/db/dbm/sqlparser/mysql/antlr4"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"github.com/spf13/cast"
)
type MysqlVisitor struct {
*mysqlparser.BaseMySqlParserVisitor
}
func (v *MysqlVisitor) VisitRoot(ctx *mysqlparser.RootContext) interface{} {
stms := ctx.SqlStatements()
if stms != nil {
return stms.Accept(v)
}
return nil
}
func (v *MysqlVisitor) VisitSqlStatements(ctx *mysqlparser.SqlStatementsContext) interface{} {
allSqlStatement := ctx.AllSqlStatement()
stmts := make([]sqlstmt.Stmt, 0)
for _, sqlStatement := range allSqlStatement {
stmts = append(stmts, sqlStatement.Accept(v).(sqlstmt.Stmt))
}
return stmts
}
func (v *MysqlVisitor) VisitSqlStatement(ctx *mysqlparser.SqlStatementContext) interface{} {
if c := ctx.DmlStatement(); c != nil {
return ctx.DmlStatement().Accept(v)
}
if c := ctx.DdlStatement(); c != nil {
return ctx.DdlStatement().Accept(v)
}
if c := ctx.AdministrationStatement(); c != nil {
return c.Accept(v)
}
if c := ctx.UtilityStatement(); c != nil {
return c.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *MysqlVisitor) VisitEmptyStatement_(ctx *mysqlparser.EmptyStatement_Context) interface{} {
return ""
}
func (v *MysqlVisitor) VisitDdlStatement(ctx *mysqlparser.DdlStatementContext) interface{} {
ddlStmt := &sqlstmt.DdlStmt{}
ddlStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ddlStmt
}
func (v *MysqlVisitor) VisitDmlStatement(ctx *mysqlparser.DmlStatementContext) interface{} {
if ssc := ctx.SelectStatement(); ssc != nil {
return ssc.Accept(v)
}
if withStmt := ctx.WithStatement(); withStmt != nil {
return withStmt.Accept(v)
}
if usc := ctx.UpdateStatement(); usc != nil {
return usc.Accept(v)
}
if dsc := ctx.DeleteStatement(); dsc != nil {
return dsc.Accept(v)
}
if isc := ctx.InsertStatement(); isc != nil {
return isc.Accept(v)
}
dmlStmt := sqlstmt.DmlStmt{}
dmlStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return dmlStmt
}
func (v *MysqlVisitor) VisitAdministrationStatement(ctx *mysqlparser.AdministrationStatementContext) interface{} {
if ssc := ctx.ShowStatement(); ssc != nil {
return ssc.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *MysqlVisitor) VisitUtilityStatement(ctx *mysqlparser.UtilityStatementContext) interface{} {
if c := ctx.SimpleDescribeStatement(); c != nil {
return c.Accept(v)
}
if c := ctx.FullDescribeStatement(); c != nil {
return c.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *MysqlVisitor) VisitWithStatement(ctx *mysqlparser.WithStatementContext) interface{} {
ort := new(sqlstmt.WithStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitSimpleSelect(ctx *mysqlparser.SimpleSelectContext) interface{} {
sss := new(sqlstmt.SimpleSelectStmt)
sss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
sss.QuerySpecification = ctx.QuerySpecification().Accept(v).(*sqlstmt.QuerySpecification)
return sss
}
func (v *MysqlVisitor) VisitUnionSelect(ctx *mysqlparser.UnionSelectContext) interface{} {
uss := new(sqlstmt.UnionSelectStmt)
uss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if lc := ctx.LimitClause(); lc != nil {
uss.Limit = lc.Accept(v).(*sqlstmt.Limit)
}
if ausc := ctx.AllUnionStatement(); ausc != nil {
unionStmts := make([]*sqlstmt.UnionStmt, 0)
for _, usc := range ausc {
unionStmts = append(unionStmts, usc.Accept(v).(*sqlstmt.UnionStmt))
}
uss.UnionStmts = unionStmts
}
if qsc := ctx.QuerySpecification(); qsc != nil {
uss.QuerySpecification = qsc.Accept(v).(*sqlstmt.QuerySpecification)
}
if qscn := ctx.QuerySpecificationNointo(); qscn != nil {
uss.QuerySpecification = qscn.Accept(v).(*sqlstmt.QuerySpecification)
}
if qec := ctx.QueryExpression(); qec != nil {
uss.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
}
if qenc := ctx.QueryExpressionNointo(); qenc != nil {
uss.QueryExpr = qenc.Accept(v).(*sqlstmt.QueryExpr)
}
if ui := ctx.UNION(); ui != nil {
uss.UnionType = ui.GetText()
}
return uss
}
func (v *MysqlVisitor) VisitParenthesisSelect(ctx *mysqlparser.ParenthesisSelectContext) interface{} {
ps := new(sqlstmt.ParenthesisSelect)
ps.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if qec := ctx.QueryExpression(); qec != nil {
ps.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
}
return ps
}
func (v *MysqlVisitor) VisitUnionParenthesisSelect(ctx *mysqlparser.UnionParenthesisSelectContext) interface{} {
ss := new(sqlstmt.SelectStmt)
ss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ss
}
func (v *MysqlVisitor) VisitWithLateralStatement(ctx *mysqlparser.WithLateralStatementContext) interface{} {
ss := new(sqlstmt.SelectStmt)
ss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ss
}
func (v *MysqlVisitor) VisitUnionStatement(ctx *mysqlparser.UnionStatementContext) interface{} {
us := new(sqlstmt.UnionStmt)
us.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if qs := ctx.QuerySpecificationNointo(); qs != nil {
us.QuerySpecification = qs.Accept(v).(*sqlstmt.QuerySpecification)
}
if qec := ctx.QueryExpressionNointo(); qec != nil {
us.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
}
if ui := ctx.UNION(); ui != nil {
us.UnionType = ui.GetText()
}
return us
}
func (v *MysqlVisitor) VisitQuerySpecification(ctx *mysqlparser.QuerySpecificationContext) interface{} {
qs := new(sqlstmt.QuerySpecification)
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
qs.SelectElements = ctx.SelectElements().Accept(v).(*sqlstmt.SelectElements)
if fromClause := ctx.FromClause(); fromClause != nil {
where := fromClause.GetWhereExpr()
if where != nil {
qs.Where = v.GetExpr(where)
}
tableSourcesCtx := fromClause.TableSources()
if tableSourcesCtx != nil {
tss := new(sqlstmt.TableSources)
tss.Node = sqlstmt.NewNode(tableSourcesCtx.GetParser(), tableSourcesCtx)
tss.TableSources = tableSourcesCtx.Accept(v).([]sqlstmt.ITableSource)
qs.From = tss
}
}
if limitClause := ctx.LimitClause(); limitClause != nil {
qs.Limit = limitClause.Accept(v).(*sqlstmt.Limit)
}
return qs
}
func (v *MysqlVisitor) VisitQuerySpecificationNointo(ctx *mysqlparser.QuerySpecificationNointoContext) interface{} {
qs := new(sqlstmt.QuerySpecification)
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
qs.SelectElements = ctx.SelectElements().Accept(v).(*sqlstmt.SelectElements)
if fromClause := ctx.FromClause(); fromClause != nil {
where := fromClause.GetWhereExpr()
if where != nil {
qs.Where = v.GetExpr(where)
}
tableSourcesCtx := fromClause.TableSources()
if tableSourcesCtx != nil {
tss := new(sqlstmt.TableSources)
tss.Node = sqlstmt.NewNode(tableSourcesCtx.GetParser(), tableSourcesCtx)
tss.TableSources = tableSourcesCtx.Accept(v).([]sqlstmt.ITableSource)
qs.From = tss
}
}
if limitClause := ctx.LimitClause(); limitClause != nil {
qs.Limit = limitClause.Accept(v).(*sqlstmt.Limit)
}
return qs
}
func (v *MysqlVisitor) VisitQueryExpression(ctx *mysqlparser.QueryExpressionContext) interface{} {
qe := new(sqlstmt.QueryExpr)
qe.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if qec := ctx.QueryExpression(); qec != nil {
qe.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
}
if qsc := ctx.QuerySpecification(); qsc != nil {
qe.QuerySpecification = qsc.Accept(v).(*sqlstmt.QuerySpecification)
}
return qe
}
func (v *MysqlVisitor) VisitQueryExpressionNointo(ctx *mysqlparser.QueryExpressionNointoContext) interface{} {
qe := new(sqlstmt.QueryExpr)
qe.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if qec := ctx.QueryExpressionNointo(); qec != nil {
qe.QueryExpr = qec.Accept(v).(*sqlstmt.QueryExpr)
}
if qsc := ctx.QuerySpecificationNointo(); qsc != nil {
qe.QuerySpecification = qsc.Accept(v).(*sqlstmt.QuerySpecification)
}
return qe
}
func (v *MysqlVisitor) VisitSelectElements(ctx *mysqlparser.SelectElementsContext) interface{} {
ses := new(sqlstmt.SelectElements)
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if ctx.STAR() != nil {
ses.Star = ctx.STAR().GetText()
}
eles := make([]sqlstmt.ISelectElement, 0)
ase := ctx.AllSelectElement()
for _, selectElement := range ase {
eles = append(eles, selectElement.Accept(v).(sqlstmt.ISelectElement))
}
ses.Elements = eles
return ses
}
func (v *MysqlVisitor) VisitSelectStarElement(ctx *mysqlparser.SelectStarElementContext) interface{} {
sse := new(sqlstmt.SelectStarElement)
sse.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
sse.FullId = ctx.FullId().GetText()
return sse
}
func (v *MysqlVisitor) VisitSelectColumnElement(ctx *mysqlparser.SelectColumnElementContext) interface{} {
sce := new(sqlstmt.SelectColumnElement)
sce.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
sce.ColumnName = ctx.FullColumnName().Accept(v).(*sqlstmt.ColumnName)
if uid := ctx.Uid(); uid != nil {
sce.Alias = uid.GetText()
}
return sce
}
func (v *MysqlVisitor) VisitSelectFunctionElement(ctx *mysqlparser.SelectFunctionElementContext) interface{} {
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
return node
}
func (v *MysqlVisitor) VisitSelectExpressionElement(ctx *mysqlparser.SelectExpressionElementContext) interface{} {
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
return node
}
func (v *MysqlVisitor) VisitTableSources(ctx *mysqlparser.TableSourcesContext) interface{} {
tableSourcesCtx := ctx.AllTableSource()
tableSources := make([]sqlstmt.ITableSource, 0)
for _, tableSourceCtx := range tableSourcesCtx {
tableSources = append(tableSources, tableSourceCtx.Accept(v).(sqlstmt.ITableSource))
}
return tableSources
}
func (v *MysqlVisitor) VisitTableSourceBase(ctx *mysqlparser.TableSourceBaseContext) interface{} {
tsb := new(sqlstmt.TableSourceBase)
tsb.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
tsb.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
tsb.JoinParts = v.GetJoinParts(ctx.AllJoinPart())
return tsb
}
func (v *MysqlVisitor) VisitAtomTableItem(ctx *mysqlparser.AtomTableItemContext) interface{} {
tableSourceItem := new(sqlstmt.AtomTableItem)
tableSourceItem.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
tableSourceItem.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
if alias := ctx.GetAlias(); alias != nil {
tableSourceItem.Alias = alias.GetText()
}
return tableSourceItem
}
func (v *MysqlVisitor) VisitSubqueryTableItem(ctx *mysqlparser.SubqueryTableItemContext) interface{} {
sti := new(sqlstmt.SubqueryTableItem)
sti.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
// 解析子查询
if ss := ctx.SelectStatement(); ss != nil {
sti.SubQuery = ss.Accept(v).(sqlstmt.ISelectStmt)
}
// 获取别名
if alias := ctx.GetAlias(); alias != nil {
sti.Alias = alias.GetText()
} else if uid := ctx.Uid(); uid != nil {
sti.Alias = uid.GetText()
}
return sti
}
func (v *MysqlVisitor) VisitInnerJoin(ctx *mysqlparser.InnerJoinContext) interface{} {
ij := new(sqlstmt.InnerJoin)
ij.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
ij.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
return ij
}
func (v *MysqlVisitor) VisitStraightJoin(ctx *mysqlparser.StraightJoinContext) interface{} {
jp := new(sqlstmt.JoinPart)
jp.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
jp.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
return jp
}
func (v *MysqlVisitor) VisitOuterJoin(ctx *mysqlparser.OuterJoinContext) interface{} {
oj := new(sqlstmt.OuterJoin)
oj.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
oj.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
return oj
}
func (v *MysqlVisitor) VisitNaturalJoin(ctx *mysqlparser.NaturalJoinContext) interface{} {
nj := new(sqlstmt.NaturalJoin)
nj.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
nj.TableSourceItem = v.GetTableSourceItem(ctx.TableSourceItem())
return nj
}
func (v *MysqlVisitor) VisitJoinSpec(ctx *mysqlparser.JoinSpecContext) interface{} {
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
return node
}
func (v *MysqlVisitor) VisitIsExpression(ctx *mysqlparser.IsExpressionContext) interface{} {
e := new(sqlstmt.Expr)
e.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return e
}
func (v *MysqlVisitor) VisitNotExpression(ctx *mysqlparser.NotExpressionContext) interface{} {
e := new(sqlstmt.Expr)
e.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return e
}
func (v *MysqlVisitor) VisitLogicalExpression(ctx *mysqlparser.LogicalExpressionContext) interface{} {
le := new(sqlstmt.ExprLogical)
le.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
le.Operator = ctx.LogicalOperator().GetText()
le.Exprs = v.GetExprs(ctx.AllExpression())
return le
}
func (v *MysqlVisitor) VisitPredicateExpression(ctx *mysqlparser.PredicateExpressionContext) interface{} {
ep := new(sqlstmt.ExprPredicate)
ep.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
ep.Predicate = ctx.Predicate().Accept(v).(sqlstmt.IPredicate)
return ep
}
func (v *MysqlVisitor) VisitSoundsLikePredicate(ctx *mysqlparser.SoundsLikePredicateContext) interface{} {
e := new(sqlstmt.Expr)
e.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return e
}
func (v *MysqlVisitor) VisitExpressionAtomPredicate(ctx *mysqlparser.ExpressionAtomPredicateContext) interface{} {
pea := new(sqlstmt.PredicateExprAtom)
pea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
pea.ExprAtom = ctx.ExpressionAtom().Accept(v).(sqlstmt.IExprAtom)
return pea
}
func (v *MysqlVisitor) VisitLogicalOperator(ctx *mysqlparser.LogicalOperatorContext) interface{} {
return ctx.GetText()
}
func (v *MysqlVisitor) VisitBinaryComparisonPredicate(ctx *mysqlparser.BinaryComparisonPredicateContext) interface{} {
bcp := new(sqlstmt.PredicateBinaryComparison)
bcp.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
bcp.Left = ctx.GetLeft().Accept(v).(sqlstmt.IPredicate)
bcp.Right = ctx.GetRight().Accept(v).(sqlstmt.IPredicate)
bcp.ComparisonOperator = ctx.ComparisonOperator().Accept(v).(string)
return bcp
}
func (v *MysqlVisitor) VisitInPredicate(ctx *mysqlparser.InPredicateContext) interface{} {
inPredicate := new(sqlstmt.PredicateIn)
inPredicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if pc := ctx.Predicate(); pc != nil {
inPredicate.InPredicate = pc.Accept(v).(sqlstmt.IPredicate)
}
if ssc := ctx.SelectStatement(); ssc != nil {
inPredicate.SelectStmt = ssc.Accept(v).(sqlstmt.ISelectStmt)
}
if ec := ctx.Expressions(); ec != nil {
inPredicate.Exprs = v.GetExprs(ec.AllExpression())
}
return inPredicate
}
func (v *MysqlVisitor) VisitBetweenPredicate(ctx *mysqlparser.BetweenPredicateContext) interface{} {
predicate := new(sqlstmt.PredicateLike)
predicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return predicate
}
func (v *MysqlVisitor) VisitIsNullPredicate(ctx *mysqlparser.IsNullPredicateContext) interface{} {
predicate := new(sqlstmt.PredicateLike)
predicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return predicate
}
func (v *MysqlVisitor) VisitLikePredicate(ctx *mysqlparser.LikePredicateContext) interface{} {
likePredicate := new(sqlstmt.PredicateLike)
likePredicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return likePredicate
}
func (v *MysqlVisitor) VisitRegexpPredicate(ctx *mysqlparser.RegexpPredicateContext) interface{} {
predicate := new(sqlstmt.PredicateLike)
predicate.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return predicate
}
func (v *MysqlVisitor) VisitUnaryExpressionAtom(ctx *mysqlparser.UnaryExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitCollateExpressionAtom(ctx *mysqlparser.CollateExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitVariableAssignExpressionAtom(ctx *mysqlparser.VariableAssignExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitMysqlVariableExpressionAtom(ctx *mysqlparser.MysqlVariableExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitNestedExpressionAtom(ctx *mysqlparser.NestedExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitNestedRowExpressionAtom(ctx *mysqlparser.NestedRowExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitMathExpressionAtom(ctx *mysqlparser.MathExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitExistsExpressionAtom(ctx *mysqlparser.ExistsExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitIntervalExpressionAtom(ctx *mysqlparser.IntervalExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitJsonExpressionAtom(ctx *mysqlparser.JsonExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitSubqueryExpressionAtom(ctx *mysqlparser.SubqueryExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitConstantExpressionAtom(ctx *mysqlparser.ConstantExpressionAtomContext) interface{} {
constExprAtom := new(sqlstmt.ExprAtomConstant)
constExprAtom.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
constExprAtom.Constant = ctx.Constant().Accept(v).(*sqlstmt.Constant)
return constExprAtom
}
func (v *MysqlVisitor) VisitFunctionCallExpressionAtom(ctx *mysqlparser.FunctionCallExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitBinaryExpressionAtom(ctx *mysqlparser.BinaryExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitFullColumnNameExpressionAtom(ctx *mysqlparser.FullColumnNameExpressionAtomContext) interface{} {
eacn := new(sqlstmt.ExprAtomColumnName)
eacn.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
eacn.ColumnName = ctx.FullColumnName().Accept(v).(*sqlstmt.ColumnName)
return eacn
}
func (v *MysqlVisitor) VisitBitExpressionAtom(ctx *mysqlparser.BitExpressionAtomContext) interface{} {
ea := new(sqlstmt.ExprAtom)
ea.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ea
}
func (v *MysqlVisitor) VisitComparisonOperator(ctx *mysqlparser.ComparisonOperatorContext) interface{} {
return ctx.GetText()
}
func (v *MysqlVisitor) VisitTableName(ctx *mysqlparser.TableNameContext) interface{} {
tableName := new(sqlstmt.TableName)
tableName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
fullId := ctx.FullId().Accept(v).(*sqlstmt.FullId)
if uids := fullId.Uids; len(uids) == 1 {
tableName.Identifier = sqlstmt.NewIdentifierValue(uids[0])
} else {
tableName.Owner = uids[0]
tableName.Identifier = sqlstmt.NewIdentifierValue(uids[1])
}
return tableName
}
func (v *MysqlVisitor) VisitFullId(ctx *mysqlparser.FullIdContext) interface{} {
fid := new(sqlstmt.FullId)
fid.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
uids := make([]string, 0)
for _, uid := range ctx.AllUid() {
uids = append(uids, uid.GetText())
}
if did := ctx.DOT_ID(); did != nil {
uids = append(uids, strings.TrimPrefix(did.GetText(), "."))
}
fid.Uids = uids
return fid
}
func (v *MysqlVisitor) VisitRoleName(ctx *mysqlparser.RoleNameContext) interface{} {
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
return node
}
func (v *MysqlVisitor) VisitFullColumnName(ctx *mysqlparser.FullColumnNameContext) interface{} {
fullColumnName := new(sqlstmt.ColumnName)
fullColumnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
adis := ctx.AllDottedId()
// 不存在.则直接取标识符
if len(adis) == 0 {
fullColumnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Uid().GetText())
} else {
fullColumnName.Owner = ctx.Uid().GetText()
fullColumnName.Identifier = sqlstmt.NewIdentifierValue(adis[0].GetText())
}
return fullColumnName
}
func (v *MysqlVisitor) VisitIndexColumnName(ctx *mysqlparser.IndexColumnNameContext) interface{} {
node := sqlstmt.NewNode(ctx.GetParser(), ctx)
return node
}
func (v *MysqlVisitor) VisitConstant(ctx *mysqlparser.ConstantContext) interface{} {
constant := new(sqlstmt.Constant)
constant.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
constant.Value = ctx.GetText()
return constant
}
func (v *MysqlVisitor) VisitLimitClause(ctx *mysqlparser.LimitClauseContext) interface{} {
limit := new(sqlstmt.Limit)
limit.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if lc := ctx.GetLimit(); lc != nil {
limit.RowCount = cast.ToInt(lc.GetText())
}
if oc := ctx.GetOffset(); oc != nil {
limit.Offset = cast.ToInt(oc.GetText())
}
return limit
}
func (v *MysqlVisitor) VisitInsertStatement(ctx *mysqlparser.InsertStatementContext) interface{} {
is := new(sqlstmt.InsertStmt)
is.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
is.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
return is
}
func (v *MysqlVisitor) VisitUpdateStatement(ctx *mysqlparser.UpdateStatementContext) interface{} {
if sus := ctx.SingleUpdateStatement(); sus != nil {
return sus.Accept(v)
}
if mus := ctx.MultipleUpdateStatement(); mus != nil {
return mus.Accept(v)
}
us := new(sqlstmt.UpdateStmt)
us.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return us
}
func (v *MysqlVisitor) VisitSingleUpdateStatement(ctx *mysqlparser.SingleUpdateStatementContext) interface{} {
sus := new(sqlstmt.UpdateStmt)
sus.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
tss := new(sqlstmt.TableSources)
tss.Node = sqlstmt.NewNode(ctx.TableName().GetParser(), ctx.TableName())
atomTable := new(sqlstmt.AtomTableItem)
atomTable.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
if uid := ctx.Uid(); uid != nil {
atomTable.Alias = uid.GetText()
}
tableSourceBase := new(sqlstmt.TableSourceBase)
tableSourceBase.Node = tss.Node
tableSourceBase.TableSourceItem = atomTable
tss.TableSources = []sqlstmt.ITableSource{tableSourceBase}
sus.TableSources = tss
if aucs := ctx.AllUpdatedElement(); aucs != nil {
ues := make([]*sqlstmt.UpdatedElement, 0)
for _, auc := range aucs {
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
}
sus.UpdatedElements = ues
}
if ec := ctx.Expression(); ec != nil {
sus.Where = v.GetExpr(ec)
}
return sus
}
func (v *MysqlVisitor) VisitMultipleUpdateStatement(ctx *mysqlparser.MultipleUpdateStatementContext) interface{} {
mus := new(sqlstmt.UpdateStmt)
mus.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if tssc := ctx.TableSources(); tssc != nil {
tss := new(sqlstmt.TableSources)
tss.Node = sqlstmt.NewNode(tssc.GetParser(), tssc)
tss.TableSources = tssc.Accept(v).([]sqlstmt.ITableSource)
mus.TableSources = tss
}
if aucs := ctx.AllUpdatedElement(); aucs != nil {
ues := make([]*sqlstmt.UpdatedElement, 0)
for _, auc := range aucs {
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
}
mus.UpdatedElements = ues
}
if ec := ctx.Expression(); ec != nil {
mus.Where = v.GetExpr(ec)
}
return mus
}
func (v *MysqlVisitor) VisitUpdatedElement(ctx *mysqlparser.UpdatedElementContext) interface{} {
ue := new(sqlstmt.UpdatedElement)
ue.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
ue.ColumnName = ctx.FullColumnName().Accept(v).(*sqlstmt.ColumnName)
ue.Value = v.GetExpr(ctx.Expression())
return ue
}
func (v *MysqlVisitor) VisitDeleteStatement(ctx *mysqlparser.DeleteStatementContext) interface{} {
if sus := ctx.SingleDeleteStatement(); sus != nil {
return sus.Accept(v)
}
if mus := ctx.MultipleDeleteStatement(); mus != nil {
return mus.Accept(v)
}
ds := new(sqlstmt.DeleteStmt)
ds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ds
}
func (v *MysqlVisitor) VisitSingleDeleteStatement(ctx *mysqlparser.SingleDeleteStatementContext) interface{} {
ds := new(sqlstmt.DeleteStmt)
ds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
tss := new(sqlstmt.TableSources)
tss.Node = sqlstmt.NewNode(ctx.TableName().GetParser(), ctx.TableName())
atomTable := new(sqlstmt.AtomTableItem)
atomTable.TableName = ctx.TableName().Accept(v).(*sqlstmt.TableName)
if uid := ctx.Uid(); uid != nil {
atomTable.Alias = uid.GetText()
}
tableSourceBase := new(sqlstmt.TableSourceBase)
tableSourceBase.Node = tss.Node
tableSourceBase.TableSourceItem = atomTable
tss.TableSources = []sqlstmt.ITableSource{tableSourceBase}
ds.TableSources = tss
if ec := ctx.Expression(); ec != nil {
ds.Where = v.GetExpr(ec)
}
return ds
}
func (v *MysqlVisitor) VisitMultipleDeleteStatement(ctx *mysqlparser.MultipleDeleteStatementContext) interface{} {
ds := new(sqlstmt.DeleteStmt)
ds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if tssc := ctx.TableSources(); tssc != nil {
tss := new(sqlstmt.TableSources)
tss.Node = sqlstmt.NewNode(tssc.GetParser(), tssc)
tss.TableSources = tssc.Accept(v).([]sqlstmt.ITableSource)
ds.TableSources = tss
}
if ec := ctx.Expression(); ec != nil {
ds.Where = v.GetExpr(ec)
}
return ds
}
func (v *MysqlVisitor) VisitSimpleDescribeStatement(ctx *mysqlparser.SimpleDescribeStatementContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitFullDescribeStatement(ctx *mysqlparser.FullDescribeStatementContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowMasterLogs(ctx *mysqlparser.ShowMasterLogsContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowLogEvents(ctx *mysqlparser.ShowLogEventsContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowObjectFilter(ctx *mysqlparser.ShowObjectFilterContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowColumns(ctx *mysqlparser.ShowColumnsContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowCreateDb(ctx *mysqlparser.ShowCreateDbContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowCreateFullIdObject(ctx *mysqlparser.ShowCreateFullIdObjectContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowCreateUser(ctx *mysqlparser.ShowCreateUserContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowEngine(ctx *mysqlparser.ShowEngineContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowGlobalInfo(ctx *mysqlparser.ShowGlobalInfoContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowErrors(ctx *mysqlparser.ShowErrorsContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowCountErrors(ctx *mysqlparser.ShowCountErrorsContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowSchemaFilter(ctx *mysqlparser.ShowSchemaFilterContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowRoutine(ctx *mysqlparser.ShowRoutineContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowGrants(ctx *mysqlparser.ShowGrantsContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowIndexes(ctx *mysqlparser.ShowIndexesContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowOpenTables(ctx *mysqlparser.ShowOpenTablesContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowProfile(ctx *mysqlparser.ShowProfileContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitShowSlaveStatus(ctx *mysqlparser.ShowSlaveStatusContext) interface{} {
ort := new(sqlstmt.OtherReadStmt)
ort.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ort
}
func (v *MysqlVisitor) VisitCreateDatabase(ctx *mysqlparser.CreateDatabaseContext) interface{} {
cds := new(sqlstmt.CreateDatabase)
cds.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return cds
}
func (v *MysqlVisitor) GetTableSourceItem(ctx mysqlparser.ITableSourceItemContext) sqlstmt.ITableSourceItem {
if ctx == nil {
return nil
}
return ctx.Accept(v).(sqlstmt.ITableSourceItem)
}
func (v *MysqlVisitor) GetExpr(ctx mysqlparser.IExpressionContext) sqlstmt.IExpr {
if ctx == nil {
return nil
}
return ctx.Accept(v).(sqlstmt.IExpr)
}
func (v *MysqlVisitor) GetExprs(ctxs []mysqlparser.IExpressionContext) []sqlstmt.IExpr {
if ctxs == nil {
return nil
}
exprs := make([]sqlstmt.IExpr, 0)
for _, exprCtx := range ctxs {
exprs = append(exprs, exprCtx.Accept(v).(sqlstmt.IExpr))
}
return exprs
}
func (v *MysqlVisitor) GetJoinParts(ctxs []mysqlparser.IJoinPartContext) []sqlstmt.IJoinPart {
if ctxs == nil {
return nil
}
joinPorts := make([]sqlstmt.IJoinPart, 0)
for _, joinPartCtx := range ctxs {
joinPorts = append(joinPorts, joinPartCtx.Accept(v).(sqlstmt.IJoinPart))
}
return joinPorts
}