mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 23:40:24 +08:00
feat: 前端显示 SQL 文件执行进度
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
import Config from './config';
|
||||
import { ElNotification } from 'element-plus';
|
||||
import { ElNotification, NotificationHandle } from 'element-plus';
|
||||
import SocketBuilder from './SocketBuilder';
|
||||
import { getToken } from '@/common/utils/storage';
|
||||
import { createVNode, reactive } from "vue";
|
||||
import { buildProgressProps } from "@/components/progress-notify/progress-notify";
|
||||
import ProgressNotify from '/src/components/progress-notify/progress-notify.vue';
|
||||
|
||||
export default {
|
||||
/**
|
||||
@@ -12,32 +15,63 @@ export default {
|
||||
if (!token) {
|
||||
return null;
|
||||
}
|
||||
const messageTypes = {
|
||||
0: "error",
|
||||
1: "success",
|
||||
2: "info",
|
||||
}
|
||||
const notifyMap: Map<Number, any> = new Map()
|
||||
|
||||
return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`)
|
||||
.message((event: { data: string }) => {
|
||||
const message = JSON.parse(event.data);
|
||||
let mtype: string;
|
||||
switch (message.type) {
|
||||
case 0:
|
||||
mtype = 'error';
|
||||
break;
|
||||
case 2:
|
||||
mtype = 'info';
|
||||
break;
|
||||
case 1:
|
||||
mtype = 'success';
|
||||
const type = messageTypes[message.type]
|
||||
switch (message.category) {
|
||||
case "execSqlFileProgress":
|
||||
const content = JSON.parse(message.msg)
|
||||
const id = content.id
|
||||
let progress = notifyMap.get(id)
|
||||
if (content.terminated) {
|
||||
if (progress != undefined) {
|
||||
progress.notification?.close()
|
||||
notifyMap.delete(id)
|
||||
progress = undefined
|
||||
}
|
||||
return
|
||||
}
|
||||
if (progress == undefined) {
|
||||
progress = {
|
||||
props: reactive(buildProgressProps()),
|
||||
notification: undefined,
|
||||
}
|
||||
}
|
||||
progress.props.progress.sqlFileName = content.sqlFileName
|
||||
progress.props.progress.executedStatements = content.executedStatements
|
||||
if (!notifyMap.has(id)) {
|
||||
const vNodeMessage = createVNode(
|
||||
ProgressNotify,
|
||||
progress.props,
|
||||
null,
|
||||
)
|
||||
progress.notification = ElNotification({
|
||||
duration: 0,
|
||||
title: message.title,
|
||||
message: vNodeMessage,
|
||||
type: type,
|
||||
showClose: false,
|
||||
});
|
||||
notifyMap.set(id, progress)
|
||||
}
|
||||
break;
|
||||
default:
|
||||
mtype = 'info';
|
||||
ElNotification({
|
||||
duration: 0,
|
||||
title: message.title,
|
||||
message: message.msg,
|
||||
type: type,
|
||||
});
|
||||
break;
|
||||
}
|
||||
if (mtype == undefined) {
|
||||
return;
|
||||
}
|
||||
ElNotification({
|
||||
duration: 0,
|
||||
title: message.title,
|
||||
message: message.msg,
|
||||
type: mtype as any,
|
||||
});
|
||||
})
|
||||
.open((event: any) => console.log(event))
|
||||
.build();
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
export const buildProgressProps = () : any => {
|
||||
return {
|
||||
progress: {
|
||||
sqlFileName: {
|
||||
type: String
|
||||
},
|
||||
executedStatements: {
|
||||
type: Number
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
<template>
|
||||
<el-descriptions
|
||||
border
|
||||
size="small"
|
||||
:title="`${progress.sqlFileName}`"
|
||||
>
|
||||
<el-descriptions-item label="时间">{{ state.elapsedTime }}</el-descriptions-item>
|
||||
<el-descriptions-item label="已处理">{{ progress.executedStatements }}</el-descriptions-item>
|
||||
</el-descriptions>
|
||||
</template>
|
||||
<script lang="ts" setup>
|
||||
|
||||
import {onMounted, onUnmounted, reactive} from "vue";
|
||||
import {formatTime} from 'element-plus/es/components/countdown/src/utils';
|
||||
import {buildProgressProps} from "./progress-notify";
|
||||
|
||||
const props = defineProps(buildProgressProps());
|
||||
|
||||
const state = reactive({
|
||||
elapsedTime: "00:00:00"
|
||||
});
|
||||
|
||||
let timer = undefined;
|
||||
const startTime = Date.now()
|
||||
|
||||
onMounted(async () => {
|
||||
timer = setInterval(() => {
|
||||
const elapsed = Date.now() - startTime;
|
||||
state.elapsedTime = formatTime(elapsed, 'HH:mm:ss')
|
||||
}, 1000);
|
||||
|
||||
});
|
||||
|
||||
onUnmounted(async () => {
|
||||
if (timer != undefined) {
|
||||
clearInterval(timer); // 在Vue实例销毁前,清除我们的定时器
|
||||
timer = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
</script>
|
||||
@@ -206,7 +206,7 @@ router.beforeEach(async (to, from, next) => {
|
||||
|
||||
if (SysWs) {
|
||||
SysWs.close();
|
||||
SysWs = null;
|
||||
SysWs = undefined;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ require (
|
||||
github.com/pquerna/otp v1.4.0
|
||||
github.com/redis/go-redis/v9 v9.2.1
|
||||
github.com/robfig/cron/v3 v3.0.1 // 定时任务
|
||||
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
|
||||
github.com/stretchr/testify v1.8.4
|
||||
go.mongodb.org/mongo-driver v1.12.1 // mongo
|
||||
golang.org/x/crypto v0.14.0 // ssh
|
||||
golang.org/x/oauth2 v0.13.0
|
||||
@@ -32,12 +32,15 @@ require (
|
||||
gorm.io/gorm v1.25.4
|
||||
)
|
||||
|
||||
require github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
|
||||
|
||||
require (
|
||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
||||
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
@@ -45,21 +48,22 @@ require (
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/golang/snappy v0.0.1 // indirect
|
||||
github.com/golang/snappy v0.0.4 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/compress v1.13.6 // indirect
|
||||
github.com/klauspost/compress v1.16.5 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/kr/fs v0.1.0 // indirect
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/kr/pretty v0.3.1 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.17 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
|
||||
github.com/montanaflynn/stats v0.7.0 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
github.com/xdg-go/pbkdf2 v1.0.0 // indirect
|
||||
|
||||
@@ -3,7 +3,11 @@ package api
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"github.com/lib/pq"
|
||||
"io"
|
||||
"mayfly-go/pkg/utils/uniqueid"
|
||||
"mayfly-go/pkg/ws"
|
||||
|
||||
"mayfly-go/internal/db/api/form"
|
||||
"mayfly-go/internal/db/api/vo"
|
||||
"mayfly-go/internal/db/application"
|
||||
@@ -79,9 +83,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
|
||||
dbName := g.Query("db")
|
||||
biz.NotEmpty(dbName, "db不能为空")
|
||||
return d.DbApp.GetDbConnection(getDbId(g), dbName)
|
||||
return d.DbApp.GetDbConnection(getDbId(g), getDbName(g))
|
||||
}
|
||||
|
||||
func (d *Db) TableInfos(rc *req.Ctx) {
|
||||
@@ -152,67 +154,120 @@ func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
rc.ResData = colAndRes
|
||||
}
|
||||
|
||||
// progressCategory sql文件执行进度消息类型
|
||||
const progressCategory = "execSqlFileProgress"
|
||||
|
||||
// progressMsg sql文件执行进度消息
|
||||
type progressMsg struct {
|
||||
Id uint64 `json:"id"`
|
||||
SqlFileName string `json:"sqlFileName"`
|
||||
ExecutedStatements int `json:"executedStatements"`
|
||||
Terminated bool `json:"terminated"`
|
||||
}
|
||||
|
||||
// 执行sql文件
|
||||
func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
fileheader, err := g.FormFile("file")
|
||||
multipart, err := g.Request.MultipartReader()
|
||||
biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
|
||||
|
||||
file, _ := fileheader.Open()
|
||||
filename := fileheader.Filename
|
||||
file, err := multipart.NextPart()
|
||||
biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
|
||||
defer file.Close()
|
||||
filename := file.FileName()
|
||||
dbId := getDbId(g)
|
||||
dbName := getDbName(g)
|
||||
|
||||
dbConn := d.getDbConnection(rc.GinCtx)
|
||||
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
|
||||
|
||||
logExecRecord := true
|
||||
// 如果执行sql文件大于该值则不记录sql执行记录
|
||||
if fileheader.Size > 50*1024 {
|
||||
logExecRecord = false
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
var errInfo string
|
||||
switch t := err.(type) {
|
||||
case biz.BizError:
|
||||
errInfo = t.Error()
|
||||
case *biz.BizError:
|
||||
errInfo = t.Error()
|
||||
case string:
|
||||
errInfo = t
|
||||
}
|
||||
if len(errInfo) > 0 {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
execReq := &application.DbSqlExecReq{
|
||||
DbId: dbId,
|
||||
Db: dbName,
|
||||
Remark: filename,
|
||||
DbConn: dbConn,
|
||||
LoginAccount: rc.LoginAccount,
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
var errInfo string
|
||||
switch t := err.(type) {
|
||||
case error:
|
||||
errInfo = t.Error()
|
||||
}
|
||||
if len(errInfo) > 0 {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
||||
}
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
var errInfo string
|
||||
switch t := err.(type) {
|
||||
case error:
|
||||
errInfo = t.Error()
|
||||
}
|
||||
}()
|
||||
|
||||
execReq := &application.DbSqlExecReq{
|
||||
DbId: dbId,
|
||||
Db: dbName,
|
||||
Remark: fileheader.Filename,
|
||||
DbConn: dbConn,
|
||||
LoginAccount: rc.LoginAccount,
|
||||
}
|
||||
|
||||
sqlScanner := SplitSqls(file)
|
||||
for sqlScanner.Scan() {
|
||||
sql := sqlScanner.Text()
|
||||
execReq.Sql = sql
|
||||
// 需要记录执行记录
|
||||
if logExecRecord {
|
||||
_, err = d.DbSqlExecApp.Exec(execReq)
|
||||
} else {
|
||||
_, err = dbConn.Exec(sql)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
|
||||
return
|
||||
if len(errInfo) > 0 {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
|
||||
}
|
||||
}
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
|
||||
}()
|
||||
|
||||
progressId := uniqueid.IncrementID()
|
||||
executedStatements := 0
|
||||
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
Id: progressId,
|
||||
SqlFileName: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: true,
|
||||
}).WithCategory(progressCategory))
|
||||
|
||||
ticker := time.NewTicker(time.Second * 1)
|
||||
defer ticker.Stop()
|
||||
sqlScanner := SplitSqls(file)
|
||||
for sqlScanner.Scan() {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||||
Id: progressId,
|
||||
SqlFileName: filename,
|
||||
ExecutedStatements: executedStatements,
|
||||
Terminated: false,
|
||||
}).WithCategory(progressCategory))
|
||||
default:
|
||||
}
|
||||
sql := sqlScanner.Text()
|
||||
const prefixUse = "use "
|
||||
if strings.HasPrefix(sql, prefixUse) {
|
||||
dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
|
||||
if len(dbNameExec) > 0 {
|
||||
dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec)
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
execReq.DbConn = dbConn
|
||||
}
|
||||
}
|
||||
// 需要记录执行记录
|
||||
const maxRecordStatements = 64
|
||||
if executedStatements < maxRecordStatements {
|
||||
execReq.Sql = sql
|
||||
_, err = d.DbSqlExecApp.Exec(execReq)
|
||||
} else {
|
||||
_, err = dbConn.Exec(sql)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
|
||||
return
|
||||
}
|
||||
executedStatements++
|
||||
}
|
||||
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
|
||||
}
|
||||
|
||||
// 数据库dump
|
||||
@@ -276,23 +331,44 @@ func (d *Db) DumpSql(rc *req.Ctx) {
|
||||
rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType)
|
||||
}
|
||||
|
||||
func escapeSql(dbType string, sql string) string {
|
||||
if dbType == entity.DbTypePostgres {
|
||||
return pq.QuoteLiteral(sql)
|
||||
} else {
|
||||
sql = strings.ReplaceAll(sql, `\`, `\\`)
|
||||
sql = strings.ReplaceAll(sql, `'`, `''`)
|
||||
return "'" + sql + "'"
|
||||
}
|
||||
}
|
||||
|
||||
func quoteTable(dbType string, table string) string {
|
||||
if dbType == entity.DbTypePostgres {
|
||||
return "\"" + table + "\""
|
||||
} else {
|
||||
return "`" + table + "`"
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) {
|
||||
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
|
||||
writer.WriteString("-- ----------------------------")
|
||||
writer.WriteString("\n-- ----------------------------")
|
||||
writer.WriteString("\n-- 导出平台: mayfly-go")
|
||||
writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05")))
|
||||
writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName))
|
||||
writer.WriteString("\n-- ----------------------------\n")
|
||||
writer.TryFlush()
|
||||
|
||||
if switchDb {
|
||||
switch dbConn.Info.Type {
|
||||
case entity.DbTypeMysql:
|
||||
writer.WriteString(fmt.Sprintf("use `%s`;\n", dbName))
|
||||
writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName))
|
||||
default:
|
||||
biz.IsTrue(false, "数据库类型必须为 %s", entity.DbTypeMysql)
|
||||
biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql)
|
||||
}
|
||||
}
|
||||
if dbConn.Info.Type == entity.DbTypeMysql {
|
||||
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n")
|
||||
}
|
||||
|
||||
dbMeta := dbConn.GetMeta()
|
||||
if len(tables) == 0 {
|
||||
ti := dbMeta.GetTableInfos()
|
||||
@@ -303,23 +379,22 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
writer.TryFlush()
|
||||
quotedTable := quoteTable(dbConn.Info.Type, table)
|
||||
if needStruct {
|
||||
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
|
||||
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table))
|
||||
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
|
||||
writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n")
|
||||
}
|
||||
|
||||
if !needData {
|
||||
continue
|
||||
}
|
||||
|
||||
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
|
||||
writer.WriteString("BEGIN;\n")
|
||||
|
||||
insertSql := "INSERT INTO `%s` VALUES (%s);\n"
|
||||
|
||||
insertSql := "INSERT INTO %s VALUES (%s);\n"
|
||||
dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) {
|
||||
var values []string
|
||||
writer.TryFlush()
|
||||
for _, column := range columns {
|
||||
value := record[column]
|
||||
if value == nil {
|
||||
@@ -328,17 +403,18 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
}
|
||||
strValue, ok := value.(string)
|
||||
if ok {
|
||||
values = append(values, fmt.Sprintf("%#v", strValue))
|
||||
strValue = escapeSql(dbConn.Info.Type, strValue)
|
||||
values = append(values, strValue)
|
||||
} else {
|
||||
values = append(values, stringx.AnyToStr(value))
|
||||
}
|
||||
}
|
||||
writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", ")))
|
||||
writer.TryFlush()
|
||||
writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", ")))
|
||||
})
|
||||
|
||||
writer.WriteString("COMMIT;\n")
|
||||
writer.TryFlush()
|
||||
}
|
||||
if dbConn.Info.Type == entity.DbTypeMysql {
|
||||
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
90
server/internal/db/api/db_test.go
Normal file
90
server/internal/db/api/db_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_escapeSql(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
dbType string
|
||||
sql string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
dbType: entity.DbTypeMysql,
|
||||
sql: "\\a\\b",
|
||||
want: "'\\\\a\\\\b'",
|
||||
},
|
||||
{
|
||||
dbType: entity.DbTypeMysql,
|
||||
sql: "'a'",
|
||||
want: "'''a'''",
|
||||
},
|
||||
{
|
||||
name: "不间断空格",
|
||||
dbType: entity.DbTypeMysql,
|
||||
sql: "a\u00A0b",
|
||||
want: "'a\u00A0b'",
|
||||
},
|
||||
{
|
||||
dbType: entity.DbTypePostgres,
|
||||
sql: "\\a\\b",
|
||||
want: " E'\\\\a\\\\b'",
|
||||
},
|
||||
{
|
||||
dbType: entity.DbTypePostgres,
|
||||
sql: "'a'",
|
||||
want: "'''a'''",
|
||||
},
|
||||
{
|
||||
name: "不间断空格",
|
||||
dbType: entity.DbTypePostgres,
|
||||
sql: "a\u00A0b",
|
||||
want: "'a\u00A0b'",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := escapeSql(tt.dbType, tt.sql)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SplitSqls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "create table with current_timestamp",
|
||||
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
||||
},
|
||||
{
|
||||
name: "create table with current_date",
|
||||
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
||||
},
|
||||
{
|
||||
name: "select with ';\n'",
|
||||
input: "select 'the first line;\nthe second line;\n'",
|
||||
// SplitSqls split statements by ';\n'
|
||||
want: "select 'the first line;\n",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
scanner := SplitSqls(strings.NewReader(test.input))
|
||||
require.True(t, scanner.Scan())
|
||||
got := scanner.Text()
|
||||
if len(test.want) == 0 {
|
||||
test.want = test.input
|
||||
}
|
||||
require.Equal(t, test.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
44
server/internal/db/api/sqlparser_test.go
Normal file
44
server/internal/db/api/sqlparser_test.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/xwb1989/sqlparser"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_ParseNext_WithCurrentDate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
want string
|
||||
wantXwb1989 string
|
||||
err string
|
||||
}{
|
||||
{
|
||||
name: "create table with current_timestamp",
|
||||
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
|
||||
// xwb1989/sqlparser 不支持 current_timestamp()
|
||||
wantXwb1989: "create table tbl",
|
||||
},
|
||||
{
|
||||
name: "create table with current_date",
|
||||
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
|
||||
// xwb1989/sqlparser 不支持 current_date()
|
||||
wantXwb1989: "create table tbl",
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
token := sqlparser.NewTokenizer(strings.NewReader(test.input))
|
||||
tree, err := sqlparser.ParseNext(token)
|
||||
if len(test.err) > 0 {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), test.err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, test.wantXwb1989, sqlparser.String(tree))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package application
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/xwb1989/sqlparser"
|
||||
"mayfly-go/internal/db/config"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
@@ -10,8 +11,6 @@ import (
|
||||
"mayfly-go/pkg/model"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/xwb1989/sqlparser"
|
||||
)
|
||||
|
||||
type DbSqlExecReq struct {
|
||||
|
||||
@@ -11,7 +11,7 @@ const InfoSysMsgType = 2
|
||||
// websocket消息
|
||||
type SysMsg struct {
|
||||
Type int `json:"type"` // 消息类型
|
||||
Category int `json:"category"` // 消息类别
|
||||
Category string `json:"category"` // 消息类别
|
||||
Title string `json:"title"` // 消息标题
|
||||
Msg string `json:"msg"` // 消息内容
|
||||
}
|
||||
@@ -21,7 +21,7 @@ func (sm *SysMsg) WithTitle(title string) *SysMsg {
|
||||
return sm
|
||||
}
|
||||
|
||||
func (sm *SysMsg) WithCategory(category int) *SysMsg {
|
||||
func (sm *SysMsg) WithCategory(category string) *SysMsg {
|
||||
sm.Category = category
|
||||
return sm
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (sm *SysMsg) WithMsg(msg any) *SysMsg {
|
||||
}
|
||||
|
||||
// 普通消息
|
||||
func NewSysMsg(title string, msg any) *SysMsg {
|
||||
func InfoSysMsg(title string, msg any) *SysMsg {
|
||||
return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)}
|
||||
}
|
||||
|
||||
|
||||
9
server/pkg/utils/uniqueid/uniqueid.go
Normal file
9
server/pkg/utils/uniqueid/uniqueid.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package uniqueid
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
var id uint64 = 0
|
||||
|
||||
func IncrementID() uint64 {
|
||||
return atomic.AddUint64(&id, 1)
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
// 心跳间隔
|
||||
var heartbeatInterval = 25 * time.Second
|
||||
const heartbeatInterval = 25 * time.Second
|
||||
|
||||
// 连接管理
|
||||
type ClientManager struct {
|
||||
|
||||
Reference in New Issue
Block a user