mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-06 17:30:25 +08:00
494 lines
15 KiB
Go
494 lines
15 KiB
Go
package pgsql
|
||
|
||
import (
|
||
"strings"
|
||
|
||
pgparser "mayfly-go/internal/db/dbm/sqlparser/pgsql/antlr4"
|
||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||
|
||
"github.com/may-fly/cast"
|
||
)
|
||
|
||
type PgsqlVisitor struct {
|
||
*pgparser.BasePostgreSQLParserVisitor
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitRoot(ctx *pgparser.RootContext) interface{} {
|
||
if sbc := ctx.Stmtblock(); sbc != nil {
|
||
return sbc.Accept(v)
|
||
}
|
||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitPlsqlroot(ctx *pgparser.PlsqlrootContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitStmtblock(ctx *pgparser.StmtblockContext) interface{} {
|
||
if smc := ctx.Stmtmulti(); smc != nil {
|
||
return smc.Accept(v)
|
||
}
|
||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitStmtmulti(ctx *pgparser.StmtmultiContext) interface{} {
|
||
allSqlStatement := ctx.AllStmt()
|
||
stmts := make([]sqlstmt.Stmt, 0)
|
||
for _, sqlStatement := range allSqlStatement {
|
||
stmts = append(stmts, sqlStatement.Accept(v).(sqlstmt.Stmt))
|
||
}
|
||
return stmts
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitStmt(ctx *pgparser.StmtContext) interface{} {
|
||
if selectstmtCtx := ctx.Selectstmt(); selectstmtCtx != nil {
|
||
return selectstmtCtx.Accept(v)
|
||
}
|
||
if updatestmtCtx := ctx.Updatestmt(); updatestmtCtx != nil {
|
||
return updatestmtCtx.Accept(v)
|
||
}
|
||
if deletestmtCtx := ctx.Deletestmt(); deletestmtCtx != nil {
|
||
return deletestmtCtx.Accept(v)
|
||
}
|
||
if insertstmtC := ctx.Insertstmt(); insertstmtC != nil {
|
||
return insertstmtC.Accept(v)
|
||
}
|
||
if c := ctx.Createdbstmt(); c != nil {
|
||
cds := new(sqlstmt.CreateDatabase)
|
||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||
return cds
|
||
}
|
||
if c := ctx.Createtablespacestmt(); c != nil {
|
||
cds := new(sqlstmt.CreateTable)
|
||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||
return cds
|
||
}
|
||
if c := ctx.Altertablestmt(); c != nil {
|
||
cds := new(sqlstmt.AlterTable)
|
||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||
return cds
|
||
}
|
||
if c := ctx.Dropdbstmt(); c != nil {
|
||
cds := new(sqlstmt.DropDatabase)
|
||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||
return cds
|
||
}
|
||
if c := ctx.Droptablespacestmt(); c != nil {
|
||
cds := new(sqlstmt.DropTable)
|
||
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||
return cds
|
||
}
|
||
if explain := ctx.Explainstmt(); explain != nil {
|
||
otherRead := new(sqlstmt.OtherReadStmt)
|
||
otherRead.Node = sqlstmt.NewNode(explain.GetParser(), explain)
|
||
return otherRead
|
||
}
|
||
if c := ctx.Variableshowstmt(); c != nil {
|
||
otherRead := new(sqlstmt.OtherReadStmt)
|
||
otherRead.Node = sqlstmt.NewNode(c.GetParser(), c)
|
||
return otherRead
|
||
}
|
||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSelectstmt(ctx *pgparser.SelectstmtContext) interface{} {
|
||
if spnc := ctx.Select_no_parens(); spnc != nil {
|
||
return spnc.Accept(v)
|
||
}
|
||
selectstmt := new(sqlstmt.SelectStmt)
|
||
selectstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
return selectstmt
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSelect_with_parens(ctx *pgparser.Select_with_parensContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSelect_no_parens(ctx *pgparser.Select_no_parensContext) interface{} {
|
||
if c := ctx.Select_clause(); c == nil {
|
||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
}
|
||
|
||
var limit *sqlstmt.Limit
|
||
if limitC := ctx.Select_limit(); limitC != nil {
|
||
limit = limitC.Accept(v).(*sqlstmt.Limit)
|
||
}
|
||
if limitC := ctx.Opt_select_limit(); limitC != nil {
|
||
limit = limitC.Accept(v).(*sqlstmt.Limit)
|
||
}
|
||
|
||
selectClause := ctx.Select_clause()
|
||
asis := selectClause.AllSimple_select_intersect()
|
||
// 简单查询
|
||
if len(asis) == 1 {
|
||
sss := new(sqlstmt.SimpleSelectStmt)
|
||
sss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
sss.QuerySpecification = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)[0]
|
||
sss.QuerySpecification.Limit = limit
|
||
return sss
|
||
}
|
||
|
||
uss := new(sqlstmt.UnionSelectStmt)
|
||
uss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
allUnion := selectClause.AllUNION()
|
||
// todo 赋值union信息
|
||
for _, union := range allUnion {
|
||
uss.UnionType = union.GetText()
|
||
}
|
||
// uss.QuerySpecifications = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)
|
||
uss.Limit = limit
|
||
return uss
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSelect_clause(ctx *pgparser.Select_clauseContext) interface{} {
|
||
qs := make([]*sqlstmt.QuerySpecification, 0)
|
||
for _, ssi := range ctx.AllSimple_select_intersect() {
|
||
qs = append(qs, ssi.Accept(v).(*sqlstmt.QuerySpecification))
|
||
}
|
||
|
||
return qs
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSimple_select_intersect(ctx *pgparser.Simple_select_intersectContext) interface{} {
|
||
// 只返回一个查询,INTERSECT(交集)暂不支持
|
||
if spsc := ctx.AllSimple_select_pramary(); spsc != nil {
|
||
return spsc[0].Accept(v)
|
||
}
|
||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSimple_select_pramary(ctx *pgparser.Simple_select_pramaryContext) interface{} {
|
||
qs := new(sqlstmt.QuerySpecification)
|
||
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
if c := ctx.From_clause(); c != nil {
|
||
qs.From = c.Accept(v).(*sqlstmt.TableSources)
|
||
}
|
||
|
||
if c := ctx.Opt_target_list(); c != nil {
|
||
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
|
||
}
|
||
|
||
if c := ctx.Target_list(); c != nil {
|
||
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
|
||
}
|
||
|
||
if c := ctx.Where_clause(); c != nil {
|
||
qs.Where = c.A_expr().Accept(v).(sqlstmt.IExpr)
|
||
}
|
||
|
||
return qs
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSelect_limit(ctx *pgparser.Select_limitContext) interface{} {
|
||
limit := new(sqlstmt.Limit)
|
||
limit.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
if lc := ctx.Limit_clause(); lc != nil {
|
||
if lv := lc.Select_limit_value(); lv != nil {
|
||
limit.RowCount = cast.ToInt(lv.GetText())
|
||
}
|
||
}
|
||
if oc := ctx.Offset_clause(); oc != nil {
|
||
if ov := oc.Select_offset_value(); ov != nil {
|
||
limit.Offset = cast.ToInt(ov.GetText())
|
||
}
|
||
}
|
||
return limit
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitOpt_select_limit(ctx *pgparser.Opt_select_limitContext) interface{} {
|
||
if slc := ctx.Select_limit(); slc != nil {
|
||
return slc.Accept(v)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitFrom_clause(ctx *pgparser.From_clauseContext) interface{} {
|
||
if c := ctx.From_list(); c != nil {
|
||
return c.Accept(v)
|
||
}
|
||
return sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitFrom_list(ctx *pgparser.From_listContext) interface{} {
|
||
ts := new(sqlstmt.TableSources)
|
||
ts.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
// ts.StartIndex = ctx.GetStart().GetStart()
|
||
// ts.StopIndex = ctx.GetStop().GetStop()
|
||
|
||
tableSources := make([]sqlstmt.ITableSource, 0)
|
||
allTableRefCtx := ctx.AllTable_ref()
|
||
for _, trc := range allTableRefCtx {
|
||
tableSources = append(tableSources, trc.Accept(v).(sqlstmt.ITableSource))
|
||
}
|
||
|
||
ts.TableSources = tableSources
|
||
return ts
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitNon_ansi_join(ctx *pgparser.Non_ansi_joinContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitTable_ref(ctx *pgparser.Table_refContext) interface{} {
|
||
tableSourceBase := new(sqlstmt.TableSourceBase)
|
||
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
atomTable := new(sqlstmt.AtomTableItem)
|
||
|
||
if c := ctx.Relation_expr(); c != nil {
|
||
tableName := new(sqlstmt.TableName)
|
||
if qn := c.Qualified_name(); qn != nil {
|
||
if qc := qn.Colid(); qc != nil {
|
||
if c := qn.Indirection(); c != nil {
|
||
tableName.Owner = qc.GetText()
|
||
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
|
||
} else {
|
||
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
|
||
}
|
||
}
|
||
}
|
||
atomTable.TableName = tableName
|
||
}
|
||
if c := ctx.Opt_alias_clause(); c != nil {
|
||
if aliasC := c.Table_alias_clause(); aliasC != nil {
|
||
atomTable.Alias = aliasC.Table_alias().GetText()
|
||
}
|
||
}
|
||
|
||
tableSourceBase.TableSourceItem = atomTable
|
||
return tableSourceBase
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitOpt_target_list(ctx *pgparser.Opt_target_listContext) interface{} {
|
||
if c := ctx.Target_list(); c != nil {
|
||
return c.Accept(v)
|
||
}
|
||
|
||
ses := new(sqlstmt.SelectElements)
|
||
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
return ses
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitTarget_list(ctx *pgparser.Target_listContext) interface{} {
|
||
ses := new(sqlstmt.SelectElements)
|
||
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
if tecs := ctx.AllTarget_el(); tecs != nil {
|
||
eles := make([]sqlstmt.ISelectElement, 0)
|
||
for _, tec := range tecs {
|
||
eles = append(eles, tec.Accept(v).(sqlstmt.ISelectElement))
|
||
}
|
||
ses.Elements = eles
|
||
}
|
||
if len(ses.Elements) == 1 && ses.Elements[0].GetText() == "*" {
|
||
ses.Star = "*"
|
||
}
|
||
|
||
return ses
|
||
}
|
||
|
||
// Visit a parse tree produced by PostgreSQLParser#target_label.
|
||
func (v *PgsqlVisitor) VisitTarget_label(ctx *pgparser.Target_labelContext) interface{} {
|
||
sce := new(sqlstmt.SelectColumnElement)
|
||
sce.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
columnName := new(sqlstmt.ColumnName)
|
||
|
||
if c := ctx.Collabel(); c != nil {
|
||
sce.Alias = c.GetText()
|
||
}
|
||
if c := ctx.Identifier(); c != nil {
|
||
sce.Alias = c.GetText()
|
||
}
|
||
if exprCtx := ctx.A_expr(); exprCtx != nil {
|
||
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), exprCtx)
|
||
if aextrCtx := exprCtx.A_expr_qual(); aextrCtx != nil {
|
||
col := aextrCtx.GetText()
|
||
ownerAndColname := strings.Split(col, ".")
|
||
if len(ownerAndColname) == 2 {
|
||
columnName.Owner = ownerAndColname[0]
|
||
columnName.Identifier = sqlstmt.NewIdentifierValue(ownerAndColname[1])
|
||
} else {
|
||
columnName.Identifier = sqlstmt.NewIdentifierValue(col)
|
||
}
|
||
}
|
||
} else {
|
||
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
}
|
||
sce.ColumnName = columnName
|
||
return sce
|
||
}
|
||
|
||
// Visit a parse tree produced by PostgreSQLParser#target_star.
|
||
func (v *PgsqlVisitor) VisitTarget_star(ctx *pgparser.Target_starContext) interface{} {
|
||
sse := new(sqlstmt.SelectStarElement)
|
||
sse.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
sse.FullId = ctx.STAR().GetText()
|
||
return sse
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitAlias_clause(ctx *pgparser.Alias_clauseContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitOpt_alias_clause(ctx *pgparser.Opt_alias_clauseContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitTable_alias_clause(ctx *pgparser.Table_alias_clauseContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitFunc_alias_clause(ctx *pgparser.Func_alias_clauseContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitRelation_expr(ctx *pgparser.Relation_exprContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitRelation_expr_list(ctx *pgparser.Relation_expr_listContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitUpdatestmt(ctx *pgparser.UpdatestmtContext) interface{} {
|
||
updateStmt := new(sqlstmt.UpdateStmt)
|
||
updateStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
updateStmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
|
||
updateStmt.UpdatedElements = ctx.Set_clause_list().Accept(v).([]*sqlstmt.UpdatedElement)
|
||
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
|
||
updateStmt.Where = ec.Accept(v).(sqlstmt.IExpr)
|
||
}
|
||
|
||
return updateStmt
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSet_clause_list(ctx *pgparser.Set_clause_listContext) interface{} {
|
||
ues := make([]*sqlstmt.UpdatedElement, 0)
|
||
aucs := ctx.AllSet_clause()
|
||
for _, auc := range aucs {
|
||
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
|
||
}
|
||
return ues
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSet_clause(ctx *pgparser.Set_clauseContext) interface{} {
|
||
updateEle := new(sqlstmt.UpdatedElement)
|
||
updateEle.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
updateEle.ColumnName = ctx.Set_target().Accept(v).(*sqlstmt.ColumnName)
|
||
if ac := ctx.A_expr(); ac != nil {
|
||
updateEle.Value = ac.Accept(v).(sqlstmt.IExpr)
|
||
}
|
||
return updateEle
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSet_target(ctx *pgparser.Set_targetContext) interface{} {
|
||
columnName := new(sqlstmt.ColumnName)
|
||
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
if ic := ctx.Opt_indirection(); ic != nil {
|
||
if ic.GetText() == "" {
|
||
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
|
||
} else {
|
||
columnName.Owner = ctx.Colid().GetText()
|
||
columnName.Identifier = sqlstmt.NewIdentifierValue(ic.GetText())
|
||
}
|
||
} else {
|
||
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
|
||
}
|
||
return columnName
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitSet_target_list(ctx *pgparser.Set_target_listContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitA_expr(ctx *pgparser.A_exprContext) interface{} {
|
||
expr := new(sqlstmt.Expr)
|
||
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
return expr
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitA_expr_qual(ctx *pgparser.A_expr_qualContext) interface{} {
|
||
expr := new(sqlstmt.Expr)
|
||
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
return expr
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitDeletestmt(ctx *pgparser.DeletestmtContext) interface{} {
|
||
deletestmt := new(sqlstmt.DeleteStmt)
|
||
deletestmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
deletestmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
|
||
|
||
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
|
||
deletestmt.Where = ec.Accept(v).(sqlstmt.IExpr)
|
||
}
|
||
return deletestmt
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitInsertstmt(ctx *pgparser.InsertstmtContext) interface{} {
|
||
insertstmt := new(sqlstmt.InsertStmt)
|
||
insertstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
insertstmt.TableName = ctx.Insert_target().Accept(v).(*sqlstmt.TableName)
|
||
return insertstmt
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitInsert_target(ctx *pgparser.Insert_targetContext) interface{} {
|
||
tableName := new(sqlstmt.TableName)
|
||
tableName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
table := ctx.GetText()
|
||
if strings.Contains(table, ".") {
|
||
tableAndOwner := strings.Split(table, ".")
|
||
tableName.Identifier = sqlstmt.NewIdentifierValue(tableAndOwner[1])
|
||
tableName.Owner = tableAndOwner[0]
|
||
} else {
|
||
tableName.Identifier = sqlstmt.NewIdentifierValue(table)
|
||
}
|
||
return tableName
|
||
}
|
||
|
||
func (v *PgsqlVisitor) VisitInsert_rest(ctx *pgparser.Insert_restContext) interface{} {
|
||
return v.VisitChildren(ctx)
|
||
}
|
||
|
||
func (v *PgsqlVisitor) GetTableSourcesByrelation_expr_opt_alias(ctx pgparser.IRelation_expr_opt_aliasContext) *sqlstmt.TableSources {
|
||
tableSources := new(sqlstmt.TableSources)
|
||
tableSources.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
atomTable := new(sqlstmt.AtomTableItem)
|
||
atomTable.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
|
||
if c := ctx.Relation_expr(); c != nil {
|
||
tableName := new(sqlstmt.TableName)
|
||
if qn := c.Qualified_name(); qn != nil {
|
||
if qc := qn.Colid(); qc != nil {
|
||
if c := qn.Indirection(); c != nil {
|
||
tableName.Owner = qc.GetText()
|
||
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
|
||
} else {
|
||
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
|
||
}
|
||
}
|
||
}
|
||
atomTable.TableName = tableName
|
||
}
|
||
|
||
if c := ctx.Colid(); c != nil {
|
||
atomTable.Alias = c.GetText()
|
||
}
|
||
|
||
tableSourceBase := new(sqlstmt.TableSourceBase)
|
||
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
|
||
tableSourceBase.TableSourceItem = atomTable
|
||
|
||
tableSources.TableSources = []sqlstmt.ITableSource{tableSourceBase}
|
||
return tableSources
|
||
}
|