mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 23:40:24 +08:00
431 lines
12 KiB
Go
431 lines
12 KiB
Go
package api
|
||
|
||
import (
|
||
"context"
|
||
"encoding/base64"
|
||
"fmt"
|
||
"io"
|
||
"mayfly-go/internal/common/consts"
|
||
"mayfly-go/internal/db/api/form"
|
||
"mayfly-go/internal/db/api/vo"
|
||
"mayfly-go/internal/db/application"
|
||
"mayfly-go/internal/db/config"
|
||
"mayfly-go/internal/db/dbm/dbi"
|
||
"mayfly-go/internal/db/domain/entity"
|
||
"mayfly-go/internal/event"
|
||
msgapp "mayfly-go/internal/msg/application"
|
||
msgdto "mayfly-go/internal/msg/application/dto"
|
||
tagapp "mayfly-go/internal/tag/application"
|
||
tagentity "mayfly-go/internal/tag/domain/entity"
|
||
"mayfly-go/pkg/biz"
|
||
"mayfly-go/pkg/global"
|
||
"mayfly-go/pkg/logx"
|
||
"mayfly-go/pkg/model"
|
||
"mayfly-go/pkg/req"
|
||
"mayfly-go/pkg/utils/anyx"
|
||
"mayfly-go/pkg/utils/collx"
|
||
"mayfly-go/pkg/utils/stringx"
|
||
"mayfly-go/pkg/ws"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||
"github.com/may-fly/cast"
|
||
)
|
||
|
||
type Db struct {
|
||
InstanceApp application.Instance `inject:"DbInstanceApp"`
|
||
DbApp application.Db `inject:""`
|
||
DbSqlExecApp application.DbSqlExec `inject:""`
|
||
MsgApp msgapp.Msg `inject:""`
|
||
TagApp tagapp.TagTree `inject:"TagTreeApp"`
|
||
}
|
||
|
||
// @router /api/dbs [get]
|
||
func (d *Db) Dbs(rc *req.Ctx) {
|
||
queryCond, page := req.BindQueryAndPage[*entity.DbQuery](rc, new(entity.DbQuery))
|
||
|
||
// 不存在可访问标签id,即没有可操作数据
|
||
codes := d.TagApp.GetAccountTagCodes(rc.GetLoginAccount().Id, int8(tagentity.TagTypeDbName), queryCond.TagPath)
|
||
if len(codes) == 0 {
|
||
rc.ResData = model.EmptyPageResult[any]()
|
||
return
|
||
}
|
||
queryCond.Codes = codes
|
||
|
||
var dbvos []*vo.DbListVO
|
||
res, err := d.DbApp.GetPageList(queryCond, page, &dbvos)
|
||
biz.ErrIsNil(err)
|
||
|
||
// 填充标签信息
|
||
d.TagApp.FillTagInfo(tagentity.TagTypeDbName, collx.ArrayMap(dbvos, func(dbvo *vo.DbListVO) tagentity.ITagResource {
|
||
return dbvo
|
||
})...)
|
||
|
||
rc.ResData = res
|
||
}
|
||
|
||
func (d *Db) Save(rc *req.Ctx) {
|
||
form := &form.DbForm{}
|
||
db := req.BindJsonAndCopyTo[*entity.Db](rc, form, new(entity.Db))
|
||
|
||
rc.ReqParam = form
|
||
|
||
biz.ErrIsNil(d.DbApp.SaveDb(rc.MetaCtx, db))
|
||
}
|
||
|
||
func (d *Db) DeleteDb(rc *req.Ctx) {
|
||
idsStr := rc.PathParam("dbId")
|
||
rc.ReqParam = idsStr
|
||
ids := strings.Split(idsStr, ",")
|
||
|
||
ctx := rc.MetaCtx
|
||
for _, v := range ids {
|
||
dbId := cast.ToUint64(v)
|
||
biz.NotBlank(dbId, "存在错误dbId")
|
||
biz.ErrIsNil(d.DbApp.Delete(ctx, dbId))
|
||
}
|
||
}
|
||
|
||
/** 数据库操作相关、执行sql等 ***/
|
||
|
||
func (d *Db) ExecSql(rc *req.Ctx) {
|
||
form := req.BindJsonAndValid(rc, new(form.DbSqlExecForm))
|
||
|
||
dbId := getDbId(rc)
|
||
dbConn, err := d.DbApp.GetDbConn(dbId, form.Db)
|
||
biz.ErrIsNil(err)
|
||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
|
||
|
||
global.EventBus.Publish(rc.MetaCtx, event.EventTopicResourceOp, dbConn.Info.CodePath[0])
|
||
|
||
sqlBytes, err := base64.StdEncoding.DecodeString(form.Sql)
|
||
biz.ErrIsNilAppendErr(err, "sql解码失败: %s")
|
||
// 去除前后空格及换行符
|
||
sql := stringx.TrimSpaceAndBr(string(sqlBytes))
|
||
|
||
rc.ReqParam = fmt.Sprintf("%s %s\n-> %s", dbConn.Info.GetLogDesc(), form.ExecId, sql)
|
||
biz.NotEmpty(form.Sql, "sql不能为空")
|
||
|
||
execReq := &application.DbSqlExecReq{
|
||
DbId: dbId,
|
||
Db: form.Db,
|
||
Remark: form.Remark,
|
||
DbConn: dbConn,
|
||
}
|
||
|
||
ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second)
|
||
defer cancel()
|
||
|
||
sqls, err := sqlparser.SplitStatementToPieces(sql, sqlparser.WithDialect(dbConn.GetMetaData().GetSqlParserDialect()))
|
||
biz.ErrIsNil(err, "SQL解析错误,请检查您的执行SQL")
|
||
isMulti := len(sqls) > 1
|
||
var execResAll *application.DbSqlExecRes
|
||
|
||
for _, s := range sqls {
|
||
s = stringx.TrimSpaceAndBr(s)
|
||
// 多条执行,暂不支持查询语句
|
||
if isMulti && len(s) > 10 {
|
||
biz.IsTrue(!strings.HasPrefix(strings.ToLower(s[:10]), "select"), "多条语句执行暂不不支持select语句")
|
||
}
|
||
|
||
execReq.Sql = s
|
||
execRes, err := d.DbSqlExecApp.Exec(ctx, execReq)
|
||
biz.ErrIsNilAppendErr(err, fmt.Sprintf("[%s] -> 执行失败: ", s)+"%s")
|
||
|
||
if execResAll == nil {
|
||
execResAll = execRes
|
||
} else {
|
||
execResAll.Merge(execRes)
|
||
}
|
||
}
|
||
|
||
colAndRes := make(map[string]any)
|
||
if execResAll != nil {
|
||
colAndRes["columns"] = execResAll.Columns
|
||
colAndRes["res"] = execResAll.Res
|
||
}
|
||
rc.ResData = colAndRes
|
||
}
|
||
|
||
// progressCategory sql文件执行进度消息类型
|
||
const progressCategory = "execSqlFileProgress"
|
||
|
||
// progressMsg sql文件执行进度消息
|
||
type progressMsg struct {
|
||
Id string `json:"id"`
|
||
Title string `json:"title"`
|
||
ExecutedStatements int `json:"executedStatements"`
|
||
Terminated bool `json:"terminated"`
|
||
}
|
||
|
||
// 执行sql文件
|
||
func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||
multipart, err := rc.GetRequest().MultipartReader()
|
||
biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
|
||
file, err := multipart.NextPart()
|
||
biz.ErrIsNilAppendErr(err, "读取sql文件失败: %s")
|
||
defer file.Close()
|
||
filename := file.FileName()
|
||
dbId := getDbId(rc)
|
||
dbName := getDbName(rc)
|
||
clientId := rc.Query("clientId")
|
||
|
||
dbConn, err := d.DbApp.GetDbConn(dbId, dbName)
|
||
biz.ErrIsNil(err)
|
||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
|
||
rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
|
||
|
||
defer func() {
|
||
if err := recover(); err != nil {
|
||
errInfo := anyx.ToString(err)
|
||
if len(errInfo) > 300 {
|
||
errInfo = errInfo[:300] + "..."
|
||
}
|
||
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.ErrSysMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s]执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), errInfo)).WithClientId(clientId))
|
||
}
|
||
}()
|
||
|
||
execReq := &application.DbSqlExecReq{
|
||
DbId: dbId,
|
||
Db: dbName,
|
||
Remark: filename,
|
||
DbConn: dbConn,
|
||
}
|
||
|
||
var sql string
|
||
|
||
tokenizer := sqlparser.NewReaderTokenizer(file,
|
||
sqlparser.WithCacheInBuffer(), sqlparser.WithDialect(dbConn.GetMetaData().GetSqlParserDialect()))
|
||
|
||
executedStatements := 0
|
||
progressId := stringx.Rand(32)
|
||
laId := rc.GetLoginAccount().Id
|
||
defer ws.SendJsonMsg(ws.UserId(laId), clientId, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||
Id: progressId,
|
||
Title: filename,
|
||
ExecutedStatements: executedStatements,
|
||
Terminated: true,
|
||
}).WithCategory(progressCategory))
|
||
ticker := time.NewTicker(time.Second * 1)
|
||
defer ticker.Stop()
|
||
for {
|
||
select {
|
||
case <-ticker.C:
|
||
ws.SendJsonMsg(ws.UserId(laId), clientId, msgdto.InfoSysMsg("sql脚本执行进度", &progressMsg{
|
||
Id: progressId,
|
||
Title: filename,
|
||
ExecutedStatements: executedStatements,
|
||
Terminated: false,
|
||
}).WithCategory(progressCategory))
|
||
default:
|
||
}
|
||
executedStatements++
|
||
sql, err = sqlparser.SplitNext(tokenizer)
|
||
if err == io.EOF {
|
||
break
|
||
}
|
||
biz.ErrIsNilAppendErr(err, "%s")
|
||
|
||
const prefixUse = "use "
|
||
const prefixUSE = "USE "
|
||
if strings.HasPrefix(sql, prefixUSE) || strings.HasPrefix(sql, prefixUse) {
|
||
var stmt sqlparser.Statement
|
||
stmt, err = sqlparser.Parse(sql)
|
||
biz.ErrIsNilAppendErr(err, "%s")
|
||
|
||
stmtUse, ok := stmt.(*sqlparser.Use)
|
||
// 最终执行结果以数据库直接结果为准
|
||
if !ok {
|
||
logx.Warnf("sql解析失败: %s", sql)
|
||
}
|
||
dbConn, err = d.DbApp.GetDbConn(dbId, stmtUse.DBName.String())
|
||
biz.ErrIsNil(err)
|
||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(laId, dbConn.Info.CodePath...), "%s")
|
||
execReq.DbConn = dbConn
|
||
}
|
||
// 需要记录执行记录
|
||
const maxRecordStatements = 64
|
||
if executedStatements < maxRecordStatements {
|
||
execReq.Sql = sql
|
||
_, err = d.DbSqlExecApp.Exec(rc.MetaCtx, execReq)
|
||
} else {
|
||
_, err = dbConn.Exec(sql)
|
||
}
|
||
|
||
biz.ErrIsNilAppendErr(err, "%s")
|
||
}
|
||
d.MsgApp.CreateAndSend(rc.GetLoginAccount(), msgdto.SuccessSysMsg("sql脚本执行成功", fmt.Sprintf("sql脚本执行完成:%s", rc.ReqParam)).WithClientId(clientId))
|
||
}
|
||
|
||
// 数据库dump
|
||
func (d *Db) DumpSql(rc *req.Ctx) {
|
||
dbId := getDbId(rc)
|
||
dbName := rc.Query("db")
|
||
dumpType := rc.Query("type")
|
||
tablesStr := rc.Query("tables")
|
||
extName := rc.Query("extName")
|
||
switch extName {
|
||
case ".gz", ".gzip", "gz", "gzip":
|
||
extName = ".gz"
|
||
default:
|
||
extName = ""
|
||
}
|
||
|
||
// 是否需要导出表结构
|
||
needStruct := dumpType == "1" || dumpType == "3"
|
||
// 是否需要导出数据
|
||
needData := dumpType == "2" || dumpType == "3"
|
||
|
||
la := rc.GetLoginAccount()
|
||
db, err := d.DbApp.GetById(dbId)
|
||
biz.ErrIsNil(err, "该数据库不存在")
|
||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(la.Id, d.TagApp.ListTagPathByTypeAndCode(consts.ResourceTypeDb, db.Code)...), "%s")
|
||
|
||
now := time.Now()
|
||
filename := fmt.Sprintf("%s-%s.%s.sql%s", db.Name, dbName, now.Format("20060102150405"), extName)
|
||
rc.Header("Content-Type", "application/octet-stream")
|
||
rc.Header("Content-Disposition", "attachment; filename="+filename)
|
||
if extName != ".gz" {
|
||
rc.Header("Content-Encoding", "gzip")
|
||
}
|
||
|
||
var tables []string
|
||
if len(tablesStr) > 0 {
|
||
tables = strings.Split(tablesStr, ",")
|
||
}
|
||
|
||
defer func() {
|
||
msg := anyx.ToString(recover())
|
||
if len(msg) > 0 {
|
||
msg = "数据库导出失败: " + msg
|
||
rc.GetWriter().Write([]byte(msg))
|
||
d.MsgApp.CreateAndSend(la, msgdto.ErrSysMsg("数据库导出失败", msg))
|
||
}
|
||
}()
|
||
|
||
biz.ErrIsNil(d.DbApp.DumpDb(rc.MetaCtx, &application.DumpDbReq{
|
||
DbId: dbId,
|
||
DbName: dbName,
|
||
Tables: tables,
|
||
DumpDDL: needStruct,
|
||
DumpData: needData,
|
||
Writer: rc.GetWriter(),
|
||
}))
|
||
|
||
rc.ReqParam = collx.Kvs("db", db, "database", dbName, "tables", tablesStr, "dumpType", dumpType)
|
||
}
|
||
|
||
func (d *Db) TableInfos(rc *req.Ctx) {
|
||
res, err := d.getDbConn(rc).GetMetaData().GetTables()
|
||
biz.ErrIsNilAppendErr(err, "获取表信息失败: %s")
|
||
rc.ResData = res
|
||
}
|
||
|
||
func (d *Db) TableIndex(rc *req.Ctx) {
|
||
tn := rc.Query("tableName")
|
||
biz.NotEmpty(tn, "tableName不能为空")
|
||
res, err := d.getDbConn(rc).GetMetaData().GetTableIndex(tn)
|
||
biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s")
|
||
rc.ResData = res
|
||
}
|
||
|
||
// @router /api/db/:dbId/c-metadata [get]
|
||
func (d *Db) ColumnMA(rc *req.Ctx) {
|
||
tn := rc.Query("tableName")
|
||
biz.NotEmpty(tn, "tableName不能为空")
|
||
|
||
dbi := d.getDbConn(rc)
|
||
res, err := dbi.GetMetaData().GetColumns(tn)
|
||
biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s")
|
||
rc.ResData = res
|
||
}
|
||
|
||
// @router /api/db/:dbId/hint-tables [get]
|
||
func (d *Db) HintTables(rc *req.Ctx) {
|
||
dbi := d.getDbConn(rc)
|
||
|
||
metadata := dbi.GetMetaData()
|
||
// 获取所有表
|
||
tables, err := metadata.GetTables()
|
||
biz.ErrIsNil(err)
|
||
tableNames := make([]string, 0)
|
||
for _, v := range tables {
|
||
tableNames = append(tableNames, v.TableName)
|
||
}
|
||
// key = 表名,value = 列名数组
|
||
res := make(map[string][]string)
|
||
|
||
// 表为空,则直接返回
|
||
if len(tableNames) == 0 {
|
||
rc.ResData = res
|
||
return
|
||
}
|
||
|
||
// 获取所有表下的所有列信息
|
||
columnMds, err := metadata.GetColumns(tableNames...)
|
||
biz.ErrIsNil(err)
|
||
for _, v := range columnMds {
|
||
tName := v.TableName
|
||
if res[tName] == nil {
|
||
res[tName] = make([]string, 0)
|
||
}
|
||
|
||
columnName := fmt.Sprintf("%s [%s]", v.ColumnName, v.GetColumnType())
|
||
comment := v.ColumnComment
|
||
// 如果字段备注不为空,则加上备注信息
|
||
if comment != "" {
|
||
columnName = fmt.Sprintf("%s[%s]", columnName, comment)
|
||
}
|
||
|
||
res[tName] = append(res[tName], columnName)
|
||
}
|
||
rc.ResData = res
|
||
}
|
||
|
||
func (d *Db) GetTableDDL(rc *req.Ctx) {
|
||
tn := rc.Query("tableName")
|
||
biz.NotEmpty(tn, "tableName不能为空")
|
||
res, err := d.getDbConn(rc).GetMetaData().GetTableDDL(tn, false)
|
||
biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s")
|
||
rc.ResData = res
|
||
}
|
||
|
||
func (d *Db) GetSchemas(rc *req.Ctx) {
|
||
res, err := d.getDbConn(rc).GetMetaData().GetSchemas()
|
||
biz.ErrIsNilAppendErr(err, "获取schemas失败: %s")
|
||
rc.ResData = res
|
||
}
|
||
|
||
func (d *Db) CopyTable(rc *req.Ctx) {
|
||
form := &form.DbCopyTableForm{}
|
||
copy := req.BindJsonAndCopyTo[*dbi.DbCopyTable](rc, form, new(dbi.DbCopyTable))
|
||
|
||
conn, err := d.DbApp.GetDbConn(form.Id, form.Db)
|
||
biz.ErrIsNilAppendErr(err, "拷贝表失败: %s")
|
||
|
||
err = conn.GetDialect().CopyTable(copy)
|
||
if err != nil {
|
||
logx.Errorf("拷贝表失败: %s", err.Error())
|
||
}
|
||
biz.ErrIsNilAppendErr(err, "拷贝表失败: %s")
|
||
}
|
||
|
||
func getDbId(rc *req.Ctx) uint64 {
|
||
dbId := rc.PathParamInt("dbId")
|
||
biz.IsTrue(dbId > 0, "dbId错误")
|
||
return uint64(dbId)
|
||
}
|
||
|
||
func getDbName(rc *req.Ctx) string {
|
||
db := rc.Query("db")
|
||
biz.NotEmpty(db, "db不能为空")
|
||
return db
|
||
}
|
||
|
||
func (d *Db) getDbConn(rc *req.Ctx) *dbi.DbConn {
|
||
dc, err := d.DbApp.GetDbConn(getDbId(rc), getDbName(rc))
|
||
biz.ErrIsNil(err)
|
||
return dc
|
||
}
|