refactor: 移除antlr4减小包体积&ai助手优化

This commit is contained in:
meilin.huang
2026-05-08 20:45:13 +08:00
parent 3768cef62d
commit f23b243fc5
154 changed files with 13054 additions and 396804 deletions

View File

@@ -6,7 +6,6 @@ import (
"mayfly-go/internal/db/application/dto"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/sqlparser"
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
@@ -86,7 +85,6 @@ func createSqlExecRecord(ctx context.Context, execSqlReq *dto.DbSqlExecReq, sql
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *dto.DbSqlExecReq) ([]*dto.DbSqlExecRes, error) {
dbConn := execSqlReq.DbConn
execSql := execSqlReq.Sql
sp := dbConn.GetDialect().GetSQLParser()
var flowProcdef *flowentity.Procdef
if execSqlReq.CheckFlow {
@@ -95,101 +93,84 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *dto.DbSqlExecRe
allExecRes := make([]*dto.DbSqlExecRes, 0)
stmts, err := sp.Parse(execSql)
// sql解析失败则使用默认方式切割
// 先使用方言切割器切割 SQL
splitter := dbConn.GetDialect().GetSQLSplitter()
var sqlList []string
err := splitter.SplitSQL(strings.NewReader(execSql), func(oneSql string) error {
sqlList = append(sqlList, oneSql)
return nil
})
if err != nil {
sqlparser.SQLSplit(strings.NewReader(execSql), ';', func(oneSql string) error {
var execRes *dto.DbSqlExecRes
var err error
dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq, oneSql)
dbSqlExecRecord.Type = entity.DbSqlExecTypeOther
sqlExec := &sqlExecParam{DbConn: dbConn, Sql: oneSql, Procdef: flowProcdef, SqlExecRecord: dbSqlExecRecord}
if isSelect(oneSql) {
execRes, err = d.doSelect(ctx, sqlExec)
} else if isUpdate(oneSql) {
execRes, err = d.doUpdate(ctx, sqlExec)
} else if isDelete(oneSql) {
execRes, err = d.doDelete(ctx, sqlExec)
} else if isInsert(oneSql) {
execRes, err = d.doInsert(ctx, sqlExec)
} else if isOtherQuery(oneSql) {
execRes, err = d.doOtherRead(ctx, sqlExec)
} else if isDDL(oneSql) {
execRes, err = d.doExecDDL(ctx, sqlExec)
} else {
execRes, err = d.doExec(ctx, dbConn, oneSql)
}
// 执行错误
if err != nil {
if execRes == nil {
execRes = &dto.DbSqlExecRes{Sql: oneSql}
}
execRes.ErrorMsg = err.Error()
} else {
d.saveSqlExecLog(ctx, dbSqlExecRecord, dbSqlExecRecord.Res)
}
allExecRes = append(allExecRes, execRes)
return nil
})
return allExecRes, nil
return nil, fmt.Errorf("SQL 切割失败: %v", err)
}
// mysql parser with语句会分解析为两条故需要特殊处理
currentWithSql := ""
for _, stmt := range stmts {
// 获取解析器
sp := dbConn.GetDialect().GetSQLParser()
// 逐条解析并执行
for _, sql := range sqlList {
var execRes *dto.DbSqlExecRes
var err error
sql := stmt.GetText()
stmt, parseErr := sp.Parse(sql)
dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq, sql)
dbSqlExecRecord.Type = entity.DbSqlExecTypeOther
sqlExec := &sqlExecParam{DbConn: dbConn, Sql: currentWithSql + sql, Procdef: flowProcdef, Stmt: stmt, SqlExecRecord: dbSqlExecRecord}
currentWithSql = ""
switch stmt.(type) {
case *sqlstmt.SimpleSelectStmt:
execRes, err = d.doSelect(ctx, sqlExec)
case *sqlstmt.UnionSelectStmt:
execRes, err = d.doSelect(ctx, sqlExec)
case *sqlstmt.OtherReadStmt:
execRes, err = d.doOtherRead(ctx, sqlExec)
case *sqlstmt.WithStmt:
currentWithSql = sql
case *sqlstmt.UpdateStmt:
execRes, err = d.doUpdate(ctx, sqlExec)
case *sqlstmt.DeleteStmt:
execRes, err = d.doDelete(ctx, sqlExec)
case *sqlstmt.InsertStmt:
execRes, err = d.doInsert(ctx, sqlExec)
case *sqlstmt.DdlStmt:
execRes, err = d.doExecDDL(ctx, sqlExec)
case *sqlstmt.CreateDatabase:
execRes, err = d.doExecDDL(ctx, sqlExec)
case *sqlstmt.CreateTable:
execRes, err = d.doExecDDL(ctx, sqlExec)
case *sqlstmt.CreateIndex:
execRes, err = d.doExecDDL(ctx, sqlExec)
case *sqlstmt.AlterDatabase:
execRes, err = d.doExecDDL(ctx, sqlExec)
case *sqlstmt.AlterTable:
execRes, err = d.doExecDDL(ctx, sqlExec)
default:
execRes, err = d.doExec(ctx, dbConn, sql)
sqlExec := &sqlExecParam{
DbConn: dbConn,
Sql: sql,
Stmt: stmt,
Procdef: flowProcdef,
SqlExecRecord: dbSqlExecRecord,
}
if currentWithSql != "" {
continue
// 优先使用 Stmt 类型判断,解析失败时使用字符串匹配兜底
if parseErr != nil || stmt == nil {
// 解析失败,使用字符串匹配兜底
if isSelect(sql) {
execRes, err = d.doSelect(ctx, sqlExec)
} else if isUpdate(sql) {
execRes, err = d.doUpdate(ctx, sqlExec)
} else if isDelete(sql) {
execRes, err = d.doDelete(ctx, sqlExec)
} else if isInsert(sql) {
execRes, err = d.doInsert(ctx, sqlExec)
} else if isOtherQuery(sql) {
execRes, err = d.doOtherRead(ctx, sqlExec)
} else if isDDL(sql) {
execRes, err = d.doExecDDL(ctx, sqlExec)
} else {
execRes, err = d.doExec(ctx, dbConn, sql)
}
} else {
// 解析成功,使用 Stmt 类型判断
switch stmt.(type) {
case *sqlstmt.WithStmt:
execRes, err = d.doSelect(ctx, sqlExec)
case *sqlstmt.SelectStmt:
execRes, err = d.doSelect(ctx, sqlExec)
case *sqlstmt.UpdateStmt:
execRes, err = d.doUpdate(ctx, sqlExec)
case *sqlstmt.DeleteStmt:
execRes, err = d.doDelete(ctx, sqlExec)
case *sqlstmt.InsertStmt:
execRes, err = d.doInsert(ctx, sqlExec)
case *sqlstmt.DdlStmt:
execRes, err = d.doExecDDL(ctx, sqlExec)
case *sqlstmt.OtherStmt:
execRes, err = d.doOtherRead(ctx, sqlExec)
default:
execRes, err = d.doExec(ctx, dbConn, sql)
}
}
// 执行错误
if err != nil {
if execRes == nil {
execRes = &dto.DbSqlExecRes{Sql: sqlExec.Sql}
execRes = &dto.DbSqlExecRes{Sql: sql}
}
execRes.ErrorMsg = err.Error()
} else {
d.saveSqlExecLog(ctx, dbSqlExecRecord, execRes.Res)
d.saveSqlExecLog(ctx, dbSqlExecRecord, dbSqlExecRecord.Res)
}
allExecRes = append(allExecRes, execRes)
}
@@ -249,7 +230,9 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe
}()
tx, _ := dbConn.Begin()
err := sqlparser.SQLSplit(execReader.Reader, ';', func(sql string) error {
// 使用方言切割器进行 SQL 切割
splitter := dbConn.GetDialect().GetSQLSplitter()
err := splitter.SplitSQL(execReader.Reader, func(sql string) error {
if executedStatements%50 == 0 {
if needSendMsg {
progressMsgEvent.Params["executedStatements"] = executedStatements
@@ -416,21 +399,14 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
tableSources := updatestmt.TableSources.TableSources
// 不支持多表更新记录旧值
if len(tableSources) != 1 {
if len(updatestmt.Tables) != 1 {
logx.ErrorContext(ctx, "update SQL - logging old values only supports single-table updates")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
tableName := ""
tableAlias := ""
if tableSourceBase, ok := tableSources[0].(*sqlstmt.TableSourceBase); ok {
if atmoTableItem, ok := tableSourceBase.TableSourceItem.(*sqlstmt.AtomTableItem); ok {
tableName = atmoTableItem.TableName.Identifier.Value
tableAlias = atmoTableItem.Alias
}
}
tableName := updatestmt.Tables[0].Name
tableAlias := updatestmt.Tables[0].Alias
if tableName == "" {
logx.ErrorContext(ctx, "update SQL - failed to get table name")
@@ -442,7 +418,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
logx.ErrorContext(ctx, "update SQL - there is no where condition")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
whereStr := updatestmt.Where.GetText()
whereStr := updatestmt.Where.Text
// 获取表主键列名,排除使用别名
primaryKey, err := dbConn.GetMetadata().GetPrimaryKey(tableName)
@@ -451,8 +427,8 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
updateColumns := collx.ArrayMap[*sqlstmt.UpdatedElement, string](updatestmt.UpdatedElements, func(ue *sqlstmt.UpdatedElement) string {
return ue.ColumnName.GetText()
updateColumns := collx.ArrayMap[sqlstmt.Assignment, string](updatestmt.Set, func(a sqlstmt.Assignment) string {
return a.Column
})
primaryKeyColumn := primaryKey
@@ -505,21 +481,14 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
tableSources := deletestmt.TableSources.TableSources
// 不支持多表删除记录旧值
if len(tableSources) != 1 {
if len(deletestmt.Tables) != 1 {
logx.ErrorContext(ctx, "delete SQL - logging old values only supports single-table deletion")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
tableName := ""
tableAlias := ""
if tableSourceBase, ok := tableSources[0].(*sqlstmt.TableSourceBase); ok {
if atmoTableItem, ok := tableSourceBase.TableSourceItem.(*sqlstmt.AtomTableItem); ok {
tableName = atmoTableItem.TableName.Identifier.Value
tableAlias = atmoTableItem.Alias
}
}
tableName := deletestmt.Tables[0].Name
tableAlias := deletestmt.Tables[0].Alias
if tableName == "" {
logx.ErrorContext(ctx, "delete SQL - failed to get table name")
@@ -527,13 +496,12 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
}
execRecord.Table = tableName
deleteWhere := deletestmt.Where
if deleteWhere == nil {
if deletestmt.Where == nil {
logx.ErrorContext(ctx, "delete SQL - there is no where condition")
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
whereStr := deleteWhere.GetText()
whereStr := deletestmt.Where.Text
// 查询删除数据
selectSql := fmt.Sprintf("SELECT * FROM %s where %s LIMIT 200", tableName+" "+tableAlias, whereStr)
_, res, _ := dbConn.QueryContext(ctx, selectSql)
@@ -563,7 +531,7 @@ func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecPa
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
}
execRecord.Table = insertstmt.TableName.Identifier.Value
execRecord.Table = insertstmt.Table.Name
return d.doExec(ctx, sqlExecParam.DbConn, sqlExecParam.Sql)
}
@@ -603,7 +571,7 @@ func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql s
return &dto.DbSqlExecRes{
Columns: []*dbi.QueryColumn{
{Name: "rowsAffected", Key:"rowsAffected", Type: "number"},
{Name: "rowsAffected", Key: "rowsAffected", Type: "number"},
},
Res: res,
Sql: sql,