Files
mayfly-go/server/internal/db/dbm/sqlparser/mysql/mysql.go

60 lines
1.4 KiB
Go

package mysql
import (
"mayfly-go/internal/db/dbm/sqlparser/base"
mysqlparser "mayfly-go/internal/db/dbm/sqlparser/mysql/antlr4"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"mayfly-go/pkg/logx"
"github.com/antlr4-go/antlr/v4"
)
func GetMysqlParserTree(baseLine int, statement string) (antlr.ParseTree, *antlr.CommonTokenStream, error) {
lexer := mysqlparser.NewMySqlLexer(antlr.NewInputStream(statement))
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)
parser := mysqlparser.NewMySqlParser(stream)
lexerErrorListener := &base.ParseErrorListener{
BaseLine: baseLine,
}
lexer.RemoveErrorListeners()
lexer.AddErrorListener(lexerErrorListener)
parserErrorListener := &base.ParseErrorListener{
BaseLine: baseLine,
}
parser.RemoveErrorListeners()
parser.AddErrorListener(parserErrorListener)
parser.BuildParseTrees = true
tree := parser.Root()
if lexerErrorListener.Err != nil {
return nil, nil, lexerErrorListener.Err
}
if parserErrorListener.Err != nil {
return nil, nil, parserErrorListener.Err
}
return tree, stream, nil
}
type MysqlParser struct {
}
func (*MysqlParser) Parse(stmt string) (stmts []sqlstmt.Stmt, err error) {
defer func() {
if e := recover(); e != nil {
logx.ErrorTrace("mysql sql parser err: ", e)
err = e.(error)
}
}()
tree, _, err := GetMysqlParserTree(1, stmt)
if err != nil {
return nil, err
}
return tree.Accept(new(MysqlVisitor)).([]sqlstmt.Stmt), nil
}