mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 00:10:25 +08:00 
			
		
		
		
	Merge pull request #69 from kanzihuang/feat-progress-notify-pullrequest
feat: 显示 SQL 文件执行进度
This commit is contained in:
		@@ -1,7 +1,10 @@
 | 
			
		||||
import Config from './config';
 | 
			
		||||
import { ElNotification } from 'element-plus';
 | 
			
		||||
import { ElNotification, NotificationHandle } from 'element-plus';
 | 
			
		||||
import SocketBuilder from './SocketBuilder';
 | 
			
		||||
import { getToken } from '@/common/utils/storage';
 | 
			
		||||
import { createVNode, reactive } from "vue";
 | 
			
		||||
import { buildProgressProps } from "@/components/progress-notify/progress-notify";
 | 
			
		||||
import ProgressNotify from '/src/components/progress-notify/progress-notify.vue';
 | 
			
		||||
 | 
			
		||||
export default {
 | 
			
		||||
    /**
 | 
			
		||||
@@ -12,32 +15,63 @@ export default {
 | 
			
		||||
        if (!token) {
 | 
			
		||||
            return null;
 | 
			
		||||
        }
 | 
			
		||||
        const messageTypes = {
 | 
			
		||||
            0: "error",
 | 
			
		||||
            1: "success",
 | 
			
		||||
            2: "info",
 | 
			
		||||
        }
 | 
			
		||||
        const notifyMap: Map<Number, any> = new Map()
 | 
			
		||||
 | 
			
		||||
        return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`)
 | 
			
		||||
            .message((event: { data: string }) => {
 | 
			
		||||
                const message = JSON.parse(event.data);
 | 
			
		||||
                let mtype: string;
 | 
			
		||||
                switch (message.type) {
 | 
			
		||||
                    case 0:
 | 
			
		||||
                        mtype = 'error';
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 2:
 | 
			
		||||
                        mtype = 'info';
 | 
			
		||||
                        break;
 | 
			
		||||
                    case 1:
 | 
			
		||||
                        mtype = 'success';
 | 
			
		||||
                const type = messageTypes[message.type]
 | 
			
		||||
                switch (message.category) {
 | 
			
		||||
                    case "execSqlFileProgress":
 | 
			
		||||
                        const content = JSON.parse(message.msg)
 | 
			
		||||
                        const id = content.id
 | 
			
		||||
                        let progress = notifyMap.get(id)
 | 
			
		||||
                        if (content.terminated) {
 | 
			
		||||
                            if (progress != undefined) {
 | 
			
		||||
                                progress.notification?.close()
 | 
			
		||||
                                notifyMap.delete(id)
 | 
			
		||||
                                progress = undefined
 | 
			
		||||
                            }
 | 
			
		||||
                            return
 | 
			
		||||
                        }
 | 
			
		||||
                        if (progress == undefined) {
 | 
			
		||||
                            progress = {
 | 
			
		||||
                                props: reactive(buildProgressProps()),
 | 
			
		||||
                                notification: undefined,
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
                        progress.props.progress.sqlFileName = content.sqlFileName
 | 
			
		||||
                        progress.props.progress.executedStatements = content.executedStatements
 | 
			
		||||
                        if (!notifyMap.has(id)) {
 | 
			
		||||
                            const vNodeMessage = createVNode(
 | 
			
		||||
                                ProgressNotify,
 | 
			
		||||
                                progress.props,
 | 
			
		||||
                                null,
 | 
			
		||||
                            )
 | 
			
		||||
                            progress.notification = ElNotification({
 | 
			
		||||
                                duration: 0,
 | 
			
		||||
                                title: message.title,
 | 
			
		||||
                                message: vNodeMessage,
 | 
			
		||||
                                type: type,
 | 
			
		||||
                                showClose: false,
 | 
			
		||||
                            });
 | 
			
		||||
                            notifyMap.set(id, progress)
 | 
			
		||||
                        }
 | 
			
		||||
                        break;
 | 
			
		||||
                    default:
 | 
			
		||||
                        mtype = 'info';
 | 
			
		||||
                        ElNotification({
 | 
			
		||||
                            duration: 0,
 | 
			
		||||
                            title: message.title,
 | 
			
		||||
                            message: message.msg,
 | 
			
		||||
                            type: type,
 | 
			
		||||
                        });
 | 
			
		||||
                        break;
 | 
			
		||||
                }
 | 
			
		||||
                if (mtype == undefined) {
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                ElNotification({
 | 
			
		||||
                    duration: 0,
 | 
			
		||||
                    title: message.title,
 | 
			
		||||
                    message: message.msg,
 | 
			
		||||
                    type: mtype as any,
 | 
			
		||||
                });
 | 
			
		||||
            })
 | 
			
		||||
            .open((event: any) => console.log(event))
 | 
			
		||||
            .build();
 | 
			
		||||
 
 | 
			
		||||
@@ -0,0 +1,14 @@
 | 
			
		||||
export const buildProgressProps = () : any =>  {
 | 
			
		||||
    return {
 | 
			
		||||
        progress: {
 | 
			
		||||
            sqlFileName: {
 | 
			
		||||
                type: String
 | 
			
		||||
            },
 | 
			
		||||
            executedStatements: {
 | 
			
		||||
                type: Number
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -0,0 +1,41 @@
 | 
			
		||||
<template>
 | 
			
		||||
    <el-descriptions
 | 
			
		||||
        border
 | 
			
		||||
        size="small"
 | 
			
		||||
        :title="`${progress.sqlFileName}`"
 | 
			
		||||
    >
 | 
			
		||||
        <el-descriptions-item label="时间">{{ state.elapsedTime }}</el-descriptions-item>
 | 
			
		||||
        <el-descriptions-item label="已处理">{{ progress.executedStatements }}</el-descriptions-item>
 | 
			
		||||
    </el-descriptions>
 | 
			
		||||
</template>
 | 
			
		||||
<script lang="ts" setup>
 | 
			
		||||
 | 
			
		||||
import {onMounted, onUnmounted, reactive} from "vue";
 | 
			
		||||
import {formatTime} from 'element-plus/es/components/countdown/src/utils';
 | 
			
		||||
import {buildProgressProps} from "./progress-notify";
 | 
			
		||||
 | 
			
		||||
const props = defineProps(buildProgressProps());
 | 
			
		||||
 | 
			
		||||
const state = reactive({
 | 
			
		||||
    elapsedTime: "00:00:00"
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
let timer = undefined;
 | 
			
		||||
const startTime = Date.now()
 | 
			
		||||
 | 
			
		||||
onMounted(async () => {
 | 
			
		||||
  timer = setInterval(() => {
 | 
			
		||||
    const elapsed = Date.now() - startTime;
 | 
			
		||||
    state.elapsedTime = formatTime(elapsed, 'HH:mm:ss')
 | 
			
		||||
  }, 1000);
 | 
			
		||||
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
onUnmounted(async () => {
 | 
			
		||||
  if (timer != undefined) {
 | 
			
		||||
    clearInterval(timer); // 在Vue实例销毁前,清除我们的定时器
 | 
			
		||||
    timer = undefined;
 | 
			
		||||
  }
 | 
			
		||||
});
 | 
			
		||||
 | 
			
		||||
</script>
 | 
			
		||||
@@ -206,7 +206,7 @@ router.beforeEach(async (to, from, next) => {
 | 
			
		||||
 | 
			
		||||
        if (SysWs) {
 | 
			
		||||
            SysWs.close();
 | 
			
		||||
            SysWs = null;
 | 
			
		||||
            SysWs = undefined;
 | 
			
		||||
        }
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
@@ -13,6 +13,7 @@ require (
 | 
			
		||||
	github.com/go-sql-driver/mysql v1.7.1
 | 
			
		||||
	github.com/golang-jwt/jwt/v5 v5.0.0
 | 
			
		||||
	github.com/gorilla/websocket v1.5.0
 | 
			
		||||
	github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31
 | 
			
		||||
	github.com/lib/pq v1.10.9
 | 
			
		||||
	github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
 | 
			
		||||
	github.com/mojocn/base64Captcha v1.3.5 // 验证码
 | 
			
		||||
@@ -21,45 +22,52 @@ require (
 | 
			
		||||
	github.com/pquerna/otp v1.4.0
 | 
			
		||||
	github.com/redis/go-redis/v9 v9.2.1
 | 
			
		||||
	github.com/robfig/cron/v3 v3.0.1 // 定时任务
 | 
			
		||||
	github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
 | 
			
		||||
	github.com/stretchr/testify v1.8.4
 | 
			
		||||
	go.mongodb.org/mongo-driver v1.12.1 // mongo
 | 
			
		||||
	golang.org/x/crypto v0.14.0 // ssh
 | 
			
		||||
	golang.org/x/oauth2 v0.13.0
 | 
			
		||||
	gopkg.in/yaml.v3 v3.0.1
 | 
			
		||||
	// gorm
 | 
			
		||||
	gorm.io/driver/mysql v1.5.1
 | 
			
		||||
	gorm.io/driver/sqlite v1.5.4
 | 
			
		||||
	gorm.io/gorm v1.25.4
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
 | 
			
		||||
	gorm.io/driver/sqlite v1.5.1
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
 | 
			
		||||
	github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.9.1 // indirect
 | 
			
		||||
	github.com/cespare/xxhash/v2 v2.2.0 // indirect
 | 
			
		||||
	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
 | 
			
		||||
	github.com/davecgh/go-spew v1.1.1 // indirect
 | 
			
		||||
	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
 | 
			
		||||
	github.com/gabriel-vasile/mimetype v1.4.2 // indirect
 | 
			
		||||
	github.com/gin-contrib/sse v0.1.0 // indirect
 | 
			
		||||
	github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
 | 
			
		||||
	github.com/goccy/go-json v0.10.2 // indirect
 | 
			
		||||
	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
 | 
			
		||||
	github.com/golang/glog v1.0.0 // indirect
 | 
			
		||||
	github.com/golang/protobuf v1.5.3 // indirect
 | 
			
		||||
	github.com/golang/snappy v0.0.1 // indirect
 | 
			
		||||
	github.com/golang/snappy v0.0.4 // indirect
 | 
			
		||||
	github.com/jinzhu/inflection v1.0.0 // indirect
 | 
			
		||||
	github.com/jinzhu/now v1.1.5 // indirect
 | 
			
		||||
	github.com/json-iterator/go v1.1.12 // indirect
 | 
			
		||||
	github.com/klauspost/compress v1.13.6 // indirect
 | 
			
		||||
	github.com/klauspost/compress v1.16.5 // indirect
 | 
			
		||||
	github.com/klauspost/cpuid/v2 v2.2.4 // indirect
 | 
			
		||||
	github.com/kr/fs v0.1.0 // indirect
 | 
			
		||||
	github.com/kr/pretty v0.3.0 // indirect
 | 
			
		||||
	github.com/leodido/go-urn v1.2.4 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.19 // indirect
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.17 // indirect
 | 
			
		||||
	github.com/mattn/go-sqlite3 v1.14.16 // indirect
 | 
			
		||||
	github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
 | 
			
		||||
	github.com/modern-go/reflect2 v1.0.2 // indirect
 | 
			
		||||
	github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
 | 
			
		||||
	github.com/montanaflynn/stats v0.7.0 // indirect
 | 
			
		||||
	github.com/pelletier/go-toml/v2 v2.0.8 // indirect
 | 
			
		||||
	github.com/pmezard/go-difflib v1.0.0 // indirect
 | 
			
		||||
	github.com/spf13/pflag v1.0.5 // indirect
 | 
			
		||||
	github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
 | 
			
		||||
	github.com/ugorji/go/codec v1.2.11 // indirect
 | 
			
		||||
	github.com/xdg-go/pbkdf2 v1.0.0 // indirect
 | 
			
		||||
@@ -67,12 +75,15 @@ require (
 | 
			
		||||
	github.com/xdg-go/stringprep v1.0.4 // indirect
 | 
			
		||||
	github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
 | 
			
		||||
	golang.org/x/arch v0.3.0 // indirect
 | 
			
		||||
	golang.org/x/exp v0.0.0-20230519143937-03e91628a987 // indirect
 | 
			
		||||
	golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect
 | 
			
		||||
	golang.org/x/net v0.16.0 // indirect
 | 
			
		||||
	golang.org/x/sync v0.1.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.13.0 // indirect
 | 
			
		||||
	golang.org/x/text v0.13.0 // indirect
 | 
			
		||||
	google.golang.org/appengine v1.6.7 // indirect
 | 
			
		||||
	google.golang.org/genproto v0.0.0-20230131230820-1c016267d619 // indirect
 | 
			
		||||
	google.golang.org/grpc v1.52.3 // indirect
 | 
			
		||||
	google.golang.org/protobuf v1.31.0 // indirect
 | 
			
		||||
	gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect
 | 
			
		||||
	vitess.io/vitess v0.17.3 // indirect
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,12 @@
 | 
			
		||||
package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/lib/pq"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mayfly-go/pkg/utils/uniqueid"
 | 
			
		||||
	"mayfly-go/pkg/ws"
 | 
			
		||||
 | 
			
		||||
	"mayfly-go/internal/db/api/form"
 | 
			
		||||
	"mayfly-go/internal/db/api/vo"
 | 
			
		||||
	"mayfly-go/internal/db/application"
 | 
			
		||||
@@ -16,14 +19,13 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/gormx"
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
	"mayfly-go/pkg/req"
 | 
			
		||||
	"mayfly-go/pkg/sqlparser"
 | 
			
		||||
	"mayfly-go/pkg/utils/stringx"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/xwb1989/sqlparser"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Db struct {
 | 
			
		||||
@@ -79,9 +81,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
 | 
			
		||||
	dbName := g.Query("db")
 | 
			
		||||
	biz.NotEmpty(dbName, "db不能为空")
 | 
			
		||||
	return d.DbApp.GetDbConnection(getDbId(g), dbName)
 | 
			
		||||
	return d.DbApp.GetDbConnection(getDbId(g), getDbName(g))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) TableInfos(rc *req.Ctx) {
 | 
			
		||||
@@ -152,67 +152,119 @@ func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
	rc.ResData = colAndRes
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// progressCategory sql文件执行进度消息类型
 | 
			
		||||
const progressCategory = "execSqlFileProgress"
 | 
			
		||||
 | 
			
		||||
// progressMsg sql文件执行进度消息
 | 
			
		||||
type progressMsg struct {
 | 
			
		||||
	Id                 uint64 `json:"id"`
 | 
			
		||||
	SqlFileName        string `json:"sqlFileName"`
 | 
			
		||||
	ExecutedStatements int    `json:"executedStatements"`
 | 
			
		||||
	Terminated         bool   `json:"terminated"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行sql文件
 | 
			
		||||
func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
	g := rc.GinCtx
 | 
			
		||||
	fileheader, err := g.FormFile("file")
 | 
			
		||||
	multipart, err := g.Request.MultipartReader()
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
 | 
			
		||||
 | 
			
		||||
	file, _ := fileheader.Open()
 | 
			
		||||
	filename := fileheader.Filename
 | 
			
		||||
	file, err := multipart.NextPart()
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
	filename := file.FileName()
 | 
			
		||||
	dbId := getDbId(g)
 | 
			
		||||
	dbName := getDbName(g)
 | 
			
		||||
 | 
			
		||||
	dbConn := d.getDbConnection(rc.GinCtx)
 | 
			
		||||
	dbConn := d.DbApp.GetDbConnection(dbId, dbName)
 | 
			
		||||
	biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
 | 
			
		||||
 | 
			
		||||
	logExecRecord := true
 | 
			
		||||
	// 如果执行sql文件大于该值则不记录sql执行记录
 | 
			
		||||
	if fileheader.Size > 50*1024 {
 | 
			
		||||
		logExecRecord = false
 | 
			
		||||
	defer func() {
 | 
			
		||||
		var errInfo string
 | 
			
		||||
		switch t := recover().(type) {
 | 
			
		||||
		case biz.BizError:
 | 
			
		||||
			errInfo = t.Error()
 | 
			
		||||
		case *biz.BizError:
 | 
			
		||||
			errInfo = t.Error()
 | 
			
		||||
		case string:
 | 
			
		||||
			errInfo = t
 | 
			
		||||
		}
 | 
			
		||||
		if len(errInfo) > 0 {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	execReq := &application.DbSqlExecReq{
 | 
			
		||||
		DbId:         dbId,
 | 
			
		||||
		Db:           dbName,
 | 
			
		||||
		Remark:       filename,
 | 
			
		||||
		DbConn:       dbConn,
 | 
			
		||||
		LoginAccount: rc.LoginAccount,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if err := recover(); err != nil {
 | 
			
		||||
				var errInfo string
 | 
			
		||||
				switch t := err.(type) {
 | 
			
		||||
				case error:
 | 
			
		||||
					errInfo = t.Error()
 | 
			
		||||
				}
 | 
			
		||||
				if len(errInfo) > 0 {
 | 
			
		||||
					d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	progressId := uniqueid.IncrementID()
 | 
			
		||||
	executedStatements := 0
 | 
			
		||||
	defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
			
		||||
		Id:                 progressId,
 | 
			
		||||
		SqlFileName:        filename,
 | 
			
		||||
		ExecutedStatements: executedStatements,
 | 
			
		||||
		Terminated:         true,
 | 
			
		||||
	}).WithCategory(progressCategory))
 | 
			
		||||
 | 
			
		||||
		execReq := &application.DbSqlExecReq{
 | 
			
		||||
			DbId:         dbId,
 | 
			
		||||
			Db:           dbName,
 | 
			
		||||
			Remark:       fileheader.Filename,
 | 
			
		||||
			DbConn:       dbConn,
 | 
			
		||||
			LoginAccount: rc.LoginAccount,
 | 
			
		||||
	var parser sqlparser.Parser
 | 
			
		||||
	if dbConn.Info.Type == entity.DbTypeMysql {
 | 
			
		||||
		parser = sqlparser.NewMysqlParser(file)
 | 
			
		||||
	} else {
 | 
			
		||||
		parser = sqlparser.NewPostgresParser(file)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ticker := time.NewTicker(time.Second * 1)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
			
		||||
				Id:                 progressId,
 | 
			
		||||
				SqlFileName:        filename,
 | 
			
		||||
				ExecutedStatements: executedStatements,
 | 
			
		||||
				Terminated:         false,
 | 
			
		||||
			}).WithCategory(progressCategory))
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sqlScanner := SplitSqls(file)
 | 
			
		||||
		for sqlScanner.Scan() {
 | 
			
		||||
			sql := sqlScanner.Text()
 | 
			
		||||
		err = parser.Next()
 | 
			
		||||
		if err == io.EOF {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		sql := parser.Current()
 | 
			
		||||
		const prefixUse = "use "
 | 
			
		||||
		if strings.HasPrefix(sql, prefixUse) {
 | 
			
		||||
			dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
 | 
			
		||||
			if len(dbNameExec) > 0 {
 | 
			
		||||
				dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec)
 | 
			
		||||
				biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
 | 
			
		||||
				execReq.DbConn = dbConn
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// 需要记录执行记录
 | 
			
		||||
		const maxRecordStatements = 64
 | 
			
		||||
		if executedStatements < maxRecordStatements {
 | 
			
		||||
			execReq.Sql = sql
 | 
			
		||||
			// 需要记录执行记录
 | 
			
		||||
			if logExecRecord {
 | 
			
		||||
				_, err = d.DbSqlExecApp.Exec(execReq)
 | 
			
		||||
			} else {
 | 
			
		||||
				_, err = dbConn.Exec(sql)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			_, err = d.DbSqlExecApp.Exec(execReq)
 | 
			
		||||
		} else {
 | 
			
		||||
			_, err = dbConn.Exec(sql)
 | 
			
		||||
		}
 | 
			
		||||
		d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		executedStatements++
 | 
			
		||||
	}
 | 
			
		||||
	d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 数据库dump
 | 
			
		||||
@@ -276,23 +328,44 @@ func (d *Db) DumpSql(rc *req.Ctx) {
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func escapeSql(dbType string, sql string) string {
 | 
			
		||||
	if dbType == entity.DbTypePostgres {
 | 
			
		||||
		return pq.QuoteLiteral(sql)
 | 
			
		||||
	} else {
 | 
			
		||||
		sql = strings.ReplaceAll(sql, `\`, `\\`)
 | 
			
		||||
		sql = strings.ReplaceAll(sql, `'`, `''`)
 | 
			
		||||
		return "'" + sql + "'"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func quoteTable(dbType string, table string) string {
 | 
			
		||||
	if dbType == entity.DbTypePostgres {
 | 
			
		||||
		return "\"" + table + "\""
 | 
			
		||||
	} else {
 | 
			
		||||
		return "`" + table + "`"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) {
 | 
			
		||||
	dbConn := d.DbApp.GetDbConnection(dbId, dbName)
 | 
			
		||||
	writer.WriteString("-- ----------------------------")
 | 
			
		||||
	writer.WriteString("\n-- ----------------------------")
 | 
			
		||||
	writer.WriteString("\n-- 导出平台: mayfly-go")
 | 
			
		||||
	writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05")))
 | 
			
		||||
	writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName))
 | 
			
		||||
	writer.WriteString("\n-- ----------------------------\n")
 | 
			
		||||
	writer.TryFlush()
 | 
			
		||||
 | 
			
		||||
	if switchDb {
 | 
			
		||||
		switch dbConn.Info.Type {
 | 
			
		||||
		case entity.DbTypeMysql:
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("use `%s`;\n", dbName))
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName))
 | 
			
		||||
		default:
 | 
			
		||||
			biz.IsTrue(false, "数据库类型必须为 %s", entity.DbTypeMysql)
 | 
			
		||||
			biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if dbConn.Info.Type == entity.DbTypeMysql {
 | 
			
		||||
		writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbMeta := dbConn.GetMeta()
 | 
			
		||||
	if len(tables) == 0 {
 | 
			
		||||
		ti := dbMeta.GetTableInfos()
 | 
			
		||||
@@ -303,23 +376,22 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, table := range tables {
 | 
			
		||||
		writer.TryFlush()
 | 
			
		||||
		quotedTable := quoteTable(dbConn.Info.Type, table)
 | 
			
		||||
		if needStruct {
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table))
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
 | 
			
		||||
			writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !needData {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
 | 
			
		||||
		writer.WriteString("BEGIN;\n")
 | 
			
		||||
 | 
			
		||||
		insertSql := "INSERT INTO `%s` VALUES (%s);\n"
 | 
			
		||||
 | 
			
		||||
		insertSql := "INSERT INTO %s VALUES (%s);\n"
 | 
			
		||||
		dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) {
 | 
			
		||||
			var values []string
 | 
			
		||||
			writer.TryFlush()
 | 
			
		||||
			for _, column := range columns {
 | 
			
		||||
				value := record[column]
 | 
			
		||||
				if value == nil {
 | 
			
		||||
@@ -328,17 +400,18 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
 | 
			
		||||
				}
 | 
			
		||||
				strValue, ok := value.(string)
 | 
			
		||||
				if ok {
 | 
			
		||||
					values = append(values, fmt.Sprintf("%#v", strValue))
 | 
			
		||||
					strValue = escapeSql(dbConn.Info.Type, strValue)
 | 
			
		||||
					values = append(values, strValue)
 | 
			
		||||
				} else {
 | 
			
		||||
					values = append(values, stringx.AnyToStr(value))
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", ")))
 | 
			
		||||
			writer.TryFlush()
 | 
			
		||||
			writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", ")))
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		writer.WriteString("COMMIT;\n")
 | 
			
		||||
		writer.TryFlush()
 | 
			
		||||
	}
 | 
			
		||||
	if dbConn.Info.Type == entity.DbTypeMysql {
 | 
			
		||||
		writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -477,35 +550,3 @@ func getDbName(g *gin.Context) string {
 | 
			
		||||
	biz.NotEmpty(db, "db不能为空")
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据;\n切割sql
 | 
			
		||||
func SplitSqls(r io.Reader) *bufio.Scanner {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	re := regexp.MustCompile(`\s*;\s*\n`)
 | 
			
		||||
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, io.EOF
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		match := re.FindIndex(data)
 | 
			
		||||
 | 
			
		||||
		if match != nil {
 | 
			
		||||
			// 如果找到了";\n",判断是否为最后一行
 | 
			
		||||
			if match[1] == len(data) {
 | 
			
		||||
				// 如果是最后一行,则返回完整的切片
 | 
			
		||||
				return len(data), data, nil
 | 
			
		||||
			}
 | 
			
		||||
			// 否则,返回到";\n"之后,并且包括";\n"本身
 | 
			
		||||
			return match[1], data[:match[1]], nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return scanner
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										55
									
								
								server/internal/db/api/db_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								server/internal/db/api/db_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
			
		||||
package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_escapeSql(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name   string
 | 
			
		||||
		dbType string
 | 
			
		||||
		sql    string
 | 
			
		||||
		want   string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			dbType: entity.DbTypeMysql,
 | 
			
		||||
			sql:    "\\a\\b",
 | 
			
		||||
			want:   "'\\\\a\\\\b'",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dbType: entity.DbTypeMysql,
 | 
			
		||||
			sql:    "'a'",
 | 
			
		||||
			want:   "'''a'''",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:   "不间断空格",
 | 
			
		||||
			dbType: entity.DbTypeMysql,
 | 
			
		||||
			sql:    "a\u00A0b",
 | 
			
		||||
			want:   "'a\u00A0b'",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dbType: entity.DbTypePostgres,
 | 
			
		||||
			sql:    "\\a\\b",
 | 
			
		||||
			want:   " E'\\\\a\\\\b'",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			dbType: entity.DbTypePostgres,
 | 
			
		||||
			sql:    "'a'",
 | 
			
		||||
			want:   "'''a'''",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:   "不间断空格",
 | 
			
		||||
			dbType: entity.DbTypePostgres,
 | 
			
		||||
			sql:    "a\u00A0b",
 | 
			
		||||
			want:   "'a\u00A0b'",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	for _, tt := range tests {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			got := escapeSql(tt.dbType, tt.sql)
 | 
			
		||||
			require.Equal(t, tt.want, got)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -3,6 +3,7 @@ package application
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
			
		||||
	"mayfly-go/internal/db/config"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
@@ -10,8 +11,6 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/xwb1989/sqlparser"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DbSqlExecReq struct {
 | 
			
		||||
 
 | 
			
		||||
@@ -11,7 +11,7 @@ const InfoSysMsgType = 2
 | 
			
		||||
// websocket消息
 | 
			
		||||
type SysMsg struct {
 | 
			
		||||
	Type     int    `json:"type"`     // 消息类型
 | 
			
		||||
	Category int    `json:"category"` // 消息类别
 | 
			
		||||
	Category string `json:"category"` // 消息类别
 | 
			
		||||
	Title    string `json:"title"`    // 消息标题
 | 
			
		||||
	Msg      string `json:"msg"`      // 消息内容
 | 
			
		||||
}
 | 
			
		||||
@@ -21,7 +21,7 @@ func (sm *SysMsg) WithTitle(title string) *SysMsg {
 | 
			
		||||
	return sm
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sm *SysMsg) WithCategory(category int) *SysMsg {
 | 
			
		||||
func (sm *SysMsg) WithCategory(category string) *SysMsg {
 | 
			
		||||
	sm.Category = category
 | 
			
		||||
	return sm
 | 
			
		||||
}
 | 
			
		||||
@@ -32,7 +32,7 @@ func (sm *SysMsg) WithMsg(msg any) *SysMsg {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 普通消息
 | 
			
		||||
func NewSysMsg(title string, msg any) *SysMsg {
 | 
			
		||||
func InfoSysMsg(title string, msg any) *SysMsg {
 | 
			
		||||
	return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										99
									
								
								server/pkg/sqlparser/sqlparser.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								server/pkg/sqlparser/sqlparser.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,99 @@
 | 
			
		||||
package sqlparser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
			
		||||
	"io"
 | 
			
		||||
	"regexp"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Parser interface {
 | 
			
		||||
	Next() error
 | 
			
		||||
	Current() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Parser = &MysqlParser{}
 | 
			
		||||
var _ Parser = &PostgresParser{}
 | 
			
		||||
 | 
			
		||||
type MysqlParser struct {
 | 
			
		||||
	tokenizer *sqlparser.Tokenizer
 | 
			
		||||
	statement string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMysqlParser(reader io.Reader) *MysqlParser {
 | 
			
		||||
	return &MysqlParser{
 | 
			
		||||
		tokenizer: sqlparser.NewReaderTokenizer(reader),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *MysqlParser) Next() error {
 | 
			
		||||
	statement, err := sqlparser.ParseNext(parser.tokenizer)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		parser.statement = ""
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	parser.statement = sqlparser.String(statement)
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *MysqlParser) Current() string {
 | 
			
		||||
	return parser.statement
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PostgresParser struct {
 | 
			
		||||
	scanner   *bufio.Scanner
 | 
			
		||||
	statement string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPostgresParser(reader io.Reader) *PostgresParser {
 | 
			
		||||
	return &PostgresParser{
 | 
			
		||||
		scanner: splitSqls(reader),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *PostgresParser) Next() error {
 | 
			
		||||
	if !parser.scanner.Scan() {
 | 
			
		||||
		return io.EOF
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (parser *PostgresParser) Current() string {
 | 
			
		||||
	return parser.scanner.Text()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据;\n切割sql
 | 
			
		||||
func splitSqls(r io.Reader) *bufio.Scanner {
 | 
			
		||||
	scanner := bufio.NewScanner(r)
 | 
			
		||||
	re := regexp.MustCompile(`\s*;\s*\n`)
 | 
			
		||||
 | 
			
		||||
	scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
 | 
			
		||||
		if atEOF && len(data) == 0 {
 | 
			
		||||
			return 0, nil, io.EOF
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		match := re.FindIndex(data)
 | 
			
		||||
 | 
			
		||||
		if match != nil {
 | 
			
		||||
			// 如果找到了";\n",判断是否为最后一行
 | 
			
		||||
			if match[1] == len(data) {
 | 
			
		||||
				// 如果是最后一行,则返回完整的切片
 | 
			
		||||
				return len(data), data, nil
 | 
			
		||||
			}
 | 
			
		||||
			// 否则,返回到";\n"之后,并且包括";\n"本身
 | 
			
		||||
			return match[1], data[:match[1]], nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if atEOF {
 | 
			
		||||
			return len(data), data, nil
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return 0, nil, nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return scanner
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func SplitStatementToPieces(sql string) ([]string, error) {
 | 
			
		||||
	return sqlparser.SplitStatementToPieces(sql)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										98
									
								
								server/pkg/sqlparser/sqlparser_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								server/pkg/sqlparser/sqlparser_test.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,98 @@
 | 
			
		||||
package sqlparser
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/kanzihuang/vitess/go/vt/sqlparser"
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	sqlparser_xwb1989 "github.com/xwb1989/sqlparser"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_ParseNext_WithCurrentDate(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name        string
 | 
			
		||||
		input       string
 | 
			
		||||
		want        string
 | 
			
		||||
		wantXwb1989 string
 | 
			
		||||
		err         string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_timestamp",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
			
		||||
			// xwb1989/sqlparser 不支持 current_timestamp()
 | 
			
		||||
			wantXwb1989: "create table tbl",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_date",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
			
		||||
			// xwb1989/sqlparser 不支持 current_date()
 | 
			
		||||
			wantXwb1989: "create table tbl",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input))
 | 
			
		||||
			tree, err := sqlparser.ParseNext(token)
 | 
			
		||||
			if len(test.err) > 0 {
 | 
			
		||||
				require.Error(t, err)
 | 
			
		||||
				require.Contains(t, err.Error(), test.err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
			if len(test.want) == 0 {
 | 
			
		||||
				test.want = test.input
 | 
			
		||||
			}
 | 
			
		||||
			require.Equal(t, test.want, sqlparser.String(tree))
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input))
 | 
			
		||||
			tree, err := sqlparser_xwb1989.ParseNext(token)
 | 
			
		||||
			if len(test.err) > 0 {
 | 
			
		||||
				require.Error(t, err)
 | 
			
		||||
				require.Contains(t, err.Error(), test.err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			require.NoError(t, err)
 | 
			
		||||
			if len(test.want) == 0 {
 | 
			
		||||
				test.want = test.input
 | 
			
		||||
			}
 | 
			
		||||
			require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree))
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Test_SplitSqls(t *testing.T) {
 | 
			
		||||
	tests := []struct {
 | 
			
		||||
		name  string
 | 
			
		||||
		input string
 | 
			
		||||
		want  string
 | 
			
		||||
	}{
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_timestamp",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:  "create table with current_date",
 | 
			
		||||
			input: "create table tbl (\n\tcreate_at date default current_date()\n)",
 | 
			
		||||
		},
 | 
			
		||||
		{
 | 
			
		||||
			name:  "select with ';\n'",
 | 
			
		||||
			input: "select 'the first line;\nthe second line;\n'",
 | 
			
		||||
			// SplitSqls split statements by ';\n'
 | 
			
		||||
			want: "select 'the first line;\n",
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	for _, test := range tests {
 | 
			
		||||
		t.Run(test.name, func(t *testing.T) {
 | 
			
		||||
			scanner := splitSqls(strings.NewReader(test.input))
 | 
			
		||||
			require.True(t, scanner.Scan())
 | 
			
		||||
			got := scanner.Text()
 | 
			
		||||
			if len(test.want) == 0 {
 | 
			
		||||
				test.want = test.input
 | 
			
		||||
			}
 | 
			
		||||
			require.Equal(t, test.want, got)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								server/pkg/utils/uniqueid/uniqueid.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								server/pkg/utils/uniqueid/uniqueid.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,9 @@
 | 
			
		||||
package uniqueid
 | 
			
		||||
 | 
			
		||||
import "sync/atomic"
 | 
			
		||||
 | 
			
		||||
var id uint64 = 0
 | 
			
		||||
 | 
			
		||||
func IncrementID() uint64 {
 | 
			
		||||
	return atomic.AddUint64(&id, 1)
 | 
			
		||||
}
 | 
			
		||||
@@ -7,7 +7,7 @@ import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 心跳间隔
 | 
			
		||||
var heartbeatInterval = 25 * time.Second
 | 
			
		||||
const heartbeatInterval = 25 * time.Second
 | 
			
		||||
 | 
			
		||||
// 连接管理
 | 
			
		||||
type ClientManager struct {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user