mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-01-01 04:06:37 +08:00
feat: 数据库sql执行支持取消执行操作
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -24,6 +25,7 @@ import (
|
||||
"mayfly-go/pkg/ws"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -80,6 +82,11 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
/** 数据库操作相关、执行sql等 ***/
|
||||
|
||||
// 取消执行sql函数map; key -> execId ; value -> cancelFunc
|
||||
var cancelExecSqlMap = sync.Map{}
|
||||
|
||||
func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
form := &form.DbSqlExecForm{}
|
||||
@@ -95,15 +102,23 @@ func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
// 去除前后空格及换行符
|
||||
sql := stringx.TrimSpaceAndBr(string(sqlBytes))
|
||||
|
||||
rc.ReqParam = fmt.Sprintf("%s\n-> %s", dbConn.Info.GetLogDesc(), sql)
|
||||
rc.ReqParam = fmt.Sprintf("%s %s\n-> %s", dbConn.Info.GetLogDesc(), form.ExecId, sql)
|
||||
biz.NotEmpty(form.Sql, "sql不能为空")
|
||||
|
||||
execReq := &application.DbSqlExecReq{
|
||||
DbId: dbId,
|
||||
Db: form.Db,
|
||||
Remark: form.Remark,
|
||||
DbConn: dbConn,
|
||||
LoginAccount: rc.GetLoginAccount(),
|
||||
DbId: dbId,
|
||||
Db: form.Db,
|
||||
Remark: form.Remark,
|
||||
DbConn: dbConn,
|
||||
}
|
||||
|
||||
ctx := rc.MetaCtx
|
||||
// 如果存在执行id,则保存取消函数,用于后续可能的取消操作
|
||||
if form.ExecId != "" {
|
||||
cancelCtx, cancel := context.WithCancel(rc.MetaCtx)
|
||||
ctx = cancelCtx
|
||||
cancelExecSqlMap.Store(form.ExecId, cancel)
|
||||
defer cancelExecSqlMap.Delete(form.ExecId)
|
||||
}
|
||||
|
||||
sqls, err := sqlparser.SplitStatementToPieces(sql, sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
|
||||
@@ -119,7 +134,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
execReq.Sql = s
|
||||
execRes, err := d.DbSqlExecApp.Exec(execReq)
|
||||
execRes, err := d.DbSqlExecApp.Exec(ctx, execReq)
|
||||
biz.ErrIsNilAppendErr(err, fmt.Sprintf("[%s] -> 执行失败: ", s)+"%s")
|
||||
|
||||
if execResAll == nil {
|
||||
@@ -135,6 +150,14 @@ func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
rc.ResData = colAndRes
|
||||
}
|
||||
|
||||
func (d *Db) CancelExecSql(rc *req.Ctx) {
|
||||
execId := ginx.PathParam(rc.GinCtx, "execId")
|
||||
if cancelFunc, ok := cancelExecSqlMap.LoadAndDelete(execId); ok {
|
||||
rc.ReqParam = execId
|
||||
cancelFunc.(context.CancelFunc)()
|
||||
}
|
||||
}
|
||||
|
||||
// progressCategory sql文件执行进度消息类型
|
||||
const progressCategory = "execSqlFileProgress"
|
||||
|
||||
@@ -175,11 +198,10 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
}()
|
||||
|
||||
execReq := &application.DbSqlExecReq{
|
||||
DbId: dbId,
|
||||
Db: dbName,
|
||||
Remark: filename,
|
||||
DbConn: dbConn,
|
||||
LoginAccount: rc.GetLoginAccount(),
|
||||
DbId: dbId,
|
||||
Db: dbName,
|
||||
Remark: filename,
|
||||
DbConn: dbConn,
|
||||
}
|
||||
|
||||
var sql string
|
||||
@@ -237,7 +259,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
const maxRecordStatements = 64
|
||||
if executedStatements < maxRecordStatements {
|
||||
execReq.Sql = sql
|
||||
_, err = d.DbSqlExecApp.Exec(execReq)
|
||||
_, err = d.DbSqlExecApp.Exec(rc.MetaCtx, execReq)
|
||||
} else {
|
||||
_, err = dbConn.Exec(sql)
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ type DbSqlSaveForm struct {
|
||||
|
||||
// 数据库SQL执行表单
|
||||
type DbSqlExecForm struct {
|
||||
ExecId string `json:"execId"` // 执行id(用于取消执行使用)
|
||||
Db string `binding:"required" json:"db"` //数据库名
|
||||
Sql string `binding:"required" json:"sql"` // 执行sql
|
||||
Remark string `json:"remark"` // 执行备注
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/contextx"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/model"
|
||||
"strconv"
|
||||
@@ -17,12 +18,11 @@ import (
|
||||
)
|
||||
|
||||
type DbSqlExecReq struct {
|
||||
DbId uint64
|
||||
Db string
|
||||
Sql string
|
||||
Remark string
|
||||
LoginAccount *model.LoginAccount
|
||||
DbConn *dbm.DbConn
|
||||
DbId uint64
|
||||
Db string
|
||||
Sql string
|
||||
Remark string
|
||||
DbConn *dbm.DbConn
|
||||
}
|
||||
|
||||
type DbSqlExecRes struct {
|
||||
@@ -47,7 +47,7 @@ func (d *DbSqlExecRes) Merge(execRes *DbSqlExecRes) {
|
||||
|
||||
type DbSqlExec interface {
|
||||
// 执行sql
|
||||
Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
|
||||
Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
|
||||
|
||||
// 根据条件删除sql执行记录
|
||||
DeleteBy(ctx context.Context, condition *entity.DbSqlExec)
|
||||
@@ -66,19 +66,19 @@ type dbSqlExecAppImpl struct {
|
||||
dbSqlExecRepo repository.DbSqlExec
|
||||
}
|
||||
|
||||
func createSqlExecRecord(execSqlReq *DbSqlExecReq) *entity.DbSqlExec {
|
||||
func createSqlExecRecord(ctx context.Context, execSqlReq *DbSqlExecReq) *entity.DbSqlExec {
|
||||
dbSqlExecRecord := new(entity.DbSqlExec)
|
||||
dbSqlExecRecord.DbId = execSqlReq.DbId
|
||||
dbSqlExecRecord.Db = execSqlReq.Db
|
||||
dbSqlExecRecord.Sql = execSqlReq.Sql
|
||||
dbSqlExecRecord.Remark = execSqlReq.Remark
|
||||
dbSqlExecRecord.SetBaseInfo(execSqlReq.LoginAccount)
|
||||
dbSqlExecRecord.SetBaseInfo(contextx.GetLoginAccount(ctx))
|
||||
return dbSqlExecRecord
|
||||
}
|
||||
|
||||
func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
|
||||
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
|
||||
sql := execSqlReq.Sql
|
||||
dbSqlExecRecord := createSqlExecRecord(execSqlReq)
|
||||
dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq)
|
||||
dbSqlExecRecord.Type = entity.DbSqlExecTypeOther
|
||||
var execRes *DbSqlExecRes
|
||||
isSelect := false
|
||||
@@ -100,9 +100,9 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
|
||||
}
|
||||
var execErr error
|
||||
if isSelect || strings.HasPrefix(lowerSql, "show") {
|
||||
execRes, execErr = doRead(execSqlReq)
|
||||
execRes, execErr = doRead(ctx, execSqlReq)
|
||||
} else {
|
||||
execRes, execErr = doExec(execSqlReq.Sql, execSqlReq.DbConn)
|
||||
execRes, execErr = doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
|
||||
}
|
||||
if execErr != nil {
|
||||
return nil, execErr
|
||||
@@ -114,21 +114,21 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
|
||||
switch stmt := stmt.(type) {
|
||||
case *sqlparser.Select:
|
||||
isSelect = true
|
||||
execRes, err = doSelect(stmt, execSqlReq)
|
||||
execRes, err = doSelect(ctx, stmt, execSqlReq)
|
||||
case *sqlparser.Show:
|
||||
isSelect = true
|
||||
execRes, err = doRead(execSqlReq)
|
||||
execRes, err = doRead(ctx, execSqlReq)
|
||||
case *sqlparser.OtherRead:
|
||||
isSelect = true
|
||||
execRes, err = doRead(execSqlReq)
|
||||
execRes, err = doRead(ctx, execSqlReq)
|
||||
case *sqlparser.Update:
|
||||
execRes, err = doUpdate(stmt, execSqlReq, dbSqlExecRecord)
|
||||
execRes, err = doUpdate(ctx, stmt, execSqlReq, dbSqlExecRecord)
|
||||
case *sqlparser.Delete:
|
||||
execRes, err = doDelete(stmt, execSqlReq, dbSqlExecRecord)
|
||||
execRes, err = doDelete(ctx, stmt, execSqlReq, dbSqlExecRecord)
|
||||
case *sqlparser.Insert:
|
||||
execRes, err = doInsert(stmt, execSqlReq, dbSqlExecRecord)
|
||||
execRes, err = doInsert(ctx, stmt, execSqlReq, dbSqlExecRecord)
|
||||
default:
|
||||
execRes, err = doExec(execSqlReq.Sql, execSqlReq.DbConn)
|
||||
execRes, err = doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -159,7 +159,7 @@ func (d *dbSqlExecAppImpl) GetPageList(condition *entity.DbSqlExecQuery, pagePar
|
||||
return d.dbSqlExecRepo.GetPageList(condition, pageParam, toEntity, orderBy...)
|
||||
}
|
||||
|
||||
func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
|
||||
func doSelect(ctx context.Context, selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
|
||||
selectExprsStr := sqlparser.String(selectStmt.SelectExprs)
|
||||
if selectExprsStr == "*" || strings.Contains(selectExprsStr, ".*") ||
|
||||
len(strings.Split(selectExprsStr, ",")) > 1 {
|
||||
@@ -182,13 +182,13 @@ func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExe
|
||||
}
|
||||
}
|
||||
|
||||
return doRead(execSqlReq)
|
||||
return doRead(ctx, execSqlReq)
|
||||
}
|
||||
|
||||
func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
|
||||
func doRead(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
|
||||
dbConn := execSqlReq.DbConn
|
||||
sql := execSqlReq.Sql
|
||||
colNames, res, err := dbConn.Query(sql)
|
||||
colNames, res, err := dbConn.QueryContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -198,7 +198,7 @@ func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
|
||||
func doUpdate(ctx context.Context, update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
|
||||
dbConn := execSqlReq.DbConn
|
||||
|
||||
tableStr := sqlparser.String(update.TableExprs)
|
||||
@@ -224,7 +224,7 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
|
||||
updateColumnsAndPrimaryKey := strings.Join(updateColumns, ",") + "," + primaryKey
|
||||
// 查询要更新字段数据的旧值,以及主键值
|
||||
selectSql := fmt.Sprintf("SELECT %s FROM %s %s LIMIT 200", updateColumnsAndPrimaryKey, tableStr, where)
|
||||
_, res, err := dbConn.Query(selectSql)
|
||||
_, res, err := dbConn.QueryContext(ctx, selectSql)
|
||||
if err == nil {
|
||||
bytes, _ := json.Marshal(res)
|
||||
dbSqlExec.OldValue = string(bytes)
|
||||
@@ -235,10 +235,10 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
|
||||
dbSqlExec.Table = tableName
|
||||
dbSqlExec.Type = entity.DbSqlExecTypeUpdate
|
||||
|
||||
return doExec(execSqlReq.Sql, dbConn)
|
||||
return doExec(ctx, execSqlReq.Sql, dbConn)
|
||||
}
|
||||
|
||||
func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
|
||||
func doDelete(ctx context.Context, delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
|
||||
dbConn := execSqlReq.DbConn
|
||||
|
||||
tableStr := sqlparser.String(delete.TableExprs)
|
||||
@@ -251,28 +251,28 @@ func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *ent
|
||||
|
||||
// 查询删除数据
|
||||
selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where)
|
||||
_, res, _ := dbConn.Query(selectSql)
|
||||
_, res, _ := dbConn.QueryContext(ctx, selectSql)
|
||||
|
||||
bytes, _ := json.Marshal(res)
|
||||
dbSqlExec.OldValue = string(bytes)
|
||||
dbSqlExec.Table = table
|
||||
dbSqlExec.Type = entity.DbSqlExecTypeDelete
|
||||
|
||||
return doExec(execSqlReq.Sql, dbConn)
|
||||
return doExec(ctx, execSqlReq.Sql, dbConn)
|
||||
}
|
||||
|
||||
func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
|
||||
func doInsert(ctx context.Context, insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
|
||||
tableStr := sqlparser.String(insert.Table)
|
||||
// 可能使用别名,故空格切割
|
||||
table := strings.Split(tableStr, " ")[0]
|
||||
dbSqlExec.Table = table
|
||||
dbSqlExec.Type = entity.DbSqlExecTypeInsert
|
||||
|
||||
return doExec(execSqlReq.Sql, execSqlReq.DbConn)
|
||||
return doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
|
||||
}
|
||||
|
||||
func doExec(sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
|
||||
rowsAffected, err := dbConn.Exec(sql)
|
||||
func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
|
||||
rowsAffected, err := dbConn.ExecContext(ctx, sql)
|
||||
execRes := "success"
|
||||
if err != nil {
|
||||
execRes = err.Error()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"reflect"
|
||||
"strconv"
|
||||
@@ -20,11 +22,20 @@ type DbConn struct {
|
||||
// 执行查询语句
|
||||
// 依次返回 列名数组(顺序),结果map,错误
|
||||
func (d *DbConn) Query(querySql string) ([]string, []map[string]any, error) {
|
||||
return d.QueryContext(context.Background(), querySql)
|
||||
}
|
||||
|
||||
// 执行查询语句
|
||||
// 依次返回 列名数组(顺序),结果map,错误
|
||||
func (d *DbConn) QueryContext(ctx context.Context, querySql string) ([]string, []map[string]any, error) {
|
||||
result := make([]map[string]any, 0, 16)
|
||||
columns, err := walkTableRecord(d.db, querySql, func(record map[string]any, columns []string) {
|
||||
columns, err := walkTableRecord(ctx, d.db, querySql, func(record map[string]any, columns []string) {
|
||||
result = append(result, record)
|
||||
})
|
||||
if err != nil {
|
||||
if err == context.Canceled {
|
||||
return nil, nil, errorx.NewBiz("取消执行")
|
||||
}
|
||||
return nil, nil, err
|
||||
}
|
||||
return columns, result, nil
|
||||
@@ -47,16 +58,25 @@ func (d *DbConn) Query2Struct(execSql string, dest any) error {
|
||||
}
|
||||
|
||||
// WalkTableRecord 遍历表记录
|
||||
func (d *DbConn) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error {
|
||||
_, err := walkTableRecord(d.db, selectSql, walk)
|
||||
func (d *DbConn) WalkTableRecord(ctx context.Context, selectSql string, walk func(record map[string]any, columns []string)) error {
|
||||
_, err := walkTableRecord(ctx, d.db, selectSql, walk)
|
||||
return err
|
||||
}
|
||||
|
||||
// 执行 update, insert, delete,建表等sql
|
||||
// 返回影响条数和错误
|
||||
func (d *DbConn) Exec(sql string) (int64, error) {
|
||||
res, err := d.db.Exec(sql)
|
||||
return d.ExecContext(context.Background(), sql)
|
||||
}
|
||||
|
||||
// 执行 update, insert, delete,建表等sql
|
||||
// 返回影响条数和错误
|
||||
func (d *DbConn) ExecContext(ctx context.Context, sql string) (int64, error) {
|
||||
res, err := d.db.ExecContext(ctx, sql)
|
||||
if err != nil {
|
||||
if err == context.Canceled {
|
||||
return 0, errorx.NewBiz("取消执行")
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
@@ -84,8 +104,9 @@ func (d *DbConn) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) ([]string, error) {
|
||||
rows, err := db.Query(selectSql)
|
||||
func walkTableRecord(ctx context.Context, db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) ([]string, error) {
|
||||
rows, err := db.QueryContext(ctx, selectSql)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -179,5 +179,5 @@ func (md *MysqlDialect) GetTableRecord(tableName string, pageNum, pageSize int)
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
|
||||
return md.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
return md.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
@@ -249,7 +250,7 @@ func (pd *PgsqlDialect) GetTableRecord(tableName string, pageNum, pageSize int)
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
|
||||
return pd.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
return pd.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
}
|
||||
|
||||
// 获取pgsql当前连接的库可访问的schemaNames
|
||||
|
||||
@@ -35,6 +35,8 @@ func InitDbRouter(router *gin.RouterGroup) {
|
||||
|
||||
req.NewPost(":dbId/exec-sql", d.ExecSql).Log(req.NewLog("db-执行Sql")),
|
||||
|
||||
req.NewPost(":dbId/exec-sql/cancel/:execId", d.CancelExecSql).Log(req.NewLog("db-取消执行Sql")),
|
||||
|
||||
req.NewPost(":dbId/exec-sql-file", d.ExecSqlFile).Log(req.NewLogSave("db-执行Sql文件")),
|
||||
|
||||
req.NewGet(":dbId/dump", d.DumpSql).Log(req.NewLogSave("db-导出sql文件")).NoRes(),
|
||||
|
||||
Reference in New Issue
Block a user