mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 16:30:25 +08:00 
			
		
		
		
	feat: 数据库sql执行支持取消执行操作
This commit is contained in:
		@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
	"mayfly-go/pkg/contextx"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -17,12 +18,11 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DbSqlExecReq struct {
 | 
			
		||||
	DbId         uint64
 | 
			
		||||
	Db           string
 | 
			
		||||
	Sql          string
 | 
			
		||||
	Remark       string
 | 
			
		||||
	LoginAccount *model.LoginAccount
 | 
			
		||||
	DbConn       *dbm.DbConn
 | 
			
		||||
	DbId   uint64
 | 
			
		||||
	Db     string
 | 
			
		||||
	Sql    string
 | 
			
		||||
	Remark string
 | 
			
		||||
	DbConn *dbm.DbConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DbSqlExecRes struct {
 | 
			
		||||
@@ -47,7 +47,7 @@ func (d *DbSqlExecRes) Merge(execRes *DbSqlExecRes) {
 | 
			
		||||
 | 
			
		||||
type DbSqlExec interface {
 | 
			
		||||
	// 执行sql
 | 
			
		||||
	Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
	Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
 | 
			
		||||
	// 根据条件删除sql执行记录
 | 
			
		||||
	DeleteBy(ctx context.Context, condition *entity.DbSqlExec)
 | 
			
		||||
@@ -66,19 +66,19 @@ type dbSqlExecAppImpl struct {
 | 
			
		||||
	dbSqlExecRepo repository.DbSqlExec
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func createSqlExecRecord(execSqlReq *DbSqlExecReq) *entity.DbSqlExec {
 | 
			
		||||
func createSqlExecRecord(ctx context.Context, execSqlReq *DbSqlExecReq) *entity.DbSqlExec {
 | 
			
		||||
	dbSqlExecRecord := new(entity.DbSqlExec)
 | 
			
		||||
	dbSqlExecRecord.DbId = execSqlReq.DbId
 | 
			
		||||
	dbSqlExecRecord.Db = execSqlReq.Db
 | 
			
		||||
	dbSqlExecRecord.Sql = execSqlReq.Sql
 | 
			
		||||
	dbSqlExecRecord.Remark = execSqlReq.Remark
 | 
			
		||||
	dbSqlExecRecord.SetBaseInfo(execSqlReq.LoginAccount)
 | 
			
		||||
	dbSqlExecRecord.SetBaseInfo(contextx.GetLoginAccount(ctx))
 | 
			
		||||
	return dbSqlExecRecord
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	sql := execSqlReq.Sql
 | 
			
		||||
	dbSqlExecRecord := createSqlExecRecord(execSqlReq)
 | 
			
		||||
	dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq)
 | 
			
		||||
	dbSqlExecRecord.Type = entity.DbSqlExecTypeOther
 | 
			
		||||
	var execRes *DbSqlExecRes
 | 
			
		||||
	isSelect := false
 | 
			
		||||
@@ -100,9 +100,9 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
		}
 | 
			
		||||
		var execErr error
 | 
			
		||||
		if isSelect || strings.HasPrefix(lowerSql, "show") {
 | 
			
		||||
			execRes, execErr = doRead(execSqlReq)
 | 
			
		||||
			execRes, execErr = doRead(ctx, execSqlReq)
 | 
			
		||||
		} else {
 | 
			
		||||
			execRes, execErr = doExec(execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
			execRes, execErr = doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
		}
 | 
			
		||||
		if execErr != nil {
 | 
			
		||||
			return nil, execErr
 | 
			
		||||
@@ -114,21 +114,21 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
	switch stmt := stmt.(type) {
 | 
			
		||||
	case *sqlparser.Select:
 | 
			
		||||
		isSelect = true
 | 
			
		||||
		execRes, err = doSelect(stmt, execSqlReq)
 | 
			
		||||
		execRes, err = doSelect(ctx, stmt, execSqlReq)
 | 
			
		||||
	case *sqlparser.Show:
 | 
			
		||||
		isSelect = true
 | 
			
		||||
		execRes, err = doRead(execSqlReq)
 | 
			
		||||
		execRes, err = doRead(ctx, execSqlReq)
 | 
			
		||||
	case *sqlparser.OtherRead:
 | 
			
		||||
		isSelect = true
 | 
			
		||||
		execRes, err = doRead(execSqlReq)
 | 
			
		||||
		execRes, err = doRead(ctx, execSqlReq)
 | 
			
		||||
	case *sqlparser.Update:
 | 
			
		||||
		execRes, err = doUpdate(stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
		execRes, err = doUpdate(ctx, stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
	case *sqlparser.Delete:
 | 
			
		||||
		execRes, err = doDelete(stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
		execRes, err = doDelete(ctx, stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
	case *sqlparser.Insert:
 | 
			
		||||
		execRes, err = doInsert(stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
		execRes, err = doInsert(ctx, stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
	default:
 | 
			
		||||
		execRes, err = doExec(execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
		execRes, err = doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -159,7 +159,7 @@ func (d *dbSqlExecAppImpl) GetPageList(condition *entity.DbSqlExecQuery, pagePar
 | 
			
		||||
	return d.dbSqlExecRepo.GetPageList(condition, pageParam, toEntity, orderBy...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
func doSelect(ctx context.Context, selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	selectExprsStr := sqlparser.String(selectStmt.SelectExprs)
 | 
			
		||||
	if selectExprsStr == "*" || strings.Contains(selectExprsStr, ".*") ||
 | 
			
		||||
		len(strings.Split(selectExprsStr, ",")) > 1 {
 | 
			
		||||
@@ -182,13 +182,13 @@ func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExe
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return doRead(execSqlReq)
 | 
			
		||||
	return doRead(ctx, execSqlReq)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
func doRead(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	dbConn := execSqlReq.DbConn
 | 
			
		||||
	sql := execSqlReq.Sql
 | 
			
		||||
	colNames, res, err := dbConn.Query(sql)
 | 
			
		||||
	colNames, res, err := dbConn.QueryContext(ctx, sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -198,7 +198,7 @@ func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
func doUpdate(ctx context.Context, update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
	dbConn := execSqlReq.DbConn
 | 
			
		||||
 | 
			
		||||
	tableStr := sqlparser.String(update.TableExprs)
 | 
			
		||||
@@ -224,7 +224,7 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
 | 
			
		||||
	updateColumnsAndPrimaryKey := strings.Join(updateColumns, ",") + "," + primaryKey
 | 
			
		||||
	// 查询要更新字段数据的旧值,以及主键值
 | 
			
		||||
	selectSql := fmt.Sprintf("SELECT %s FROM %s %s LIMIT 200", updateColumnsAndPrimaryKey, tableStr, where)
 | 
			
		||||
	_, res, err := dbConn.Query(selectSql)
 | 
			
		||||
	_, res, err := dbConn.QueryContext(ctx, selectSql)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		bytes, _ := json.Marshal(res)
 | 
			
		||||
		dbSqlExec.OldValue = string(bytes)
 | 
			
		||||
@@ -235,10 +235,10 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
 | 
			
		||||
	dbSqlExec.Table = tableName
 | 
			
		||||
	dbSqlExec.Type = entity.DbSqlExecTypeUpdate
 | 
			
		||||
 | 
			
		||||
	return doExec(execSqlReq.Sql, dbConn)
 | 
			
		||||
	return doExec(ctx, execSqlReq.Sql, dbConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
func doDelete(ctx context.Context, delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
	dbConn := execSqlReq.DbConn
 | 
			
		||||
 | 
			
		||||
	tableStr := sqlparser.String(delete.TableExprs)
 | 
			
		||||
@@ -251,28 +251,28 @@ func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *ent
 | 
			
		||||
 | 
			
		||||
	// 查询删除数据
 | 
			
		||||
	selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where)
 | 
			
		||||
	_, res, _ := dbConn.Query(selectSql)
 | 
			
		||||
	_, res, _ := dbConn.QueryContext(ctx, selectSql)
 | 
			
		||||
 | 
			
		||||
	bytes, _ := json.Marshal(res)
 | 
			
		||||
	dbSqlExec.OldValue = string(bytes)
 | 
			
		||||
	dbSqlExec.Table = table
 | 
			
		||||
	dbSqlExec.Type = entity.DbSqlExecTypeDelete
 | 
			
		||||
 | 
			
		||||
	return doExec(execSqlReq.Sql, dbConn)
 | 
			
		||||
	return doExec(ctx, execSqlReq.Sql, dbConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
func doInsert(ctx context.Context, insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
	tableStr := sqlparser.String(insert.Table)
 | 
			
		||||
	// 可能使用别名,故空格切割
 | 
			
		||||
	table := strings.Split(tableStr, " ")[0]
 | 
			
		||||
	dbSqlExec.Table = table
 | 
			
		||||
	dbSqlExec.Type = entity.DbSqlExecTypeInsert
 | 
			
		||||
 | 
			
		||||
	return doExec(execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
	return doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doExec(sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
 | 
			
		||||
	rowsAffected, err := dbConn.Exec(sql)
 | 
			
		||||
func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
 | 
			
		||||
	rowsAffected, err := dbConn.ExecContext(ctx, sql)
 | 
			
		||||
	execRes := "success"
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		execRes = err.Error()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user