mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	feat: 新增数据库导出功能&其他小优化
This commit is contained in:
		@@ -2,7 +2,7 @@ package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io/ioutil"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mayfly-go/internal/devops/api/form"
 | 
			
		||||
	"mayfly-go/internal/devops/api/vo"
 | 
			
		||||
	"mayfly-go/internal/devops/application"
 | 
			
		||||
@@ -16,8 +16,10 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/ws"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/xwb1989/sqlparser"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Db struct {
 | 
			
		||||
@@ -27,6 +29,8 @@ type Db struct {
 | 
			
		||||
	ProjectApp   application.Project
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const DEFAULT_COLUMN_SIZE = 500
 | 
			
		||||
 | 
			
		||||
// @router /api/dbs [get]
 | 
			
		||||
func (d *Db) Dbs(rc *ctx.ReqCtx) {
 | 
			
		||||
	g := rc.GinCtx
 | 
			
		||||
@@ -91,7 +95,7 @@ func (d *Db) ExecSql(rc *ctx.ReqCtx) {
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("db: %d:%s | sql: %s", id, db, sql)
 | 
			
		||||
 | 
			
		||||
	biz.NotEmpty(sql, "sql不能为空")
 | 
			
		||||
	if strings.HasPrefix(sql, "SELECT") || strings.HasPrefix(sql, "select") || strings.HasPrefix(sql, "show") {
 | 
			
		||||
	if strings.HasPrefix(sql, "SELECT") || strings.HasPrefix(sql, "select") || strings.HasPrefix(sql, "show") || strings.HasPrefix(sql, "explain") {
 | 
			
		||||
		colNames, res, err := dbInstance.SelectData(sql)
 | 
			
		||||
		biz.ErrIsNilAppendErr(err, "查询失败: %s")
 | 
			
		||||
		colAndRes := make(map[string]interface{})
 | 
			
		||||
@@ -128,12 +132,8 @@ func (d *Db) ExecSqlFile(rc *ctx.ReqCtx) {
 | 
			
		||||
	fileheader, err := g.FormFile("file")
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
 | 
			
		||||
 | 
			
		||||
	// 读取sql文件并根据;切割sql语句
 | 
			
		||||
	file, _ := fileheader.Open()
 | 
			
		||||
	filename := fileheader.Filename
 | 
			
		||||
	bytes, _ := ioutil.ReadAll(file)
 | 
			
		||||
	sqlContent := string(bytes)
 | 
			
		||||
	sqls := strings.Split(sqlContent, ";")
 | 
			
		||||
	dbId, db := GetIdAndDb(g)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
@@ -153,12 +153,14 @@ func (d *Db) ExecSqlFile(rc *ctx.ReqCtx) {
 | 
			
		||||
 | 
			
		||||
		biz.ErrIsNilAppendErr(d.ProjectApp.CanAccess(rc.LoginAccount.Id, db.ProjectId), "%s")
 | 
			
		||||
 | 
			
		||||
		for _, sql := range sqls {
 | 
			
		||||
			sql = strings.Trim(sql, " ")
 | 
			
		||||
			if sql == "" || sql == "\n" {
 | 
			
		||||
				continue
 | 
			
		||||
		tokens := sqlparser.NewTokenizer(file)
 | 
			
		||||
		for {
 | 
			
		||||
			stmt, err := sqlparser.ParseNext(tokens)
 | 
			
		||||
			if err == io.EOF {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
			_, err := db.Exec(sql)
 | 
			
		||||
			sql := sqlparser.String(stmt)
 | 
			
		||||
			_, err = db.Exec(sql)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbInfo, err.Error())))
 | 
			
		||||
				return
 | 
			
		||||
@@ -168,6 +170,88 @@ func (d *Db) ExecSqlFile(rc *ctx.ReqCtx) {
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 数据库dump
 | 
			
		||||
func (d *Db) DumpSql(rc *ctx.ReqCtx) {
 | 
			
		||||
	g := rc.GinCtx
 | 
			
		||||
	dbId, db := GetIdAndDb(g)
 | 
			
		||||
	dumpType := g.Query("type")
 | 
			
		||||
	tablesStr := g.Query("tables")
 | 
			
		||||
	biz.NotEmpty(tablesStr, "请选择要导出的表")
 | 
			
		||||
	tables := strings.Split(tablesStr, ",")
 | 
			
		||||
 | 
			
		||||
	// 是否需要导出表结构
 | 
			
		||||
	needStruct := dumpType == "1" || dumpType == "3"
 | 
			
		||||
	// 是否需要导出数据
 | 
			
		||||
	needData := dumpType == "2" || dumpType == "3"
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	filename := fmt.Sprintf("%s.%s.sql", db, now.Format("200601021504"))
 | 
			
		||||
	g.Header("Content-Type", "application/octet-stream")
 | 
			
		||||
	g.Header("Content-Disposition", "attachment; filename="+filename)
 | 
			
		||||
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("数据库id: %d -- %s", dbId, db)
 | 
			
		||||
	dbInstance := d.DbApp.GetDbInstance(dbId, db)
 | 
			
		||||
	writer := g.Writer
 | 
			
		||||
	writer.WriteString("-- ----------------------------")
 | 
			
		||||
	writer.WriteString("\n-- 导出平台: mayfly-go")
 | 
			
		||||
	writer.WriteString(fmt.Sprintf("\n-- 导出时间: %s ", now.Format("2006-01-02 15:04:05")))
 | 
			
		||||
	writer.WriteString(fmt.Sprintf("\n-- 导出数据库: %s ", db))
 | 
			
		||||
	writer.WriteString("\n-- ----------------------------\n")
 | 
			
		||||
 | 
			
		||||
	for _, table := range tables {
 | 
			
		||||
		if needStruct {
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表结构: %s \n-- ----------------------------\n", table))
 | 
			
		||||
			writer.WriteString(fmt.Sprintf("DROP TABLE IF EXISTS `%s`;\n", table))
 | 
			
		||||
			writer.WriteString(dbInstance.GetCreateTableDdl(table)[0]["Create Table"].(string) + ";\n")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if !needData {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
 | 
			
		||||
		writer.WriteString("BEGIN;\n")
 | 
			
		||||
 | 
			
		||||
		countSql := fmt.Sprintf("SELECT COUNT(*) count FROM %s", table)
 | 
			
		||||
		_, countRes, _ := dbInstance.SelectData(countSql)
 | 
			
		||||
		// 查询出所有列信息总数,手动分页获取所有数据
 | 
			
		||||
		maCount := int(countRes[0]["count"].(int64))
 | 
			
		||||
		// 计算需要查询的页数
 | 
			
		||||
		pageNum := maCount / DEFAULT_COLUMN_SIZE
 | 
			
		||||
		if maCount%DEFAULT_COLUMN_SIZE > 0 {
 | 
			
		||||
			pageNum++
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sqlTmp := "SELECT * FROM %s LIMIT %d, %d"
 | 
			
		||||
		for index := 0; index < pageNum; index++ {
 | 
			
		||||
			sql := fmt.Sprintf(sqlTmp, table, index*DEFAULT_COLUMN_SIZE, DEFAULT_COLUMN_SIZE)
 | 
			
		||||
			columns, result, _ := dbInstance.SelectData(sql)
 | 
			
		||||
 | 
			
		||||
			insertSql := "INSERT INTO `%s` VALUES (%s);\n"
 | 
			
		||||
			for _, res := range result {
 | 
			
		||||
				var values []string
 | 
			
		||||
				for _, column := range columns {
 | 
			
		||||
					value := res[column]
 | 
			
		||||
					if value == nil {
 | 
			
		||||
						values = append(values, "NULL")
 | 
			
		||||
						continue
 | 
			
		||||
					}
 | 
			
		||||
					strValue, ok := value.(string)
 | 
			
		||||
					if ok {
 | 
			
		||||
						values = append(values, fmt.Sprintf("%#v", strValue))
 | 
			
		||||
					} else {
 | 
			
		||||
						values = append(values, utils.ToString(value))
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
				writer.WriteString(fmt.Sprintf(insertSql, table, strings.Join(values, ", ")))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		writer.WriteString("COMMIT;\n")
 | 
			
		||||
	}
 | 
			
		||||
	rc.NoRes = true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// @router /api/db/:dbId/t-metadata [get]
 | 
			
		||||
func (d *Db) TableMA(rc *ctx.ReqCtx) {
 | 
			
		||||
	dbi := d.DbApp.GetDbInstance(GetIdAndDb(rc.GinCtx))
 | 
			
		||||
 
 | 
			
		||||
@@ -223,8 +223,9 @@ func (d *DbInstance) SelectData(execSql string) ([]string, []map[string]interfac
 | 
			
		||||
	execSql = strings.Trim(execSql, " ")
 | 
			
		||||
	isSelect := strings.HasPrefix(execSql, "SELECT") || strings.HasPrefix(execSql, "select")
 | 
			
		||||
	isShow := strings.HasPrefix(execSql, "show")
 | 
			
		||||
	isExplain := strings.HasPrefix(execSql, "explain")
 | 
			
		||||
 | 
			
		||||
	if !isSelect && !isShow {
 | 
			
		||||
	if !isSelect && !isShow && !isExplain {
 | 
			
		||||
		return nil, nil, errors.New("该sql非查询语句")
 | 
			
		||||
	}
 | 
			
		||||
	// 没加limit,则默认限制50条
 | 
			
		||||
@@ -272,14 +273,13 @@ func (d *DbInstance) SelectData(execSql string) ([]string, []map[string]interfac
 | 
			
		||||
			colName := colType.Name()
 | 
			
		||||
			// 字段类型名
 | 
			
		||||
			colScanType := colType.ScanType().Name()
 | 
			
		||||
 | 
			
		||||
			// 如果是密码字段,则脱敏显示
 | 
			
		||||
			if colName == "password" {
 | 
			
		||||
				v = []byte("******")
 | 
			
		||||
			}
 | 
			
		||||
			if isFirst {
 | 
			
		||||
				colNames = append(colNames, colName)
 | 
			
		||||
			}
 | 
			
		||||
			if v == nil {
 | 
			
		||||
				rowData[colName] = nil
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			// 这里把[]byte数据转成string
 | 
			
		||||
			stringV := string(v)
 | 
			
		||||
			if stringV == "" {
 | 
			
		||||
 
 | 
			
		||||
@@ -61,6 +61,12 @@ func InitDbRouter(router *gin.RouterGroup) {
 | 
			
		||||
			rc.Handle(d.ExecSqlFile)
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		db.GET(":dbId/dump", func(g *gin.Context) {
 | 
			
		||||
			ctx.NewReqCtxWithGin(g).
 | 
			
		||||
				WithLog(ctx.NewLogInfo("Sql文件dump")).
 | 
			
		||||
				Handle(d.DumpSql)
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		db.GET(":dbId/t-metadata", func(c *gin.Context) {
 | 
			
		||||
			ctx.NewReqCtxWithGin(c).Handle(d.TableMA)
 | 
			
		||||
		})
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user