Files
mayfly-go/server/internal/db/dbm/sqlparser/pgsql/visitor.go

494 lines
15 KiB
Go
Raw Normal View History

package pgsql
import (
"strings"
pgparser "mayfly-go/internal/db/dbm/sqlparser/pgsql/antlr4"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"github.com/may-fly/cast"
)
type PgsqlVisitor struct {
*pgparser.BasePostgreSQLParserVisitor
}
func (v *PgsqlVisitor) VisitRoot(ctx *pgparser.RootContext) interface{} {
if sbc := ctx.Stmtblock(); sbc != nil {
return sbc.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitPlsqlroot(ctx *pgparser.PlsqlrootContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitStmtblock(ctx *pgparser.StmtblockContext) interface{} {
if smc := ctx.Stmtmulti(); smc != nil {
return smc.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitStmtmulti(ctx *pgparser.StmtmultiContext) interface{} {
allSqlStatement := ctx.AllStmt()
stmts := make([]sqlstmt.Stmt, 0)
for _, sqlStatement := range allSqlStatement {
stmts = append(stmts, sqlStatement.Accept(v).(sqlstmt.Stmt))
}
return stmts
}
func (v *PgsqlVisitor) VisitStmt(ctx *pgparser.StmtContext) interface{} {
if selectstmtCtx := ctx.Selectstmt(); selectstmtCtx != nil {
return selectstmtCtx.Accept(v)
}
if updatestmtCtx := ctx.Updatestmt(); updatestmtCtx != nil {
return updatestmtCtx.Accept(v)
}
if deletestmtCtx := ctx.Deletestmt(); deletestmtCtx != nil {
return deletestmtCtx.Accept(v)
}
if insertstmtC := ctx.Insertstmt(); insertstmtC != nil {
return insertstmtC.Accept(v)
}
if c := ctx.Createdbstmt(); c != nil {
cds := new(sqlstmt.CreateDatabase)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Createtablespacestmt(); c != nil {
cds := new(sqlstmt.CreateTable)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Altertablestmt(); c != nil {
cds := new(sqlstmt.AlterTable)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Dropdbstmt(); c != nil {
cds := new(sqlstmt.DropDatabase)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if c := ctx.Droptablespacestmt(); c != nil {
cds := new(sqlstmt.DropTable)
cds.Node = sqlstmt.NewNode(c.GetParser(), c)
return cds
}
if explain := ctx.Explainstmt(); explain != nil {
otherRead := new(sqlstmt.OtherReadStmt)
otherRead.Node = sqlstmt.NewNode(explain.GetParser(), explain)
return otherRead
}
if c := ctx.Variableshowstmt(); c != nil {
otherRead := new(sqlstmt.OtherReadStmt)
otherRead.Node = sqlstmt.NewNode(c.GetParser(), c)
return otherRead
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitSelectstmt(ctx *pgparser.SelectstmtContext) interface{} {
if spnc := ctx.Select_no_parens(); spnc != nil {
return spnc.Accept(v)
}
selectstmt := new(sqlstmt.SelectStmt)
selectstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return selectstmt
}
func (v *PgsqlVisitor) VisitSelect_with_parens(ctx *pgparser.Select_with_parensContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitSelect_no_parens(ctx *pgparser.Select_no_parensContext) interface{} {
if c := ctx.Select_clause(); c == nil {
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
var limit *sqlstmt.Limit
if limitC := ctx.Select_limit(); limitC != nil {
limit = limitC.Accept(v).(*sqlstmt.Limit)
}
if limitC := ctx.Opt_select_limit(); limitC != nil {
limit = limitC.Accept(v).(*sqlstmt.Limit)
}
selectClause := ctx.Select_clause()
asis := selectClause.AllSimple_select_intersect()
// 简单查询
if len(asis) == 1 {
sss := new(sqlstmt.SimpleSelectStmt)
sss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
sss.QuerySpecification = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)[0]
sss.QuerySpecification.Limit = limit
return sss
}
uss := new(sqlstmt.UnionSelectStmt)
uss.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
allUnion := selectClause.AllUNION()
// todo 赋值union信息
for _, union := range allUnion {
uss.UnionType = union.GetText()
}
// uss.QuerySpecifications = ctx.Select_clause().Accept(v).([]*sqlstmt.QuerySpecification)
uss.Limit = limit
return uss
}
func (v *PgsqlVisitor) VisitSelect_clause(ctx *pgparser.Select_clauseContext) interface{} {
qs := make([]*sqlstmt.QuerySpecification, 0)
for _, ssi := range ctx.AllSimple_select_intersect() {
qs = append(qs, ssi.Accept(v).(*sqlstmt.QuerySpecification))
}
return qs
}
func (v *PgsqlVisitor) VisitSimple_select_intersect(ctx *pgparser.Simple_select_intersectContext) interface{} {
// 只返回一个查询INTERSECT交集暂不支持
if spsc := ctx.AllSimple_select_pramary(); spsc != nil {
return spsc[0].Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitSimple_select_pramary(ctx *pgparser.Simple_select_pramaryContext) interface{} {
qs := new(sqlstmt.QuerySpecification)
qs.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if c := ctx.From_clause(); c != nil {
qs.From = c.Accept(v).(*sqlstmt.TableSources)
}
if c := ctx.Opt_target_list(); c != nil {
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
}
if c := ctx.Target_list(); c != nil {
qs.SelectElements = c.Accept(v).(*sqlstmt.SelectElements)
}
if c := ctx.Where_clause(); c != nil {
qs.Where = c.A_expr().Accept(v).(sqlstmt.IExpr)
}
return qs
}
func (v *PgsqlVisitor) VisitSelect_limit(ctx *pgparser.Select_limitContext) interface{} {
limit := new(sqlstmt.Limit)
limit.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if lc := ctx.Limit_clause(); lc != nil {
if lv := lc.Select_limit_value(); lv != nil {
limit.RowCount = cast.ToInt(lv.GetText())
}
}
if oc := ctx.Offset_clause(); oc != nil {
if ov := oc.Select_offset_value(); ov != nil {
limit.Offset = cast.ToInt(ov.GetText())
}
}
return limit
}
func (v *PgsqlVisitor) VisitOpt_select_limit(ctx *pgparser.Opt_select_limitContext) interface{} {
if slc := ctx.Select_limit(); slc != nil {
return slc.Accept(v)
}
return nil
}
func (v *PgsqlVisitor) VisitFrom_clause(ctx *pgparser.From_clauseContext) interface{} {
if c := ctx.From_list(); c != nil {
return c.Accept(v)
}
return sqlstmt.NewNode(ctx.GetParser(), ctx)
}
func (v *PgsqlVisitor) VisitFrom_list(ctx *pgparser.From_listContext) interface{} {
ts := new(sqlstmt.TableSources)
ts.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
// ts.StartIndex = ctx.GetStart().GetStart()
// ts.StopIndex = ctx.GetStop().GetStop()
tableSources := make([]sqlstmt.ITableSource, 0)
allTableRefCtx := ctx.AllTable_ref()
for _, trc := range allTableRefCtx {
tableSources = append(tableSources, trc.Accept(v).(sqlstmt.ITableSource))
}
ts.TableSources = tableSources
return ts
}
func (v *PgsqlVisitor) VisitNon_ansi_join(ctx *pgparser.Non_ansi_joinContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitTable_ref(ctx *pgparser.Table_refContext) interface{} {
tableSourceBase := new(sqlstmt.TableSourceBase)
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
atomTable := new(sqlstmt.AtomTableItem)
if c := ctx.Relation_expr(); c != nil {
tableName := new(sqlstmt.TableName)
if qn := c.Qualified_name(); qn != nil {
if qc := qn.Colid(); qc != nil {
if c := qn.Indirection(); c != nil {
tableName.Owner = qc.GetText()
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
} else {
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
}
}
}
atomTable.TableName = tableName
}
if c := ctx.Opt_alias_clause(); c != nil {
if aliasC := c.Table_alias_clause(); aliasC != nil {
atomTable.Alias = aliasC.Table_alias().GetText()
}
}
tableSourceBase.TableSourceItem = atomTable
return tableSourceBase
}
func (v *PgsqlVisitor) VisitOpt_target_list(ctx *pgparser.Opt_target_listContext) interface{} {
if c := ctx.Target_list(); c != nil {
return c.Accept(v)
}
ses := new(sqlstmt.SelectElements)
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return ses
}
func (v *PgsqlVisitor) VisitTarget_list(ctx *pgparser.Target_listContext) interface{} {
ses := new(sqlstmt.SelectElements)
ses.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if tecs := ctx.AllTarget_el(); tecs != nil {
eles := make([]sqlstmt.ISelectElement, 0)
for _, tec := range tecs {
eles = append(eles, tec.Accept(v).(sqlstmt.ISelectElement))
}
ses.Elements = eles
}
if len(ses.Elements) == 1 && ses.Elements[0].GetText() == "*" {
ses.Star = "*"
}
return ses
}
// Visit a parse tree produced by PostgreSQLParser#target_label.
func (v *PgsqlVisitor) VisitTarget_label(ctx *pgparser.Target_labelContext) interface{} {
sce := new(sqlstmt.SelectColumnElement)
sce.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
columnName := new(sqlstmt.ColumnName)
if c := ctx.Collabel(); c != nil {
sce.Alias = c.GetText()
}
if c := ctx.Identifier(); c != nil {
sce.Alias = c.GetText()
}
if exprCtx := ctx.A_expr(); exprCtx != nil {
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), exprCtx)
if aextrCtx := exprCtx.A_expr_qual(); aextrCtx != nil {
col := aextrCtx.GetText()
ownerAndColname := strings.Split(col, ".")
if len(ownerAndColname) == 2 {
columnName.Owner = ownerAndColname[0]
columnName.Identifier = sqlstmt.NewIdentifierValue(ownerAndColname[1])
} else {
columnName.Identifier = sqlstmt.NewIdentifierValue(col)
}
}
} else {
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
}
sce.ColumnName = columnName
return sce
}
// Visit a parse tree produced by PostgreSQLParser#target_star.
func (v *PgsqlVisitor) VisitTarget_star(ctx *pgparser.Target_starContext) interface{} {
sse := new(sqlstmt.SelectStarElement)
sse.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
sse.FullId = ctx.STAR().GetText()
return sse
}
func (v *PgsqlVisitor) VisitAlias_clause(ctx *pgparser.Alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitOpt_alias_clause(ctx *pgparser.Opt_alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitTable_alias_clause(ctx *pgparser.Table_alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitFunc_alias_clause(ctx *pgparser.Func_alias_clauseContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitRelation_expr(ctx *pgparser.Relation_exprContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitRelation_expr_list(ctx *pgparser.Relation_expr_listContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitUpdatestmt(ctx *pgparser.UpdatestmtContext) interface{} {
updateStmt := new(sqlstmt.UpdateStmt)
updateStmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
updateStmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
updateStmt.UpdatedElements = ctx.Set_clause_list().Accept(v).([]*sqlstmt.UpdatedElement)
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
updateStmt.Where = ec.Accept(v).(sqlstmt.IExpr)
}
return updateStmt
}
func (v *PgsqlVisitor) VisitSet_clause_list(ctx *pgparser.Set_clause_listContext) interface{} {
ues := make([]*sqlstmt.UpdatedElement, 0)
aucs := ctx.AllSet_clause()
for _, auc := range aucs {
ues = append(ues, auc.Accept(v).(*sqlstmt.UpdatedElement))
}
return ues
}
func (v *PgsqlVisitor) VisitSet_clause(ctx *pgparser.Set_clauseContext) interface{} {
updateEle := new(sqlstmt.UpdatedElement)
updateEle.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
updateEle.ColumnName = ctx.Set_target().Accept(v).(*sqlstmt.ColumnName)
if ac := ctx.A_expr(); ac != nil {
updateEle.Value = ac.Accept(v).(sqlstmt.IExpr)
}
return updateEle
}
func (v *PgsqlVisitor) VisitSet_target(ctx *pgparser.Set_targetContext) interface{} {
columnName := new(sqlstmt.ColumnName)
columnName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if ic := ctx.Opt_indirection(); ic != nil {
if ic.GetText() == "" {
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
} else {
columnName.Owner = ctx.Colid().GetText()
columnName.Identifier = sqlstmt.NewIdentifierValue(ic.GetText())
}
} else {
columnName.Identifier = sqlstmt.NewIdentifierValue(ctx.Colid().GetText())
}
return columnName
}
func (v *PgsqlVisitor) VisitSet_target_list(ctx *pgparser.Set_target_listContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) VisitA_expr(ctx *pgparser.A_exprContext) interface{} {
expr := new(sqlstmt.Expr)
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return expr
}
func (v *PgsqlVisitor) VisitA_expr_qual(ctx *pgparser.A_expr_qualContext) interface{} {
expr := new(sqlstmt.Expr)
expr.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
return expr
}
func (v *PgsqlVisitor) VisitDeletestmt(ctx *pgparser.DeletestmtContext) interface{} {
deletestmt := new(sqlstmt.DeleteStmt)
deletestmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
deletestmt.TableSources = v.GetTableSourcesByrelation_expr_opt_alias(ctx.Relation_expr_opt_alias())
if ec := ctx.Where_or_current_clause().A_expr(); ec != nil {
deletestmt.Where = ec.Accept(v).(sqlstmt.IExpr)
}
return deletestmt
}
func (v *PgsqlVisitor) VisitInsertstmt(ctx *pgparser.InsertstmtContext) interface{} {
insertstmt := new(sqlstmt.InsertStmt)
insertstmt.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
insertstmt.TableName = ctx.Insert_target().Accept(v).(*sqlstmt.TableName)
return insertstmt
}
func (v *PgsqlVisitor) VisitInsert_target(ctx *pgparser.Insert_targetContext) interface{} {
tableName := new(sqlstmt.TableName)
tableName.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
table := ctx.GetText()
if strings.Contains(table, ".") {
tableAndOwner := strings.Split(table, ".")
tableName.Identifier = sqlstmt.NewIdentifierValue(tableAndOwner[1])
tableName.Owner = tableAndOwner[0]
} else {
tableName.Identifier = sqlstmt.NewIdentifierValue(table)
}
return tableName
}
func (v *PgsqlVisitor) VisitInsert_rest(ctx *pgparser.Insert_restContext) interface{} {
return v.VisitChildren(ctx)
}
func (v *PgsqlVisitor) GetTableSourcesByrelation_expr_opt_alias(ctx pgparser.IRelation_expr_opt_aliasContext) *sqlstmt.TableSources {
tableSources := new(sqlstmt.TableSources)
tableSources.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
atomTable := new(sqlstmt.AtomTableItem)
atomTable.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
if c := ctx.Relation_expr(); c != nil {
tableName := new(sqlstmt.TableName)
if qn := c.Qualified_name(); qn != nil {
if qc := qn.Colid(); qc != nil {
if c := qn.Indirection(); c != nil {
tableName.Owner = qc.GetText()
tableName.Identifier = sqlstmt.NewIdentifierValue(c.GetText())
} else {
tableName.Identifier = sqlstmt.NewIdentifierValue(qc.Identifier().GetText())
}
}
}
atomTable.TableName = tableName
}
if c := ctx.Colid(); c != nil {
atomTable.Alias = c.GetText()
}
tableSourceBase := new(sqlstmt.TableSourceBase)
tableSourceBase.Node = sqlstmt.NewNode(ctx.GetParser(), ctx)
tableSourceBase.TableSourceItem = atomTable
tableSources.TableSources = []sqlstmt.ITableSource{tableSourceBase}
return tableSources
}