mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	feat: 前端显示 SQL 文件执行进度
This commit is contained in:
		@@ -3,7 +3,11 @@ package api
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/lib/pq"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mayfly-go/pkg/utils/uniqueid"
 | 
			
		||||
	"mayfly-go/pkg/ws"
 | 
			
		||||
 | 
			
		||||
	"mayfly-go/internal/db/api/form"
 | 
			
		||||
	"mayfly-go/internal/db/api/vo"
 | 
			
		||||
	"mayfly-go/internal/db/application"
 | 
			
		||||
@@ -79,9 +83,7 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
 | 
			
		||||
	dbName := g.Query("db")
 | 
			
		||||
	biz.NotEmpty(dbName, "db不能为空")
 | 
			
		||||
	return d.DbApp.GetDbConnection(getDbId(g), dbName)
 | 
			
		||||
	return d.DbApp.GetDbConnection(getDbId(g), getDbName(g))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) TableInfos(rc *req.Ctx) {
 | 
			
		||||
@@ -152,67 +154,120 @@ func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
	rc.ResData = colAndRes
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// progressCategory sql文件执行进度消息类型
 | 
			
		||||
const progressCategory = "execSqlFileProgress"
 | 
			
		||||
 | 
			
		||||
// progressMsg sql文件执行进度消息
 | 
			
		||||
type progressMsg struct {
 | 
			
		||||
	Id                 uint64 `json:"id"`
 | 
			
		||||
	SqlFileName        string `json:"sqlFileName"`
 | 
			
		||||
	ExecutedStatements int    `json:"executedStatements"`
 | 
			
		||||
	Terminated         bool   `json:"terminated"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行sql文件
 | 
			
		||||
func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
	g := rc.GinCtx
 | 
			
		||||
	fileheader, err := g.FormFile("file")
 | 
			
		||||
	multipart, err := g.Request.MultipartReader()
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
 | 
			
		||||
 | 
			
		||||
	file, _ := fileheader.Open()
 | 
			
		||||
	filename := fileheader.Filename
 | 
			
		||||
	file, err := multipart.NextPart()
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
 | 
			
		||||
	defer file.Close()
 | 
			
		||||
	filename := file.FileName()
 | 
			
		||||
	dbId := getDbId(g)
 | 
			
		||||
	dbName := getDbName(g)
 | 
			
		||||
 | 
			
		||||
	dbConn := d.getDbConnection(rc.GinCtx)
 | 
			
		||||
	dbConn := d.DbApp.GetDbConnection(dbId, dbName)
 | 
			
		||||
	biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
 | 
			
		||||
 | 
			
		||||
	logExecRecord := true
 | 
			
		||||
	// 如果执行sql文件大于该值则不记录sql执行记录
 | 
			
		||||
	if fileheader.Size > 50*1024 {
 | 
			
		||||
		logExecRecord = false
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			var errInfo string
 | 
			
		||||
			switch t := err.(type) {
 | 
			
		||||
			case biz.BizError:
 | 
			
		||||
				errInfo = t.Error()
 | 
			
		||||
			case *biz.BizError:
 | 
			
		||||
				errInfo = t.Error()
 | 
			
		||||
			case string:
 | 
			
		||||
				errInfo = t
 | 
			
		||||
			}
 | 
			
		||||
			if len(errInfo) > 0 {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	execReq := &application.DbSqlExecReq{
 | 
			
		||||
		DbId:         dbId,
 | 
			
		||||
		Db:           dbName,
 | 
			
		||||
		Remark:       filename,
 | 
			
		||||
		DbConn:       dbConn,
 | 
			
		||||
		LoginAccount: rc.LoginAccount,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if err := recover(); err != nil {
 | 
			
		||||
				var errInfo string
 | 
			
		||||
				switch t := err.(type) {
 | 
			
		||||
				case error:
 | 
			
		||||
					errInfo = t.Error()
 | 
			
		||||
				}
 | 
			
		||||
				if len(errInfo) > 0 {
 | 
			
		||||
					d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
				}
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			var errInfo string
 | 
			
		||||
			switch t := err.(type) {
 | 
			
		||||
			case error:
 | 
			
		||||
				errInfo = t.Error()
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		execReq := &application.DbSqlExecReq{
 | 
			
		||||
			DbId:         dbId,
 | 
			
		||||
			Db:           dbName,
 | 
			
		||||
			Remark:       fileheader.Filename,
 | 
			
		||||
			DbConn:       dbConn,
 | 
			
		||||
			LoginAccount: rc.LoginAccount,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sqlScanner := SplitSqls(file)
 | 
			
		||||
		for sqlScanner.Scan() {
 | 
			
		||||
			sql := sqlScanner.Text()
 | 
			
		||||
			execReq.Sql = sql
 | 
			
		||||
			// 需要记录执行记录
 | 
			
		||||
			if logExecRecord {
 | 
			
		||||
				_, err = d.DbSqlExecApp.Exec(execReq)
 | 
			
		||||
			} else {
 | 
			
		||||
				_, err = dbConn.Exec(sql)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
 | 
			
		||||
				return
 | 
			
		||||
			if len(errInfo) > 0 {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	progressId := uniqueid.IncrementID()
 | 
			
		||||
	executedStatements := 0
 | 
			
		||||
	defer ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
			
		||||
		Id:                 progressId,
 | 
			
		||||
		SqlFileName:        filename,
 | 
			
		||||
		ExecutedStatements: executedStatements,
 | 
			
		||||
		Terminated:         true,
 | 
			
		||||
	}).WithCategory(progressCategory))
 | 
			
		||||
 | 
			
		||||
	ticker := time.NewTicker(time.Second * 1)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
	sqlScanner := SplitSqls(file)
 | 
			
		||||
	for sqlScanner.Scan() {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			ws.SendJsonMsg(rc.LoginAccount.Id, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
 | 
			
		||||
				Id:                 progressId,
 | 
			
		||||
				SqlFileName:        filename,
 | 
			
		||||
				ExecutedStatements: executedStatements,
 | 
			
		||||
				Terminated:         false,
 | 
			
		||||
			}).WithCategory(progressCategory))
 | 
			
		||||
		default:
 | 
			
		||||
		}
 | 
			
		||||
		sql := sqlScanner.Text()
 | 
			
		||||
		const prefixUse = "use "
 | 
			
		||||
		if strings.HasPrefix(sql, prefixUse) {
 | 
			
		||||
			dbNameExec := strings.Trim(sql[len(prefixUse):], " `;\n")
 | 
			
		||||
			if len(dbNameExec) > 0 {
 | 
			
		||||
				dbConn = d.DbApp.GetDbConnection(dbId, dbNameExec)
 | 
			
		||||
				biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
 | 
			
		||||
				execReq.DbConn = dbConn
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// 需要记录执行记录
 | 
			
		||||
		const maxRecordStatements = 64
 | 
			
		||||
		if executedStatements < maxRecordStatements {
 | 
			
		||||
			execReq.Sql = sql
 | 
			
		||||
			_, err = d.DbSqlExecApp.Exec(execReq)
 | 
			
		||||
		} else {
 | 
			
		||||
			_, err = dbConn.Exec(sql)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		executedStatements++
 | 
			
		||||
	}
 | 
			
		||||
	d.MsgApp.CreateAndSend(rc.LoginAccount, msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 数据库dump
 | 
			
		||||
@@ -276,23 +331,44 @@ func (d *Db) DumpSql(rc *req.Ctx) {
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("DB[id=%d, tag=%s, name=%s, databases=%s, tables=%s, dumpType=%s]", db.Id, db.TagPath, db.Name, dbNamesStr, tablesStr, dumpType)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func escapeSql(dbType string, sql string) string {
 | 
			
		||||
	if dbType == entity.DbTypePostgres {
 | 
			
		||||
		return pq.QuoteLiteral(sql)
 | 
			
		||||
	} else {
 | 
			
		||||
		sql = strings.ReplaceAll(sql, `\`, `\\`)
 | 
			
		||||
		sql = strings.ReplaceAll(sql, `'`, `''`)
 | 
			
		||||
		return "'" + sql + "'"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func quoteTable(dbType string, table string) string {
 | 
			
		||||
	if dbType == entity.DbTypePostgres {
 | 
			
		||||
		return "\"" + table + "\""
 | 
			
		||||
	} else {
 | 
			
		||||
		return "`" + table + "`"
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool, switchDb bool) {
 | 
			
		||||
	dbConn := d.DbApp.GetDbConnection(dbId, dbName)
 | 
			
		||||
	writer.WriteString("-- ----------------------------")
 | 
			
		||||
	writer.WriteString("\n-- ----------------------------")
 | 
			
		||||
	writer.WriteString("\n-- 导出平台: mayfly-go")
 | 
			
		||||
	writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", time.Now().Format("2006-01-02 15:04:05")))
 | 
			
		||||
	writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", dbName))
 | 
			
		||||
	writer.WriteString("\n-- ----------------------------\n")
 | 
			
		||||
	writer.TryFlush()
 | 
			
		||||
 | 
			
		||||
	if switchDb {
 | 
			
		||||
		switch dbConn.Info.Type {
 | 
			
		||||
		case entity.DbTypeMysql:
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("use `%s`;\n", dbName))
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("USE `%s`;\n", dbName))
 | 
			
		||||
		default:
 | 
			
		||||
			biz.IsTrue(false, "数据库类型必须为 %s", entity.DbTypeMysql)
 | 
			
		||||
			biz.IsTrue(false, "同时导出多个数据库,数据库类型必须为 %s", entity.DbTypeMysql)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	if dbConn.Info.Type == entity.DbTypeMysql {
 | 
			
		||||
		writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 0;\n")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbMeta := dbConn.GetMeta()
 | 
			
		||||
	if len(tables) == 0 {
 | 
			
		||||
		ti := dbMeta.GetTableInfos()
 | 
			
		||||
@@ -303,23 +379,22 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, table := range tables {
 | 
			
		||||
		writer.TryFlush()
 | 
			
		||||
		quotedTable := quoteTable(dbConn.Info.Type, table)
 | 
			
		||||
		if needStruct {
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table))
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS %s;\n", quotedTable))
 | 
			
		||||
			writer.WriteString(dbMeta.GetCreateTableDdl(table) + ";\n")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !needData {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
 | 
			
		||||
		writer.WriteString("BEGIN;\n")
 | 
			
		||||
 | 
			
		||||
		insertSql := "INSERT INTO `%s` VALUES (%s);\n"
 | 
			
		||||
 | 
			
		||||
		insertSql := "INSERT INTO %s VALUES (%s);\n"
 | 
			
		||||
		dbMeta.WalkTableRecord(table, func(record map[string]any, columns []string) {
 | 
			
		||||
			var values []string
 | 
			
		||||
			writer.TryFlush()
 | 
			
		||||
			for _, column := range columns {
 | 
			
		||||
				value := record[column]
 | 
			
		||||
				if value == nil {
 | 
			
		||||
@@ -328,17 +403,18 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
 | 
			
		||||
				}
 | 
			
		||||
				strValue, ok := value.(string)
 | 
			
		||||
				if ok {
 | 
			
		||||
					values = append(values, fmt.Sprintf("%#v", strValue))
 | 
			
		||||
					strValue = escapeSql(dbConn.Info.Type, strValue)
 | 
			
		||||
					values = append(values, strValue)
 | 
			
		||||
				} else {
 | 
			
		||||
					values = append(values, stringx.AnyToStr(value))
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", ")))
 | 
			
		||||
			writer.TryFlush()
 | 
			
		||||
			writer.WriteString(fmt.Sprintf(insertSql, quotedTable, strings.Join(values, ", ")))
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		writer.WriteString("COMMIT;\n")
 | 
			
		||||
		writer.TryFlush()
 | 
			
		||||
	}
 | 
			
		||||
	if dbConn.Info.Type == entity.DbTypeMysql {
 | 
			
		||||
		writer.WriteString("\nSET FOREIGN_KEY_CHECKS = 1;\n")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user