mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 16:30:25 +08:00 
			
		
		
		
	feat: 数据库sql执行支持取消执行操作
This commit is contained in:
		@@ -1,4 +1,5 @@
 | 
			
		||||
import request from './request';
 | 
			
		||||
import { randomUuid } from './utils/string';
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * 可用于各模块定义各自api请求
 | 
			
		||||
@@ -20,6 +21,8 @@ class Api {
 | 
			
		||||
     */
 | 
			
		||||
    beforeHandler: Function;
 | 
			
		||||
 | 
			
		||||
    static abortControllers: Map<string, AbortController> = new Map();
 | 
			
		||||
 | 
			
		||||
    constructor(url: string, method: string) {
 | 
			
		||||
        this.url = url;
 | 
			
		||||
        this.method = method;
 | 
			
		||||
@@ -53,8 +56,63 @@ class Api {
 | 
			
		||||
        return request.request(this.method, this.url, param, headers, options);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 请求对应的该api
 | 
			
		||||
     * @param {Object} param 请求该api的参数
 | 
			
		||||
     */
 | 
			
		||||
    requestCanCancel(key: string, param: any = null, options: any = null, headers: any = null): Promise<any> {
 | 
			
		||||
        let controller = Api.abortControllers.get(key);
 | 
			
		||||
        if (!controller) {
 | 
			
		||||
            controller = new AbortController();
 | 
			
		||||
            Api.abortControllers.set(key, controller);
 | 
			
		||||
        }
 | 
			
		||||
        if (options) {
 | 
			
		||||
            options.signal = controller.signal;
 | 
			
		||||
        } else {
 | 
			
		||||
            options = {
 | 
			
		||||
                signal: controller.signal,
 | 
			
		||||
            };
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return this.request(param, options, headers);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**    静态方法     **/
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 取消请求
 | 
			
		||||
     * @param key 请求key
 | 
			
		||||
     */
 | 
			
		||||
    static cancelReq(key: string) {
 | 
			
		||||
        let controller = Api.abortControllers.get(key);
 | 
			
		||||
        if (controller) {
 | 
			
		||||
            controller.abort();
 | 
			
		||||
            Api.removeAbortKey(key);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static removeAbortKey(key: string) {
 | 
			
		||||
        if (key) {
 | 
			
		||||
            console.log('remove abort key: ', key);
 | 
			
		||||
            Api.abortControllers.delete(key);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 根据旧key生成新的abort key,可能旧key未取消,造成多余无用对象
 | 
			
		||||
     * @param oldKey 旧key
 | 
			
		||||
     * @returns key
 | 
			
		||||
     */
 | 
			
		||||
    static genAbortKey(oldKey: string) {
 | 
			
		||||
        if (!oldKey) {
 | 
			
		||||
            return randomUuid();
 | 
			
		||||
        }
 | 
			
		||||
        if (Api.abortControllers.get(oldKey)) {
 | 
			
		||||
            return oldKey;
 | 
			
		||||
        }
 | 
			
		||||
        return randomUuid();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * 静态工厂,返回Api对象,并设置url与method属性
 | 
			
		||||
     * @param url url
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import config from './config';
 | 
			
		||||
import { getClientId, getToken } from './utils/storage';
 | 
			
		||||
import { templateResolve } from './utils/string';
 | 
			
		||||
import { ElMessage } from 'element-plus';
 | 
			
		||||
import axios from 'axios';
 | 
			
		||||
 | 
			
		||||
export interface Result {
 | 
			
		||||
    /**
 | 
			
		||||
@@ -67,21 +68,26 @@ service.interceptors.request.use(
 | 
			
		||||
service.interceptors.response.use(
 | 
			
		||||
    (response) => {
 | 
			
		||||
        // 获取请求返回结果
 | 
			
		||||
        const data: Result = response.data;
 | 
			
		||||
        if (data.code === ResultEnum.SUCCESS) {
 | 
			
		||||
            return data.data;
 | 
			
		||||
        const res: Result = response.data;
 | 
			
		||||
        if (res.code === ResultEnum.SUCCESS) {
 | 
			
		||||
            return res.data;
 | 
			
		||||
        }
 | 
			
		||||
        // 如果提示没有权限,则移除token,使其重新登录
 | 
			
		||||
        if (data.code === ResultEnum.NO_PERMISSION) {
 | 
			
		||||
        if (res.code === ResultEnum.NO_PERMISSION) {
 | 
			
		||||
            router.push({
 | 
			
		||||
                path: '/401',
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
        return Promise.reject(data);
 | 
			
		||||
        return Promise.reject(res);
 | 
			
		||||
    },
 | 
			
		||||
    (e: any) => {
 | 
			
		||||
        const rejectPromise = Promise.reject(e);
 | 
			
		||||
 | 
			
		||||
        if (axios.isCancel(e)) {
 | 
			
		||||
            console.log('请求已取消');
 | 
			
		||||
            return rejectPromise;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const statusCode = e.response?.status;
 | 
			
		||||
        if (statusCode == 500) {
 | 
			
		||||
            notifyErrorMsg('服务器未知异常');
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,7 @@ export const dbApi = {
 | 
			
		||||
            param.sql = Base64.encode(param.sql);
 | 
			
		||||
        }
 | 
			
		||||
    }),
 | 
			
		||||
    sqlExecCancel: Api.newPost('/dbs/{id}/exec-sql/cancel/{execId}'),
 | 
			
		||||
    // 保存sql
 | 
			
		||||
    saveSql: Api.newPost('/dbs/{id}/sql'),
 | 
			
		||||
    // 获取保存的sql
 | 
			
		||||
 
 | 
			
		||||
@@ -103,6 +103,7 @@
 | 
			
		||||
                        :table="dt.table"
 | 
			
		||||
                        :columns="dt.tableColumn"
 | 
			
		||||
                        :loading="dt.loading"
 | 
			
		||||
                        :loading-key="dt.loadingKey"
 | 
			
		||||
                        :height="tableDataHeight"
 | 
			
		||||
                        empty-text="tips: select *开头的单表查询或点击表名默认查询的数据,可双击数据在线修改"
 | 
			
		||||
                        @change-updated-field="changeUpdatedField($event, dt)"
 | 
			
		||||
@@ -139,6 +140,7 @@ import { ElNotification } from 'element-plus';
 | 
			
		||||
import syssocket from '@/common/syssocket';
 | 
			
		||||
import SvgIcon from '@/components/svgIcon/index.vue';
 | 
			
		||||
import { getDbDialect } from '../../dialect';
 | 
			
		||||
import { randomUuid } from '@/common/utils/string';
 | 
			
		||||
 | 
			
		||||
const emits = defineEmits(['saveSqlSuccess']);
 | 
			
		||||
 | 
			
		||||
@@ -171,6 +173,8 @@ class ExecResTab {
 | 
			
		||||
 | 
			
		||||
    loading: boolean;
 | 
			
		||||
 | 
			
		||||
    loadingKey: string;
 | 
			
		||||
 | 
			
		||||
    dbTableRef: any;
 | 
			
		||||
 | 
			
		||||
    tableColumn: any[] = [];
 | 
			
		||||
@@ -341,7 +345,10 @@ const onRunSql = async (newTab = false) => {
 | 
			
		||||
        execRes.errorMsg = '';
 | 
			
		||||
        execRes.sql = '';
 | 
			
		||||
 | 
			
		||||
        const colAndData: any = await getNowDbInst().runSql(props.dbName, sql, execRemark);
 | 
			
		||||
        const loadingKey = randomUuid();
 | 
			
		||||
        execRes.loadingKey = loadingKey;
 | 
			
		||||
 | 
			
		||||
        const colAndData: any = await getNowDbInst().runSql(props.dbName, sql, execRemark, loadingKey);
 | 
			
		||||
        if (!colAndData.res || colAndData.res.length === 0) {
 | 
			
		||||
            ElMessage.warning('未查询到结果集');
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -90,10 +90,15 @@
 | 
			
		||||
                        </div>
 | 
			
		||||
                    </template>
 | 
			
		||||
 | 
			
		||||
                    <template v-if="loading" #overlay>
 | 
			
		||||
                        <div class="el-loading-mask" style="display: flex; align-items: center; justify-content: center">
 | 
			
		||||
                    <template v-if="state.loading" #overlay>
 | 
			
		||||
                        <div class="el-loading-mask" style="display: flex; flex-direction: column; align-items: center; justify-content: center">
 | 
			
		||||
                            <div>
 | 
			
		||||
                                <SvgIcon class="is-loading" name="loading" color="var(--el-color-primary)" :size="42" />
 | 
			
		||||
                            </div>
 | 
			
		||||
                            <div v-if="loadingKey" class="mt10">
 | 
			
		||||
                                <el-button @click="cancelLoading" type="info" size="small" plain>取 消</el-button>
 | 
			
		||||
                            </div>
 | 
			
		||||
                        </div>
 | 
			
		||||
                    </template>
 | 
			
		||||
 | 
			
		||||
                    <template #empty>
 | 
			
		||||
@@ -127,10 +132,14 @@ import { ContextmenuItem, Contextmenu } from '@/components/contextmenu';
 | 
			
		||||
import SvgIcon from '@/components/svgIcon/index.vue';
 | 
			
		||||
import { exportCsv, exportFile } from '@/common/utils/export';
 | 
			
		||||
import { dateStrFormat } from '@/common/utils/date';
 | 
			
		||||
import { dbApi } from '../../api';
 | 
			
		||||
 | 
			
		||||
const emits = defineEmits(['dataDelete', 'sortChange', 'deleteData', 'selectionChange', 'changeUpdatedField']);
 | 
			
		||||
 | 
			
		||||
const props = defineProps({
 | 
			
		||||
    loadingKey: {
 | 
			
		||||
        type: String,
 | 
			
		||||
    },
 | 
			
		||||
    dbId: {
 | 
			
		||||
        type: Number,
 | 
			
		||||
        required: true,
 | 
			
		||||
@@ -372,6 +381,12 @@ onMounted(async () => {
 | 
			
		||||
    setTableData(props.data);
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
const cancelLoading = async () => {
 | 
			
		||||
    if (props.loadingKey) {
 | 
			
		||||
        await dbApi.sqlExecCancel.request({ id: state.dbId, execId: props.loadingKey });
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
const setTableData = (datas: any) => {
 | 
			
		||||
    tableRef.value.scrollTo({ scrollLeft: 0, scrollTop: 0 });
 | 
			
		||||
    selectionRowsMap.clear();
 | 
			
		||||
 
 | 
			
		||||
@@ -3,7 +3,7 @@
 | 
			
		||||
        <el-row>
 | 
			
		||||
            <el-col :span="8">
 | 
			
		||||
                <div class="mt5">
 | 
			
		||||
                    <el-link @click="onRefresh()" icon="refresh" :underline="false" class="ml5"> </el-link>
 | 
			
		||||
                    <el-link :disabled="state.loading" @click="onRefresh()" icon="refresh" :underline="false" class="ml5"> </el-link>
 | 
			
		||||
                    <el-divider direction="vertical" border-style="dashed" />
 | 
			
		||||
 | 
			
		||||
                    <el-popover
 | 
			
		||||
 
 | 
			
		||||
@@ -196,9 +196,10 @@ export class DbInst {
 | 
			
		||||
     * @param sql sql
 | 
			
		||||
     * @param remark 执行备注
 | 
			
		||||
     */
 | 
			
		||||
    async runSql(dbName: string, sql: string, remark: string = '') {
 | 
			
		||||
    async runSql(dbName: string, sql: string, remark: string = '', key: string = '') {
 | 
			
		||||
        return await dbApi.sqlExec.request({
 | 
			
		||||
            id: this.id,
 | 
			
		||||
            execId: key,
 | 
			
		||||
            db: dbName,
 | 
			
		||||
            sql: sql.trim(),
 | 
			
		||||
            remark,
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -24,6 +25,7 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/ws"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -80,6 +82,11 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**  数据库操作相关、执行sql等   ***/
 | 
			
		||||
 | 
			
		||||
// 取消执行sql函数map; key -> execId ; value -> cancelFunc
 | 
			
		||||
var cancelExecSqlMap = sync.Map{}
 | 
			
		||||
 | 
			
		||||
func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
	g := rc.GinCtx
 | 
			
		||||
	form := &form.DbSqlExecForm{}
 | 
			
		||||
@@ -95,7 +102,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
	// 去除前后空格及换行符
 | 
			
		||||
	sql := stringx.TrimSpaceAndBr(string(sqlBytes))
 | 
			
		||||
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("%s\n-> %s", dbConn.Info.GetLogDesc(), sql)
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("%s %s\n-> %s", dbConn.Info.GetLogDesc(), form.ExecId, sql)
 | 
			
		||||
	biz.NotEmpty(form.Sql, "sql不能为空")
 | 
			
		||||
 | 
			
		||||
	execReq := &application.DbSqlExecReq{
 | 
			
		||||
@@ -103,7 +110,15 @@ func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
		Db:     form.Db,
 | 
			
		||||
		Remark: form.Remark,
 | 
			
		||||
		DbConn: dbConn,
 | 
			
		||||
		LoginAccount: rc.GetLoginAccount(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx := rc.MetaCtx
 | 
			
		||||
	// 如果存在执行id,则保存取消函数,用于后续可能的取消操作
 | 
			
		||||
	if form.ExecId != "" {
 | 
			
		||||
		cancelCtx, cancel := context.WithCancel(rc.MetaCtx)
 | 
			
		||||
		ctx = cancelCtx
 | 
			
		||||
		cancelExecSqlMap.Store(form.ExecId, cancel)
 | 
			
		||||
		defer cancelExecSqlMap.Delete(form.ExecId)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqls, err := sqlparser.SplitStatementToPieces(sql, sqlparser.WithDialect(dbConn.Info.Type.Dialect()))
 | 
			
		||||
@@ -119,7 +134,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		execReq.Sql = s
 | 
			
		||||
		execRes, err := d.DbSqlExecApp.Exec(execReq)
 | 
			
		||||
		execRes, err := d.DbSqlExecApp.Exec(ctx, execReq)
 | 
			
		||||
		biz.ErrIsNilAppendErr(err, fmt.Sprintf("[%s] -> 执行失败: ", s)+"%s")
 | 
			
		||||
 | 
			
		||||
		if execResAll == nil {
 | 
			
		||||
@@ -135,6 +150,14 @@ func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
	rc.ResData = colAndRes
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) CancelExecSql(rc *req.Ctx) {
 | 
			
		||||
	execId := ginx.PathParam(rc.GinCtx, "execId")
 | 
			
		||||
	if cancelFunc, ok := cancelExecSqlMap.LoadAndDelete(execId); ok {
 | 
			
		||||
		rc.ReqParam = execId
 | 
			
		||||
		cancelFunc.(context.CancelFunc)()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// progressCategory sql文件执行进度消息类型
 | 
			
		||||
const progressCategory = "execSqlFileProgress"
 | 
			
		||||
 | 
			
		||||
@@ -179,7 +202,6 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
		Db:     dbName,
 | 
			
		||||
		Remark: filename,
 | 
			
		||||
		DbConn: dbConn,
 | 
			
		||||
		LoginAccount: rc.GetLoginAccount(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sql string
 | 
			
		||||
@@ -237,7 +259,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
		const maxRecordStatements = 64
 | 
			
		||||
		if executedStatements < maxRecordStatements {
 | 
			
		||||
			execReq.Sql = sql
 | 
			
		||||
			_, err = d.DbSqlExecApp.Exec(execReq)
 | 
			
		||||
			_, err = d.DbSqlExecApp.Exec(rc.MetaCtx, execReq)
 | 
			
		||||
		} else {
 | 
			
		||||
			_, err = dbConn.Exec(sql)
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -18,6 +18,7 @@ type DbSqlSaveForm struct {
 | 
			
		||||
 | 
			
		||||
// 数据库SQL执行表单
 | 
			
		||||
type DbSqlExecForm struct {
 | 
			
		||||
	ExecId string `json:"execId"`                 // 执行id(用于取消执行使用)
 | 
			
		||||
	Db     string `binding:"required" json:"db"`  //数据库名
 | 
			
		||||
	Sql    string `binding:"required" json:"sql"` // 执行sql
 | 
			
		||||
	Remark string `json:"remark"`                 // 执行备注
 | 
			
		||||
 
 | 
			
		||||
@@ -8,6 +8,7 @@ import (
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
	"mayfly-go/pkg/contextx"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -21,7 +22,6 @@ type DbSqlExecReq struct {
 | 
			
		||||
	Db     string
 | 
			
		||||
	Sql    string
 | 
			
		||||
	Remark string
 | 
			
		||||
	LoginAccount *model.LoginAccount
 | 
			
		||||
	DbConn *dbm.DbConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -47,7 +47,7 @@ func (d *DbSqlExecRes) Merge(execRes *DbSqlExecRes) {
 | 
			
		||||
 | 
			
		||||
type DbSqlExec interface {
 | 
			
		||||
	// 执行sql
 | 
			
		||||
	Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
	Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
 | 
			
		||||
	// 根据条件删除sql执行记录
 | 
			
		||||
	DeleteBy(ctx context.Context, condition *entity.DbSqlExec)
 | 
			
		||||
@@ -66,19 +66,19 @@ type dbSqlExecAppImpl struct {
 | 
			
		||||
	dbSqlExecRepo repository.DbSqlExec
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func createSqlExecRecord(execSqlReq *DbSqlExecReq) *entity.DbSqlExec {
 | 
			
		||||
func createSqlExecRecord(ctx context.Context, execSqlReq *DbSqlExecReq) *entity.DbSqlExec {
 | 
			
		||||
	dbSqlExecRecord := new(entity.DbSqlExec)
 | 
			
		||||
	dbSqlExecRecord.DbId = execSqlReq.DbId
 | 
			
		||||
	dbSqlExecRecord.Db = execSqlReq.Db
 | 
			
		||||
	dbSqlExecRecord.Sql = execSqlReq.Sql
 | 
			
		||||
	dbSqlExecRecord.Remark = execSqlReq.Remark
 | 
			
		||||
	dbSqlExecRecord.SetBaseInfo(execSqlReq.LoginAccount)
 | 
			
		||||
	dbSqlExecRecord.SetBaseInfo(contextx.GetLoginAccount(ctx))
 | 
			
		||||
	return dbSqlExecRecord
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
func (d *dbSqlExecAppImpl) Exec(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	sql := execSqlReq.Sql
 | 
			
		||||
	dbSqlExecRecord := createSqlExecRecord(execSqlReq)
 | 
			
		||||
	dbSqlExecRecord := createSqlExecRecord(ctx, execSqlReq)
 | 
			
		||||
	dbSqlExecRecord.Type = entity.DbSqlExecTypeOther
 | 
			
		||||
	var execRes *DbSqlExecRes
 | 
			
		||||
	isSelect := false
 | 
			
		||||
@@ -100,9 +100,9 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
		}
 | 
			
		||||
		var execErr error
 | 
			
		||||
		if isSelect || strings.HasPrefix(lowerSql, "show") {
 | 
			
		||||
			execRes, execErr = doRead(execSqlReq)
 | 
			
		||||
			execRes, execErr = doRead(ctx, execSqlReq)
 | 
			
		||||
		} else {
 | 
			
		||||
			execRes, execErr = doExec(execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
			execRes, execErr = doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
		}
 | 
			
		||||
		if execErr != nil {
 | 
			
		||||
			return nil, execErr
 | 
			
		||||
@@ -114,21 +114,21 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
 | 
			
		||||
	switch stmt := stmt.(type) {
 | 
			
		||||
	case *sqlparser.Select:
 | 
			
		||||
		isSelect = true
 | 
			
		||||
		execRes, err = doSelect(stmt, execSqlReq)
 | 
			
		||||
		execRes, err = doSelect(ctx, stmt, execSqlReq)
 | 
			
		||||
	case *sqlparser.Show:
 | 
			
		||||
		isSelect = true
 | 
			
		||||
		execRes, err = doRead(execSqlReq)
 | 
			
		||||
		execRes, err = doRead(ctx, execSqlReq)
 | 
			
		||||
	case *sqlparser.OtherRead:
 | 
			
		||||
		isSelect = true
 | 
			
		||||
		execRes, err = doRead(execSqlReq)
 | 
			
		||||
		execRes, err = doRead(ctx, execSqlReq)
 | 
			
		||||
	case *sqlparser.Update:
 | 
			
		||||
		execRes, err = doUpdate(stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
		execRes, err = doUpdate(ctx, stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
	case *sqlparser.Delete:
 | 
			
		||||
		execRes, err = doDelete(stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
		execRes, err = doDelete(ctx, stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
	case *sqlparser.Insert:
 | 
			
		||||
		execRes, err = doInsert(stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
		execRes, err = doInsert(ctx, stmt, execSqlReq, dbSqlExecRecord)
 | 
			
		||||
	default:
 | 
			
		||||
		execRes, err = doExec(execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
		execRes, err = doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -159,7 +159,7 @@ func (d *dbSqlExecAppImpl) GetPageList(condition *entity.DbSqlExecQuery, pagePar
 | 
			
		||||
	return d.dbSqlExecRepo.GetPageList(condition, pageParam, toEntity, orderBy...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
func doSelect(ctx context.Context, selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	selectExprsStr := sqlparser.String(selectStmt.SelectExprs)
 | 
			
		||||
	if selectExprsStr == "*" || strings.Contains(selectExprsStr, ".*") ||
 | 
			
		||||
		len(strings.Split(selectExprsStr, ",")) > 1 {
 | 
			
		||||
@@ -182,13 +182,13 @@ func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExe
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return doRead(execSqlReq)
 | 
			
		||||
	return doRead(ctx, execSqlReq)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
func doRead(ctx context.Context, execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	dbConn := execSqlReq.DbConn
 | 
			
		||||
	sql := execSqlReq.Sql
 | 
			
		||||
	colNames, res, err := dbConn.Query(sql)
 | 
			
		||||
	colNames, res, err := dbConn.QueryContext(ctx, sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -198,7 +198,7 @@ func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
func doUpdate(ctx context.Context, update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
	dbConn := execSqlReq.DbConn
 | 
			
		||||
 | 
			
		||||
	tableStr := sqlparser.String(update.TableExprs)
 | 
			
		||||
@@ -224,7 +224,7 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
 | 
			
		||||
	updateColumnsAndPrimaryKey := strings.Join(updateColumns, ",") + "," + primaryKey
 | 
			
		||||
	// 查询要更新字段数据的旧值,以及主键值
 | 
			
		||||
	selectSql := fmt.Sprintf("SELECT %s FROM %s %s LIMIT 200", updateColumnsAndPrimaryKey, tableStr, where)
 | 
			
		||||
	_, res, err := dbConn.Query(selectSql)
 | 
			
		||||
	_, res, err := dbConn.QueryContext(ctx, selectSql)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		bytes, _ := json.Marshal(res)
 | 
			
		||||
		dbSqlExec.OldValue = string(bytes)
 | 
			
		||||
@@ -235,10 +235,10 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
 | 
			
		||||
	dbSqlExec.Table = tableName
 | 
			
		||||
	dbSqlExec.Type = entity.DbSqlExecTypeUpdate
 | 
			
		||||
 | 
			
		||||
	return doExec(execSqlReq.Sql, dbConn)
 | 
			
		||||
	return doExec(ctx, execSqlReq.Sql, dbConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
func doDelete(ctx context.Context, delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
	dbConn := execSqlReq.DbConn
 | 
			
		||||
 | 
			
		||||
	tableStr := sqlparser.String(delete.TableExprs)
 | 
			
		||||
@@ -251,28 +251,28 @@ func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *ent
 | 
			
		||||
 | 
			
		||||
	// 查询删除数据
 | 
			
		||||
	selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where)
 | 
			
		||||
	_, res, _ := dbConn.Query(selectSql)
 | 
			
		||||
	_, res, _ := dbConn.QueryContext(ctx, selectSql)
 | 
			
		||||
 | 
			
		||||
	bytes, _ := json.Marshal(res)
 | 
			
		||||
	dbSqlExec.OldValue = string(bytes)
 | 
			
		||||
	dbSqlExec.Table = table
 | 
			
		||||
	dbSqlExec.Type = entity.DbSqlExecTypeDelete
 | 
			
		||||
 | 
			
		||||
	return doExec(execSqlReq.Sql, dbConn)
 | 
			
		||||
	return doExec(ctx, execSqlReq.Sql, dbConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
func doInsert(ctx context.Context, insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
 | 
			
		||||
	tableStr := sqlparser.String(insert.Table)
 | 
			
		||||
	// 可能使用别名,故空格切割
 | 
			
		||||
	table := strings.Split(tableStr, " ")[0]
 | 
			
		||||
	dbSqlExec.Table = table
 | 
			
		||||
	dbSqlExec.Type = entity.DbSqlExecTypeInsert
 | 
			
		||||
 | 
			
		||||
	return doExec(execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
	return doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doExec(sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
 | 
			
		||||
	rowsAffected, err := dbConn.Exec(sql)
 | 
			
		||||
func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
 | 
			
		||||
	rowsAffected, err := dbConn.ExecContext(ctx, sql)
 | 
			
		||||
	execRes := "success"
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		execRes = err.Error()
 | 
			
		||||
 
 | 
			
		||||
@@ -1,8 +1,10 @@
 | 
			
		||||
package dbm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"strconv"
 | 
			
		||||
@@ -20,11 +22,20 @@ type DbConn struct {
 | 
			
		||||
// 执行查询语句
 | 
			
		||||
// 依次返回 列名数组(顺序),结果map,错误
 | 
			
		||||
func (d *DbConn) Query(querySql string) ([]string, []map[string]any, error) {
 | 
			
		||||
	return d.QueryContext(context.Background(), querySql)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行查询语句
 | 
			
		||||
// 依次返回 列名数组(顺序),结果map,错误
 | 
			
		||||
func (d *DbConn) QueryContext(ctx context.Context, querySql string) ([]string, []map[string]any, error) {
 | 
			
		||||
	result := make([]map[string]any, 0, 16)
 | 
			
		||||
	columns, err := walkTableRecord(d.db, querySql, func(record map[string]any, columns []string) {
 | 
			
		||||
	columns, err := walkTableRecord(ctx, d.db, querySql, func(record map[string]any, columns []string) {
 | 
			
		||||
		result = append(result, record)
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == context.Canceled {
 | 
			
		||||
			return nil, nil, errorx.NewBiz("取消执行")
 | 
			
		||||
		}
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return columns, result, nil
 | 
			
		||||
@@ -47,16 +58,25 @@ func (d *DbConn) Query2Struct(execSql string, dest any) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WalkTableRecord 遍历表记录
 | 
			
		||||
func (d *DbConn) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error {
 | 
			
		||||
	_, err := walkTableRecord(d.db, selectSql, walk)
 | 
			
		||||
func (d *DbConn) WalkTableRecord(ctx context.Context, selectSql string, walk func(record map[string]any, columns []string)) error {
 | 
			
		||||
	_, err := walkTableRecord(ctx, d.db, selectSql, walk)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行 update, insert, delete,建表等sql
 | 
			
		||||
// 返回影响条数和错误
 | 
			
		||||
func (d *DbConn) Exec(sql string) (int64, error) {
 | 
			
		||||
	res, err := d.db.Exec(sql)
 | 
			
		||||
	return d.ExecContext(context.Background(), sql)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行 update, insert, delete,建表等sql
 | 
			
		||||
// 返回影响条数和错误
 | 
			
		||||
func (d *DbConn) ExecContext(ctx context.Context, sql string) (int64, error) {
 | 
			
		||||
	res, err := d.db.ExecContext(ctx, sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if err == context.Canceled {
 | 
			
		||||
			return 0, errorx.NewBiz("取消执行")
 | 
			
		||||
		}
 | 
			
		||||
		return 0, err
 | 
			
		||||
	}
 | 
			
		||||
	return res.RowsAffected()
 | 
			
		||||
@@ -84,8 +104,9 @@ func (d *DbConn) Close() {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) ([]string, error) {
 | 
			
		||||
	rows, err := db.Query(selectSql)
 | 
			
		||||
func walkTableRecord(ctx context.Context, db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) ([]string, error) {
 | 
			
		||||
	rows, err := db.QueryContext(ctx, selectSql)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -179,5 +179,5 @@ func (md *MysqlDialect) GetTableRecord(tableName string, pageNum, pageSize int)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
 | 
			
		||||
	return md.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
 | 
			
		||||
	return md.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package dbm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -249,7 +250,7 @@ func (pd *PgsqlDialect) GetTableRecord(tableName string, pageNum, pageSize int)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
 | 
			
		||||
	return pd.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
 | 
			
		||||
	return pd.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取pgsql当前连接的库可访问的schemaNames
 | 
			
		||||
 
 | 
			
		||||
@@ -35,6 +35,8 @@ func InitDbRouter(router *gin.RouterGroup) {
 | 
			
		||||
 | 
			
		||||
		req.NewPost(":dbId/exec-sql", d.ExecSql).Log(req.NewLog("db-执行Sql")),
 | 
			
		||||
 | 
			
		||||
		req.NewPost(":dbId/exec-sql/cancel/:execId", d.CancelExecSql).Log(req.NewLog("db-取消执行Sql")),
 | 
			
		||||
 | 
			
		||||
		req.NewPost(":dbId/exec-sql-file", d.ExecSqlFile).Log(req.NewLogSave("db-执行Sql文件")),
 | 
			
		||||
 | 
			
		||||
		req.NewGet(":dbId/dump", d.DumpSql).Log(req.NewLogSave("db-导出sql文件")).NoRes(),
 | 
			
		||||
 
 | 
			
		||||
@@ -101,7 +101,7 @@ func (rc *Ctx) GetLogInfo() *LogInfo {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewCtxWithGin(g *gin.Context) *Ctx {
 | 
			
		||||
	return &Ctx{GinCtx: g, MetaCtx: contextx.NewTraceId()}
 | 
			
		||||
	return &Ctx{GinCtx: g, MetaCtx: contextx.WithTraceId(g.Request.Context())}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 处理器拦截器函数
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user