mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 23:40:24 +08:00
refactor: db/redis/mongo连接代码包独立
This commit is contained in:
@@ -11,7 +11,7 @@
|
||||
"dependencies": {
|
||||
"@element-plus/icons-vue": "^2.1.0",
|
||||
"asciinema-player": "^3.6.2",
|
||||
"axios": "^1.5.1",
|
||||
"axios": "^1.6.0",
|
||||
"countup.js": "^2.7.0",
|
||||
"cropperjs": "^1.5.11",
|
||||
"echarts": "^5.4.3",
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
:disabled="form.id !== undefined"
|
||||
remote
|
||||
:remote-method="getInstances"
|
||||
@change="getAllDatabase"
|
||||
@change="changeInstance"
|
||||
v-model="form.instanceId"
|
||||
placeholder="请输入实例名称搜索并选择实例"
|
||||
filterable
|
||||
@@ -163,6 +163,11 @@ watch(props, (newValue: any) => {
|
||||
}
|
||||
});
|
||||
|
||||
const changeInstance = () => {
|
||||
state.databaseList = [];
|
||||
getAllDatabase();
|
||||
};
|
||||
|
||||
/**
|
||||
* 改变表单中的数据库字段,方便表单错误提示。如全部删光,可提示请添加数据库
|
||||
*/
|
||||
@@ -171,8 +176,6 @@ const changeDatabase = () => {
|
||||
};
|
||||
|
||||
const getAllDatabase = async () => {
|
||||
// 清空数据库列表,可能已经有选择库了
|
||||
state.databaseList = [];
|
||||
if (state.form.instanceId > 0) {
|
||||
state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId });
|
||||
}
|
||||
|
||||
@@ -628,10 +628,10 @@ asynckit@^0.4.0:
|
||||
resolved "https://registry.npmmirror.com/asynckit/-/asynckit-0.4.0.tgz"
|
||||
integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==
|
||||
|
||||
axios@^1.5.1:
|
||||
version "1.5.1"
|
||||
resolved "https://registry.npmmirror.com/axios/-/axios-1.5.1.tgz#11fbaa11fc35f431193a9564109c88c1f27b585f"
|
||||
integrity sha512-Q28iYCWzNHjAm+yEAot5QaAMxhMghWLFVf7rRdwhUI+c2jix2DUXjAHXVi+s1ibs3mjPO/cCgbA++3BjD0vP/A==
|
||||
axios@^1.6.0:
|
||||
version "1.6.0"
|
||||
resolved "https://registry.npmmirror.com/axios/-/axios-1.6.0.tgz#f1e5292f26b2fd5c2e66876adc5b06cdbd7d2102"
|
||||
integrity sha512-EZ1DYihju9pwVB+jg67ogm+Tmqc6JmhamRN6I4Zt8DfZu5lbcQGw3ozH9lFejSJgs/ibaef3A9PMXPLeefFGJg==
|
||||
dependencies:
|
||||
follow-redirects "^1.15.0"
|
||||
form-data "^4.0.0"
|
||||
|
||||
@@ -6,7 +6,7 @@ const (
|
||||
AdminId = 1
|
||||
|
||||
MachineConnExpireTime = 60 * time.Minute
|
||||
DbConnExpireTime = 45 * time.Minute
|
||||
DbConnExpireTime = 120 * time.Minute
|
||||
RedisConnExpireTime = 30 * time.Minute
|
||||
MongoConnExpireTime = 30 * time.Minute
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"mayfly-go/internal/db/api/form"
|
||||
"mayfly-go/internal/db/api/vo"
|
||||
"mayfly-go/internal/db/application"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
msgapp "mayfly-go/internal/msg/application"
|
||||
msgdto "mayfly-go/internal/msg/application/dto"
|
||||
@@ -82,14 +83,14 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
|
||||
dc, err := d.DbApp.GetDbConnection(getDbId(g), getDbName(g))
|
||||
func (d *Db) getDbConn(g *gin.Context) *dbm.DbConn {
|
||||
dc, err := d.DbApp.GetDbConn(getDbId(g), getDbName(g))
|
||||
biz.ErrIsNil(err)
|
||||
return dc
|
||||
}
|
||||
|
||||
func (d *Db) TableInfos(rc *req.Ctx) {
|
||||
res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetTableInfos()
|
||||
res, err := d.getDbConn(rc.GinCtx).GetMeta().GetTableInfos()
|
||||
biz.ErrIsNilAppendErr(err, "获取表信息失败: %s")
|
||||
rc.ResData = res
|
||||
}
|
||||
@@ -97,7 +98,7 @@ func (d *Db) TableInfos(rc *req.Ctx) {
|
||||
func (d *Db) TableIndex(rc *req.Ctx) {
|
||||
tn := rc.GinCtx.Query("tableName")
|
||||
biz.NotEmpty(tn, "tableName不能为空")
|
||||
res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetTableIndex(tn)
|
||||
res, err := d.getDbConn(rc.GinCtx).GetMeta().GetTableIndex(tn)
|
||||
biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s")
|
||||
rc.ResData = res
|
||||
}
|
||||
@@ -105,7 +106,7 @@ func (d *Db) TableIndex(rc *req.Ctx) {
|
||||
func (d *Db) GetCreateTableDdl(rc *req.Ctx) {
|
||||
tn := rc.GinCtx.Query("tableName")
|
||||
biz.NotEmpty(tn, "tableName不能为空")
|
||||
res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetCreateTableDdl(tn)
|
||||
res, err := d.getDbConn(rc.GinCtx).GetMeta().GetCreateTableDdl(tn)
|
||||
biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s")
|
||||
rc.ResData = res
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
|
||||
ginx.BindJsonAndValid(g, form)
|
||||
|
||||
dbId := getDbId(g)
|
||||
dbConn, err := d.DbApp.GetDbConnection(dbId, form.Db)
|
||||
dbConn, err := d.DbApp.GetDbConn(dbId, form.Db)
|
||||
biz.ErrIsNil(err)
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
|
||||
@@ -188,7 +189,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
dbName := getDbName(g)
|
||||
clientId := g.Query("clientId")
|
||||
|
||||
dbConn, err := d.DbApp.GetDbConnection(dbId, dbName)
|
||||
dbConn, err := d.DbApp.GetDbConn(dbId, dbName)
|
||||
biz.ErrIsNil(err)
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
|
||||
@@ -256,7 +257,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
|
||||
if !ok {
|
||||
logx.Warnf("sql解析失败: %s", sql)
|
||||
}
|
||||
dbConn, err = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String())
|
||||
dbConn, err = d.DbApp.GetDbConn(dbId, stmtUse.DBName.String())
|
||||
biz.ErrIsNil(err)
|
||||
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
|
||||
execReq.DbConn = dbConn
|
||||
@@ -334,7 +335,7 @@ func (d *Db) DumpSql(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool) {
|
||||
dbConn, err := d.DbApp.GetDbConnection(dbId, dbName)
|
||||
dbConn, err := d.DbApp.GetDbConn(dbId, dbName)
|
||||
biz.ErrIsNil(err)
|
||||
writer.WriteString("\n-- ----------------------------")
|
||||
writer.WriteString("\n-- 导出平台: mayfly-go")
|
||||
@@ -397,7 +398,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
|
||||
// @router /api/db/:dbId/t-metadata [get]
|
||||
func (d *Db) TableMA(rc *req.Ctx) {
|
||||
dbi := d.getDbConnection(rc.GinCtx)
|
||||
dbi := d.getDbConn(rc.GinCtx)
|
||||
res, err := dbi.GetMeta().GetTables()
|
||||
biz.ErrIsNilAppendErr(err, "获取表基础信息失败: %s")
|
||||
rc.ResData = res
|
||||
@@ -409,7 +410,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) {
|
||||
tn := g.Query("tableName")
|
||||
biz.NotEmpty(tn, "tableName不能为空")
|
||||
|
||||
dbi := d.getDbConnection(rc.GinCtx)
|
||||
dbi := d.getDbConn(rc.GinCtx)
|
||||
res, err := dbi.GetMeta().GetColumns(tn)
|
||||
biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s")
|
||||
rc.ResData = res
|
||||
@@ -417,7 +418,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) {
|
||||
|
||||
// @router /api/db/:dbId/hint-tables [get]
|
||||
func (d *Db) HintTables(rc *req.Ctx) {
|
||||
dbi := d.getDbConnection(rc.GinCtx)
|
||||
dbi := d.getDbConn(rc.GinCtx)
|
||||
|
||||
dm := dbi.GetMeta()
|
||||
// 获取所有表
|
||||
|
||||
@@ -33,7 +33,7 @@ func (d *Instance) Instances(rc *req.Ctx) {
|
||||
// @router /api/instances [post]
|
||||
func (d *Instance) SaveInstance(rc *req.Ctx) {
|
||||
form := &form.InstanceForm{}
|
||||
instance := ginx.BindJsonAndCopyTo[*entity.Instance](rc.GinCtx, form, new(entity.Instance))
|
||||
instance := ginx.BindJsonAndCopyTo[*entity.DbInstance](rc.GinCtx, form, new(entity.DbInstance))
|
||||
|
||||
// 密码解密,并使用解密后的赋值
|
||||
originPwd, err := cryptox.DefaultRsaDecrypt(form.Password, true)
|
||||
@@ -52,7 +52,7 @@ func (d *Instance) SaveInstance(rc *req.Ctx) {
|
||||
// @router /api/instances/:instance [GET]
|
||||
func (d *Instance) GetInstance(rc *req.Ctx) {
|
||||
dbId := getInstanceId(rc.GinCtx)
|
||||
dbEntity, err := d.InstanceApp.GetById(new(entity.Instance), dbId)
|
||||
dbEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), dbId)
|
||||
biz.ErrIsNil(err, "获取数据库实例错误")
|
||||
dbEntity.Password = ""
|
||||
rc.ResData = dbEntity
|
||||
@@ -62,7 +62,7 @@ func (d *Instance) GetInstance(rc *req.Ctx) {
|
||||
// @router /api/instances/:instance/pwd [GET]
|
||||
func (d *Instance) GetInstancePwd(rc *req.Ctx) {
|
||||
instanceId := getInstanceId(rc.GinCtx)
|
||||
instanceEntity, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "Password")
|
||||
instanceEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
|
||||
biz.ErrIsNil(err, "获取数据库实例错误")
|
||||
instanceEntity.PwdDecrypt()
|
||||
rc.ResData = instanceEntity.Password
|
||||
@@ -80,7 +80,7 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) {
|
||||
biz.ErrIsNilAppendErr(err, "string类型转换为int异常: %s")
|
||||
instanceId := uint64(value)
|
||||
if d.DbApp.Count(&entity.DbQuery{InstanceId: instanceId}) != 0 {
|
||||
instance, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "name")
|
||||
instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "name")
|
||||
biz.ErrIsNil(err, "获取数据库实例错误,数据库实例ID为: %d", instance.Id)
|
||||
biz.IsTrue(false, "不能删除数据库实例【%s】,请先删除其关联的数据库资源。", instance.Name)
|
||||
}
|
||||
@@ -91,7 +91,7 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) {
|
||||
// 获取数据库实例的所有数据库名
|
||||
func (d *Instance) GetDatabaseNames(rc *req.Ctx) {
|
||||
instanceId := getInstanceId(rc.GinCtx)
|
||||
instance, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "Password")
|
||||
instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
|
||||
biz.ErrIsNil(err, "获取数据库实例错误")
|
||||
instance.PwdDecrypt()
|
||||
res, err := d.InstanceApp.GetDatabases(instance)
|
||||
|
||||
@@ -1,23 +1,15 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/consts"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/internal/machine/infrastructure/machine"
|
||||
"mayfly-go/pkg/base"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"mayfly-go/pkg/utils/structx"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Db interface {
|
||||
@@ -34,9 +26,9 @@ type Db interface {
|
||||
Delete(id uint64) error
|
||||
|
||||
// 获取数据库连接实例
|
||||
// @param id 数据库实例id
|
||||
// @param db 数据库
|
||||
GetDbConnection(dbId uint64, dbName string) (*DbConnection, error)
|
||||
// @param id 数据库id
|
||||
// @param dbName 数据库
|
||||
GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error)
|
||||
}
|
||||
|
||||
func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance) Db {
|
||||
@@ -96,7 +88,7 @@ func (d *dbAppImpl) Save(dbEntity *entity.Db) error {
|
||||
|
||||
for _, v := range delDb {
|
||||
// 关闭数据库连接
|
||||
CloseDb(dbEntity.Id, v)
|
||||
dbm.CloseDb(dbEntity.Id, v)
|
||||
// 删除该库关联的所有sql记录
|
||||
d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: dbId, Db: v})
|
||||
}
|
||||
@@ -112,316 +104,39 @@ func (d *dbAppImpl) Delete(id uint64) error {
|
||||
dbs := strings.Split(db.Database, " ")
|
||||
for _, v := range dbs {
|
||||
// 关闭连接
|
||||
CloseDb(id, v)
|
||||
dbm.CloseDb(id, v)
|
||||
}
|
||||
// 删除该库下用户保存的所有sql信息
|
||||
d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: id})
|
||||
return d.DeleteById(id)
|
||||
}
|
||||
|
||||
var mutex sync.Mutex
|
||||
|
||||
func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) (*DbConnection, error) {
|
||||
cacheKey := GetDbCacheKey(dbId, dbName)
|
||||
|
||||
// Id不为0,则为需要缓存
|
||||
needCache := dbId != 0
|
||||
if needCache {
|
||||
load, ok := dbCache.Get(cacheKey)
|
||||
if ok {
|
||||
return load.(*DbConnection), nil
|
||||
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
|
||||
return dbm.GetDbConn(dbId, dbName, func() (*dbm.DbInfo, error) {
|
||||
db, err := d.GetById(new(entity.Db), dbId)
|
||||
if err != nil {
|
||||
return nil, errorx.NewBiz("数据库信息不存在")
|
||||
}
|
||||
}
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
db, err := d.GetById(new(entity.Db), dbId)
|
||||
if err != nil {
|
||||
return nil, errorx.NewBiz("数据库信息不存在")
|
||||
}
|
||||
if !strings.Contains(" "+db.Database+" ", " "+dbName+" ") {
|
||||
return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName)
|
||||
}
|
||||
|
||||
instance, err := d.dbInstanceApp.GetById(new(entity.Instance), db.InstanceId)
|
||||
if err != nil {
|
||||
return nil, errorx.NewBiz("数据库实例不存在")
|
||||
}
|
||||
// 密码解密
|
||||
instance.PwdDecrypt()
|
||||
|
||||
dbInfo := NewDbInfo(db, instance)
|
||||
dbInfo.Database = dbName
|
||||
dbi := &DbConnection{Id: cacheKey, Info: dbInfo}
|
||||
|
||||
conn, err := getInstanceConn(instance, dbName)
|
||||
if err != nil {
|
||||
dbi.Close()
|
||||
logx.Errorf("连接db失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, dbName)
|
||||
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
|
||||
}
|
||||
|
||||
// 最大连接周期,超过时间的连接就close
|
||||
// conn.SetConnMaxLifetime(100 * time.Second)
|
||||
// 设置最大连接数
|
||||
conn.SetMaxOpenConns(5)
|
||||
// 设置闲置连接数
|
||||
conn.SetMaxIdleConns(1)
|
||||
|
||||
dbi.db = conn
|
||||
logx.Infof("连接db: %s:%d/%s", dbInfo.Host, dbInfo.Port, dbName)
|
||||
if needCache {
|
||||
dbCache.Put(cacheKey, dbi)
|
||||
}
|
||||
return dbi, nil
|
||||
}
|
||||
|
||||
//---------------------------------------- db instance ------------------------------------
|
||||
|
||||
type DbInfo struct {
|
||||
Id uint64
|
||||
Name string
|
||||
Type entity.DbType // 类型,mysql oracle等
|
||||
Host string
|
||||
Port int
|
||||
Network string
|
||||
Username string
|
||||
TagPath string
|
||||
Database string
|
||||
SshTunnelMachineId int
|
||||
}
|
||||
|
||||
func NewDbInfo(db *entity.Db, instance *entity.Instance) *DbInfo {
|
||||
return &DbInfo{
|
||||
Id: db.Id,
|
||||
Name: db.Name,
|
||||
Type: instance.Type,
|
||||
Host: instance.Host,
|
||||
Port: instance.Port,
|
||||
Username: instance.Username,
|
||||
TagPath: db.TagPath,
|
||||
SshTunnelMachineId: instance.SshTunnelMachineId,
|
||||
}
|
||||
}
|
||||
|
||||
// 获取记录日志的描述
|
||||
func (d *DbInfo) GetLogDesc() string {
|
||||
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database)
|
||||
}
|
||||
|
||||
// db实例连接信息
|
||||
type DbConnection struct {
|
||||
Id string
|
||||
Info *DbInfo
|
||||
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// 执行查询语句
|
||||
// 依次返回 列名数组,结果map,错误
|
||||
func (d *DbConnection) SelectData(execSql string) ([]string, []map[string]any, error) {
|
||||
return selectDataByDb(d.db, execSql)
|
||||
}
|
||||
|
||||
// 将查询结果映射至struct,可具体参考sqlx库
|
||||
func (d *DbConnection) SelectData2Struct(execSql string, dest any) error {
|
||||
return select2StructByDb(d.db, execSql, dest)
|
||||
}
|
||||
|
||||
// WalkTableRecord 遍历表记录
|
||||
func (d *DbConnection) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error {
|
||||
return walkTableRecord(d.db, selectSql, walk)
|
||||
}
|
||||
|
||||
// 执行 update, insert, delete,建表等sql
|
||||
// 返回影响条数和错误
|
||||
func (d *DbConnection) Exec(sql string) (int64, error) {
|
||||
res, err := d.db.Exec(sql)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// 获取数据库元信息实现接口
|
||||
func (d *DbConnection) GetMeta() DbMetadata {
|
||||
switch d.Info.Type {
|
||||
case entity.DbTypeMysql:
|
||||
return &MysqlMetadata{di: d}
|
||||
case entity.DbTypePostgres:
|
||||
return &PgsqlMetadata{di: d}
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
func (d *DbConnection) Close() {
|
||||
if d.db != nil {
|
||||
if err := d.db.Close(); err != nil {
|
||||
logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
|
||||
if !strings.Contains(" "+db.Database+" ", " "+dbName+" ") {
|
||||
return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName)
|
||||
}
|
||||
d.db = nil
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// 客户端连接缓存,指定时间内没有访问则会被关闭, key为数据库实例id:数据库
|
||||
var dbCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
|
||||
value.(*DbConnection).Close()
|
||||
})
|
||||
|
||||
func init() {
|
||||
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
|
||||
// 遍历所有db连接实例,若存在db实例使用该ssh隧道机器,则返回true,表示还在使用中...
|
||||
items := dbCache.Items()
|
||||
for _, v := range items {
|
||||
if v.Value.(*DbConnection).Info.SshTunnelMachineId == machineId {
|
||||
return true
|
||||
}
|
||||
instance, err := d.dbInstanceApp.GetById(new(entity.DbInstance), db.InstanceId)
|
||||
if err != nil {
|
||||
return nil, errorx.NewBiz("数据库实例不存在")
|
||||
}
|
||||
return false
|
||||
// 密码解密
|
||||
instance.PwdDecrypt()
|
||||
return toDbInfo(instance, dbId, dbName, db.TagPath), nil
|
||||
})
|
||||
}
|
||||
|
||||
func GetDbCacheKey(dbId uint64, db string) string {
|
||||
return fmt.Sprintf("%d:%s", dbId, db)
|
||||
}
|
||||
|
||||
func selectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) {
|
||||
// 列名用于前端表头名称按照数据库与查询字段顺序显示
|
||||
var colNames []string
|
||||
result := make([]map[string]any, 0, 16)
|
||||
err := walkTableRecord(db, selectSql, func(record map[string]any, columns []string) {
|
||||
result = append(result, record)
|
||||
if colNames == nil {
|
||||
colNames = make([]string, len(columns))
|
||||
copy(colNames, columns)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return colNames, result, nil
|
||||
}
|
||||
|
||||
func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) error {
|
||||
rows, err := db.Query(selectSql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数,
|
||||
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
|
||||
defer func() {
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lenCols := len(colTypes)
|
||||
// 列名用于前端表头名称按照数据库与查询字段顺序显示
|
||||
colNames := make([]string, lenCols)
|
||||
// 这里表示一行填充数据
|
||||
scans := make([]any, lenCols)
|
||||
// 这里表示一行所有列的值,用[]byte表示
|
||||
values := make([][]byte, lenCols)
|
||||
for k, colType := range colTypes {
|
||||
colNames[k] = colType.Name()
|
||||
// 这里scans引用values,把数据填充到[]byte里
|
||||
scans[k] = &values[k]
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
// 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大
|
||||
if err := rows.Scan(scans...); err != nil {
|
||||
return err
|
||||
}
|
||||
// 每行数据
|
||||
rowData := make(map[string]any, lenCols)
|
||||
// 把values中的数据复制到row中
|
||||
for i, v := range values {
|
||||
rowData[colTypes[i].Name()] = valueConvert(v, colTypes[i])
|
||||
}
|
||||
walk(rowData, colNames)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 将查询的值转为对应列类型的实际值,不全部转为字符串
|
||||
func valueConvert(data []byte, colType *sql.ColumnType) any {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
// 列的数据库类型名
|
||||
colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName())
|
||||
|
||||
// 如果类型是bit,则直接返回第一个字节即可
|
||||
if strings.Contains(colDatabaseTypeName, "bit") {
|
||||
return data[0]
|
||||
}
|
||||
|
||||
// 这里把[]byte数据转成string
|
||||
stringV := string(data)
|
||||
if stringV == "" {
|
||||
return ""
|
||||
}
|
||||
colScanType := strings.ToLower(colType.ScanType().Name())
|
||||
|
||||
if strings.Contains(colScanType, "int") {
|
||||
// 如果长度超过16位,则返回字符串,因为前端js长度大于16会丢失精度
|
||||
if len(stringV) > 16 {
|
||||
return stringV
|
||||
}
|
||||
intV, _ := strconv.Atoi(stringV)
|
||||
switch colType.ScanType().Kind() {
|
||||
case reflect.Int8:
|
||||
return int8(intV)
|
||||
case reflect.Uint8:
|
||||
return uint8(intV)
|
||||
case reflect.Int64:
|
||||
return int64(intV)
|
||||
case reflect.Uint64:
|
||||
return uint64(intV)
|
||||
case reflect.Uint:
|
||||
return uint(intV)
|
||||
default:
|
||||
return intV
|
||||
}
|
||||
}
|
||||
if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") {
|
||||
floatV, _ := strconv.ParseFloat(stringV, 64)
|
||||
return floatV
|
||||
}
|
||||
|
||||
return stringV
|
||||
}
|
||||
|
||||
// 查询数据结果映射至struct。可参考sqlx库
|
||||
func select2StructByDb(db *sql.DB, selectSql string, dest any) error {
|
||||
rows, err := db.Query(selectSql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数,
|
||||
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
|
||||
defer func() {
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
}()
|
||||
return scanAll(rows, dest, false)
|
||||
}
|
||||
|
||||
// 删除db缓存并关闭该数据库所有连接
|
||||
func CloseDb(dbId uint64, db string) {
|
||||
dbCache.Delete(GetDbCacheKey(dbId, db))
|
||||
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath string) *dbm.DbInfo {
|
||||
di := new(dbm.DbInfo)
|
||||
di.Id = dbId
|
||||
di.Database = database
|
||||
di.TagPath = tagPath
|
||||
|
||||
structx.Copy(di, instance)
|
||||
return di
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/config"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/errorx"
|
||||
@@ -20,7 +21,7 @@ type DbSqlExecReq struct {
|
||||
Sql string
|
||||
Remark string
|
||||
LoginAccount *model.LoginAccount
|
||||
DbConn *DbConnection
|
||||
DbConn *dbm.DbConn
|
||||
}
|
||||
|
||||
type DbSqlExecRes struct {
|
||||
@@ -269,7 +270,7 @@ func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *ent
|
||||
return doExec(execSqlReq.Sql, execSqlReq.DbConn)
|
||||
}
|
||||
|
||||
func doExec(sql string, dbConn *DbConnection) (*DbSqlExecRes, error) {
|
||||
func doExec(sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
|
||||
rowsAffected, err := dbConn.Exec(sql)
|
||||
execRes := "success"
|
||||
if err != nil {
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/base"
|
||||
@@ -11,20 +9,20 @@ import (
|
||||
)
|
||||
|
||||
type Instance interface {
|
||||
base.App[*entity.Instance]
|
||||
base.App[*entity.DbInstance]
|
||||
|
||||
// GetPageList 分页获取数据库实例
|
||||
GetPageList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
|
||||
|
||||
Count(condition *entity.InstanceQuery) int64
|
||||
|
||||
Save(instanceEntity *entity.Instance) error
|
||||
Save(instanceEntity *entity.DbInstance) error
|
||||
|
||||
// Delete 删除数据库信息
|
||||
Delete(id uint64) error
|
||||
|
||||
// GetDatabases 获取数据库实例的所有数据库列表
|
||||
GetDatabases(entity *entity.Instance) ([]string, error)
|
||||
GetDatabases(entity *entity.DbInstance) ([]string, error)
|
||||
}
|
||||
|
||||
func newInstanceApp(instanceRepo repository.Instance) Instance {
|
||||
@@ -34,7 +32,7 @@ func newInstanceApp(instanceRepo repository.Instance) Instance {
|
||||
}
|
||||
|
||||
type instanceAppImpl struct {
|
||||
base.AppImpl[*entity.Instance, repository.Instance]
|
||||
base.AppImpl[*entity.DbInstance, repository.Instance]
|
||||
}
|
||||
|
||||
// GetPageList 分页获取数据库实例
|
||||
@@ -46,19 +44,21 @@ func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 {
|
||||
return app.CountByCond(condition)
|
||||
}
|
||||
|
||||
func (app *instanceAppImpl) Save(instanceEntity *entity.Instance) error {
|
||||
func (app *instanceAppImpl) Save(instanceEntity *entity.DbInstance) error {
|
||||
// 默认tcp连接
|
||||
instanceEntity.Network = instanceEntity.GetNetwork()
|
||||
|
||||
// 测试连接
|
||||
if instanceEntity.Password != "" {
|
||||
if err := testConnection(instanceEntity); err != nil {
|
||||
return errorx.NewBiz("数据库连接失败: %s", err.Error())
|
||||
dbConn, err := toDbInfo(instanceEntity, 0, "", "").Conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dbConn.Close()
|
||||
}
|
||||
|
||||
// 查找是否存在该库
|
||||
oldInstance := &entity.Instance{Host: instanceEntity.Host, Port: instanceEntity.Port, Username: instanceEntity.Username}
|
||||
oldInstance := &entity.DbInstance{Host: instanceEntity.Host, Port: instanceEntity.Port, Username: instanceEntity.Username}
|
||||
if instanceEntity.SshTunnelMachineId > 0 {
|
||||
oldInstance.SshTunnelMachineId = instanceEntity.SshTunnelMachineId
|
||||
}
|
||||
@@ -87,55 +87,19 @@ func (app *instanceAppImpl) Delete(id uint64) error {
|
||||
return app.DeleteById(id)
|
||||
}
|
||||
|
||||
// getInstanceConn 获取数据库连接数据库实例
|
||||
func getInstanceConn(instance *entity.Instance, db string) (*sql.DB, error) {
|
||||
var conn *sql.DB
|
||||
var err error
|
||||
switch instance.Type {
|
||||
case entity.DbTypeMysql:
|
||||
conn, err = getMysqlDB(instance, db)
|
||||
case entity.DbTypePostgres:
|
||||
conn, err = getPgsqlDB(instance, db)
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", instance.Type))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = conn.Ping()
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func testConnection(d *entity.Instance) error {
|
||||
// 不指定数据库名称
|
||||
conn, err := getInstanceConn(d, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) ([]string, error) {
|
||||
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) {
|
||||
ed.Network = ed.GetNetwork()
|
||||
databases := make([]string, 0)
|
||||
var dbConn *sql.DB
|
||||
metaDb := ed.Type.MetaDbName()
|
||||
getDatabasesSql := ed.Type.StmtSelectDbName()
|
||||
|
||||
dbConn, err := getInstanceConn(ed, metaDb)
|
||||
dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn()
|
||||
if err != nil {
|
||||
return nil, errorx.NewBiz("数据库连接失败: %s", err.Error())
|
||||
return nil, err
|
||||
}
|
||||
defer dbConn.Close()
|
||||
|
||||
_, res, err := selectDataByDb(dbConn, getDatabasesSql)
|
||||
_, res, err := dbConn.SelectData(getDatabasesSql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
195
server/internal/db/dbm/conn.go
Normal file
195
server/internal/db/dbm/conn.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/pkg/logx"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// db实例连接信息
|
||||
type DbConn struct {
|
||||
Id string
|
||||
Info *DbInfo
|
||||
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// 执行查询语句
|
||||
// 依次返回 列名数组,结果map,错误
|
||||
func (d *DbConn) SelectData(execSql string) ([]string, []map[string]any, error) {
|
||||
return selectDataByDb(d.db, execSql)
|
||||
}
|
||||
|
||||
// 将查询结果映射至struct,可具体参考sqlx库
|
||||
func (d *DbConn) SelectData2Struct(execSql string, dest any) error {
|
||||
return select2StructByDb(d.db, execSql, dest)
|
||||
}
|
||||
|
||||
// WalkTableRecord 遍历表记录
|
||||
func (d *DbConn) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error {
|
||||
return walkTableRecord(d.db, selectSql, walk)
|
||||
}
|
||||
|
||||
// 执行 update, insert, delete,建表等sql
|
||||
// 返回影响条数和错误
|
||||
func (d *DbConn) Exec(sql string) (int64, error) {
|
||||
res, err := d.db.Exec(sql)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// 获取数据库元信息实现接口
|
||||
func (d *DbConn) GetMeta() DbMetadata {
|
||||
switch d.Info.Type {
|
||||
case DbTypeMysql:
|
||||
return &MysqlMetadata{dc: d}
|
||||
case DbTypePostgres:
|
||||
return &PgsqlMetadata{dc: d}
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
func (d *DbConn) Close() {
|
||||
if d.db != nil {
|
||||
if err := d.db.Close(); err != nil {
|
||||
logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
|
||||
}
|
||||
d.db = nil
|
||||
}
|
||||
}
|
||||
|
||||
func selectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) {
|
||||
// 列名用于前端表头名称按照数据库与查询字段顺序显示
|
||||
var colNames []string
|
||||
result := make([]map[string]any, 0, 16)
|
||||
err := walkTableRecord(db, selectSql, func(record map[string]any, columns []string) {
|
||||
result = append(result, record)
|
||||
if colNames == nil {
|
||||
colNames = make([]string, len(columns))
|
||||
copy(colNames, columns)
|
||||
}
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return colNames, result, nil
|
||||
}
|
||||
|
||||
func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) error {
|
||||
rows, err := db.Query(selectSql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数,
|
||||
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
|
||||
defer func() {
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
colTypes, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lenCols := len(colTypes)
|
||||
// 列名用于前端表头名称按照数据库与查询字段顺序显示
|
||||
colNames := make([]string, lenCols)
|
||||
// 这里表示一行填充数据
|
||||
scans := make([]any, lenCols)
|
||||
// 这里表示一行所有列的值,用[]byte表示
|
||||
values := make([][]byte, lenCols)
|
||||
for k, colType := range colTypes {
|
||||
colNames[k] = colType.Name()
|
||||
// 这里scans引用values,把数据填充到[]byte里
|
||||
scans[k] = &values[k]
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
// 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大
|
||||
if err := rows.Scan(scans...); err != nil {
|
||||
return err
|
||||
}
|
||||
// 每行数据
|
||||
rowData := make(map[string]any, lenCols)
|
||||
// 把values中的数据复制到row中
|
||||
for i, v := range values {
|
||||
rowData[colTypes[i].Name()] = valueConvert(v, colTypes[i])
|
||||
}
|
||||
walk(rowData, colNames)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 将查询的值转为对应列类型的实际值,不全部转为字符串
|
||||
func valueConvert(data []byte, colType *sql.ColumnType) any {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
// 列的数据库类型名
|
||||
colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName())
|
||||
|
||||
// 如果类型是bit,则直接返回第一个字节即可
|
||||
if strings.Contains(colDatabaseTypeName, "bit") {
|
||||
return data[0]
|
||||
}
|
||||
|
||||
// 这里把[]byte数据转成string
|
||||
stringV := string(data)
|
||||
if stringV == "" {
|
||||
return ""
|
||||
}
|
||||
colScanType := strings.ToLower(colType.ScanType().Name())
|
||||
|
||||
if strings.Contains(colScanType, "int") {
|
||||
// 如果长度超过16位,则返回字符串,因为前端js长度大于16会丢失精度
|
||||
if len(stringV) > 16 {
|
||||
return stringV
|
||||
}
|
||||
intV, _ := strconv.Atoi(stringV)
|
||||
switch colType.ScanType().Kind() {
|
||||
case reflect.Int8:
|
||||
return int8(intV)
|
||||
case reflect.Uint8:
|
||||
return uint8(intV)
|
||||
case reflect.Int64:
|
||||
return int64(intV)
|
||||
case reflect.Uint64:
|
||||
return uint64(intV)
|
||||
case reflect.Uint:
|
||||
return uint(intV)
|
||||
default:
|
||||
return intV
|
||||
}
|
||||
}
|
||||
if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") {
|
||||
floatV, _ := strconv.ParseFloat(stringV, 64)
|
||||
return floatV
|
||||
}
|
||||
|
||||
return stringV
|
||||
}
|
||||
|
||||
// 查询数据结果映射至struct。可参考sqlx库
|
||||
func select2StructByDb(db *sql.DB, selectSql string, dest any) error {
|
||||
rows, err := db.Query(selectSql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数,
|
||||
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
|
||||
defer func() {
|
||||
if rows != nil {
|
||||
rows.Close()
|
||||
}
|
||||
}()
|
||||
return scanAll(rows, dest, false)
|
||||
}
|
||||
73
server/internal/db/dbm/conn_cache.go
Normal file
73
server/internal/db/dbm/conn_cache.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/consts"
|
||||
"mayfly-go/internal/machine/infrastructure/machine"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/logx"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 客户端连接缓存,指定时间内没有访问则会被关闭, key为数据库连接id
|
||||
var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
|
||||
value.(*DbConn).Close()
|
||||
})
|
||||
|
||||
func init() {
|
||||
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
|
||||
// 遍历所有db连接实例,若存在db实例使用该ssh隧道机器,则返回true,表示还在使用中...
|
||||
items := connCache.Items()
|
||||
for _, v := range items {
|
||||
if v.Value.(*DbConn).Info.SshTunnelMachineId == machineId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
var mutex sync.Mutex
|
||||
|
||||
// 从缓存中获取数据库连接信息,若缓存中不存在则会使用回调函数获取dbInfo进行连接并缓存
|
||||
func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) (*DbConn, error) {
|
||||
connId := GetDbConnId(dbId, database)
|
||||
|
||||
// connId不为空,则为需要缓存
|
||||
needCache := connId != ""
|
||||
if needCache {
|
||||
load, ok := connCache.Get(connId)
|
||||
if ok {
|
||||
return load.(*DbConn), nil
|
||||
}
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// 若缓存中不存在,则从回调函数中获取DbInfo
|
||||
dbInfo, err := getDbInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
dbConn, err := dbInfo.Conn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if needCache {
|
||||
connCache.Put(connId, dbConn)
|
||||
}
|
||||
return dbConn, nil
|
||||
}
|
||||
|
||||
// 删除db缓存并关闭该数据库所有连接
|
||||
func CloseDb(dbId uint64, db string) {
|
||||
connCache.Delete(GetDbConnId(dbId, db))
|
||||
}
|
||||
@@ -1,10 +1,11 @@
|
||||
package entity
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/kanzihuang/vitess/go/vt/sqlparser"
|
||||
"github.com/lib/pq"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DbType string
|
||||
@@ -1,8 +1,9 @@
|
||||
package entity
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_QuoteLiteral(t *testing.T) {
|
||||
79
server/internal/db/dbm/info.go
Normal file
79
server/internal/db/dbm/info.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
)
|
||||
|
||||
type DbInfo struct {
|
||||
Id uint64
|
||||
Name string
|
||||
|
||||
Type DbType // 类型,mysql postgres等
|
||||
Host string
|
||||
Port int
|
||||
Network string
|
||||
Username string
|
||||
Password string
|
||||
Params string
|
||||
Database string
|
||||
|
||||
TagPath string
|
||||
SshTunnelMachineId int
|
||||
}
|
||||
|
||||
// 获取记录日志的描述
|
||||
func (d *DbInfo) GetLogDesc() string {
|
||||
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database)
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
func (dbInfo *DbInfo) Conn() (*DbConn, error) {
|
||||
var conn *sql.DB
|
||||
var err error
|
||||
database := dbInfo.Database
|
||||
|
||||
switch dbInfo.Type {
|
||||
case DbTypeMysql:
|
||||
conn, err = getMysqlDB(dbInfo)
|
||||
case DbTypePostgres:
|
||||
conn, err = getPgsqlDB(dbInfo)
|
||||
default:
|
||||
return nil, errorx.NewBiz("invalid database type: %s", dbInfo.Type)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logx.Errorf("连接db失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
|
||||
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
|
||||
}
|
||||
|
||||
err = conn.Ping()
|
||||
if err != nil {
|
||||
logx.Errorf("db ping失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
|
||||
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
|
||||
}
|
||||
|
||||
dbc := &DbConn{Id: GetDbConnId(dbInfo.Id, database), Info: dbInfo}
|
||||
|
||||
// 最大连接周期,超过时间的连接就close
|
||||
// conn.SetConnMaxLifetime(100 * time.Second)
|
||||
// 设置最大连接数
|
||||
conn.SetMaxOpenConns(5)
|
||||
// 设置闲置连接数
|
||||
conn.SetMaxIdleConns(1)
|
||||
dbc.db = conn
|
||||
logx.Infof("连接db: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
|
||||
|
||||
return dbc, nil
|
||||
}
|
||||
|
||||
// 获取连接id
|
||||
func GetDbConnId(dbId uint64, db string) string {
|
||||
if dbId == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d:%s", dbId, db)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package application
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
@@ -1,10 +1,9 @@
|
||||
package application
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
@@ -13,7 +12,7 @@ import (
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) {
|
||||
func getMysqlDB(d *DbInfo) (*sql.DB, error) {
|
||||
// SSH Conect
|
||||
if d.SshTunnelMachineId > 0 {
|
||||
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
|
||||
@@ -25,7 +24,7 @@ func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) {
|
||||
})
|
||||
}
|
||||
// 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
@@ -42,12 +41,12 @@ const (
|
||||
)
|
||||
|
||||
type MysqlMetadata struct {
|
||||
di *DbConnection
|
||||
dc *DbConn
|
||||
}
|
||||
|
||||
// 获取表基础元信息, 如表名等
|
||||
func (mm *MysqlMetadata) GetTables() ([]Table, error) {
|
||||
_, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_MA_KEY))
|
||||
_, res, err := mm.dc.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_MA_KEY))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -72,7 +71,7 @@ func (mm *MysqlMetadata) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
tableName = tableName + "'" + tableNames[i] + "'"
|
||||
}
|
||||
|
||||
_, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
|
||||
_, res, err := mm.dc.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -113,7 +112,7 @@ func (mm *MysqlMetadata) GetPrimaryKey(tablename string) (string, error) {
|
||||
|
||||
// 获取表信息,比GetTableMetedatas获取更详细的表信息
|
||||
func (mm *MysqlMetadata) GetTableInfos() ([]Table, error) {
|
||||
_, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
|
||||
_, res, err := mm.dc.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -134,7 +133,7 @@ func (mm *MysqlMetadata) GetTableInfos() ([]Table, error) {
|
||||
|
||||
// 获取表索引信息
|
||||
func (mm *MysqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
|
||||
_, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName))
|
||||
_, res, err := mm.dc.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -171,7 +170,7 @@ func (mm *MysqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
|
||||
_, res, err := mm.di.SelectData(fmt.Sprintf("show create table `%s` ", tableName))
|
||||
_, res, err := mm.dc.SelectData(fmt.Sprintf("show create table `%s` ", tableName))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -179,9 +178,9 @@ func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
|
||||
}
|
||||
|
||||
func (mm *MysqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) {
|
||||
return mm.di.SelectData(fmt.Sprintf("SELECT * FROM %s LIMIT %d, %d", tableName, (pageNum-1)*pageSize, pageSize))
|
||||
return mm.dc.SelectData(fmt.Sprintf("SELECT * FROM %s LIMIT %d, %d", tableName, (pageNum-1)*pageSize, pageSize))
|
||||
}
|
||||
|
||||
func (mm *MysqlMetadata) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
|
||||
return mm.di.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
return mm.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
}
|
||||
@@ -1,10 +1,9 @@
|
||||
package application
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
@@ -16,7 +15,7 @@ import (
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) {
|
||||
func getPgsqlDB(d *DbInfo) (*sql.DB, error) {
|
||||
driverName := string(d.Type)
|
||||
// SSH Conect
|
||||
if d.SshTunnelMachineId > 0 {
|
||||
@@ -28,6 +27,7 @@ func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) {
|
||||
sql.Drivers()
|
||||
}
|
||||
|
||||
db := d.Database
|
||||
var dbParam string
|
||||
if db != "" {
|
||||
dbParam = "dbname=" + db
|
||||
@@ -75,12 +75,12 @@ const (
|
||||
)
|
||||
|
||||
type PgsqlMetadata struct {
|
||||
di *DbConnection
|
||||
dc *DbConn
|
||||
}
|
||||
|
||||
// 获取表基础元信息, 如表名等
|
||||
func (pm *PgsqlMetadata) GetTables() ([]Table, error) {
|
||||
_, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_MA_KEY))
|
||||
_, res, err := pm.dc.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_MA_KEY))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -105,7 +105,7 @@ func (pm *PgsqlMetadata) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
tableName = tableName + "'" + tableNames[i] + "'"
|
||||
}
|
||||
|
||||
_, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
|
||||
_, res, err := pm.dc.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -144,7 +144,7 @@ func (pm *PgsqlMetadata) GetPrimaryKey(tablename string) (string, error) {
|
||||
|
||||
// 获取表信息,比GetTables获取更详细的表信息
|
||||
func (pm *PgsqlMetadata) GetTableInfos() ([]Table, error) {
|
||||
_, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
|
||||
_, res, err := pm.dc.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -165,7 +165,7 @@ func (pm *PgsqlMetadata) GetTableInfos() ([]Table, error) {
|
||||
|
||||
// 获取表索引信息
|
||||
func (pm *PgsqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
|
||||
_, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
|
||||
_, res, err := pm.dc.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -186,16 +186,16 @@ func (pm *PgsqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
|
||||
_, err := pm.di.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
|
||||
_, err := pm.dc.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
_, schemaRes, _ := pm.di.SelectData("select current_schema() as schema")
|
||||
_, schemaRes, _ := pm.dc.SelectData("select current_schema() as schema")
|
||||
schemaName := schemaRes[0]["schema"].(string)
|
||||
|
||||
ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName)
|
||||
_, res, err := pm.di.SelectData(ddlSql)
|
||||
_, res, err := pm.dc.SelectData(ddlSql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -204,9 +204,9 @@ func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
|
||||
}
|
||||
|
||||
func (pm *PgsqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) {
|
||||
return pm.di.SelectData(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
|
||||
return pm.dc.SelectData(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
|
||||
}
|
||||
|
||||
func (pm *PgsqlMetadata) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
|
||||
return pm.di.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
return pm.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package application
|
||||
package dbm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
50
server/internal/db/domain/entity/db_instance.go
Normal file
50
server/internal/db/domain/entity/db_instance.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/utils"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/pkg/model"
|
||||
)
|
||||
|
||||
type DbInstance struct {
|
||||
model.Model
|
||||
|
||||
Name string `orm:"column(name)" json:"name"`
|
||||
Type dbm.DbType `orm:"column(type)" json:"type"` // 类型,mysql oracle等
|
||||
Host string `orm:"column(host)" json:"host"`
|
||||
Port int `orm:"column(port)" json:"port"`
|
||||
Network string `orm:"column(network)" json:"network"`
|
||||
Username string `orm:"column(username)" json:"username"`
|
||||
Password string `orm:"column(password)" json:"-"`
|
||||
Params string `orm:"column(params)" json:"params"`
|
||||
Remark string `orm:"column(remark)" json:"remark"`
|
||||
SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id
|
||||
}
|
||||
|
||||
func (d *DbInstance) TableName() string {
|
||||
return "t_db_instance"
|
||||
}
|
||||
|
||||
// 获取数据库连接网络, 若没有使用ssh隧道,则直接返回。否则返回拼接的网络需要注册至指定dial
|
||||
func (d *DbInstance) GetNetwork() string {
|
||||
network := d.Network
|
||||
if d.SshTunnelMachineId <= 0 {
|
||||
if network == "" {
|
||||
return "tcp"
|
||||
} else {
|
||||
return network
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId)
|
||||
}
|
||||
|
||||
func (d *DbInstance) PwdEncrypt() {
|
||||
// 密码替换为加密后的密码
|
||||
d.Password = utils.PwdAesEncrypt(d.Password)
|
||||
}
|
||||
|
||||
func (d *DbInstance) PwdDecrypt() {
|
||||
// 密码替换为解密后的密码
|
||||
d.Password = utils.PwdAesDecrypt(d.Password)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package entity
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/utils"
|
||||
"mayfly-go/pkg/model"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
model.Model
|
||||
|
||||
Name string `orm:"column(name)" json:"name"`
|
||||
Type DbType `orm:"column(type)" json:"type"` // 类型,mysql oracle等
|
||||
Host string `orm:"column(host)" json:"host"`
|
||||
Port int `orm:"column(port)" json:"port"`
|
||||
Network string `orm:"column(network)" json:"network"`
|
||||
Username string `orm:"column(username)" json:"username"`
|
||||
Password string `orm:"column(password)" json:"-"`
|
||||
Params string `orm:"column(params)" json:"params"`
|
||||
Remark string `orm:"column(remark)" json:"remark"`
|
||||
SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id
|
||||
}
|
||||
|
||||
func (d *Instance) TableName() string {
|
||||
return "t_db_instance"
|
||||
}
|
||||
|
||||
// 获取数据库连接网络, 若没有使用ssh隧道,则直接返回。否则返回拼接的网络需要注册至指定dial
|
||||
func (d *Instance) GetNetwork() string {
|
||||
network := d.Network
|
||||
if d.SshTunnelMachineId <= 0 {
|
||||
if network == "" {
|
||||
return "tcp"
|
||||
} else {
|
||||
return network
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId)
|
||||
}
|
||||
|
||||
func (d *Instance) PwdEncrypt() {
|
||||
// 密码替换为加密后的密码
|
||||
d.Password = utils.PwdAesEncrypt(d.Password)
|
||||
}
|
||||
|
||||
func (d *Instance) PwdDecrypt() {
|
||||
// 密码替换为解密后的密码
|
||||
d.Password = utils.PwdAesDecrypt(d.Password)
|
||||
}
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
)
|
||||
|
||||
type Instance interface {
|
||||
base.Repo[*entity.Instance]
|
||||
base.Repo[*entity.DbInstance]
|
||||
|
||||
// 分页获取数据库实例信息列表
|
||||
GetInstanceList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
|
||||
|
||||
@@ -9,16 +9,16 @@ import (
|
||||
)
|
||||
|
||||
type instanceRepoImpl struct {
|
||||
base.RepoImpl[*entity.Instance]
|
||||
base.RepoImpl[*entity.DbInstance]
|
||||
}
|
||||
|
||||
func newInstanceRepo() repository.Instance {
|
||||
return &instanceRepoImpl{base.RepoImpl[*entity.Instance]{M: new(entity.Instance)}}
|
||||
return &instanceRepoImpl{base.RepoImpl[*entity.DbInstance]{M: new(entity.DbInstance)}}
|
||||
}
|
||||
|
||||
// 分页获取数据库信息列表
|
||||
func (d *instanceRepoImpl) GetInstanceList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
|
||||
qd := gormx.NewQuery(new(entity.Instance)).
|
||||
qd := gormx.NewQuery(new(entity.DbInstance)).
|
||||
Eq("id", condition.Id).
|
||||
Eq("host", condition.Host).
|
||||
Like("name", condition.Name)
|
||||
|
||||
@@ -74,18 +74,20 @@ func (m *Mongo) DeleteMongo(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (m *Mongo) Databases(rc *req.Ctx) {
|
||||
cli := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx)).Cli
|
||||
res, err := cli.ListDatabases(context.TODO(), bson.D{})
|
||||
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx))
|
||||
biz.ErrIsNil(err)
|
||||
res, err := conn.Cli.ListDatabases(context.TODO(), bson.D{})
|
||||
biz.ErrIsNilAppendErr(err, "获取mongo所有库信息失败: %s")
|
||||
rc.ResData = res
|
||||
}
|
||||
|
||||
func (m *Mongo) Collections(rc *req.Ctx) {
|
||||
cli := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx)).Cli
|
||||
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx))
|
||||
biz.ErrIsNil(err)
|
||||
db := rc.GinCtx.Query("database")
|
||||
biz.NotEmpty(db, "database不能为空")
|
||||
ctx := context.TODO()
|
||||
res, err := cli.Database(db).ListCollectionNames(ctx, bson.D{})
|
||||
res, err := conn.Cli.Database(db).ListCollectionNames(ctx, bson.D{})
|
||||
biz.ErrIsNilAppendErr(err, "获取库集合信息失败: %s")
|
||||
rc.ResData = res
|
||||
}
|
||||
@@ -94,8 +96,9 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
|
||||
commandForm := new(form.MongoRunCommand)
|
||||
ginx.BindJsonAndValid(rc.GinCtx, commandForm)
|
||||
|
||||
inst := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx))
|
||||
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
|
||||
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx))
|
||||
biz.ErrIsNil(err)
|
||||
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
|
||||
|
||||
// 顺序执行
|
||||
commands := bson.D{}
|
||||
@@ -110,7 +113,7 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
|
||||
|
||||
ctx := context.TODO()
|
||||
var bm bson.M
|
||||
err := inst.Cli.Database(commandForm.Database).RunCommand(
|
||||
err = conn.Cli.Database(commandForm.Database).RunCommand(
|
||||
ctx,
|
||||
commands,
|
||||
).Decode(&bm)
|
||||
@@ -121,10 +124,13 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
|
||||
|
||||
func (m *Mongo) FindCommand(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
cli := m.MongoApp.GetMongoInst(m.GetMongoId(g)).Cli
|
||||
commandForm := new(form.MongoFindCommand)
|
||||
ginx.BindJsonAndValid(g, commandForm)
|
||||
|
||||
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
|
||||
biz.ErrIsNil(err)
|
||||
cli := conn.Cli
|
||||
|
||||
limit := commandForm.Limit
|
||||
if limit != 0 {
|
||||
biz.IsTrue(limit <= 100, "limit不能超过100")
|
||||
@@ -157,8 +163,9 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) {
|
||||
commandForm := new(form.MongoUpdateByIdCommand)
|
||||
ginx.BindJsonAndValid(g, commandForm)
|
||||
|
||||
inst := m.MongoApp.GetMongoInst(m.GetMongoId(g))
|
||||
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
|
||||
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
|
||||
biz.ErrIsNil(err)
|
||||
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
|
||||
|
||||
// 解析docId文档id,如果为string类型则使用ObjectId解析,解析失败则为普通字符串
|
||||
docId := commandForm.DocId
|
||||
@@ -170,7 +177,7 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).UpdateByID(context.TODO(), docId, commandForm.Update)
|
||||
res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).UpdateByID(context.TODO(), docId, commandForm.Update)
|
||||
biz.ErrIsNilAppendErr(err, "命令执行失败: %s")
|
||||
|
||||
rc.ResData = res
|
||||
@@ -181,8 +188,9 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) {
|
||||
commandForm := new(form.MongoUpdateByIdCommand)
|
||||
ginx.BindJsonAndValid(g, commandForm)
|
||||
|
||||
inst := m.MongoApp.GetMongoInst(m.GetMongoId(g))
|
||||
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
|
||||
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
|
||||
biz.ErrIsNil(err)
|
||||
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
|
||||
|
||||
// 解析docId文档id,如果为string类型则使用ObjectId解析,解析失败则为普通字符串
|
||||
docId := commandForm.DocId
|
||||
@@ -194,7 +202,7 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).DeleteOne(context.TODO(), bson.D{{Key: "_id", Value: docId}})
|
||||
res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).DeleteOne(context.TODO(), bson.D{{Key: "_id", Value: docId}})
|
||||
biz.ErrIsNilAppendErr(err, "命令执行失败: %s")
|
||||
rc.ResData = res
|
||||
}
|
||||
@@ -204,10 +212,11 @@ func (m *Mongo) InsertOneCommand(rc *req.Ctx) {
|
||||
commandForm := new(form.MongoInsertCommand)
|
||||
ginx.BindJsonAndValid(g, commandForm)
|
||||
|
||||
inst := m.MongoApp.GetMongoInst(m.GetMongoId(g))
|
||||
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
|
||||
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
|
||||
biz.ErrIsNil(err)
|
||||
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
|
||||
|
||||
res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).InsertOne(context.TODO(), commandForm.Doc)
|
||||
res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).InsertOne(context.TODO(), commandForm.Doc)
|
||||
biz.ErrIsNilAppendErr(err, "命令执行失败: %s")
|
||||
rc.ResData = res
|
||||
}
|
||||
|
||||
@@ -1,26 +1,12 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/consts"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/internal/machine/infrastructure/machine"
|
||||
"mayfly-go/internal/mongo/domain/entity"
|
||||
"mayfly-go/internal/mongo/domain/repository"
|
||||
"mayfly-go/internal/mongo/mgm"
|
||||
"mayfly-go/pkg/base"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/utils/netx"
|
||||
"mayfly-go/pkg/utils/structx"
|
||||
"net"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type Mongo interface {
|
||||
@@ -38,7 +24,7 @@ type Mongo interface {
|
||||
|
||||
// 获取mongo连接实例
|
||||
// @param id mongo id
|
||||
GetMongoInst(id uint64) *MongoInstance
|
||||
GetMongoConn(id uint64) (*mgm.MongoConn, error)
|
||||
}
|
||||
|
||||
func newMongoAppImpl(mongoRepo repository.Mongo) Mongo {
|
||||
@@ -61,7 +47,7 @@ func (d *mongoAppImpl) Count(condition *entity.MongoQuery) int64 {
|
||||
}
|
||||
|
||||
func (d *mongoAppImpl) Delete(id uint64) error {
|
||||
DeleteMongoCache(id)
|
||||
mgm.CloseConn(id)
|
||||
return d.GetRepo().DeleteById(id)
|
||||
}
|
||||
|
||||
@@ -71,148 +57,16 @@ func (d *mongoAppImpl) Save(m *entity.Mongo) error {
|
||||
}
|
||||
|
||||
// 先关闭连接
|
||||
DeleteMongoCache(m.Id)
|
||||
mgm.CloseConn(m.Id)
|
||||
return d.GetRepo().UpdateById(m)
|
||||
}
|
||||
|
||||
func (d *mongoAppImpl) GetMongoInst(id uint64) *MongoInstance {
|
||||
mongoInstance, err := GetMongoInstance(id, func(u uint64) (*entity.Mongo, error) {
|
||||
mongo, err := d.GetById(new(entity.Mongo), u)
|
||||
func (d *mongoAppImpl) GetMongoConn(id uint64) (*mgm.MongoConn, error) {
|
||||
return mgm.GetMongoConn(id, func() (*mgm.MongoInfo, error) {
|
||||
mongo, err := d.GetById(new(entity.Mongo), id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errorx.NewBiz("mongo信息不存在")
|
||||
}
|
||||
return mongo, nil
|
||||
})
|
||||
biz.ErrIsNilAppendErr(err, "连接mongo失败: %s")
|
||||
return mongoInstance
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------
|
||||
|
||||
// mongo客户端连接缓存,指定时间内没有访问则会被关闭
|
||||
var mongoCliCache = cache.NewTimedCache(consts.MongoConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Infof("删除mongo连接缓存: id = %v", key)
|
||||
value.(*MongoInstance).Close()
|
||||
})
|
||||
|
||||
func init() {
|
||||
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
|
||||
// 遍历所有mongo连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中...
|
||||
items := mongoCliCache.Items()
|
||||
for _, v := range items {
|
||||
if v.Value.(*MongoInstance).Info.SshTunnelMachineId == machineId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return mongo.ToMongoInfo(), nil
|
||||
})
|
||||
}
|
||||
|
||||
// 获取mongo的连接实例
|
||||
func GetMongoInstance(mongoId uint64, getMongoEntity func(uint64) (*entity.Mongo, error)) (*MongoInstance, error) {
|
||||
mi, err := mongoCliCache.ComputeIfAbsent(mongoId, func(_ any) (any, error) {
|
||||
mongoEntity, err := getMongoEntity(mongoId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c, err := connect(mongoEntity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
})
|
||||
|
||||
if mi != nil {
|
||||
return mi.(*MongoInstance), err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func DeleteMongoCache(mongoId uint64) {
|
||||
mongoCliCache.Delete(mongoId)
|
||||
}
|
||||
|
||||
type MongoInfo struct {
|
||||
Id uint64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
TagPath string `json:"tagPath"`
|
||||
SshTunnelMachineId int `json:"-"` // ssh隧道机器id
|
||||
}
|
||||
|
||||
func (m *MongoInfo) GetLogDesc() string {
|
||||
return fmt.Sprintf("Mongo[id=%d, tag=%s, name=%s]", m.Id, m.TagPath, m.Name)
|
||||
}
|
||||
|
||||
type MongoInstance struct {
|
||||
Id uint64
|
||||
Info *MongoInfo
|
||||
|
||||
Cli *mongo.Client
|
||||
}
|
||||
|
||||
func (mi *MongoInstance) Close() {
|
||||
if mi.Cli != nil {
|
||||
if err := mi.Cli.Disconnect(context.Background()); err != nil {
|
||||
logx.Errorf("关闭mongo实例[%d]连接失败: %s", mi.Id, err)
|
||||
}
|
||||
mi.Cli = nil
|
||||
}
|
||||
}
|
||||
|
||||
// 连接mongo,并返回client
|
||||
func connect(me *entity.Mongo) (*MongoInstance, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mongoInstance := &MongoInstance{Id: me.Id, Info: toMongoInfo(me)}
|
||||
|
||||
mongoOptions := options.Client().ApplyURI(me.Uri).
|
||||
SetMaxPoolSize(1)
|
||||
// 启用ssh隧道则连接隧道机器
|
||||
if me.SshTunnelMachineId > 0 {
|
||||
mongoOptions.SetDialer(&MongoSshDialer{machineId: me.SshTunnelMachineId})
|
||||
}
|
||||
|
||||
client, err := mongo.Connect(ctx, mongoOptions)
|
||||
if err != nil {
|
||||
mongoInstance.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err = client.Ping(context.TODO(), nil); err != nil {
|
||||
mongoInstance.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logx.Infof("连接mongo: %s", func(str string) string {
|
||||
reg := regexp.MustCompile(`(^mongodb://.+?:)(.+)(@.+$)`)
|
||||
return reg.ReplaceAllString(str, `${1}****${3}`)
|
||||
}(me.Uri))
|
||||
mongoInstance.Cli = client
|
||||
return mongoInstance, err
|
||||
}
|
||||
|
||||
func toMongoInfo(me *entity.Mongo) *MongoInfo {
|
||||
mi := new(MongoInfo)
|
||||
structx.Copy(mi, me)
|
||||
return mi
|
||||
}
|
||||
|
||||
type MongoSshDialer struct {
|
||||
machineId int
|
||||
}
|
||||
|
||||
func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sshConn, err := stm.GetDialConn(network, address); err == nil {
|
||||
// 将ssh conn包装,否则内部部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
|
||||
return &netx.WrapSshConn{Conn: sshConn}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
package entity
|
||||
|
||||
import "mayfly-go/pkg/model"
|
||||
import (
|
||||
"mayfly-go/internal/mongo/mgm"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/utils/structx"
|
||||
)
|
||||
|
||||
type Mongo struct {
|
||||
model.Model
|
||||
@@ -11,3 +15,10 @@ type Mongo struct {
|
||||
TagId uint64 `json:"tagId"`
|
||||
TagPath string `json:"tagPath"`
|
||||
}
|
||||
|
||||
// 转换为mongoInfo进行连接
|
||||
func (me *Mongo) ToMongoInfo() *mgm.MongoInfo {
|
||||
mongoInfo := new(mgm.MongoInfo)
|
||||
structx.Copy(mongoInfo, me)
|
||||
return mongoInfo
|
||||
}
|
||||
|
||||
24
server/internal/mongo/mgm/conn.go
Normal file
24
server/internal/mongo/mgm/conn.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package mgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"mayfly-go/pkg/logx"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
)
|
||||
|
||||
type MongoConn struct {
|
||||
Id string
|
||||
Info *MongoInfo
|
||||
|
||||
Cli *mongo.Client
|
||||
}
|
||||
|
||||
func (mc *MongoConn) Close() {
|
||||
if mc.Cli != nil {
|
||||
if err := mc.Cli.Disconnect(context.Background()); err != nil {
|
||||
logx.Errorf("关闭mongo实例[%s]连接失败: %s", mc.Id, err)
|
||||
}
|
||||
mc.Cli = nil
|
||||
}
|
||||
}
|
||||
72
server/internal/mongo/mgm/conn_cache.go
Normal file
72
server/internal/mongo/mgm/conn_cache.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package mgm
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/common/consts"
|
||||
"mayfly-go/internal/machine/infrastructure/machine"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/logx"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mongo客户端连接缓存,指定时间内没有访问则会被关闭
|
||||
var connCache = cache.NewTimedCache(consts.MongoConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Infof("删除mongo连接缓存: id = %v", key)
|
||||
value.(*MongoConn).Close()
|
||||
})
|
||||
|
||||
func init() {
|
||||
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
|
||||
// 遍历所有mongo连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中...
|
||||
items := connCache.Items()
|
||||
for _, v := range items {
|
||||
if v.Value.(*MongoConn).Info.SshTunnelMachineId == machineId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
var mutex sync.Mutex
|
||||
|
||||
// 从缓存中获取mongo连接信息, 若缓存中不存在则会使用回调函数获取mongoInfo进行连接并缓存
|
||||
func GetMongoConn(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (*MongoConn, error) {
|
||||
connId := getConnId(mongoId)
|
||||
|
||||
// connId不为空,则为需要缓存
|
||||
needCache := connId != ""
|
||||
if needCache {
|
||||
load, ok := connCache.Get(connId)
|
||||
if ok {
|
||||
return load.(*MongoConn), nil
|
||||
}
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// 若缓存中不存在,则从回调函数中获取MongoInfo
|
||||
mi, err := getMongoInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 连接mongo
|
||||
mc, err := mi.Conn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if needCache {
|
||||
connCache.Put(connId, mc)
|
||||
}
|
||||
return mc, nil
|
||||
}
|
||||
|
||||
// 关闭连接,并移除缓存连接
|
||||
func CloseConn(mongoId uint64) {
|
||||
connCache.Delete(mongoId)
|
||||
}
|
||||
79
server/internal/mongo/mgm/info.go
Normal file
79
server/internal/mongo/mgm/info.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package mgm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/netx"
|
||||
"net"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
|
||||
"go.mongodb.org/mongo-driver/mongo"
|
||||
"go.mongodb.org/mongo-driver/mongo/options"
|
||||
)
|
||||
|
||||
type MongoInfo struct {
|
||||
Id uint64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
|
||||
Uri string `json:"-"`
|
||||
|
||||
TagPath string `json:"tagPath"`
|
||||
SshTunnelMachineId int `json:"-"` // ssh隧道机器id
|
||||
}
|
||||
|
||||
func (mi *MongoInfo) Conn() (*MongoConn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
mongoOptions := options.Client().ApplyURI(mi.Uri).
|
||||
SetMaxPoolSize(1)
|
||||
// 启用ssh隧道则连接隧道机器
|
||||
if mi.SshTunnelMachineId > 0 {
|
||||
mongoOptions.SetDialer(&MongoSshDialer{machineId: mi.SshTunnelMachineId})
|
||||
}
|
||||
|
||||
client, err := mongo.Connect(ctx, mongoOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = client.Ping(context.TODO(), nil); err != nil {
|
||||
client.Disconnect(ctx)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logx.Infof("连接mongo: %s", func(str string) string {
|
||||
reg := regexp.MustCompile(`(^mongodb://.+?:)(.+)(@.+$)`)
|
||||
return reg.ReplaceAllString(str, `${1}****${3}`)
|
||||
}(mi.Uri))
|
||||
|
||||
return &MongoConn{Id: getConnId(mi.Id), Info: mi, Cli: client}, nil
|
||||
}
|
||||
|
||||
type MongoSshDialer struct {
|
||||
machineId int
|
||||
}
|
||||
|
||||
func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sshConn, err := stm.GetDialConn(network, address); err == nil {
|
||||
// 将ssh conn包装,否则内部部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
|
||||
return &netx.WrapSshConn{Conn: sshConn}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 生成mongo连接id
|
||||
func getConnId(id uint64) string {
|
||||
if id == 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d", id)
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (r *Redis) Hscan(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
g := rc.GinCtx
|
||||
count := ginx.QueryInt(g, "count", 10)
|
||||
match := g.Query("match")
|
||||
@@ -32,7 +32,7 @@ func (r *Redis) Hscan(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (r *Redis) Hdel(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
field := rc.GinCtx.Query("field")
|
||||
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key, "field", field)
|
||||
@@ -42,7 +42,7 @@ func (r *Redis) Hdel(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (r *Redis) Hget(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
field := rc.GinCtx.Query("field")
|
||||
|
||||
res, err := ri.GetCmdable().HGet(context.TODO(), key, field).Result()
|
||||
@@ -56,7 +56,7 @@ func (r *Redis) Hset(rc *req.Ctx) {
|
||||
ginx.BindJsonAndValid(g, hashValue)
|
||||
|
||||
hv := hashValue.Value[0]
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "hash", hv)
|
||||
|
||||
res, err := ri.GetCmdable().HSet(context.TODO(), hashValue.Key, hv["field"].(string), hv["value"]).Result()
|
||||
@@ -69,7 +69,7 @@ func (r *Redis) SetHashValue(rc *req.Ctx) {
|
||||
hashValue := new(form.HashValue)
|
||||
ginx.BindJsonAndValid(g, hashValue)
|
||||
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "hash", hashValue)
|
||||
cmd := ri.GetCmdable()
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"mayfly-go/internal/redis/api/form"
|
||||
"mayfly-go/internal/redis/api/vo"
|
||||
"mayfly-go/internal/redis/domain/entity"
|
||||
"mayfly-go/internal/redis/rdm"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/ginx"
|
||||
"mayfly-go/pkg/req"
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
|
||||
// scan获取redis的key列表信息
|
||||
func (r *Redis) ScanKeys(rc *req.Ctx) {
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
|
||||
form := &form.RedisScanForm{}
|
||||
ginx.BindJsonAndValid(rc.GinCtx, form)
|
||||
@@ -44,7 +44,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) {
|
||||
|
||||
// 通配符或全匹配
|
||||
mode := ri.Info.Mode
|
||||
if mode == "" || mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel {
|
||||
if mode == "" || mode == rdm.StandaloneMode || mode == rdm.SentinelMode {
|
||||
redisAddr := ri.Cli.Options().Addr
|
||||
cursorRes[redisAddr] = form.Cursor[redisAddr]
|
||||
for {
|
||||
@@ -64,7 +64,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else if mode == entity.RedisModeCluster {
|
||||
} else if mode == rdm.ClusterMode {
|
||||
mu := &sync.Mutex{}
|
||||
// 遍历所有master节点,并执行scan命令,合并keys
|
||||
ri.ClusterCli.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error {
|
||||
@@ -98,7 +98,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (r *Redis) KeyInfo(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
cmd := ri.GetCmdable()
|
||||
ctx := context.Background()
|
||||
ttl, err := cmd.TTL(ctx, key).Result()
|
||||
@@ -120,7 +120,7 @@ func (r *Redis) KeyInfo(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (r *Redis) TtlKey(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
ttl, err := ri.GetCmdable().TTL(context.Background(), key).Result()
|
||||
biz.ErrIsNilAppendErr(err, "ttl失败: %s")
|
||||
|
||||
@@ -132,7 +132,7 @@ func (r *Redis) TtlKey(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (r *Redis) DeleteKey(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key)
|
||||
ri.GetCmdable().Del(context.Background(), key)
|
||||
}
|
||||
@@ -141,7 +141,7 @@ func (r *Redis) RenameKey(rc *req.Ctx) {
|
||||
form := &form.Rename{}
|
||||
ginx.BindJsonAndValid(rc.GinCtx, form)
|
||||
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "rename", form)
|
||||
ri.GetCmdable().Rename(context.Background(), form.Key, form.NewKey)
|
||||
}
|
||||
@@ -150,21 +150,21 @@ func (r *Redis) ExpireKey(rc *req.Ctx) {
|
||||
form := &form.Expire{}
|
||||
ginx.BindJsonAndValid(rc.GinCtx, form)
|
||||
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "expire", form)
|
||||
ri.GetCmdable().Expire(context.Background(), form.Key, time.Duration(form.Seconds)*time.Second)
|
||||
}
|
||||
|
||||
// 移除过期时间
|
||||
func (r *Redis) PersistKey(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key)
|
||||
ri.GetCmdable().Persist(context.Background(), key)
|
||||
}
|
||||
|
||||
// 清空库
|
||||
func (r *Redis) FlushDb(rc *req.Ctx) {
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
rc.ReqParam = ri.Info
|
||||
ri.GetCmdable().FlushDB(context.Background())
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
)
|
||||
|
||||
func (r *Redis) GetListValue(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
ctx := context.TODO()
|
||||
cmdable := ri.GetCmdable()
|
||||
|
||||
@@ -34,7 +34,7 @@ func (r *Redis) Lrem(rc *req.Ctx) {
|
||||
option := new(form.LRemOption)
|
||||
ginx.BindJsonAndValid(g, option)
|
||||
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
res, err := cmd.LRem(context.TODO(), option.Key, int64(option.Count), option.Member).Result()
|
||||
biz.ErrIsNilAppendErr(err, "lrem失败: %s")
|
||||
rc.ResData = res
|
||||
@@ -45,7 +45,7 @@ func (r *Redis) SaveListValue(rc *req.Ctx) {
|
||||
listValue := new(form.ListValue)
|
||||
ginx.BindJsonAndValid(g, listValue)
|
||||
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
|
||||
key := listValue.Key
|
||||
ctx := context.TODO()
|
||||
@@ -59,7 +59,7 @@ func (r *Redis) SetListValue(rc *req.Ctx) {
|
||||
listSetValue := new(form.ListSetValue)
|
||||
ginx.BindJsonAndValid(g, listSetValue)
|
||||
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
|
||||
_, err := ri.GetCmdable().LSet(context.TODO(), listSetValue.Key, listSetValue.Index, listSetValue.Value).Result()
|
||||
biz.ErrIsNilAppendErr(err, "list set失败: %s")
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"mayfly-go/internal/redis/api/vo"
|
||||
"mayfly-go/internal/redis/application"
|
||||
"mayfly-go/internal/redis/domain/entity"
|
||||
"mayfly-go/internal/redis/rdm"
|
||||
tagapp "mayfly-go/internal/tag/application"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/ginx"
|
||||
@@ -86,7 +87,7 @@ func (r *Redis) DeleteRedis(rc *req.Ctx) {
|
||||
|
||||
func (r *Redis) RedisInfo(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
ri, err := r.RedisApp.GetRedisInstance(uint64(ginx.PathParamInt(g, "id")), 0)
|
||||
ri, err := r.RedisApp.GetRedisConn(uint64(ginx.PathParamInt(g, "id")), 0)
|
||||
biz.ErrIsNil(err)
|
||||
|
||||
section := rc.GinCtx.Query("section")
|
||||
@@ -94,9 +95,9 @@ func (r *Redis) RedisInfo(rc *req.Ctx) {
|
||||
ctx := context.Background()
|
||||
var redisCli *redis.Client
|
||||
|
||||
if mode == "" || mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel {
|
||||
if mode == "" || mode == rdm.StandaloneMode || mode == rdm.SentinelMode {
|
||||
redisCli = ri.Cli
|
||||
} else if mode == entity.RedisModeCluster {
|
||||
} else if mode == rdm.ClusterMode {
|
||||
host := rc.GinCtx.Query("host")
|
||||
biz.NotEmpty(host, "集群模式host信息不能为空")
|
||||
clusterClient := ri.ClusterCli
|
||||
@@ -164,9 +165,9 @@ func (r *Redis) RedisInfo(rc *req.Ctx) {
|
||||
|
||||
func (r *Redis) ClusterInfo(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
ri, err := r.RedisApp.GetRedisInstance(uint64(ginx.PathParamInt(g, "id")), 0)
|
||||
ri, err := r.RedisApp.GetRedisConn(uint64(ginx.PathParamInt(g, "id")), 0)
|
||||
biz.ErrIsNil(err)
|
||||
biz.IsEquals(ri.Info.Mode, entity.RedisModeCluster, "非集群模式")
|
||||
biz.IsEquals(ri.Info.Mode, rdm.ClusterMode, "非集群模式")
|
||||
info, _ := ri.ClusterCli.ClusterInfo(context.Background()).Result()
|
||||
nodesStr, _ := ri.ClusterCli.ClusterNodes(context.Background()).Result()
|
||||
|
||||
@@ -208,14 +209,14 @@ func (r *Redis) ClusterInfo(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
// 校验查询参数中的key为必填项,并返回redis实例
|
||||
func (r *Redis) checkKeyAndGetRedisIns(rc *req.Ctx) (*application.RedisInstance, string) {
|
||||
func (r *Redis) checkKeyAndGetRedisConn(rc *req.Ctx) (*rdm.RedisConn, string) {
|
||||
key := rc.GinCtx.Query("key")
|
||||
biz.NotEmpty(key, "key不能为空")
|
||||
return r.getRedisIns(rc), key
|
||||
return r.getRedisConn(rc), key
|
||||
}
|
||||
|
||||
func (r *Redis) getRedisIns(rc *req.Ctx) *application.RedisInstance {
|
||||
ri, err := r.RedisApp.GetRedisInstance(getIdAndDbNum(rc.GinCtx))
|
||||
func (r *Redis) getRedisConn(rc *req.Ctx) *rdm.RedisConn {
|
||||
ri, err := r.RedisApp.GetRedisConn(getIdAndDbNum(rc.GinCtx))
|
||||
biz.ErrIsNil(err)
|
||||
biz.ErrIsNilAppendErr(r.TagApp.CanAccess(rc.LoginAccount.Id, ri.Info.TagPath), "%s")
|
||||
return ri
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (r *Redis) GetSetValue(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
res, err := ri.GetCmdable().SMembers(context.TODO(), key).Result()
|
||||
biz.ErrIsNilAppendErr(err, "获取set值失败: %s")
|
||||
rc.ResData = res
|
||||
@@ -22,7 +22,7 @@ func (r *Redis) SetSetValue(rc *req.Ctx) {
|
||||
keyvalue := new(form.SetValue)
|
||||
ginx.BindJsonAndValid(g, keyvalue)
|
||||
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
|
||||
key := keyvalue.Key
|
||||
// 简单处理->先删除,后新增
|
||||
@@ -35,7 +35,7 @@ func (r *Redis) SetSetValue(rc *req.Ctx) {
|
||||
}
|
||||
|
||||
func (r *Redis) Scard(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
|
||||
total, err := ri.GetCmdable().SCard(context.TODO(), key).Result()
|
||||
biz.ErrIsNilAppendErr(err, "scard失败: %s")
|
||||
@@ -47,7 +47,7 @@ func (r *Redis) Sscan(rc *req.Ctx) {
|
||||
scan := new(form.ScanForm)
|
||||
ginx.BindJsonAndValid(g, scan)
|
||||
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
keys, cursor, err := cmd.SScan(context.TODO(), scan.Key, scan.Cursor, scan.Match, scan.Count).Result()
|
||||
biz.ErrIsNilAppendErr(err, "sscan失败: %s")
|
||||
rc.ResData = collx.M{
|
||||
@@ -60,7 +60,7 @@ func (r *Redis) Sadd(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
option := new(form.SmemberOption)
|
||||
ginx.BindJsonAndValid(g, option)
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
|
||||
res, err := cmd.SAdd(context.TODO(), option.Key, option.Member).Result()
|
||||
biz.ErrIsNilAppendErr(err, "sadd失败: %s")
|
||||
@@ -72,7 +72,7 @@ func (r *Redis) Srem(rc *req.Ctx) {
|
||||
option := new(form.SmemberOption)
|
||||
ginx.BindJsonAndValid(g, option)
|
||||
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
res, err := cmd.SRem(context.TODO(), option.Key, option.Member).Result()
|
||||
biz.ErrIsNilAppendErr(err, "srem失败: %s")
|
||||
rc.ResData = res
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func (r *Redis) GetStringValue(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
str, err := ri.GetCmdable().Get(context.TODO(), key).Result()
|
||||
biz.ErrIsNilAppendErr(err, "获取字符串值失败: %s")
|
||||
rc.ResData = str
|
||||
@@ -22,7 +22,7 @@ func (r *Redis) SetStringValue(rc *req.Ctx) {
|
||||
keyValue := new(form.StringValue)
|
||||
ginx.BindJsonAndValid(g, keyValue)
|
||||
|
||||
ri := r.getRedisIns(rc)
|
||||
ri := r.getRedisConn(rc)
|
||||
cmd := ri.GetCmdable()
|
||||
rc.ReqParam = collx.Kvs("redis", ri.Info, "string", keyValue)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
)
|
||||
|
||||
func (r *Redis) ZCard(rc *req.Ctx) {
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
|
||||
total, err := ri.GetCmdable().ZCard(context.TODO(), key).Result()
|
||||
biz.ErrIsNilAppendErr(err, "zcard失败: %s")
|
||||
@@ -21,7 +21,7 @@ func (r *Redis) ZCard(rc *req.Ctx) {
|
||||
|
||||
func (r *Redis) ZScan(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
|
||||
cursor := uint64(ginx.QueryInt(g, "cursor", 0))
|
||||
match := ginx.Query(g, "match", "*")
|
||||
@@ -37,7 +37,7 @@ func (r *Redis) ZScan(rc *req.Ctx) {
|
||||
|
||||
func (r *Redis) ZRevRange(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
ri, key := r.checkKeyAndGetRedisIns(rc)
|
||||
ri, key := r.checkKeyAndGetRedisConn(rc)
|
||||
start := ginx.QueryInt(g, "start", 0)
|
||||
stop := ginx.QueryInt(g, "stop", 50)
|
||||
|
||||
@@ -51,7 +51,7 @@ func (r *Redis) ZRem(rc *req.Ctx) {
|
||||
option := new(form.SmemberOption)
|
||||
ginx.BindJsonAndValid(g, option)
|
||||
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
res, err := cmd.ZRem(context.TODO(), option.Key, option.Member).Result()
|
||||
biz.ErrIsNilAppendErr(err, "zrem失败: %s")
|
||||
rc.ResData = res
|
||||
@@ -62,7 +62,7 @@ func (r *Redis) ZAdd(rc *req.Ctx) {
|
||||
option := new(form.ZAddOption)
|
||||
ginx.BindJsonAndValid(g, option)
|
||||
|
||||
cmd := r.getRedisIns(rc).GetCmdable()
|
||||
cmd := r.getRedisConn(rc).GetCmdable()
|
||||
zm := redis.Z{
|
||||
Score: option.Score,
|
||||
Member: option.Member,
|
||||
|
||||
@@ -1,26 +1,14 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/consts"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/internal/machine/infrastructure/machine"
|
||||
"mayfly-go/internal/redis/domain/entity"
|
||||
"mayfly-go/internal/redis/domain/repository"
|
||||
"mayfly-go/internal/redis/rdm"
|
||||
"mayfly-go/pkg/base"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/utils/netx"
|
||||
"mayfly-go/pkg/utils/structx"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type Redis interface {
|
||||
@@ -31,7 +19,7 @@ type Redis interface {
|
||||
|
||||
Count(condition *entity.RedisQuery) int64
|
||||
|
||||
Save(entity *entity.Redis) error
|
||||
Save(re *entity.Redis) error
|
||||
|
||||
// 删除数据库信息
|
||||
Delete(id uint64) error
|
||||
@@ -39,7 +27,10 @@ type Redis interface {
|
||||
// 获取数据库连接实例
|
||||
// id: 数据库实例id
|
||||
// db: 库号
|
||||
GetRedisInstance(id uint64, db int) (*RedisInstance, error)
|
||||
GetRedisConn(id uint64, db int) (*rdm.RedisConn, error)
|
||||
|
||||
// 测试连接
|
||||
TestConn(re *entity.Redis) error
|
||||
}
|
||||
|
||||
func newRedisApp(redisRepo repository.Redis) Redis {
|
||||
@@ -52,7 +43,7 @@ type redisAppImpl struct {
|
||||
base.AppImpl[*entity.Redis, repository.Redis]
|
||||
}
|
||||
|
||||
// 分页获取机器脚本信息列表
|
||||
// 分页获取redis列表
|
||||
func (r *redisAppImpl) GetPageList(condition *entity.RedisQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
|
||||
return r.GetRepo().GetRedisList(condition, pageParam, toEntity, orderBy...)
|
||||
}
|
||||
@@ -64,7 +55,7 @@ func (r *redisAppImpl) Count(condition *entity.RedisQuery) int64 {
|
||||
func (r *redisAppImpl) Save(re *entity.Redis) error {
|
||||
// ’修改信息且密码不为空‘ or ‘新增’需要测试是否可连接
|
||||
if (re.Id != 0 && re.Password != "") || re.Id == 0 {
|
||||
if err := TestRedisConnection(re); err != nil {
|
||||
if err := r.TestConn(re); err != nil {
|
||||
return errorx.NewBiz("Redis连接失败: %s", err.Error())
|
||||
}
|
||||
}
|
||||
@@ -92,7 +83,7 @@ func (r *redisAppImpl) Save(re *entity.Redis) error {
|
||||
if oldRedis.Db != re.Db || oldRedis.SshTunnelMachineId != re.SshTunnelMachineId {
|
||||
for _, dbStr := range strings.Split(oldRedis.Db, ",") {
|
||||
db, _ := strconv.Atoi(dbStr)
|
||||
CloseRedis(re.Id, db)
|
||||
rdm.CloseConn(re.Id, db)
|
||||
}
|
||||
}
|
||||
re.PwdEncrypt()
|
||||
@@ -108,253 +99,35 @@ func (r *redisAppImpl) Delete(id uint64) error {
|
||||
// 如果存在连接,则关闭所有库连接信息
|
||||
for _, dbStr := range strings.Split(re.Db, ",") {
|
||||
db, _ := strconv.Atoi(dbStr)
|
||||
CloseRedis(re.Id, db)
|
||||
rdm.CloseConn(re.Id, db)
|
||||
}
|
||||
return r.DeleteById(id)
|
||||
}
|
||||
|
||||
// 获取数据库连接实例
|
||||
func (r *redisAppImpl) GetRedisInstance(id uint64, db int) (*RedisInstance, error) {
|
||||
// Id不为0,则为需要缓存
|
||||
needCache := id != 0
|
||||
if needCache {
|
||||
load, ok := redisCache.Get(getRedisCacheKey(id, db))
|
||||
if ok {
|
||||
return load.(*RedisInstance), nil
|
||||
}
|
||||
}
|
||||
// 缓存不存在,则回调获取redis信息
|
||||
re, err := r.GetById(new(entity.Redis), id)
|
||||
if err != nil {
|
||||
return nil, errorx.NewBiz("redis信息不存在")
|
||||
}
|
||||
re.PwdDecrypt()
|
||||
|
||||
redisMode := re.Mode
|
||||
var ri *RedisInstance
|
||||
if redisMode == "" || redisMode == entity.RedisModeStandalone {
|
||||
ri = getRedisCient(re, db)
|
||||
// 测试连接
|
||||
_, e := ri.Cli.Ping(context.Background()).Result()
|
||||
if e != nil {
|
||||
ri.Close()
|
||||
return nil, errorx.NewBiz("redis连接失败: %s", e.Error())
|
||||
}
|
||||
} else if redisMode == entity.RedisModeCluster {
|
||||
ri = getRedisClusterClient(re)
|
||||
// 测试连接
|
||||
_, e := ri.ClusterCli.Ping(context.Background()).Result()
|
||||
if e != nil {
|
||||
ri.Close()
|
||||
return nil, errorx.NewBiz("redis集群连接失败: %s", e.Error())
|
||||
}
|
||||
} else if redisMode == entity.RedisModeSentinel {
|
||||
ri = getRedisSentinelCient(re, db)
|
||||
// 测试连接
|
||||
_, e := ri.Cli.Ping(context.Background()).Result()
|
||||
if e != nil {
|
||||
ri.Close()
|
||||
return nil, errorx.NewBiz("redis sentinel连接失败: %s", e.Error())
|
||||
}
|
||||
}
|
||||
|
||||
logx.Infof("连接redis: %s/%d", re.Host, db)
|
||||
if needCache {
|
||||
redisCache.Put(getRedisCacheKey(id, db), ri)
|
||||
}
|
||||
return ri, nil
|
||||
}
|
||||
|
||||
// 生成redis连接缓存key
|
||||
func getRedisCacheKey(id uint64, db int) string {
|
||||
return fmt.Sprintf("%d/%d", id, db)
|
||||
}
|
||||
|
||||
func toRedisInfo(re *entity.Redis, db int) *RedisInfo {
|
||||
redisInfo := new(RedisInfo)
|
||||
structx.Copy(redisInfo, re)
|
||||
redisInfo.Db = db
|
||||
return redisInfo
|
||||
}
|
||||
|
||||
func getRedisCient(re *entity.Redis, db int) *RedisInstance {
|
||||
ri := &RedisInstance{Id: getRedisCacheKey(re.Id, db), Info: toRedisInfo(re, db)}
|
||||
|
||||
redisOptions := &redis.Options{
|
||||
Addr: re.Host,
|
||||
Username: re.Username,
|
||||
Password: re.Password, // no password set
|
||||
DB: db, // use default DB
|
||||
DialTimeout: 8 * time.Second,
|
||||
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
|
||||
WriteTimeout: -1,
|
||||
}
|
||||
if re.SshTunnelMachineId > 0 {
|
||||
redisOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
|
||||
}
|
||||
ri.Cli = redis.NewClient(redisOptions)
|
||||
return ri
|
||||
}
|
||||
|
||||
func getRedisClusterClient(re *entity.Redis) *RedisInstance {
|
||||
ri := &RedisInstance{Id: getRedisCacheKey(re.Id, 0), Info: toRedisInfo(re, 0)}
|
||||
|
||||
redisClusterOptions := &redis.ClusterOptions{
|
||||
Addrs: strings.Split(re.Host, ","),
|
||||
Username: re.Username,
|
||||
Password: re.Password,
|
||||
DialTimeout: 8 * time.Second,
|
||||
}
|
||||
if re.SshTunnelMachineId > 0 {
|
||||
redisClusterOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
|
||||
}
|
||||
ri.ClusterCli = redis.NewClusterClient(redisClusterOptions)
|
||||
return ri
|
||||
}
|
||||
|
||||
func getRedisSentinelCient(re *entity.Redis, db int) *RedisInstance {
|
||||
ri := &RedisInstance{Id: getRedisCacheKey(re.Id, db), Info: toRedisInfo(re, db)}
|
||||
// sentinel模式host为 masterName=host:port,host:port
|
||||
masterNameAndHosts := strings.Split(re.Host, "=")
|
||||
sentinelOptions := &redis.FailoverOptions{
|
||||
MasterName: masterNameAndHosts[0],
|
||||
SentinelAddrs: strings.Split(masterNameAndHosts[1], ","),
|
||||
Username: re.Username,
|
||||
Password: re.Password, // no password set
|
||||
SentinelPassword: re.Password, // 哨兵节点密码需与redis节点密码一致
|
||||
DB: db, // use default DB
|
||||
DialTimeout: 8 * time.Second,
|
||||
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
|
||||
WriteTimeout: -1,
|
||||
}
|
||||
if re.SshTunnelMachineId > 0 {
|
||||
sentinelOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
|
||||
}
|
||||
ri.Cli = redis.NewFailoverClient(sentinelOptions)
|
||||
return ri
|
||||
}
|
||||
|
||||
func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId)
|
||||
func (r *redisAppImpl) GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) {
|
||||
return rdm.GetRedisConn(id, db, func() (*rdm.RedisInfo, error) {
|
||||
// 缓存不存在,则回调获取redis信息
|
||||
re, err := r.GetById(new(entity.Redis), id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errorx.NewBiz("redis信息不存在")
|
||||
}
|
||||
re.PwdDecrypt()
|
||||
|
||||
if sshConn, err := sshTunnel.GetDialConn(network, addr); err == nil {
|
||||
// 将ssh conn包装,否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
|
||||
return &netx.WrapSshConn{Conn: sshConn}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// redis客户端连接缓存,指定时间内没有访问则会被关闭
|
||||
var redisCache = cache.NewTimedCache(consts.RedisConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Info(fmt.Sprintf("删除redis连接缓存 id = %s", key))
|
||||
value.(*RedisInstance).Close()
|
||||
})
|
||||
|
||||
// 移除redis连接缓存并关闭redis连接
|
||||
func CloseRedis(id uint64, db int) {
|
||||
redisCache.Delete(getRedisCacheKey(id, db))
|
||||
}
|
||||
|
||||
func init() {
|
||||
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
|
||||
// 遍历所有redis连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中...
|
||||
items := redisCache.Items()
|
||||
for _, v := range items {
|
||||
if v.Value.(*RedisInstance).Info.SshTunnelMachineId == machineId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
return re.ToRedisInfo(db), nil
|
||||
})
|
||||
}
|
||||
|
||||
func TestRedisConnection(re *entity.Redis) error {
|
||||
var cmd redis.Cmdable
|
||||
// 取第一个库测试连接即可
|
||||
dbStr := strings.Split(re.Db, ",")[0]
|
||||
db, _ := strconv.Atoi(dbStr)
|
||||
if re.Mode == "" || re.Mode == entity.RedisModeStandalone {
|
||||
rcli := getRedisCient(re, db)
|
||||
defer rcli.Close()
|
||||
cmd = rcli.Cli
|
||||
} else if re.Mode == entity.RedisModeCluster {
|
||||
ccli := getRedisClusterClient(re)
|
||||
defer ccli.Close()
|
||||
cmd = ccli.ClusterCli
|
||||
} else if re.Mode == entity.RedisModeSentinel {
|
||||
rcli := getRedisSentinelCient(re, db)
|
||||
defer rcli.Close()
|
||||
cmd = rcli.Cli
|
||||
func (r *redisAppImpl) TestConn(re *entity.Redis) error {
|
||||
db := 0
|
||||
if re.Db != "" {
|
||||
db, _ = strconv.Atoi(strings.Split(re.Db, ",")[0])
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
_, e := cmd.Ping(context.Background()).Result()
|
||||
return e
|
||||
}
|
||||
|
||||
type RedisInfo struct {
|
||||
Id uint64 `json:"id"`
|
||||
Host string `json:"host"`
|
||||
Db int `json:"db"` // 库号
|
||||
TagPath string `json:"tagPath"`
|
||||
Mode string `json:"-"`
|
||||
Name string `json:"-"`
|
||||
|
||||
SshTunnelMachineId int `json:"-"`
|
||||
}
|
||||
|
||||
// 获取记录日志的描述
|
||||
func (r *RedisInfo) GetLogDesc() string {
|
||||
return fmt.Sprintf("Redis[id=%d, tag=%s, host=%s, db=%d]", r.Id, r.TagPath, r.Host, r.Db)
|
||||
}
|
||||
|
||||
// redis实例
|
||||
type RedisInstance struct {
|
||||
Id string
|
||||
Info *RedisInfo
|
||||
|
||||
Cli *redis.Client
|
||||
ClusterCli *redis.ClusterClient
|
||||
}
|
||||
|
||||
// 获取命令执行接口的具体实现
|
||||
func (r *RedisInstance) GetCmdable() redis.Cmdable {
|
||||
redisMode := r.Info.Mode
|
||||
if redisMode == "" || redisMode == entity.RedisModeStandalone || r.Info.Mode == entity.RedisModeSentinel {
|
||||
return r.Cli
|
||||
}
|
||||
if redisMode == entity.RedisModeCluster {
|
||||
return r.ClusterCli
|
||||
rc, err := re.ToRedisInfo(db).Conn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisInstance) Scan(cursor uint64, match string, count int64) ([]string, uint64, error) {
|
||||
return r.GetCmdable().Scan(context.Background(), cursor, match, count).Result()
|
||||
}
|
||||
|
||||
func (r *RedisInstance) Close() {
|
||||
mode := r.Info.Mode
|
||||
if mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel {
|
||||
if err := r.Cli.Close(); err != nil {
|
||||
logx.Errorf("关闭redis单机实例[%s]连接失败: %s", r.Id, err.Error())
|
||||
}
|
||||
r.Cli = nil
|
||||
}
|
||||
if mode == entity.RedisModeCluster {
|
||||
if err := r.ClusterCli.Close(); err != nil {
|
||||
logx.Errorf("关闭redis集群实例[%s]连接失败: %s", r.Id, err.Error())
|
||||
}
|
||||
r.ClusterCli = nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,9 @@ package entity
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/common/utils"
|
||||
"mayfly-go/internal/redis/rdm"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/utils/structx"
|
||||
)
|
||||
|
||||
type Redis struct {
|
||||
@@ -20,12 +22,6 @@ type Redis struct {
|
||||
TagPath string
|
||||
}
|
||||
|
||||
const (
|
||||
RedisModeStandalone = "standalone"
|
||||
RedisModeCluster = "cluster"
|
||||
RedisModeSentinel = "sentinel"
|
||||
)
|
||||
|
||||
func (r *Redis) PwdEncrypt() {
|
||||
// 密码替换为加密后的密码
|
||||
r.Password = utils.PwdAesEncrypt(r.Password)
|
||||
@@ -35,3 +31,11 @@ func (r *Redis) PwdDecrypt() {
|
||||
// 密码替换为解密后的密码
|
||||
r.Password = utils.PwdAesDecrypt(r.Password)
|
||||
}
|
||||
|
||||
// 转换为redisInfo进行连接
|
||||
func (re *Redis) ToRedisInfo(db int) *rdm.RedisInfo {
|
||||
redisInfo := new(rdm.RedisInfo)
|
||||
structx.Copy(redisInfo, re)
|
||||
redisInfo.Db = db
|
||||
return redisInfo
|
||||
}
|
||||
|
||||
51
server/internal/redis/rdm/conn.go
Normal file
51
server/internal/redis/rdm/conn.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package rdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"mayfly-go/pkg/logx"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// redis连接信息
|
||||
type RedisConn struct {
|
||||
Id string
|
||||
Info *RedisInfo
|
||||
|
||||
Cli *redis.Client
|
||||
ClusterCli *redis.ClusterClient
|
||||
}
|
||||
|
||||
// 获取命令执行接口的具体实现
|
||||
func (r *RedisConn) GetCmdable() redis.Cmdable {
|
||||
redisMode := r.Info.Mode
|
||||
if redisMode == "" || redisMode == StandaloneMode || r.Info.Mode == SentinelMode {
|
||||
return r.Cli
|
||||
}
|
||||
if redisMode == ClusterMode {
|
||||
return r.ClusterCli
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RedisConn) Scan(cursor uint64, match string, count int64) ([]string, uint64, error) {
|
||||
return r.GetCmdable().Scan(context.Background(), cursor, match, count).Result()
|
||||
}
|
||||
|
||||
func (r *RedisConn) Close() {
|
||||
mode := r.Info.Mode
|
||||
if mode == StandaloneMode || mode == SentinelMode {
|
||||
if err := r.Cli.Close(); err != nil {
|
||||
logx.Errorf("关闭redis单机实例[%s]连接失败: %s", r.Id, err.Error())
|
||||
}
|
||||
r.Cli = nil
|
||||
return
|
||||
}
|
||||
|
||||
if mode == ClusterMode {
|
||||
if err := r.ClusterCli.Close(); err != nil {
|
||||
logx.Errorf("关闭redis集群实例[%s]连接失败: %s", r.Id, err.Error())
|
||||
}
|
||||
r.ClusterCli = nil
|
||||
}
|
||||
}
|
||||
73
server/internal/redis/rdm/conn_cache.go
Normal file
73
server/internal/redis/rdm/conn_cache.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package rdm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/consts"
|
||||
"mayfly-go/internal/machine/infrastructure/machine"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/logx"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// redis客户端连接缓存,指定时间内没有访问则会被关闭
|
||||
var connCache = cache.NewTimedCache(consts.RedisConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Info(fmt.Sprintf("删除redis连接缓存 id = %s", key))
|
||||
value.(*RedisConn).Close()
|
||||
})
|
||||
|
||||
func init() {
|
||||
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
|
||||
// 遍历所有redis连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中...
|
||||
items := connCache.Items()
|
||||
for _, v := range items {
|
||||
if v.Value.(*RedisConn).Info.SshTunnelMachineId == machineId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
})
|
||||
}
|
||||
|
||||
var mutex sync.Mutex
|
||||
|
||||
// 从缓存中获取redis连接信息, 若缓存中不存在则会使用回调函数获取redisInfo进行连接并缓存
|
||||
func GetRedisConn(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (*RedisConn, error) {
|
||||
connId := getConnId(redisId, db)
|
||||
|
||||
// connId不为空,则为需要缓存
|
||||
needCache := connId != ""
|
||||
if needCache {
|
||||
load, ok := connCache.Get(connId)
|
||||
if ok {
|
||||
return load.(*RedisConn), nil
|
||||
}
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
// 若缓存中不存在,则从回调函数中获取RedisInfo
|
||||
ri, err := getRedisInfo()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
rc, err := ri.Conn()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if needCache {
|
||||
connCache.Put(connId, rc)
|
||||
}
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
// 移除redis连接缓存并关闭redis连接
|
||||
func CloseConn(id uint64, db int) {
|
||||
connCache.Delete(getConnId(id, db))
|
||||
}
|
||||
161
server/internal/redis/rdm/info.go
Normal file
161
server/internal/redis/rdm/info.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package rdm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/netx"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
type RedisMode string
|
||||
|
||||
const (
|
||||
StandaloneMode RedisMode = "standalone"
|
||||
ClusterMode RedisMode = "cluster"
|
||||
SentinelMode RedisMode = "sentinel"
|
||||
)
|
||||
|
||||
type RedisInfo struct {
|
||||
Id uint64 `json:"id"`
|
||||
|
||||
Host string `json:"host"`
|
||||
Db int `json:"db"` // 库号
|
||||
Mode RedisMode `json:"-"`
|
||||
Username string `json:"-"`
|
||||
Password string `json:"-"`
|
||||
|
||||
Name string `json:"-"`
|
||||
TagPath string `json:"tagPath"`
|
||||
SshTunnelMachineId int `json:"-"`
|
||||
}
|
||||
|
||||
func (r *RedisInfo) Conn() (*RedisConn, error) {
|
||||
redisMode := r.Mode
|
||||
if redisMode == StandaloneMode {
|
||||
return r.connStandalone()
|
||||
}
|
||||
if redisMode == ClusterMode {
|
||||
return r.connCluster()
|
||||
}
|
||||
if redisMode == SentinelMode {
|
||||
return r.connSentinel()
|
||||
}
|
||||
|
||||
return nil, errorx.NewBiz("redis mode error")
|
||||
}
|
||||
|
||||
func (re *RedisInfo) connStandalone() (*RedisConn, error) {
|
||||
redisOptions := &redis.Options{
|
||||
Addr: re.Host,
|
||||
Username: re.Username,
|
||||
Password: re.Password, // no password set
|
||||
DB: re.Db, // use default DB
|
||||
DialTimeout: 8 * time.Second,
|
||||
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
|
||||
WriteTimeout: -1,
|
||||
}
|
||||
if re.SshTunnelMachineId > 0 {
|
||||
redisOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
|
||||
}
|
||||
|
||||
cli := redis.NewClient(redisOptions)
|
||||
_, e := cli.Ping(context.Background()).Result()
|
||||
if e != nil {
|
||||
cli.Close()
|
||||
return nil, errorx.NewBiz("redis连接失败: %s", e.Error())
|
||||
}
|
||||
|
||||
logx.Infof("连接redis standalone: %s/%d", re.Host, re.Db)
|
||||
|
||||
rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re}
|
||||
rc.Cli = cli
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func (re *RedisInfo) connCluster() (*RedisConn, error) {
|
||||
redisClusterOptions := &redis.ClusterOptions{
|
||||
Addrs: strings.Split(re.Host, ","),
|
||||
Username: re.Username,
|
||||
Password: re.Password,
|
||||
DialTimeout: 8 * time.Second,
|
||||
}
|
||||
if re.SshTunnelMachineId > 0 {
|
||||
redisClusterOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
|
||||
}
|
||||
cli := redis.NewClusterClient(redisClusterOptions)
|
||||
// 测试连接
|
||||
_, e := cli.Ping(context.Background()).Result()
|
||||
if e != nil {
|
||||
cli.Close()
|
||||
return nil, errorx.NewBiz("redis集群连接失败: %s", e.Error())
|
||||
}
|
||||
|
||||
logx.Infof("连接redis cluster: %s/%d", re.Host, re.Db)
|
||||
|
||||
rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re}
|
||||
rc.ClusterCli = cli
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func (re *RedisInfo) connSentinel() (*RedisConn, error) {
|
||||
// sentinel模式host为 masterName=host:port,host:port
|
||||
masterNameAndHosts := strings.Split(re.Host, "=")
|
||||
sentinelOptions := &redis.FailoverOptions{
|
||||
MasterName: masterNameAndHosts[0],
|
||||
SentinelAddrs: strings.Split(masterNameAndHosts[1], ","),
|
||||
Username: re.Username,
|
||||
Password: re.Password, // no password set
|
||||
SentinelPassword: re.Password, // 哨兵节点密码需与redis节点密码一致
|
||||
DB: re.Db, // use default DB
|
||||
DialTimeout: 8 * time.Second,
|
||||
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
|
||||
WriteTimeout: -1,
|
||||
}
|
||||
if re.SshTunnelMachineId > 0 {
|
||||
sentinelOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
|
||||
}
|
||||
cli := redis.NewFailoverClient(sentinelOptions)
|
||||
|
||||
_, e := cli.Ping(context.Background()).Result()
|
||||
if e != nil {
|
||||
cli.Close()
|
||||
return nil, errorx.NewBiz("redis sentinel连接失败: %s", e.Error())
|
||||
}
|
||||
|
||||
logx.Infof("连接redis sentinel: %s/%d", re.Host, re.Db)
|
||||
|
||||
rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re}
|
||||
rc.Cli = cli
|
||||
return rc, nil
|
||||
}
|
||||
|
||||
func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return func(_ context.Context, network, addr string) (net.Conn, error) {
|
||||
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if sshConn, err := sshTunnel.GetDialConn(network, addr); err == nil {
|
||||
// 将ssh conn包装,否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
|
||||
return &netx.WrapSshConn{Conn: sshConn}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成redis连接id
|
||||
func getConnId(id uint64, db int) string {
|
||||
if id == 0 {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%d/%d", id, db)
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func T2022() *gormigrate.Migration {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.AutoMigrate(&entity2.Instance{}); err != nil {
|
||||
if err := tx.AutoMigrate(&entity2.DbInstance{}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := tx.AutoMigrate(&entity2.Db{}); err != nil {
|
||||
|
||||
@@ -10,9 +10,15 @@ type App[T any] interface {
|
||||
// 新增一个实体
|
||||
Insert(e T) error
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
InsertWithDb(db *gorm.DB, e T) error
|
||||
|
||||
// 批量新增实体
|
||||
BatchInsert(models []T) error
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
BatchInsertWithDb(db *gorm.DB, es []T) error
|
||||
|
||||
// 根据实体id更新实体信息
|
||||
UpdateById(e T) error
|
||||
|
||||
@@ -70,11 +76,21 @@ func (ai *AppImpl[T, R]) Insert(e T) error {
|
||||
return ai.GetRepo().Insert(e)
|
||||
}
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
func (ai *AppImpl[T, R]) InsertWithDb(db *gorm.DB, e T) error {
|
||||
return ai.GetRepo().InsertWithDb(db, e)
|
||||
}
|
||||
|
||||
// 批量新增实体 (单纯新增,不做其他业务逻辑处理)
|
||||
func (ai *AppImpl[T, R]) BatchInsert(es []T) error {
|
||||
return ai.GetRepo().BatchInsert(es)
|
||||
}
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
func (ai *AppImpl[T, R]) BatchInsertWithDb(db *gorm.DB, es []T) error {
|
||||
return ai.GetRepo().BatchInsertWithDb(db, es)
|
||||
}
|
||||
|
||||
// 根据实体id更新实体信息 (单纯更新,不做其他业务逻辑处理)
|
||||
func (ai *AppImpl[T, R]) UpdateById(e T) error {
|
||||
return ai.GetRepo().UpdateById(e)
|
||||
|
||||
@@ -109,6 +109,7 @@ func ErrorRes(g *gin.Context, err any) {
|
||||
g.JSON(http.StatusOK, model.ServerError())
|
||||
default:
|
||||
logx.Errorf("未知错误: %v", t)
|
||||
g.JSON(http.StatusOK, model.ServerError())
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user