Merge pull request #69 from kanzihuang/feat-progress-notify-pullrequest

feat: 显示 SQL 文件执行进度
This commit is contained in:
may-fly
2023-10-10 20:52:24 -05:00
committed by GitHub
13 changed files with 535 additions and 134 deletions

View File

@@ -1,7 +1,10 @@
import Config from './config'; import Config from './config';
import { ElNotification } from 'element-plus'; import { ElNotification, NotificationHandle } from 'element-plus';
import SocketBuilder from './SocketBuilder'; import SocketBuilder from './SocketBuilder';
import { getToken } from '@/common/utils/storage'; import { getToken } from '@/common/utils/storage';
import { createVNode, reactive } from "vue";
import { buildProgressProps } from "@/components/progress-notify/progress-notify";
import ProgressNotify from '/src/components/progress-notify/progress-notify.vue';
export default { export default {
/** /**
@@ -12,32 +15,63 @@ export default {
if (!token) { if (!token) {
return null; return null;
} }
const messageTypes = {
0: "error",
1: "success",
2: "info",
}
const notifyMap: Map<Number, any> = new Map()
return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`) return SocketBuilder.builder(`${Config.baseWsUrl}/sysmsg?token=${token}`)
.message((event: { data: string }) => { .message((event: { data: string }) => {
const message = JSON.parse(event.data); const message = JSON.parse(event.data);
let mtype: string; const type = messageTypes[message.type]
switch (message.type) { switch (message.category) {
case 0: case "execSqlFileProgress":
mtype = 'error'; const content = JSON.parse(message.msg)
break; const id = content.id
case 2: let progress = notifyMap.get(id)
mtype = 'info'; if (content.terminated) {
break; if (progress != undefined) {
case 1: progress.notification?.close()
mtype = 'success'; notifyMap.delete(id)
progress = undefined
}
return
}
if (progress == undefined) {
progress = {
props: reactive(buildProgressProps()),
notification: undefined,
}
}
progress.props.progress.sqlFileName = content.sqlFileName
progress.props.progress.executedStatements = content.executedStatements
if (!notifyMap.has(id)) {
const vNodeMessage = createVNode(
ProgressNotify,
progress.props,
null,
)
progress.notification = ElNotification({
duration: 0,
title: message.title,
message: vNodeMessage,
type: type,
showClose: false,
});
notifyMap.set(id, progress)
}
break; break;
default: default:
mtype = 'info'; ElNotification({
duration: 0,
title: message.title,
message: message.msg,
type: type,
});
break;
} }
if (mtype == undefined) {
return;
}
ElNotification({
duration: 0,
title: message.title,
message: message.msg,
type: mtype as any,
});
}) })
.open((event: any) => console.log(event)) .open((event: any) => console.log(event))
.build(); .build();

View File

@@ -0,0 +1,14 @@
export const buildProgressProps = () : any => {
return {
progress: {
sqlFileName: {
type: String
},
executedStatements: {
type: Number
},
},
};
}

View File

@@ -0,0 +1,41 @@
<template>
<el-descriptions
border
size="small"
:title="`${progress.sqlFileName}`"
>
<el-descriptions-item label="时间">{{ state.elapsedTime }}</el-descriptions-item>
<el-descriptions-item label="已处理">{{ progress.executedStatements }}</el-descriptions-item>
</el-descriptions>
</template>
<script lang="ts" setup>
import {onMounted, onUnmounted, reactive} from "vue";
import {formatTime} from 'element-plus/es/components/countdown/src/utils';
import {buildProgressProps} from "./progress-notify";
const props = defineProps(buildProgressProps());
const state = reactive({
elapsedTime: "00:00:00"
});
let timer = undefined;
const startTime = Date.now()
onMounted(async () => {
timer = setInterval(() => {
const elapsed = Date.now() - startTime;
state.elapsedTime = formatTime(elapsed, 'HH:mm:ss')
}, 1000);
});
onUnmounted(async () => {
if (timer != undefined) {
clearInterval(timer); // 在Vue实例销毁前清除我们的定时器
timer = undefined;
}
});
</script>

View File

@@ -206,7 +206,7 @@ router.beforeEach(async (to, from, next) => {
if (SysWs) { if (SysWs) {
SysWs.close(); SysWs.close();
SysWs = null; SysWs = undefined;
} }
return; return;
} }

View File

@@ -13,6 +13,7 @@ require (
github.com/go-sql-driver/mysql v1.7.1 github.com/go-sql-driver/mysql v1.7.1
github.com/golang-jwt/jwt/v5 v5.0.0 github.com/golang-jwt/jwt/v5 v5.0.0
github.com/gorilla/websocket v1.5.0 github.com/gorilla/websocket v1.5.0
github.com/kanzihuang/vitess/go/vt/sqlparser v0.0.0-20231007020222-b91ee5ef3b31
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d github.com/lionsoul2014/ip2region/binding/golang v0.0.0-20230712084735-068dc2aee82d
github.com/mojocn/base64Captcha v1.3.5 // github.com/mojocn/base64Captcha v1.3.5 //
@@ -21,45 +22,52 @@ require (
github.com/pquerna/otp v1.4.0 github.com/pquerna/otp v1.4.0
github.com/redis/go-redis/v9 v9.2.1 github.com/redis/go-redis/v9 v9.2.1
github.com/robfig/cron/v3 v3.0.1 // github.com/robfig/cron/v3 v3.0.1 //
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2 github.com/stretchr/testify v1.8.4
go.mongodb.org/mongo-driver v1.12.1 // mongo go.mongodb.org/mongo-driver v1.12.1 // mongo
golang.org/x/crypto v0.14.0 // ssh golang.org/x/crypto v0.14.0 // ssh
golang.org/x/oauth2 v0.13.0 golang.org/x/oauth2 v0.13.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
// gorm // gorm
gorm.io/driver/mysql v1.5.1 gorm.io/driver/mysql v1.5.1
gorm.io/driver/sqlite v1.5.4
gorm.io/gorm v1.25.4 gorm.io/gorm v1.25.4
) )
require (
github.com/xwb1989/sqlparser v0.0.0-20180606152119-120387863bf2
gorm.io/driver/sqlite v1.5.1
)
require ( require (
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect github.com/boombuler/barcode v1.0.1-0.20190219062509-6c824513bacc // indirect
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect github.com/go-asn1-ber/asn1-ber v1.5.4 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/glog v1.0.0 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.1 // indirect github.com/golang/snappy v0.0.4 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect github.com/jinzhu/now v1.1.5 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/compress v1.13.6 // indirect github.com/klauspost/compress v1.16.5 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/kr/fs v0.1.0 // indirect github.com/kr/fs v0.1.0 // indirect
github.com/kr/pretty v0.3.0 // indirect
github.com/leodido/go-urn v1.2.4 // indirect github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-isatty v0.0.19 // indirect
github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/mattn/go-sqlite3 v1.14.16 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect github.com/montanaflynn/stats v0.7.0 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect github.com/ugorji/go/codec v1.2.11 // indirect
github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect
@@ -67,12 +75,15 @@ require (
github.com/xdg-go/stringprep v1.0.4 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect
golang.org/x/arch v0.3.0 // indirect golang.org/x/arch v0.3.0 // indirect
golang.org/x/exp v0.0.0-20230519143937-03e91628a987 // indirect
golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect golang.org/x/image v0.0.0-20220302094943-723b81ca9867 // indirect
golang.org/x/net v0.16.0 // indirect golang.org/x/net v0.16.0 // indirect
golang.org/x/sync v0.1.0 // indirect golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.13.0 // indirect
golang.org/x/text v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect
google.golang.org/appengine v1.6.7 // indirect google.golang.org/appengine v1.6.7 // indirect
google.golang.org/genproto v0.0.0-20230131230820-1c016267d619 // indirect
google.golang.org/grpc v1.52.3 // indirect
google.golang.org/protobuf v1.31.0 // indirect google.golang.org/protobuf v1.31.0 // indirect
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect vitess.io/vitess v0.17.3 // indirect
) )

View File

@@ -1,9 +1,12 @@
package api package api
import ( import (
"bufio"
"fmt" "fmt"
"github.com/lib/pq"
"io" "io"
"mayfly-go/pkg/utils/uniqueid"
"mayfly-go/pkg/ws"
"mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/api/vo"
"mayfly-go/internal/db/application" "mayfly-go/internal/db/application"
@@ -16,14 +19,13 @@ import (
"mayfly-go/pkg/gormx" "mayfly-go/pkg/gormx"
"mayfly-go/pkg/model" "mayfly-go/pkg/model"
"mayfly-go/pkg/req" "mayfly-go/pkg/req"
"mayfly-go/pkg/sqlparser"
"mayfly-go/pkg/utils/stringx" "mayfly-go/pkg/utils/stringx"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/xwb1989/sqlparser"
) )
type Db struct { type Db struct {
@@ -79,9 +81,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
} }
func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection { func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
dbName := g.Query("db") return d.DbApp.GetDbConnection(getDbId(g), getDbName(g))
biz.NotEmpty(dbName, "db不能为空")
return d.DbApp.GetDbConnection(getDbId(g), dbName)
} }
func (d *Db) TableInfos(rc *req.Ctx) { func (d *Db) TableInfos(rc *req.Ctx) {
@@ -152,67 +152,119 @@ func (d *Db) ExecSql(rc *req.Ctx) {
rc.ResData = colAndRes rc.ResData = colAndRes
} }
// progressCategory sql文件执行进度消息类型
const progressCategory = "execSqlFileProgress"
// progressMsg sql文件执行进度消息
type progressMsg struct {
Id uint64 `json:"id"`
SqlFileName string `json:"sqlFileName"`
ExecutedStatements int `json:"executedStatements"`
Terminated bool `json:"terminated"`
}
// 执行sql文件 // 执行sql文件
func (d *Db) ExecSqlFile(rc *req.Ctx) { func (d *Db) ExecSqlFile(rc *req.Ctx) {
g := rc.GinCtx g := rc.GinCtx
fileheader, err := g.FormFile("file") multipart, err := g.Request.MultipartReader()
biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s") biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
file, err := multipart.NextPart()
file, _ := fileheader.Open() biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
filename := fileheader.Filename defer file.Close()
filename := file.FileName()
dbId := getDbId(g) dbId := getDbId(g)
dbName := getDbName(g) dbName := getDbName(g)
dbConn := d.getDbConnection(rc.GinCtx) dbConn := d.DbApp.GetDbConnection(dbId, dbName)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename) rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
logExecRecord := true defer func() {
// 如果执行sql文件大于该值则不记录sql执行记录 var errInfo string
if fileheader.Size > 50*1024 { switch t := recover().(type) {
logExecRecord = false case biz.BizError:
errInfo = t.Error()
case *biz.BizError:
errInfo = t.Error()
case string:
errInfo = t
}
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
}()
execReq := &application.DbSqlExecReq{
DbId: dbId,
Db: dbName,
Remark: filename,
DbConn: dbConn,
LoginAccount: rc.LoginAccount,
} }
go func() { progressId := uniqueid.IncrementID()
defer func() { executedStatements := 0
if err := recover(); err != nil { defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
var errInfo string Id: progressId,
switch t := err.(type) { SqlFileName: filename,
case error: ExecutedStatements: executedStatements,
errInfo = t.Error() Terminated: true,
} }).WithCategory(progressCategory))
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
}
}()
execReq := &application.DbSqlExecReq{ var parser sqlparser.Parser
DbId: dbId, if dbConn.Info.Type == entity.DbTypeMysql {
Db: dbName, parser = sqlparser.NewMysqlParser(file)
Remark: fileheader.Filename, } else {
DbConn: dbConn, parser = sqlparser.NewPostgresParser(file)
LoginAccount: rc.LoginAccount, }
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
for {
select {
case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
SqlFileName: filename,
ExecutedStatements: executedStatements,
Terminated: false,
}).WithCategory(progressCategory))
default:
} }
err = parser.Next()
sqlScanner := SplitSqls(file) if err == io.EOF {
for sqlScanner.Scan() { break
sql := sqlScanner.Text() }
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] 解析SQL失败: [%s]", filename, dbConn.Info.GetLogDesc(), err.Error())))
return
}
sql := parser.Current()
const prefixUse = "use "
if strings.HasPrefix(sql, prefixUse) {
dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
if len(dbNameExec) > 0 {
dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
execReq.DbConn = dbConn
}
}
// 需要记录执行记录
const maxRecordStatements = 64
if executedStatements < maxRecordStatements {
execReq.Sql = sql execReq.Sql = sql
// 需要记录执行记录 _, err = d.DbSqlExecApp.Exec(execReq)
if logExecRecord { } else {
_, err = d.DbSqlExecApp.Exec(execReq) _, err = dbConn.Exec(sql)
} else {
_, err = dbConn.Exec(sql)
}
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return
}
} }
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
}() if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return
}
executedStatements++
}
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
} }
// 数据库dump // 数据库dump
@@ -276,23 +328,44 @@ func (d *Db) DumpSql(rc *req.Ctx) {
rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType) rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType)
} }
func escapeSql(dbType string, sql string) string {
if dbType == entity.DbTypePostgres {
return pq.QuoteLiteral(sql)
} else {
sql = strings.ReplaceAll(sql, `\`, `\\`)
sql = strings.ReplaceAll(sql, `'`, `''`)
return "'" + sql + "'"
}
}
func quoteTable(dbType string, table string) string {
if dbType == entity.DbTypePostgres {
return "\"" + table + "\""
} else {
return "`" + table + "`"
}
}
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) { func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) {
dbConn := d.DbApp.GetDbConnection(dbId, dbName) dbConn := d.DbApp.GetDbConnection(dbId, dbName)
writer.WriteString("-- ----------------------------") writer.WriteString("\n-- ----------------------------")
writer.WriteString("\n-- 导出平台: mayfly-go") writer.WriteString("\n-- 导出平台: mayfly-go")
writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05"))) writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05")))
writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName)) writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName))
writer.WriteString("\n-- ----------------------------\n") writer.WriteString("\n-- ----------------------------\n")
writer.TryFlush()
if switchDb { if switchDb {
switch dbConn.Info.Type { switch dbConn.Info.Type {
case entity.DbTypeMysql: case entity.DbTypeMysql:
writer.WriteString(fmt.Sprintf("use `%s`;\n", dbName)) writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName))
default: default:
biz.IsTrue(false, "数据库类型必须为 %s", entity.DbTypeMysql) biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql)
} }
} }
if dbConn.Info.Type == entity.DbTypeMysql {
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n")
}
dbMeta := dbConn.GetMeta() dbMeta := dbConn.GetMeta()
if len(tables) == 0 { if len(tables) == 0 {
ti := dbMeta.GetTableInfos() ti := dbMeta.GetTableInfos()
@@ -303,23 +376,22 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
} }
for _, table := range tables { for _, table := range tables {
writer.TryFlush()
quotedTable := quoteTable(dbConn.Info.Type, table)
if needStruct { if needStruct {
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table)) writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table)) writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n") writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n")
} }
if !needData { if !needData {
continue continue
} }
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table)) writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
writer.WriteString("BEGIN;\n") writer.WriteString("BEGIN;\n")
insertSql := "INSERT INTO %s VALUES (%s);\n"
insertSql := "INSERT INTO `%s` VALUES (%s);\n"
dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) { dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) {
var values []string var values []string
writer.TryFlush()
for _, column := range columns { for _, column := range columns {
value := record[column] value := record[column]
if value == nil { if value == nil {
@@ -328,17 +400,18 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
} }
strValue, ok := value.(string) strValue, ok := value.(string)
if ok { if ok {
values = append(values, fmt.Sprintf("%#v", strValue)) strValue = escapeSql(dbConn.Info.Type, strValue)
values = append(values, strValue)
} else { } else {
values = append(values, stringx.AnyToStr(value)) values = append(values, stringx.AnyToStr(value))
} }
} }
writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", "))) writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", ")))
writer.TryFlush()
}) })
writer.WriteString("COMMIT;\n") writer.WriteString("COMMIT;\n")
writer.TryFlush() }
if dbConn.Info.Type == entity.DbTypeMysql {
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n")
} }
} }
@@ -477,35 +550,3 @@ func getDbName(g *gin.Context) string {
biz.NotEmpty(db, "db不能为空") biz.NotEmpty(db, "db不能为空")
return db return db
} }
// 根据;\n切割sql
func SplitSqls(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
re := regexp.MustCompile(`\s*;\s*\n`)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, io.EOF
}
match := re.FindIndex(data)
if match != nil {
// 如果找到了";\n",判断是否为最后一行
if match[1] == len(data) {
// 如果是最后一行,则返回完整的切片
return len(data), data, nil
}
// 否则,返回到";\n"之后,并且包括";\n"本身
return match[1], data[:match[1]], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
return scanner
}

View File

@@ -0,0 +1,55 @@
package api
import (
"github.com/stretchr/testify/require"
"mayfly-go/internal/db/domain/entity"
"testing"
)
func Test_escapeSql(t *testing.T) {
tests := []struct {
name string
dbType string
sql string
want string
}{
{
dbType: entity.DbTypeMysql,
sql: "\\a\\b",
want: "'\\\\a\\\\b'",
},
{
dbType: entity.DbTypeMysql,
sql: "'a'",
want: "'''a'''",
},
{
name: "不间断空格",
dbType: entity.DbTypeMysql,
sql: "a\u00A0b",
want: "'a\u00A0b'",
},
{
dbType: entity.DbTypePostgres,
sql: "\\a\\b",
want: " E'\\\\a\\\\b'",
},
{
dbType: entity.DbTypePostgres,
sql: "'a'",
want: "'''a'''",
},
{
name: "不间断空格",
dbType: entity.DbTypePostgres,
sql: "a\u00A0b",
want: "'a\u00A0b'",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := escapeSql(tt.dbType, tt.sql)
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -3,6 +3,7 @@ package application
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"mayfly-go/internal/db/config" "mayfly-go/internal/db/config"
"mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository" "mayfly-go/internal/db/domain/repository"
@@ -10,8 +11,6 @@ import (
"mayfly-go/pkg/model" "mayfly-go/pkg/model"
"strconv" "strconv"
"strings" "strings"
"github.com/xwb1989/sqlparser"
) )
type DbSqlExecReq struct { type DbSqlExecReq struct {

View File

@@ -11,7 +11,7 @@ const InfoSysMsgType = 2
// websocket消息 // websocket消息
type SysMsg struct { type SysMsg struct {
Type int `json:"type"` // 消息类型 Type int `json:"type"` // 消息类型
Category int `json:"category"` // 消息类别 Category string `json:"category"` // 消息类别
Title string `json:"title"` // 消息标题 Title string `json:"title"` // 消息标题
Msg string `json:"msg"` // 消息内容 Msg string `json:"msg"` // 消息内容
} }
@@ -21,7 +21,7 @@ func (sm *SysMsg) WithTitle(title string) *SysMsg {
return sm return sm
} }
func (sm *SysMsg) WithCategory(category int) *SysMsg { func (sm *SysMsg) WithCategory(category string) *SysMsg {
sm.Category = category sm.Category = category
return sm return sm
} }
@@ -32,7 +32,7 @@ func (sm *SysMsg) WithMsg(msg any) *SysMsg {
} }
// 普通消息 // 普通消息
func NewSysMsg(title string, msg any) *SysMsg { func InfoSysMsg(title string, msg any) *SysMsg {
return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)} return &SysMsg{Type: InfoSysMsgType, Title: title, Msg: stringx.AnyToStr(msg)}
} }

View File

@@ -0,0 +1,99 @@
package sqlparser
import (
"bufio"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"io"
"regexp"
)
type Parser interface {
Next() error
Current() string
}
var _ Parser = &MysqlParser{}
var _ Parser = &PostgresParser{}
type MysqlParser struct {
tokenizer *sqlparser.Tokenizer
statement string
}
func NewMysqlParser(reader io.Reader) *MysqlParser {
return &MysqlParser{
tokenizer: sqlparser.NewReaderTokenizer(reader),
}
}
func (parser *MysqlParser) Next() error {
statement, err := sqlparser.ParseNext(parser.tokenizer)
if err != nil {
parser.statement = ""
return err
}
parser.statement = sqlparser.String(statement)
return nil
}
func (parser *MysqlParser) Current() string {
return parser.statement
}
type PostgresParser struct {
scanner *bufio.Scanner
statement string
}
func NewPostgresParser(reader io.Reader) *PostgresParser {
return &PostgresParser{
scanner: splitSqls(reader),
}
}
func (parser *PostgresParser) Next() error {
if !parser.scanner.Scan() {
return io.EOF
}
return nil
}
func (parser *PostgresParser) Current() string {
return parser.scanner.Text()
}
// 根据;\n切割sql
func splitSqls(r io.Reader) *bufio.Scanner {
scanner := bufio.NewScanner(r)
re := regexp.MustCompile(`\s*;\s*\n`)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, io.EOF
}
match := re.FindIndex(data)
if match != nil {
// 如果找到了";\n",判断是否为最后一行
if match[1] == len(data) {
// 如果是最后一行,则返回完整的切片
return len(data), data, nil
}
// 否则,返回到";\n"之后,并且包括";\n"本身
return match[1], data[:match[1]], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
return scanner
}
func SplitStatementToPieces(sql string) ([]string, error) {
return sqlparser.SplitStatementToPieces(sql)
}

View File

@@ -0,0 +1,98 @@
package sqlparser
import (
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"github.com/stretchr/testify/require"
sqlparser_xwb1989 "github.com/xwb1989/sqlparser"
"strings"
"testing"
)
func Test_ParseNext_WithCurrentDate(t *testing.T) {
tests := []struct {
name string
input string
want string
wantXwb1989 string
err string
}{
{
name: "create table with current_timestamp",
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
// xwb1989/sqlparser 不支持 current_timestamp()
wantXwb1989: "create table tbl",
},
{
name: "create table with current_date",
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
// xwb1989/sqlparser 不支持 current_date()
wantXwb1989: "create table tbl",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token := sqlparser.NewReaderTokenizer(strings.NewReader(test.input))
tree, err := sqlparser.ParseNext(token)
if len(test.err) > 0 {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
return
}
require.NoError(t, err)
if len(test.want) == 0 {
test.want = test.input
}
require.Equal(t, test.want, sqlparser.String(tree))
})
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
token := sqlparser_xwb1989.NewTokenizer(strings.NewReader(test.input))
tree, err := sqlparser_xwb1989.ParseNext(token)
if len(test.err) > 0 {
require.Error(t, err)
require.Contains(t, err.Error(), test.err)
return
}
require.NoError(t, err)
if len(test.want) == 0 {
test.want = test.input
}
require.Equal(t, test.wantXwb1989, sqlparser_xwb1989.String(tree))
})
}
}
func Test_SplitSqls(t *testing.T) {
tests := []struct {
name string
input string
want string
}{
{
name: "create table with current_timestamp",
input: "create table tbl (\n\tcreate_at datetime default current_timestamp()\n)",
},
{
name: "create table with current_date",
input: "create table tbl (\n\tcreate_at date default current_date()\n)",
},
{
name: "select with ';\n'",
input: "select 'the first line;\nthe second line;\n'",
// SplitSqls split statements by ';\n'
want: "select 'the first line;\n",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
scanner := splitSqls(strings.NewReader(test.input))
require.True(t, scanner.Scan())
got := scanner.Text()
if len(test.want) == 0 {
test.want = test.input
}
require.Equal(t, test.want, got)
})
}
}

View File

@@ -0,0 +1,9 @@
package uniqueid
import "sync/atomic"
var id uint64 = 0
func IncrementID() uint64 {
return atomic.AddUint64(&id, 1)
}

View File

@@ -7,7 +7,7 @@ import (
) )
// 心跳间隔 // 心跳间隔
var heartbeatInterval = 25 * time.Second const heartbeatInterval = 25 * time.Second
// 连接管理 // 连接管理
type ClientManager struct { type ClientManager struct {