Files
mayfly-go/server/internal/db/api/db.go

416 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package api
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo"
"mayfly-go/internal/db/application"
"mayfly-go/internal/db/application/dto"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/imsg"
msgdto "mayfly-go/internal/msg/application/dto"
"mayfly-go/internal/pkg/event"
"mayfly-go/internal/pkg/utils"
tagapp "mayfly-go/internal/tag/application"
tagentity "mayfly-go/internal/tag/domain/entity"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/global"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/req"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/writerx"
"strings"
"time"
"github.com/spf13/cast"
)
type Db struct {
instanceApp application.Instance `inject:"T"`
dbApp application.Db `inject:"T"`
dbSqlExecApp application.DbSqlExec `inject:"T"`
tagApp tagapp.TagTree `inject:"T"`
}
func (d *Db) ReqConfs() *req.Confs {
reqs := [...]*req.Conf{
// 获取数据库列表
req.NewGet("", d.Dbs),
req.NewPost("", d.Save).Log(req.NewLogSaveI(imsg.LogDbSave)),
req.NewDelete(":dbId", d.DeleteDb).Log(req.NewLogSaveI(imsg.LogDbDelete)),
req.NewGet(":dbId/t-create-ddl", d.GetTableDDL),
req.NewGet(":dbId/version", d.GetVersion),
req.NewGet(":dbId/pg/schemas", d.GetSchemas),
req.NewPost(":dbId/exec-sql", d.ExecSql).Log(req.NewLogI(imsg.LogDbRunSql)),
req.NewPost(":dbId/exec-sql-file", d.ExecSqlFile).Log(req.NewLogSaveI(imsg.LogDbRunSqlFile)).RequiredPermissionCode("db:sqlscript:run"),
req.NewGet(":dbId/dump", d.DumpSql).Log(req.NewLogSaveI(imsg.LogDbDump)).NoRes(),
req.NewGet(":dbId/t-infos", d.TableInfos),
req.NewGet(":dbId/t-index", d.TableIndex),
req.NewGet(":dbId/c-metadata", d.ColumnMA),
req.NewGet(":dbId/hint-tables", d.HintTables),
req.NewPost(":dbId/copy-table", d.CopyTable),
}
return req.NewConfs("/dbs", reqs[:]...)
}
// @router /api/dbs [get]
func (d *Db) Dbs(rc *req.Ctx) {
queryCond := req.BindQuery[entity.DbQuery](rc)
// 不存在可访问标签id即没有可操作数据
tags := d.tagApp.GetAccountTags(rc.GetLoginAccount().Id, &tagentity.TagTreeQuery{
TypePaths: collx.AsArray(tagentity.NewTypePaths(tagentity.TagTypeDbInstance, tagentity.TagTypeAuthCert, tagentity.TagTypeDb)),
CodePathLikes: collx.AsArray(queryCond.TagPath),
})
if len(tags) == 0 {
rc.ResData = model.NewEmptyPageResult[any]()
return
}
queryCond.Codes = tags.GetCodes()
res, err := d.dbApp.GetPageList(queryCond)
biz.ErrIsNil(err)
resVo := model.PageResultConv[*entity.DbListPO, *vo.DbListVO](res)
dbvos := resVo.List
instances, _ := d.instanceApp.GetByIds(collx.ArrayMap(dbvos, func(i *vo.DbListVO) uint64 {
return i.InstanceId
}))
instancesMap := collx.ArrayToMap(instances, func(i *entity.DbInstance) uint64 {
return i.Id
})
for _, dbvo := range dbvos {
di := instancesMap[dbvo.InstanceId]
if di != nil {
dbvo.InstanceCode = di.Code
dbvo.InstanceType = di.Type
dbvo.Host = di.Host
dbvo.Port = di.Port
}
}
rc.ResData = resVo
}
func (d *Db) Save(rc *req.Ctx) {
form, db := req.BindJsonAndCopyTo[form.DbForm, entity.Db](rc)
rc.ReqParam = form
biz.ErrIsNil(d.dbApp.SaveDb(rc.MetaCtx, db))
}
func (d *Db) DeleteDb(rc *req.Ctx) {
idsStr := rc.PathParam("dbId")
rc.ReqParam = idsStr
ids := strings.Split(idsStr, ",")
ctx := rc.MetaCtx
for _, v := range ids {
biz.ErrIsNil(d.dbApp.Delete(ctx, cast.ToUint64(v)))
}
}
/** 数据库操作相关、执行sql等 ***/
func (d *Db) ExecSql(rc *req.Ctx) {
form := req.BindJson[form.DbSqlExecForm](rc)
ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second)
defer cancel()
dbId := getDbId(rc)
dbConn, err := d.dbApp.GetDbConn(ctx, dbId, form.Db)
biz.ErrIsNil(err)
biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
global.EventBus.Publish(rc.MetaCtx, event.EventTopicResourceOp, dbConn.Info.CodePath[0])
sqlStr, err := utils.AesDecryptByLa(form.Sql, rc.GetLoginAccount())
biz.ErrIsNilAppendErr(err, "sql decoding failure: %s")
rc.ReqParam = fmt.Sprintf("%s %s\n-> %s", dbConn.Info.GetLogDesc(), form.ExecId, sqlStr)
biz.NotEmpty(form.Sql, "sql cannot be empty")
execReq := &dto.DbSqlExecReq{
DbId: dbId,
Db: form.Db,
Remark: form.Remark,
DbConn: dbConn,
Sql: sqlStr,
CheckFlow: true,
}
execRes, err := d.dbSqlExecApp.Exec(ctx, execReq)
biz.ErrIsNil(err)
rc.ResData = execRes
}
// 执行sql文件
func (d *Db) ExecSqlFile(rc *req.Ctx) {
dbId := getDbId(rc)
clientId := rc.Query("clientId")
dbName := rc.Query("db")
uploadId := rc.Query("uploadId")
filename := rc.QueryDefault("filename", "sql_file.sql")
dbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, dbId, dbName)
biz.ErrIsNil(err)
biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
body := rc.GetRequest().Body
defer body.Close()
// 支持 .zip 文件:如果是 zip 格式则解压后读取第一个文件内容
reader, err := d.getSqlReader(body, filename)
biz.ErrIsNilAppendErr(err, "failed to read sql file: %s")
biz.ErrIsNil(d.dbSqlExecApp.ExecReader(rc.MetaCtx, &dto.SqlReaderExec{
Reader: reader,
Filename: filename,
DbConn: dbConn,
ClientId: clientId,
UploadId: uploadId,
}))
}
// getSqlReader 如果文件名是 .zip 结尾,则解压并返回第一个文件内容;否则直接返回原 reader
func (d *Db) getSqlReader(body io.Reader, filename string) (io.Reader, error) {
if !strings.HasSuffix(strings.ToLower(filename), ".zip") {
return body, nil
}
// 限制10MB避免解压过大文件
data, err := io.ReadAll(io.LimitReader(body, 10*1024*1024))
if err != nil {
return nil, fmt.Errorf("read zip file error: %w", err)
}
zr, err := zip.NewReader(bytes.NewReader(data), int64(len(data)))
if err != nil {
return nil, fmt.Errorf("invalid zip file: %w", err)
}
for _, f := range zr.File {
if !f.FileInfo().IsDir() {
rc, err := f.Open()
if err != nil {
return nil, fmt.Errorf("open zip entry error: %w", err)
}
content, err := io.ReadAll(rc)
rc.Close()
if err != nil {
return nil, fmt.Errorf("read zip entry error: %w", err)
}
return bytes.NewReader(content), nil
}
}
return nil, fmt.Errorf("zip file is empty")
}
// 数据库dump
func (d *Db) DumpSql(rc *req.Ctx) {
dbId := getDbId(rc)
dbName := rc.Query("db")
dumpType := rc.Query("type")
tablesStr := rc.Query("tables")
extName := rc.Query("extName")
switch extName {
case ".gz", ".gzip", "gz", "gzip":
extName = ".gz"
default:
extName = ""
}
// 是否需要导出表结构
needStruct := dumpType == "1" || dumpType == "3"
// 是否需要导出数据
needData := dumpType == "2" || dumpType == "3"
la := rc.GetLoginAccount()
dbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, dbId, dbName)
biz.ErrIsNil(err)
biz.ErrIsNilAppendErr(d.tagApp.CanAccess(la.Id, dbConn.Info.CodePath...), "%s")
now := time.Now()
filename := fmt.Sprintf("%s-%s.%s.sql%s", dbConn.Info.Name, dbName, now.Format("20060102150405"), extName)
rc.Header("Content-Type", "application/octet-stream")
rc.Header("Content-Disposition", "attachment; filename="+filename)
if extName != ".gz" {
rc.Header("Content-Encoding", "gzip")
}
var tables []string
if len(tablesStr) > 0 {
tables = strings.Split(tablesStr, ",")
}
defer func() {
msg := anyx.ToString(recover())
if len(msg) > 0 {
msg = "DB dump error: " + msg
rc.GetWriter().Write([]byte(msg))
global.EventBus.Publish(rc.MetaCtx, event.EventTopicMsgTmplSend, &msgdto.MsgTmplSendEvent{
TmplChannel: msgdto.MsgTmplDbDumpFail,
Params: collx.M{"dbId": dbConn.Info.Id, "dbName": dbConn.Info.Name, "error": msg},
ReceiverIds: []uint64{la.Id},
})
}
}()
gzipWriter := writerx.NewGzipWriter(rc.GetWriter())
defer gzipWriter.Close()
biz.ErrIsNil(d.dbApp.DumpDb(rc.MetaCtx, &dto.DumpDb{
DbId: dbId,
DbName: dbName,
Tables: tables,
DumpDDL: needStruct,
DumpData: needData,
Writer: gzipWriter,
}))
rc.ReqParam = collx.Kvs("db", dbConn.Info, "database", dbName, "tables", tablesStr, "dumpType", dumpType)
}
func (d *Db) TableInfos(rc *req.Ctx) {
res, err := d.getDbConn(rc).GetMetadata().GetTables()
biz.ErrIsNilAppendErr(err, "get table error: %s")
rc.ResData = res
}
func (d *Db) TableIndex(rc *req.Ctx) {
tn := rc.Query("tableName")
biz.NotEmpty(tn, "tableName cannot be empty")
res, err := d.getDbConn(rc).GetMetadata().GetTableIndex(tn)
biz.ErrIsNilAppendErr(err, "get table index error: %s")
rc.ResData = res
}
// @router /api/db/:dbId/c-metadata [get]
func (d *Db) ColumnMA(rc *req.Ctx) {
tn := rc.Query("tableName")
biz.NotEmpty(tn, "tableName cannot be empty")
dbi := d.getDbConn(rc)
res, err := dbi.GetMetadata().GetColumns(tn)
biz.ErrIsNilAppendErr(err, "get column metadata error: %s")
rc.ResData = res
}
// @router /api/db/:dbId/hint-tables [get]
func (d *Db) HintTables(rc *req.Ctx) {
dbi := d.getDbConn(rc)
metadata := dbi.GetMetadata()
// 获取所有表
tables, err := metadata.GetTables()
biz.ErrIsNil(err)
tableNames := make([]string, 0)
for _, v := range tables {
tableNames = append(tableNames, v.TableName)
}
// key = 表名value = 列名数组
res := make(map[string][]string)
// 表为空,则直接返回
if len(tableNames) == 0 {
rc.ResData = res
return
}
// 获取所有表下的所有列信息
columnMds, err := metadata.GetColumns(tableNames...)
biz.ErrIsNil(err)
for _, v := range columnMds {
tName := v.TableName
if res[tName] == nil {
res[tName] = make([]string, 0)
}
columnName := fmt.Sprintf("%s [%s]", v.ColumnName, v.GetColumnType())
comment := v.ColumnComment
// 如果字段备注不为空,则加上备注信息
if comment != "" {
columnName = fmt.Sprintf("%s[%s]", columnName, comment)
}
res[tName] = append(res[tName], columnName)
}
rc.ResData = res
}
func (d *Db) GetTableDDL(rc *req.Ctx) {
tn := rc.Query("tableName")
biz.NotEmpty(tn, "tableName cannot be empty")
res, err := d.getDbConn(rc).GetMetadata().GetTableDDL(tn, false)
biz.ErrIsNilAppendErr(err, "get table DDL error: %s")
rc.ResData = res
}
func (d *Db) GetVersion(rc *req.Ctx) {
version := d.getDbConn(rc).GetMetadata().GetCompatibleDbVersion()
rc.ResData = version
}
func (d *Db) GetSchemas(rc *req.Ctx) {
res, err := d.getDbConn(rc).GetMetadata().GetSchemas()
biz.ErrIsNilAppendErr(err, "get schemas error: %s")
rc.ResData = res
}
func (d *Db) CopyTable(rc *req.Ctx) {
form, copy := req.BindJsonAndCopyTo[form.DbCopyTableForm, dbi.DbCopyTable](rc)
conn, err := d.dbApp.GetDbConn(rc.MetaCtx, form.Id, form.Db)
biz.ErrIsNilAppendErr(err, "copy table error: %s")
err = conn.GetDialect().CopyTable(copy)
if err != nil {
logx.Errorf("copy table error: %s", err.Error())
}
biz.ErrIsNilAppendErr(err, "copy table error: %s")
}
func getDbId(rc *req.Ctx) uint64 {
dbId := rc.PathParamInt("dbId")
biz.IsTrue(dbId > 0, "dbId error")
return uint64(dbId)
}
func getDbName(rc *req.Ctx) string {
db := rc.Query("db")
biz.NotEmpty(db, "db cannot be empty")
return db
}
func (d *Db) getDbConn(rc *req.Ctx) *dbi.DbConn {
dc, err := d.dbApp.GetDbConn(rc.MetaCtx, getDbId(rc), getDbName(rc))
biz.ErrIsNil(err)
return dc
}