mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-05-18 00:45:37 +08:00
refactor: 移除antlr4减小包体积&ai助手优化
This commit is contained in:
@@ -6,7 +6,6 @@ import (
|
||||
"mayfly-go/internal/db/application/dto"
|
||||
"mayfly-go/internal/db/config"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
@@ -86,7 +85,6 @@ func createSqlExecRecord(ctx context.Context, execSqlReq *dto.DbSqlExecReq, sql
|
||||
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *dto.DbSqlExecReq) ([]*dto.DbSqlExecRes, error) {
|
||||
dbConn := execSqlReq.DbConn
|
||||
execSql := execSqlReq.Sql
|
||||
sp := dbConn.GetDialect().GetSQLParser()
|
||||
|
||||
var flowProcdef *flowentity.Procdef
|
||||
if execSqlReq.CheckFlow {
|
||||
@@ -95,101 +93,84 @@ func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *dto.DbSqlExecRe
|
||||
|
||||
allExecRes := make([]*dto.DbSqlExecRes, 0)
|
||||
|
||||
stmts, err := sp.Parse(execSql)
|
||||
// sql解析失败,则使用默认方式切割
|
||||
// 先使用方言切割器切割 SQL
|
||||
splitter := dbConn.GetDialect().GetSQLSplitter()
|
||||
var sqlList []string
|
||||
err := splitter.SplitSQL(strings.NewReader(execSql), func(oneSql string) error {
|
||||
sqlList = append(sqlList, oneSql)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
sqlparser.SQLSplit(strings.NewReader(execSql), ';', func(oneSql string) error {
|
||||
var execRes *dto.DbSqlExecRes
|
||||
var err error
|
||||
|
||||
dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq, oneSql)
|
||||
dbSqlExecRecord.Type = entity.DbSqlExecTypeOther
|
||||
sqlExec := &sqlExecParam{DbConn: dbConn, Sql: oneSql, Procdef: flowProcdef, SqlExecRecord: dbSqlExecRecord}
|
||||
|
||||
if isSelect(oneSql) {
|
||||
execRes, err = d.doSelect(ctx, sqlExec)
|
||||
} else if isUpdate(oneSql) {
|
||||
execRes, err = d.doUpdate(ctx, sqlExec)
|
||||
} else if isDelete(oneSql) {
|
||||
execRes, err = d.doDelete(ctx, sqlExec)
|
||||
} else if isInsert(oneSql) {
|
||||
execRes, err = d.doInsert(ctx, sqlExec)
|
||||
} else if isOtherQuery(oneSql) {
|
||||
execRes, err = d.doOtherRead(ctx, sqlExec)
|
||||
} else if isDDL(oneSql) {
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
} else {
|
||||
execRes, err = d.doExec(ctx, dbConn, oneSql)
|
||||
}
|
||||
// 执行错误
|
||||
if err != nil {
|
||||
if execRes == nil {
|
||||
execRes = &dto.DbSqlExecRes{Sql: oneSql}
|
||||
}
|
||||
execRes.ErrorMsg = err.Error()
|
||||
} else {
|
||||
d.saveSqlExecLog(ctx, dbSqlExecRecord, dbSqlExecRecord.Res)
|
||||
}
|
||||
allExecRes = append(allExecRes, execRes)
|
||||
return nil
|
||||
})
|
||||
return allExecRes, nil
|
||||
return nil, fmt.Errorf("SQL 切割失败: %v", err)
|
||||
}
|
||||
|
||||
// mysql parser with语句会分解析为两条,故需要特殊处理
|
||||
currentWithSql := ""
|
||||
for _, stmt := range stmts {
|
||||
// 获取解析器
|
||||
sp := dbConn.GetDialect().GetSQLParser()
|
||||
|
||||
// 逐条解析并执行
|
||||
for _, sql := range sqlList {
|
||||
var execRes *dto.DbSqlExecRes
|
||||
var err error
|
||||
|
||||
sql := stmt.GetText()
|
||||
stmt, parseErr := sp.Parse(sql)
|
||||
dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq, sql)
|
||||
dbSqlExecRecord.Type = entity.DbSqlExecTypeOther
|
||||
sqlExec := &sqlExecParam{DbConn: dbConn, Sql: currentWithSql + sql, Procdef: flowProcdef, Stmt: stmt, SqlExecRecord: dbSqlExecRecord}
|
||||
currentWithSql = ""
|
||||
|
||||
switch stmt.(type) {
|
||||
case *sqlstmt.SimpleSelectStmt:
|
||||
execRes, err = d.doSelect(ctx, sqlExec)
|
||||
case *sqlstmt.UnionSelectStmt:
|
||||
execRes, err = d.doSelect(ctx, sqlExec)
|
||||
case *sqlstmt.OtherReadStmt:
|
||||
execRes, err = d.doOtherRead(ctx, sqlExec)
|
||||
case *sqlstmt.WithStmt:
|
||||
currentWithSql = sql
|
||||
case *sqlstmt.UpdateStmt:
|
||||
execRes, err = d.doUpdate(ctx, sqlExec)
|
||||
case *sqlstmt.DeleteStmt:
|
||||
execRes, err = d.doDelete(ctx, sqlExec)
|
||||
case *sqlstmt.InsertStmt:
|
||||
execRes, err = d.doInsert(ctx, sqlExec)
|
||||
case *sqlstmt.DdlStmt:
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
case *sqlstmt.CreateDatabase:
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
case *sqlstmt.CreateTable:
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
case *sqlstmt.CreateIndex:
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
case *sqlstmt.AlterDatabase:
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
case *sqlstmt.AlterTable:
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
default:
|
||||
execRes, err = d.doExec(ctx, dbConn, sql)
|
||||
sqlExec := &sqlExecParam{
|
||||
DbConn: dbConn,
|
||||
Sql: sql,
|
||||
Stmt: stmt,
|
||||
Procdef: flowProcdef,
|
||||
SqlExecRecord: dbSqlExecRecord,
|
||||
}
|
||||
|
||||
if currentWithSql != "" {
|
||||
continue
|
||||
// 优先使用 Stmt 类型判断,解析失败时使用字符串匹配兜底
|
||||
if parseErr != nil || stmt == nil {
|
||||
// 解析失败,使用字符串匹配兜底
|
||||
if isSelect(sql) {
|
||||
execRes, err = d.doSelect(ctx, sqlExec)
|
||||
} else if isUpdate(sql) {
|
||||
execRes, err = d.doUpdate(ctx, sqlExec)
|
||||
} else if isDelete(sql) {
|
||||
execRes, err = d.doDelete(ctx, sqlExec)
|
||||
} else if isInsert(sql) {
|
||||
execRes, err = d.doInsert(ctx, sqlExec)
|
||||
} else if isOtherQuery(sql) {
|
||||
execRes, err = d.doOtherRead(ctx, sqlExec)
|
||||
} else if isDDL(sql) {
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
} else {
|
||||
execRes, err = d.doExec(ctx, dbConn, sql)
|
||||
}
|
||||
} else {
|
||||
// 解析成功,使用 Stmt 类型判断
|
||||
switch stmt.(type) {
|
||||
case *sqlstmt.WithStmt:
|
||||
execRes, err = d.doSelect(ctx, sqlExec)
|
||||
case *sqlstmt.SelectStmt:
|
||||
execRes, err = d.doSelect(ctx, sqlExec)
|
||||
case *sqlstmt.UpdateStmt:
|
||||
execRes, err = d.doUpdate(ctx, sqlExec)
|
||||
case *sqlstmt.DeleteStmt:
|
||||
execRes, err = d.doDelete(ctx, sqlExec)
|
||||
case *sqlstmt.InsertStmt:
|
||||
execRes, err = d.doInsert(ctx, sqlExec)
|
||||
case *sqlstmt.DdlStmt:
|
||||
execRes, err = d.doExecDDL(ctx, sqlExec)
|
||||
case *sqlstmt.OtherStmt:
|
||||
execRes, err = d.doOtherRead(ctx, sqlExec)
|
||||
default:
|
||||
execRes, err = d.doExec(ctx, dbConn, sql)
|
||||
}
|
||||
}
|
||||
|
||||
// 执行错误
|
||||
if err != nil {
|
||||
if execRes == nil {
|
||||
execRes = &dto.DbSqlExecRes{Sql: sqlExec.Sql}
|
||||
execRes = &dto.DbSqlExecRes{Sql: sql}
|
||||
}
|
||||
execRes.ErrorMsg = err.Error()
|
||||
} else {
|
||||
d.saveSqlExecLog(ctx, dbSqlExecRecord, execRes.Res)
|
||||
d.saveSqlExecLog(ctx, dbSqlExecRecord, dbSqlExecRecord.Res)
|
||||
}
|
||||
allExecRes = append(allExecRes, execRes)
|
||||
}
|
||||
@@ -249,7 +230,9 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe
|
||||
}()
|
||||
|
||||
tx, _ := dbConn.Begin()
|
||||
err := sqlparser.SQLSplit(execReader.Reader, ';', func(sql string) error {
|
||||
// 使用方言切割器进行 SQL 切割
|
||||
splitter := dbConn.GetDialect().GetSQLSplitter()
|
||||
err := splitter.SplitSQL(execReader.Reader, func(sql string) error {
|
||||
if executedStatements%50 == 0 {
|
||||
if needSendMsg {
|
||||
progressMsgEvent.Params["executedStatements"] = executedStatements
|
||||
@@ -416,21 +399,14 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
tableSources := updatestmt.TableSources.TableSources
|
||||
// 不支持多表更新记录旧值
|
||||
if len(tableSources) != 1 {
|
||||
if len(updatestmt.Tables) != 1 {
|
||||
logx.ErrorContext(ctx, "update SQL - logging old values only supports single-table updates")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
tableName := ""
|
||||
tableAlias := ""
|
||||
if tableSourceBase, ok := tableSources[0].(*sqlstmt.TableSourceBase); ok {
|
||||
if atmoTableItem, ok := tableSourceBase.TableSourceItem.(*sqlstmt.AtomTableItem); ok {
|
||||
tableName = atmoTableItem.TableName.Identifier.Value
|
||||
tableAlias = atmoTableItem.Alias
|
||||
}
|
||||
}
|
||||
tableName := updatestmt.Tables[0].Name
|
||||
tableAlias := updatestmt.Tables[0].Alias
|
||||
|
||||
if tableName == "" {
|
||||
logx.ErrorContext(ctx, "update SQL - failed to get table name")
|
||||
@@ -442,7 +418,7 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
logx.ErrorContext(ctx, "update SQL - there is no where condition")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
whereStr := updatestmt.Where.GetText()
|
||||
whereStr := updatestmt.Where.Text
|
||||
|
||||
// 获取表主键列名,排除使用别名
|
||||
primaryKey, err := dbConn.GetMetadata().GetPrimaryKey(tableName)
|
||||
@@ -451,8 +427,8 @@ func (d *dbSqlExecAppImpl) doUpdate(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
updateColumns := collx.ArrayMap[*sqlstmt.UpdatedElement, string](updatestmt.UpdatedElements, func(ue *sqlstmt.UpdatedElement) string {
|
||||
return ue.ColumnName.GetText()
|
||||
updateColumns := collx.ArrayMap[sqlstmt.Assignment, string](updatestmt.Set, func(a sqlstmt.Assignment) string {
|
||||
return a.Column
|
||||
})
|
||||
|
||||
primaryKeyColumn := primaryKey
|
||||
@@ -505,21 +481,14 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
tableSources := deletestmt.TableSources.TableSources
|
||||
// 不支持多表删除记录旧值
|
||||
if len(tableSources) != 1 {
|
||||
if len(deletestmt.Tables) != 1 {
|
||||
logx.ErrorContext(ctx, "delete SQL - logging old values only supports single-table deletion")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
tableName := ""
|
||||
tableAlias := ""
|
||||
if tableSourceBase, ok := tableSources[0].(*sqlstmt.TableSourceBase); ok {
|
||||
if atmoTableItem, ok := tableSourceBase.TableSourceItem.(*sqlstmt.AtomTableItem); ok {
|
||||
tableName = atmoTableItem.TableName.Identifier.Value
|
||||
tableAlias = atmoTableItem.Alias
|
||||
}
|
||||
}
|
||||
tableName := deletestmt.Tables[0].Name
|
||||
tableAlias := deletestmt.Tables[0].Alias
|
||||
|
||||
if tableName == "" {
|
||||
logx.ErrorContext(ctx, "delete SQL - failed to get table name")
|
||||
@@ -527,13 +496,12 @@ func (d *dbSqlExecAppImpl) doDelete(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
}
|
||||
execRecord.Table = tableName
|
||||
|
||||
deleteWhere := deletestmt.Where
|
||||
if deleteWhere == nil {
|
||||
if deletestmt.Where == nil {
|
||||
logx.ErrorContext(ctx, "delete SQL - there is no where condition")
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
whereStr := deleteWhere.GetText()
|
||||
whereStr := deletestmt.Where.Text
|
||||
// 查询删除数据
|
||||
selectSql := fmt.Sprintf("SELECT * FROM %s where %s LIMIT 200", tableName+" "+tableAlias, whereStr)
|
||||
_, res, _ := dbConn.QueryContext(ctx, selectSql)
|
||||
@@ -563,7 +531,7 @@ func (d *dbSqlExecAppImpl) doInsert(ctx context.Context, sqlExecParam *sqlExecPa
|
||||
return d.doExec(ctx, dbConn, sqlExecParam.Sql)
|
||||
}
|
||||
|
||||
execRecord.Table = insertstmt.TableName.Identifier.Value
|
||||
execRecord.Table = insertstmt.Table.Name
|
||||
|
||||
return d.doExec(ctx, sqlExecParam.DbConn, sqlExecParam.Sql)
|
||||
}
|
||||
@@ -603,7 +571,7 @@ func (d *dbSqlExecAppImpl) doExec(ctx context.Context, dbConn *dbi.DbConn, sql s
|
||||
|
||||
return &dto.DbSqlExecRes{
|
||||
Columns: []*dbi.QueryColumn{
|
||||
{Name: "rowsAffected", Key:"rowsAffected", Type: "number"},
|
||||
{Name: "rowsAffected", Key: "rowsAffected", Type: "number"},
|
||||
},
|
||||
Res: res,
|
||||
Sql: sql,
|
||||
|
||||
Reference in New Issue
Block a user