feat: 前端显示 SQL 文件执行进度

This commit is contained in:
wanli
2023-10-10 09:24:49 +08:00
committed by kanzihuang
parent 41443dccc0
commit 7544288451
12 changed files with 407 additions and 96 deletions

View File

@@ -3,7 +3,11 @@ package api
import (
"bufio"
"fmt"
"github.com/lib/pq"
"io"
"mayfly-go/pkg/utils/uniqueid"
"mayfly-go/pkg/ws"
"mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo"
"mayfly-go/internal/db/application"
@@ -79,9 +83,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
}
func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
dbName := g.Query("db")
biz.NotEmpty(dbName, "db不能为空")
return d.DbApp.GetDbConnection(getDbId(g), dbName)
return d.DbApp.GetDbConnection(getDbId(g), getDbName(g))
}
func (d *Db) TableInfos(rc *req.Ctx) {
@@ -152,67 +154,120 @@ func (d *Db) ExecSql(rc *req.Ctx) {
rc.ResData = colAndRes
}
// progressCategory sql文件执行进度消息类型
const progressCategory = "execSqlFileProgress"
// progressMsg sql文件执行进度消息
type progressMsg struct {
Id uint64 `json:"id"`
SqlFileName string `json:"sqlFileName"`
ExecutedStatements int `json:"executedStatements"`
Terminated bool `json:"terminated"`
}
// 执行sql文件
func (d *Db) ExecSqlFile(rc *req.Ctx) {
g := rc.GinCtx
fileheader, err := g.FormFile("file")
multipart, err := g.Request.MultipartReader()
biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
file, _ := fileheader.Open()
filename := fileheader.Filename
file, err := multipart.NextPart()
biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
defer file.Close()
filename := file.FileName()
dbId := getDbId(g)
dbName := getDbName(g)
dbConn := d.getDbConnection(rc.GinCtx)
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
logExecRecord := true
// 如果执行sql文件大于该值则不记录sql执行记录
if fileheader.Size > 50*1024 {
logExecRecord = false
defer func() {
if err := recover(); err != nil {
var errInfo string
switch t := err.(type) {
case biz.BizError:
errInfo = t.Error()
case *biz.BizError:
errInfo = t.Error()
case string:
errInfo = t
}
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
}
}()
execReq := &application.DbSqlExecReq{
DbId: dbId,
Db: dbName,
Remark: filename,
DbConn: dbConn,
LoginAccount: rc.LoginAccount,
}
go func() {
defer func() {
if err := recover(); err != nil {
var errInfo string
switch t := err.(type) {
case error:
errInfo = t.Error()
}
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
defer func() {
if err := recover(); err != nil {
var errInfo string
switch t := err.(type) {
case error:
errInfo = t.Error()
}
}()
execReq := &application.DbSqlExecReq{
DbId: dbId,
Db: dbName,
Remark: fileheader.Filename,
DbConn: dbConn,
LoginAccount: rc.LoginAccount,
}
sqlScanner := SplitSqls(file)
for sqlScanner.Scan() {
sql := sqlScanner.Text()
execReq.Sql = sql
// 需要记录执行记录
if logExecRecord {
_, err = d.DbSqlExecApp.Exec(execReq)
} else {
_, err = dbConn.Exec(sql)
}
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return
if len(errInfo) > 0 {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
}
}
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
}()
progressId := uniqueid.IncrementID()
executedStatements := 0
defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
SqlFileName: filename,
ExecutedStatements: executedStatements,
Terminated: true,
}).WithCategory(progressCategory))
ticker := time.NewTicker(time.Second * 1)
defer ticker.Stop()
sqlScanner := SplitSqls(file)
for sqlScanner.Scan() {
select {
case <-ticker.C:
ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
Id: progressId,
SqlFileName: filename,
ExecutedStatements: executedStatements,
Terminated: false,
}).WithCategory(progressCategory))
default:
}
sql := sqlScanner.Text()
const prefixUse = "use "
if strings.HasPrefix(sql, prefixUse) {
dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
if len(dbNameExec) > 0 {
dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
execReq.DbConn = dbConn
}
}
// 需要记录执行记录
const maxRecordStatements = 64
if executedStatements < maxRecordStatements {
execReq.Sql = sql
_, err = d.DbSqlExecApp.Exec(execReq)
} else {
_, err = dbConn.Exec(sql)
}
if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return
}
executedStatements++
}
d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
}
// 数据库dump
@@ -276,23 +331,44 @@ func (d *Db) DumpSql(rc *req.Ctx) {
rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType)
}
func escapeSql(dbType string, sql string) string {
if dbType == entity.DbTypePostgres {
return pq.QuoteLiteral(sql)
} else {
sql = strings.ReplaceAll(sql, `\`, `\\`)
sql = strings.ReplaceAll(sql, `'`, `''`)
return "'" + sql + "'"
}
}
func quoteTable(dbType string, table string) string {
if dbType == entity.DbTypePostgres {
return "\"" + table + "\""
} else {
return "`" + table + "`"
}
}
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) {
dbConn := d.DbApp.GetDbConnection(dbId, dbName)
writer.WriteString("-- ----------------------------")
writer.WriteString("\n-- ----------------------------")
writer.WriteString("\n-- 导出平台: mayfly-go")
writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05")))
writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName))
writer.WriteString("\n-- ----------------------------\n")
writer.TryFlush()
if switchDb {
switch dbConn.Info.Type {
case entity.DbTypeMysql:
writer.WriteString(fmt.Sprintf("use `%s`;\n", dbName))
writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName))
default:
biz.IsTrue(false, "数据库类型必须为 %s", entity.DbTypeMysql)
biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql)
}
}
if dbConn.Info.Type == entity.DbTypeMysql {
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n")
}
dbMeta := dbConn.GetMeta()
if len(tables) == 0 {
ti := dbMeta.GetTableInfos()
@@ -303,23 +379,22 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
}
for _, table := range tables {
writer.TryFlush()
quotedTable := quoteTable(dbConn.Info.Type, table)
if needStruct {
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table))
writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n")
}
if !needData {
continue
}
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
writer.WriteString("BEGIN;\n")
insertSql := "INSERT INTO `%s` VALUES (%s);\n"
insertSql := "INSERT INTO %s VALUES (%s);\n"
dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) {
var values []string
writer.TryFlush()
for _, column := range columns {
value := record[column]
if value == nil {
@@ -328,17 +403,18 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
}
strValue, ok := value.(string)
if ok {
values = append(values, fmt.Sprintf("%#v", strValue))
strValue = escapeSql(dbConn.Info.Type, strValue)
values = append(values, strValue)
} else {
values = append(values, stringx.AnyToStr(value))
}
}
writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", ")))
writer.TryFlush()
writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", ")))
})
writer.WriteString("COMMIT;\n")
writer.TryFlush()
}
if dbConn.Info.Type == entity.DbTypeMysql {
writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n")
}
}