refactor: db/redis/mongo连接代码包独立

This commit is contained in:
meilin.huang
2023-10-27 17:41:45 +08:00
parent a1303b52eb
commit 12f63ef3dd
45 changed files with 1112 additions and 950 deletions

View File

@@ -11,7 +11,7 @@
"dependencies": {
"@element-plus/icons-vue": "^2.1.0",
"asciinema-player": "^3.6.2",
"axios": "^1.5.1",
"axios": "^1.6.0",
"countup.js": "^2.7.0",
"cropperjs": "^1.5.11",
"echarts": "^5.4.3",

View File

@@ -19,7 +19,7 @@
:disabled="form.id !== undefined"
remote
:remote-method="getInstances"
@change="getAllDatabase"
@change="changeInstance"
v-model="form.instanceId"
placeholder="请输入实例名称搜索并选择实例"
filterable
@@ -163,6 +163,11 @@ watch(props, (newValue: any) => {
}
});
const changeInstance = () => {
state.databaseList = [];
getAllDatabase();
};
/**
* 改变表单中的数据库字段,方便表单错误提示。如全部删光,可提示请添加数据库
*/
@@ -171,8 +176,6 @@ const changeDatabase = () => {
};
const getAllDatabase = async () => {
// 清空数据库列表,可能已经有选择库了
state.databaseList = [];
if (state.form.instanceId > 0) {
state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId });
}

View File

@@ -628,10 +628,10 @@ asynckit@^0.4.0:
resolved "https://registry.npmmirror.com/asynckit/-/asynckit-0.4.0.tgz"
integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==
axios@^1.5.1:
version "1.5.1"
resolved "https://registry.npmmirror.com/axios/-/axios-1.5.1.tgz#11fbaa11fc35f431193a9564109c88c1f27b585f"
integrity sha512-Q28iYCWzNHjAm+yEAot5QaAMxhMghWLFVf7rRdwhUI+c2jix2DUXjAHXVi+s1ibs3mjPO/cCgbA++3BjD0vP/A==
axios@^1.6.0:
version "1.6.0"
resolved "https://registry.npmmirror.com/axios/-/axios-1.6.0.tgz#f1e5292f26b2fd5c2e66876adc5b06cdbd7d2102"
integrity sha512-EZ1DYihju9pwVB+jg67ogm+Tmqc6JmhamRN6I4Zt8DfZu5lbcQGw3ozH9lFejSJgs/ibaef3A9PMXPLeefFGJg==
dependencies:
follow-redirects "^1.15.0"
form-data "^4.0.0"

View File

@@ -6,7 +6,7 @@ const (
AdminId = 1
MachineConnExpireTime = 60 * time.Minute
DbConnExpireTime = 45 * time.Minute
DbConnExpireTime = 120 * time.Minute
RedisConnExpireTime = 30 * time.Minute
MongoConnExpireTime = 30 * time.Minute

View File

@@ -6,6 +6,7 @@ import (
"mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo"
"mayfly-go/internal/db/application"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
msgapp "mayfly-go/internal/msg/application"
msgdto "mayfly-go/internal/msg/application/dto"
@@ -82,14 +83,14 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
}
}
func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
dc, err := d.DbApp.GetDbConnection(getDbId(g), getDbName(g))
func (d *Db) getDbConn(g *gin.Context) *dbm.DbConn {
dc, err := d.DbApp.GetDbConn(getDbId(g), getDbName(g))
biz.ErrIsNil(err)
return dc
}
func (d *Db) TableInfos(rc *req.Ctx) {
res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetTableInfos()
res, err := d.getDbConn(rc.GinCtx).GetMeta().GetTableInfos()
biz.ErrIsNilAppendErr(err, "获取表信息失败: %s")
rc.ResData = res
}
@@ -97,7 +98,7 @@ func (d *Db) TableInfos(rc *req.Ctx) {
func (d *Db) TableIndex(rc *req.Ctx) {
tn := rc.GinCtx.Query("tableName")
biz.NotEmpty(tn, "tableName不能为空")
res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetTableIndex(tn)
res, err := d.getDbConn(rc.GinCtx).GetMeta().GetTableIndex(tn)
biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s")
rc.ResData = res
}
@@ -105,7 +106,7 @@ func (d *Db) TableIndex(rc *req.Ctx) {
func (d *Db) GetCreateTableDdl(rc *req.Ctx) {
tn := rc.GinCtx.Query("tableName")
biz.NotEmpty(tn, "tableName不能为空")
res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetCreateTableDdl(tn)
res, err := d.getDbConn(rc.GinCtx).GetMeta().GetCreateTableDdl(tn)
biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s")
rc.ResData = res
}
@@ -116,7 +117,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
ginx.BindJsonAndValid(g, form)
dbId := getDbId(g)
dbConn, err := d.DbApp.GetDbConnection(dbId, form.Db)
dbConn, err := d.DbApp.GetDbConn(dbId, form.Db)
biz.ErrIsNil(err)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
@@ -188,7 +189,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
dbName := getDbName(g)
clientId := g.Query("clientId")
dbConn, err := d.DbApp.GetDbConnection(dbId, dbName)
dbConn, err := d.DbApp.GetDbConn(dbId, dbName)
biz.ErrIsNil(err)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
@@ -256,7 +257,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
if !ok {
logx.Warnf("sql解析失败: %s", sql)
}
dbConn, err = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String())
dbConn, err = d.DbApp.GetDbConn(dbId, stmtUse.DBName.String())
biz.ErrIsNil(err)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
execReq.DbConn = dbConn
@@ -334,7 +335,7 @@ func (d *Db) DumpSql(rc *req.Ctx) {
}
func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool) {
dbConn, err := d.DbApp.GetDbConnection(dbId, dbName)
dbConn, err := d.DbApp.GetDbConn(dbId, dbName)
biz.ErrIsNil(err)
writer.WriteString("\n-- ----------------------------")
writer.WriteString("\n-- 导出平台: mayfly-go")
@@ -397,7 +398,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
// @router /api/db/:dbId/t-metadata [get]
func (d *Db) TableMA(rc *req.Ctx) {
dbi := d.getDbConnection(rc.GinCtx)
dbi := d.getDbConn(rc.GinCtx)
res, err := dbi.GetMeta().GetTables()
biz.ErrIsNilAppendErr(err, "获取表基础信息失败: %s")
rc.ResData = res
@@ -409,7 +410,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) {
tn := g.Query("tableName")
biz.NotEmpty(tn, "tableName不能为空")
dbi := d.getDbConnection(rc.GinCtx)
dbi := d.getDbConn(rc.GinCtx)
res, err := dbi.GetMeta().GetColumns(tn)
biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s")
rc.ResData = res
@@ -417,7 +418,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) {
// @router /api/db/:dbId/hint-tables [get]
func (d *Db) HintTables(rc *req.Ctx) {
dbi := d.getDbConnection(rc.GinCtx)
dbi := d.getDbConn(rc.GinCtx)
dm := dbi.GetMeta()
// 获取所有表

View File

@@ -33,7 +33,7 @@ func (d *Instance) Instances(rc *req.Ctx) {
// @router /api/instances [post]
func (d *Instance) SaveInstance(rc *req.Ctx) {
form := &form.InstanceForm{}
instance := ginx.BindJsonAndCopyTo[*entity.Instance](rc.GinCtx, form, new(entity.Instance))
instance := ginx.BindJsonAndCopyTo[*entity.DbInstance](rc.GinCtx, form, new(entity.DbInstance))
// 密码解密,并使用解密后的赋值
originPwd, err := cryptox.DefaultRsaDecrypt(form.Password, true)
@@ -52,7 +52,7 @@ func (d *Instance) SaveInstance(rc *req.Ctx) {
// @router /api/instances/:instance [GET]
func (d *Instance) GetInstance(rc *req.Ctx) {
dbId := getInstanceId(rc.GinCtx)
dbEntity, err := d.InstanceApp.GetById(new(entity.Instance), dbId)
dbEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), dbId)
biz.ErrIsNil(err, "获取数据库实例错误")
dbEntity.Password = ""
rc.ResData = dbEntity
@@ -62,7 +62,7 @@ func (d *Instance) GetInstance(rc *req.Ctx) {
// @router /api/instances/:instance/pwd [GET]
func (d *Instance) GetInstancePwd(rc *req.Ctx) {
instanceId := getInstanceId(rc.GinCtx)
instanceEntity, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "Password")
instanceEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
biz.ErrIsNil(err, "获取数据库实例错误")
instanceEntity.PwdDecrypt()
rc.ResData = instanceEntity.Password
@@ -80,7 +80,7 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) {
biz.ErrIsNilAppendErr(err, "string类型转换为int异常: %s")
instanceId := uint64(value)
if d.DbApp.Count(&entity.DbQuery{InstanceId: instanceId}) != 0 {
instance, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "name")
instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "name")
biz.ErrIsNil(err, "获取数据库实例错误数据库实例ID为: %d", instance.Id)
biz.IsTrue(false, "不能删除数据库实例【%s】请先删除其关联的数据库资源。", instance.Name)
}
@@ -91,7 +91,7 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) {
// 获取数据库实例的所有数据库名
func (d *Instance) GetDatabaseNames(rc *req.Ctx) {
instanceId := getInstanceId(rc.GinCtx)
instance, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "Password")
instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password")
biz.ErrIsNil(err, "获取数据库实例错误")
instance.PwdDecrypt()
res, err := d.InstanceApp.GetDatabases(instance)

View File

@@ -1,23 +1,15 @@
package application
import (
"database/sql"
"fmt"
"mayfly-go/internal/common/consts"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/pkg/base"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/collx"
"reflect"
"strconv"
"mayfly-go/pkg/utils/structx"
"strings"
"sync"
"time"
)
type Db interface {
@@ -34,9 +26,9 @@ type Db interface {
Delete(id uint64) error
// 获取数据库连接实例
// @param id 数据库实例id
// @param db 数据库
GetDbConnection(dbId uint64, dbName string) (*DbConnection, error)
// @param id 数据库id
// @param dbName 数据库
GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error)
}
func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance) Db {
@@ -96,7 +88,7 @@ func (d *dbAppImpl) Save(dbEntity *entity.Db) error {
for _, v := range delDb {
// 关闭数据库连接
CloseDb(dbEntity.Id, v)
dbm.CloseDb(dbEntity.Id, v)
// 删除该库关联的所有sql记录
d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: dbId, Db: v})
}
@@ -112,316 +104,39 @@ func (d *dbAppImpl) Delete(id uint64) error {
dbs := strings.Split(db.Database, " ")
for _, v := range dbs {
// 关闭连接
CloseDb(id, v)
dbm.CloseDb(id, v)
}
// 删除该库下用户保存的所有sql信息
d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: id})
return d.DeleteById(id)
}
var mutex sync.Mutex
func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) (*DbConnection, error) {
cacheKey := GetDbCacheKey(dbId, dbName)
// Id不为0则为需要缓存
needCache := dbId != 0
if needCache {
load, ok := dbCache.Get(cacheKey)
if ok {
return load.(*DbConnection), nil
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
return dbm.GetDbConn(dbId, dbName, func() (*dbm.DbInfo, error) {
db, err := d.GetById(new(entity.Db), dbId)
if err != nil {
return nil, errorx.NewBiz("数据库信息不存在")
}
}
mutex.Lock()
defer mutex.Unlock()
db, err := d.GetById(new(entity.Db), dbId)
if err != nil {
return nil, errorx.NewBiz("数据库信息不存在")
}
if !strings.Contains(" "+db.Database+" ", " "+dbName+" ") {
return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName)
}
instance, err := d.dbInstanceApp.GetById(new(entity.Instance), db.InstanceId)
if err != nil {
return nil, errorx.NewBiz("数据库实例不存在")
}
// 密码解密
instance.PwdDecrypt()
dbInfo := NewDbInfo(db, instance)
dbInfo.Database = dbName
dbi := &DbConnection{Id: cacheKey, Info: dbInfo}
conn, err := getInstanceConn(instance, dbName)
if err != nil {
dbi.Close()
logx.Errorf("连接db失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, dbName)
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
}
// 最大连接周期超过时间的连接就close
// conn.SetConnMaxLifetime(100 * time.Second)
// 设置最大连接数
conn.SetMaxOpenConns(5)
// 设置闲置连接数
conn.SetMaxIdleConns(1)
dbi.db = conn
logx.Infof("连接db: %s:%d/%s", dbInfo.Host, dbInfo.Port, dbName)
if needCache {
dbCache.Put(cacheKey, dbi)
}
return dbi, nil
}
//---------------------------------------- db instance ------------------------------------
type DbInfo struct {
Id uint64
Name string
Type entity.DbType // 类型mysql oracle等
Host string
Port int
Network string
Username string
TagPath string
Database string
SshTunnelMachineId int
}
func NewDbInfo(db *entity.Db, instance *entity.Instance) *DbInfo {
return &DbInfo{
Id: db.Id,
Name: db.Name,
Type: instance.Type,
Host: instance.Host,
Port: instance.Port,
Username: instance.Username,
TagPath: db.TagPath,
SshTunnelMachineId: instance.SshTunnelMachineId,
}
}
// 获取记录日志的描述
func (d *DbInfo) GetLogDesc() string {
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database)
}
// db实例连接信息
type DbConnection struct {
Id string
Info *DbInfo
db *sql.DB
}
// 执行查询语句
// 依次返回 列名数组结果map错误
func (d *DbConnection) SelectData(execSql string) ([]string, []map[string]any, error) {
return selectDataByDb(d.db, execSql)
}
// 将查询结果映射至struct可具体参考sqlx库
func (d *DbConnection) SelectData2Struct(execSql string, dest any) error {
return select2StructByDb(d.db, execSql, dest)
}
// WalkTableRecord 遍历表记录
func (d *DbConnection) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error {
return walkTableRecord(d.db, selectSql, walk)
}
// 执行 update, insert, delete建表等sql
// 返回影响条数和错误
func (d *DbConnection) Exec(sql string) (int64, error) {
res, err := d.db.Exec(sql)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// 获取数据库元信息实现接口
func (d *DbConnection) GetMeta() DbMetadata {
switch d.Info.Type {
case entity.DbTypeMysql:
return &MysqlMetadata{di: d}
case entity.DbTypePostgres:
return &PgsqlMetadata{di: d}
default:
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
}
}
// 关闭连接
func (d *DbConnection) Close() {
if d.db != nil {
if err := d.db.Close(); err != nil {
logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
if !strings.Contains(" "+db.Database+" ", " "+dbName+" ") {
return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName)
}
d.db = nil
}
}
//------------------------------------------------------------------------------
// 客户端连接缓存,指定时间内没有访问则会被关闭, key为数据库实例id:数据库
var dbCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
value.(*DbConnection).Close()
})
func init() {
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
// 遍历所有db连接实例若存在db实例使用该ssh隧道机器则返回true表示还在使用中...
items := dbCache.Items()
for _, v := range items {
if v.Value.(*DbConnection).Info.SshTunnelMachineId == machineId {
return true
}
instance, err := d.dbInstanceApp.GetById(new(entity.DbInstance), db.InstanceId)
if err != nil {
return nil, errorx.NewBiz("数据库实例不存在")
}
return false
// 密码解密
instance.PwdDecrypt()
return toDbInfo(instance, dbId, dbName, db.TagPath), nil
})
}
func GetDbCacheKey(dbId uint64, db string) string {
return fmt.Sprintf("%d:%s", dbId, db)
}
func selectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) {
// 列名用于前端表头名称按照数据库与查询字段顺序显示
var colNames []string
result := make([]map[string]any, 0, 16)
err := walkTableRecord(db, selectSql, func(record map[string]any, columns []string) {
result = append(result, record)
if colNames == nil {
colNames = make([]string, len(columns))
copy(colNames, columns)
}
})
if err != nil {
return nil, nil, err
}
return colNames, result, nil
}
func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) error {
rows, err := db.Query(selectSql)
if err != nil {
return err
}
// rows对象一定要close掉如果出错不关掉则会很迅速的达到设置最大连接数
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
defer func() {
if rows != nil {
rows.Close()
}
}()
colTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
lenCols := len(colTypes)
// 列名用于前端表头名称按照数据库与查询字段顺序显示
colNames := make([]string, lenCols)
// 这里表示一行填充数据
scans := make([]any, lenCols)
// 这里表示一行所有列的值,用[]byte表示
values := make([][]byte, lenCols)
for k, colType := range colTypes {
colNames[k] = colType.Name()
// 这里scans引用values把数据填充到[]byte里
scans[k] = &values[k]
}
for rows.Next() {
// 不Scan也会导致等待该链接实际处于未工作的状态然后也会导致连接数迅速达到最大
if err := rows.Scan(scans...); err != nil {
return err
}
// 每行数据
rowData := make(map[string]any, lenCols)
// 把values中的数据复制到row中
for i, v := range values {
rowData[colTypes[i].Name()] = valueConvert(v, colTypes[i])
}
walk(rowData, colNames)
}
return nil
}
// 将查询的值转为对应列类型的实际值,不全部转为字符串
func valueConvert(data []byte, colType *sql.ColumnType) any {
if data == nil {
return nil
}
// 列的数据库类型名
colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName())
// 如果类型是bit则直接返回第一个字节即可
if strings.Contains(colDatabaseTypeName, "bit") {
return data[0]
}
// 这里把[]byte数据转成string
stringV := string(data)
if stringV == "" {
return ""
}
colScanType := strings.ToLower(colType.ScanType().Name())
if strings.Contains(colScanType, "int") {
// 如果长度超过16位则返回字符串因为前端js长度大于16会丢失精度
if len(stringV) > 16 {
return stringV
}
intV, _ := strconv.Atoi(stringV)
switch colType.ScanType().Kind() {
case reflect.Int8:
return int8(intV)
case reflect.Uint8:
return uint8(intV)
case reflect.Int64:
return int64(intV)
case reflect.Uint64:
return uint64(intV)
case reflect.Uint:
return uint(intV)
default:
return intV
}
}
if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") {
floatV, _ := strconv.ParseFloat(stringV, 64)
return floatV
}
return stringV
}
// 查询数据结果映射至struct。可参考sqlx库
func select2StructByDb(db *sql.DB, selectSql string, dest any) error {
rows, err := db.Query(selectSql)
if err != nil {
return err
}
// rows对象一定要close掉如果出错不关掉则会很迅速的达到设置最大连接数
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
defer func() {
if rows != nil {
rows.Close()
}
}()
return scanAll(rows, dest, false)
}
// 删除db缓存并关闭该数据库所有连接
func CloseDb(dbId uint64, db string) {
dbCache.Delete(GetDbCacheKey(dbId, db))
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath string) *dbm.DbInfo {
di := new(dbm.DbInfo)
di.Id = dbId
di.Database = database
di.TagPath = tagPath
structx.Copy(di, instance)
return di
}

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/errorx"
@@ -20,7 +21,7 @@ type DbSqlExecReq struct {
Sql string
Remark string
LoginAccount *model.LoginAccount
DbConn *DbConnection
DbConn *dbm.DbConn
}
type DbSqlExecRes struct {
@@ -269,7 +270,7 @@ func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *ent
return doExec(execSqlReq.Sql, execSqlReq.DbConn)
}
func doExec(sql string, dbConn *DbConnection) (*DbSqlExecRes, error) {
func doExec(sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
rowsAffected, err := dbConn.Exec(sql)
execRes := "success"
if err != nil {

View File

@@ -1,8 +1,6 @@
package application
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
@@ -11,20 +9,20 @@ import (
)
type Instance interface {
base.App[*entity.Instance]
base.App[*entity.DbInstance]
// GetPageList 分页获取数据库实例
GetPageList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)
Count(condition *entity.InstanceQuery) int64
Save(instanceEntity *entity.Instance) error
Save(instanceEntity *entity.DbInstance) error
// Delete 删除数据库信息
Delete(id uint64) error
// GetDatabases 获取数据库实例的所有数据库列表
GetDatabases(entity *entity.Instance) ([]string, error)
GetDatabases(entity *entity.DbInstance) ([]string, error)
}
func newInstanceApp(instanceRepo repository.Instance) Instance {
@@ -34,7 +32,7 @@ func newInstanceApp(instanceRepo repository.Instance) Instance {
}
type instanceAppImpl struct {
base.AppImpl[*entity.Instance, repository.Instance]
base.AppImpl[*entity.DbInstance, repository.Instance]
}
// GetPageList 分页获取数据库实例
@@ -46,19 +44,21 @@ func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 {
return app.CountByCond(condition)
}
func (app *instanceAppImpl) Save(instanceEntity *entity.Instance) error {
func (app *instanceAppImpl) Save(instanceEntity *entity.DbInstance) error {
// 默认tcp连接
instanceEntity.Network = instanceEntity.GetNetwork()
// 测试连接
if instanceEntity.Password != "" {
if err := testConnection(instanceEntity); err != nil {
return errorx.NewBiz("数据库连接失败: %s", err.Error())
dbConn, err := toDbInfo(instanceEntity, 0, "", "").Conn()
if err != nil {
return err
}
defer dbConn.Close()
}
// 查找是否存在该库
oldInstance := &entity.Instance{Host: instanceEntity.Host, Port: instanceEntity.Port, Username: instanceEntity.Username}
oldInstance := &entity.DbInstance{Host: instanceEntity.Host, Port: instanceEntity.Port, Username: instanceEntity.Username}
if instanceEntity.SshTunnelMachineId > 0 {
oldInstance.SshTunnelMachineId = instanceEntity.SshTunnelMachineId
}
@@ -87,55 +87,19 @@ func (app *instanceAppImpl) Delete(id uint64) error {
return app.DeleteById(id)
}
// getInstanceConn 获取数据库连接数据库实例
func getInstanceConn(instance *entity.Instance, db string) (*sql.DB, error) {
var conn *sql.DB
var err error
switch instance.Type {
case entity.DbTypeMysql:
conn, err = getMysqlDB(instance, db)
case entity.DbTypePostgres:
conn, err = getPgsqlDB(instance, db)
default:
panic(fmt.Sprintf("invalid database type: %s", instance.Type))
}
if err != nil {
return nil, err
}
err = conn.Ping()
if err != nil {
conn.Close()
return nil, err
}
return conn, nil
}
func testConnection(d *entity.Instance) error {
// 不指定数据库名称
conn, err := getInstanceConn(d, "")
if err != nil {
return err
}
defer conn.Close()
return nil
}
func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) ([]string, error) {
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) {
ed.Network = ed.GetNetwork()
databases := make([]string, 0)
var dbConn *sql.DB
metaDb := ed.Type.MetaDbName()
getDatabasesSql := ed.Type.StmtSelectDbName()
dbConn, err := getInstanceConn(ed, metaDb)
dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn()
if err != nil {
return nil, errorx.NewBiz("数据库连接失败: %s", err.Error())
return nil, err
}
defer dbConn.Close()
_, res, err := selectDataByDb(dbConn, getDatabasesSql)
_, res, err := dbConn.SelectData(getDatabasesSql)
if err != nil {
return nil, err
}

View File

@@ -0,0 +1,195 @@
package dbm
import (
"database/sql"
"fmt"
"mayfly-go/pkg/logx"
"reflect"
"strconv"
"strings"
)
// db实例连接信息
type DbConn struct {
Id string
Info *DbInfo
db *sql.DB
}
// 执行查询语句
// 依次返回 列名数组结果map错误
func (d *DbConn) SelectData(execSql string) ([]string, []map[string]any, error) {
return selectDataByDb(d.db, execSql)
}
// 将查询结果映射至struct可具体参考sqlx库
func (d *DbConn) SelectData2Struct(execSql string, dest any) error {
return select2StructByDb(d.db, execSql, dest)
}
// WalkTableRecord 遍历表记录
func (d *DbConn) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error {
return walkTableRecord(d.db, selectSql, walk)
}
// 执行 update, insert, delete建表等sql
// 返回影响条数和错误
func (d *DbConn) Exec(sql string) (int64, error) {
res, err := d.db.Exec(sql)
if err != nil {
return 0, err
}
return res.RowsAffected()
}
// 获取数据库元信息实现接口
func (d *DbConn) GetMeta() DbMetadata {
switch d.Info.Type {
case DbTypeMysql:
return &MysqlMetadata{dc: d}
case DbTypePostgres:
return &PgsqlMetadata{dc: d}
default:
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
}
}
// 关闭连接
func (d *DbConn) Close() {
if d.db != nil {
if err := d.db.Close(); err != nil {
logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
}
d.db = nil
}
}
func selectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) {
// 列名用于前端表头名称按照数据库与查询字段顺序显示
var colNames []string
result := make([]map[string]any, 0, 16)
err := walkTableRecord(db, selectSql, func(record map[string]any, columns []string) {
result = append(result, record)
if colNames == nil {
colNames = make([]string, len(columns))
copy(colNames, columns)
}
})
if err != nil {
return nil, nil, err
}
return colNames, result, nil
}
func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) error {
rows, err := db.Query(selectSql)
if err != nil {
return err
}
// rows对象一定要close掉如果出错不关掉则会很迅速的达到设置最大连接数
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
defer func() {
if rows != nil {
rows.Close()
}
}()
colTypes, err := rows.ColumnTypes()
if err != nil {
return err
}
lenCols := len(colTypes)
// 列名用于前端表头名称按照数据库与查询字段顺序显示
colNames := make([]string, lenCols)
// 这里表示一行填充数据
scans := make([]any, lenCols)
// 这里表示一行所有列的值,用[]byte表示
values := make([][]byte, lenCols)
for k, colType := range colTypes {
colNames[k] = colType.Name()
// 这里scans引用values把数据填充到[]byte里
scans[k] = &values[k]
}
for rows.Next() {
// 不Scan也会导致等待该链接实际处于未工作的状态然后也会导致连接数迅速达到最大
if err := rows.Scan(scans...); err != nil {
return err
}
// 每行数据
rowData := make(map[string]any, lenCols)
// 把values中的数据复制到row中
for i, v := range values {
rowData[colTypes[i].Name()] = valueConvert(v, colTypes[i])
}
walk(rowData, colNames)
}
return nil
}
// 将查询的值转为对应列类型的实际值,不全部转为字符串
func valueConvert(data []byte, colType *sql.ColumnType) any {
if data == nil {
return nil
}
// 列的数据库类型名
colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName())
// 如果类型是bit则直接返回第一个字节即可
if strings.Contains(colDatabaseTypeName, "bit") {
return data[0]
}
// 这里把[]byte数据转成string
stringV := string(data)
if stringV == "" {
return ""
}
colScanType := strings.ToLower(colType.ScanType().Name())
if strings.Contains(colScanType, "int") {
// 如果长度超过16位则返回字符串因为前端js长度大于16会丢失精度
if len(stringV) > 16 {
return stringV
}
intV, _ := strconv.Atoi(stringV)
switch colType.ScanType().Kind() {
case reflect.Int8:
return int8(intV)
case reflect.Uint8:
return uint8(intV)
case reflect.Int64:
return int64(intV)
case reflect.Uint64:
return uint64(intV)
case reflect.Uint:
return uint(intV)
default:
return intV
}
}
if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") {
floatV, _ := strconv.ParseFloat(stringV, 64)
return floatV
}
return stringV
}
// 查询数据结果映射至struct。可参考sqlx库
func select2StructByDb(db *sql.DB, selectSql string, dest any) error {
rows, err := db.Query(selectSql)
if err != nil {
return err
}
// rows对象一定要close掉如果出错不关掉则会很迅速的达到设置最大连接数
// 后面的链接过来直接报错或拒绝,实际上也没有起效果
defer func() {
if rows != nil {
rows.Close()
}
}()
return scanAll(rows, dest, false)
}

View File

@@ -0,0 +1,73 @@
package dbm
import (
"fmt"
"mayfly-go/internal/common/consts"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/logx"
"sync"
"time"
)
// 客户端连接缓存,指定时间内没有访问则会被关闭, key为数据库连接id
var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
value.(*DbConn).Close()
})
func init() {
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
// 遍历所有db连接实例若存在db实例使用该ssh隧道机器则返回true表示还在使用中...
items := connCache.Items()
for _, v := range items {
if v.Value.(*DbConn).Info.SshTunnelMachineId == machineId {
return true
}
}
return false
})
}
var mutex sync.Mutex
// 从缓存中获取数据库连接信息若缓存中不存在则会使用回调函数获取dbInfo进行连接并缓存
func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) (*DbConn, error) {
connId := GetDbConnId(dbId, database)
// connId不为空则为需要缓存
needCache := connId != ""
if needCache {
load, ok := connCache.Get(connId)
if ok {
return load.(*DbConn), nil
}
}
mutex.Lock()
defer mutex.Unlock()
// 若缓存中不存在则从回调函数中获取DbInfo
dbInfo, err := getDbInfo()
if err != nil {
return nil, err
}
// 连接数据库
dbConn, err := dbInfo.Conn()
if err != nil {
return nil, err
}
if needCache {
connCache.Put(connId, dbConn)
}
return dbConn, nil
}
// 删除db缓存并关闭该数据库所有连接
func CloseDb(dbId uint64, db string) {
connCache.Delete(GetDbConnId(dbId, db))
}

View File

@@ -1,10 +1,11 @@
package entity
package dbm
import (
"fmt"
"strings"
"github.com/kanzihuang/vitess/go/vt/sqlparser"
"github.com/lib/pq"
"strings"
)
type DbType string

View File

@@ -1,8 +1,9 @@
package entity
package dbm
import (
"github.com/stretchr/testify/require"
"testing"
"github.com/stretchr/testify/require"
)
func Test_QuoteLiteral(t *testing.T) {

View File

@@ -0,0 +1,79 @@
package dbm
import (
"database/sql"
"fmt"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
)
type DbInfo struct {
Id uint64
Name string
Type DbType // 类型mysql postgres等
Host string
Port int
Network string
Username string
Password string
Params string
Database string
TagPath string
SshTunnelMachineId int
}
// 获取记录日志的描述
func (d *DbInfo) GetLogDesc() string {
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database)
}
// 连接数据库
func (dbInfo *DbInfo) Conn() (*DbConn, error) {
var conn *sql.DB
var err error
database := dbInfo.Database
switch dbInfo.Type {
case DbTypeMysql:
conn, err = getMysqlDB(dbInfo)
case DbTypePostgres:
conn, err = getPgsqlDB(dbInfo)
default:
return nil, errorx.NewBiz("invalid database type: %s", dbInfo.Type)
}
if err != nil {
logx.Errorf("连接db失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
}
err = conn.Ping()
if err != nil {
logx.Errorf("db ping失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
}
dbc := &DbConn{Id: GetDbConnId(dbInfo.Id, database), Info: dbInfo}
// 最大连接周期超过时间的连接就close
// conn.SetConnMaxLifetime(100 * time.Second)
// 设置最大连接数
conn.SetMaxOpenConns(5)
// 设置闲置连接数
conn.SetMaxIdleConns(1)
dbc.db = conn
logx.Infof("连接db: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
return dbc, nil
}
// 获取连接id
func GetDbConnId(dbId uint64, db string) string {
if dbId == 0 {
return ""
}
return fmt.Sprintf("%d:%s", dbId, db)
}

View File

@@ -1,4 +1,4 @@
package application
package dbm
import (
"embed"

View File

@@ -1,10 +1,9 @@
package application
package dbm
import (
"context"
"database/sql"
"fmt"
"mayfly-go/internal/db/domain/entity"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/utils/anyx"
@@ -13,7 +12,7 @@ import (
"github.com/go-sql-driver/mysql"
)
func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) {
func getMysqlDB(d *DbInfo) (*sql.DB, error) {
// SSH Conect
if d.SshTunnelMachineId > 0 {
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
@@ -25,7 +24,7 @@ func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) {
})
}
// 设置dataSourceName -> 更多参数参考https://github.com/go-sql-driver/mysql#dsn-data-source-name
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
@@ -42,12 +41,12 @@ const (
)
type MysqlMetadata struct {
di *DbConnection
dc *DbConn
}
// 获取表基础元信息, 如表名等
func (mm *MysqlMetadata) GetTables() ([]Table, error) {
_, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_MA_KEY))
_, res, err := mm.dc.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_MA_KEY))
if err != nil {
return nil, err
}
@@ -72,7 +71,7 @@ func (mm *MysqlMetadata) GetColumns(tableNames ...string) ([]Column, error) {
tableName = tableName + "'" + tableNames[i] + "'"
}
_, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
_, res, err := mm.dc.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
@@ -113,7 +112,7 @@ func (mm *MysqlMetadata) GetPrimaryKey(tablename string) (string, error) {
// 获取表信息比GetTableMetedatas获取更详细的表信息
func (mm *MysqlMetadata) GetTableInfos() ([]Table, error) {
_, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
_, res, err := mm.dc.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
@@ -134,7 +133,7 @@ func (mm *MysqlMetadata) GetTableInfos() ([]Table, error) {
// 获取表索引信息
func (mm *MysqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName))
_, res, err := mm.dc.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
}
@@ -171,7 +170,7 @@ func (mm *MysqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
// 获取建表ddl
func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
_, res, err := mm.di.SelectData(fmt.Sprintf("show create table `%s` ", tableName))
_, res, err := mm.dc.SelectData(fmt.Sprintf("show create table `%s` ", tableName))
if err != nil {
return "", err
}
@@ -179,9 +178,9 @@ func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
}
func (mm *MysqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) {
return mm.di.SelectData(fmt.Sprintf("SELECT * FROM %s LIMIT %d, %d", tableName, (pageNum-1)*pageSize, pageSize))
return mm.dc.SelectData(fmt.Sprintf("SELECT * FROM %s LIMIT %d, %d", tableName, (pageNum-1)*pageSize, pageSize))
}
func (mm *MysqlMetadata) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
return mm.di.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
return mm.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
}

View File

@@ -1,10 +1,9 @@
package application
package dbm
import (
"database/sql"
"database/sql/driver"
"fmt"
"mayfly-go/internal/db/domain/entity"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/utils/anyx"
@@ -16,7 +15,7 @@ import (
"github.com/lib/pq"
)
func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) {
func getPgsqlDB(d *DbInfo) (*sql.DB, error) {
driverName := string(d.Type)
// SSH Conect
if d.SshTunnelMachineId > 0 {
@@ -28,6 +27,7 @@ func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) {
sql.Drivers()
}
db := d.Database
var dbParam string
if db != "" {
dbParam = "dbname=" + db
@@ -75,12 +75,12 @@ const (
)
type PgsqlMetadata struct {
di *DbConnection
dc *DbConn
}
// 获取表基础元信息, 如表名等
func (pm *PgsqlMetadata) GetTables() ([]Table, error) {
_, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_MA_KEY))
_, res, err := pm.dc.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_MA_KEY))
if err != nil {
return nil, err
}
@@ -105,7 +105,7 @@ func (pm *PgsqlMetadata) GetColumns(tableNames ...string) ([]Column, error) {
tableName = tableName + "'" + tableNames[i] + "'"
}
_, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
_, res, err := pm.dc.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
@@ -144,7 +144,7 @@ func (pm *PgsqlMetadata) GetPrimaryKey(tablename string) (string, error) {
// 获取表信息比GetTables获取更详细的表信息
func (pm *PgsqlMetadata) GetTableInfos() ([]Table, error) {
_, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
_, res, err := pm.dc.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
@@ -165,7 +165,7 @@ func (pm *PgsqlMetadata) GetTableInfos() ([]Table, error) {
// 获取表索引信息
func (pm *PgsqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
_, res, err := pm.dc.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
}
@@ -186,16 +186,16 @@ func (pm *PgsqlMetadata) GetTableIndex(tableName string) ([]Index, error) {
// 获取建表ddl
func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
_, err := pm.di.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
_, err := pm.dc.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
if err != nil {
return "", err
}
_, schemaRes, _ := pm.di.SelectData("select current_schema() as schema")
_, schemaRes, _ := pm.dc.SelectData("select current_schema() as schema")
schemaName := schemaRes[0]["schema"].(string)
ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName)
_, res, err := pm.di.SelectData(ddlSql)
_, res, err := pm.dc.SelectData(ddlSql)
if err != nil {
return "", err
}
@@ -204,9 +204,9 @@ func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) (string, error) {
}
func (pm *PgsqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) {
return pm.di.SelectData(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
return pm.dc.SelectData(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize))
}
func (pm *PgsqlMetadata) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error {
return pm.di.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
return pm.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk)
}

View File

@@ -1,4 +1,4 @@
package application
package dbm
import (
"database/sql"

View File

@@ -0,0 +1,50 @@
package entity
import (
"fmt"
"mayfly-go/internal/common/utils"
"mayfly-go/internal/db/dbm"
"mayfly-go/pkg/model"
)
type DbInstance struct {
model.Model
Name string `orm:"column(name)" json:"name"`
Type dbm.DbType `orm:"column(type)" json:"type"` // 类型mysql oracle等
Host string `orm:"column(host)" json:"host"`
Port int `orm:"column(port)" json:"port"`
Network string `orm:"column(network)" json:"network"`
Username string `orm:"column(username)" json:"username"`
Password string `orm:"column(password)" json:"-"`
Params string `orm:"column(params)" json:"params"`
Remark string `orm:"column(remark)" json:"remark"`
SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id
}
func (d *DbInstance) TableName() string {
return "t_db_instance"
}
// 获取数据库连接网络, 若没有使用ssh隧道则直接返回。否则返回拼接的网络需要注册至指定dial
func (d *DbInstance) GetNetwork() string {
network := d.Network
if d.SshTunnelMachineId <= 0 {
if network == "" {
return "tcp"
} else {
return network
}
}
return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId)
}
func (d *DbInstance) PwdEncrypt() {
// 密码替换为加密后的密码
d.Password = utils.PwdAesEncrypt(d.Password)
}
func (d *DbInstance) PwdDecrypt() {
// 密码替换为解密后的密码
d.Password = utils.PwdAesDecrypt(d.Password)
}

View File

@@ -1,49 +0,0 @@
package entity
import (
"fmt"
"mayfly-go/internal/common/utils"
"mayfly-go/pkg/model"
)
type Instance struct {
model.Model
Name string `orm:"column(name)" json:"name"`
Type DbType `orm:"column(type)" json:"type"` // 类型mysql oracle等
Host string `orm:"column(host)" json:"host"`
Port int `orm:"column(port)" json:"port"`
Network string `orm:"column(network)" json:"network"`
Username string `orm:"column(username)" json:"username"`
Password string `orm:"column(password)" json:"-"`
Params string `orm:"column(params)" json:"params"`
Remark string `orm:"column(remark)" json:"remark"`
SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id
}
func (d *Instance) TableName() string {
return "t_db_instance"
}
// 获取数据库连接网络, 若没有使用ssh隧道则直接返回。否则返回拼接的网络需要注册至指定dial
func (d *Instance) GetNetwork() string {
network := d.Network
if d.SshTunnelMachineId <= 0 {
if network == "" {
return "tcp"
} else {
return network
}
}
return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId)
}
func (d *Instance) PwdEncrypt() {
// 密码替换为加密后的密码
d.Password = utils.PwdAesEncrypt(d.Password)
}
func (d *Instance) PwdDecrypt() {
// 密码替换为解密后的密码
d.Password = utils.PwdAesDecrypt(d.Password)
}

View File

@@ -7,7 +7,7 @@ import (
)
type Instance interface {
base.Repo[*entity.Instance]
base.Repo[*entity.DbInstance]
// 分页获取数据库实例信息列表
GetInstanceList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error)

View File

@@ -9,16 +9,16 @@ import (
)
type instanceRepoImpl struct {
base.RepoImpl[*entity.Instance]
base.RepoImpl[*entity.DbInstance]
}
func newInstanceRepo() repository.Instance {
return &instanceRepoImpl{base.RepoImpl[*entity.Instance]{M: new(entity.Instance)}}
return &instanceRepoImpl{base.RepoImpl[*entity.DbInstance]{M: new(entity.DbInstance)}}
}
// 分页获取数据库信息列表
func (d *instanceRepoImpl) GetInstanceList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
qd := gormx.NewQuery(new(entity.Instance)).
qd := gormx.NewQuery(new(entity.DbInstance)).
Eq("id", condition.Id).
Eq("host", condition.Host).
Like("name", condition.Name)

View File

@@ -74,18 +74,20 @@ func (m *Mongo) DeleteMongo(rc *req.Ctx) {
}
func (m *Mongo) Databases(rc *req.Ctx) {
cli := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx)).Cli
res, err := cli.ListDatabases(context.TODO(), bson.D{})
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx))
biz.ErrIsNil(err)
res, err := conn.Cli.ListDatabases(context.TODO(), bson.D{})
biz.ErrIsNilAppendErr(err, "获取mongo所有库信息失败: %s")
rc.ResData = res
}
func (m *Mongo) Collections(rc *req.Ctx) {
cli := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx)).Cli
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx))
biz.ErrIsNil(err)
db := rc.GinCtx.Query("database")
biz.NotEmpty(db, "database不能为空")
ctx := context.TODO()
res, err := cli.Database(db).ListCollectionNames(ctx, bson.D{})
res, err := conn.Cli.Database(db).ListCollectionNames(ctx, bson.D{})
biz.ErrIsNilAppendErr(err, "获取库集合信息失败: %s")
rc.ResData = res
}
@@ -94,8 +96,9 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
commandForm := new(form.MongoRunCommand)
ginx.BindJsonAndValid(rc.GinCtx, commandForm)
inst := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx))
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx))
biz.ErrIsNil(err)
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
// 顺序执行
commands := bson.D{}
@@ -110,7 +113,7 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
ctx := context.TODO()
var bm bson.M
err := inst.Cli.Database(commandForm.Database).RunCommand(
err = conn.Cli.Database(commandForm.Database).RunCommand(
ctx,
commands,
).Decode(&bm)
@@ -121,10 +124,13 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
func (m *Mongo) FindCommand(rc *req.Ctx) {
g := rc.GinCtx
cli := m.MongoApp.GetMongoInst(m.GetMongoId(g)).Cli
commandForm := new(form.MongoFindCommand)
ginx.BindJsonAndValid(g, commandForm)
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
biz.ErrIsNil(err)
cli := conn.Cli
limit := commandForm.Limit
if limit != 0 {
biz.IsTrue(limit <= 100, "limit不能超过100")
@@ -157,8 +163,9 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) {
commandForm := new(form.MongoUpdateByIdCommand)
ginx.BindJsonAndValid(g, commandForm)
inst := m.MongoApp.GetMongoInst(m.GetMongoId(g))
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
biz.ErrIsNil(err)
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
// 解析docId文档id如果为string类型则使用ObjectId解析解析失败则为普通字符串
docId := commandForm.DocId
@@ -170,7 +177,7 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) {
}
}
res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).UpdateByID(context.TODO(), docId, commandForm.Update)
res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).UpdateByID(context.TODO(), docId, commandForm.Update)
biz.ErrIsNilAppendErr(err, "命令执行失败: %s")
rc.ResData = res
@@ -181,8 +188,9 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) {
commandForm := new(form.MongoUpdateByIdCommand)
ginx.BindJsonAndValid(g, commandForm)
inst := m.MongoApp.GetMongoInst(m.GetMongoId(g))
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
biz.ErrIsNil(err)
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
// 解析docId文档id如果为string类型则使用ObjectId解析解析失败则为普通字符串
docId := commandForm.DocId
@@ -194,7 +202,7 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) {
}
}
res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).DeleteOne(context.TODO(), bson.D{{Key: "_id", Value: docId}})
res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).DeleteOne(context.TODO(), bson.D{{Key: "_id", Value: docId}})
biz.ErrIsNilAppendErr(err, "命令执行失败: %s")
rc.ResData = res
}
@@ -204,10 +212,11 @@ func (m *Mongo) InsertOneCommand(rc *req.Ctx) {
commandForm := new(form.MongoInsertCommand)
ginx.BindJsonAndValid(g, commandForm)
inst := m.MongoApp.GetMongoInst(m.GetMongoId(g))
rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm)
conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g))
biz.ErrIsNil(err)
rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).InsertOne(context.TODO(), commandForm.Doc)
res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).InsertOne(context.TODO(), commandForm.Doc)
biz.ErrIsNilAppendErr(err, "命令执行失败: %s")
rc.ResData = res
}

View File

@@ -1,26 +1,12 @@
package application
import (
"context"
"fmt"
"mayfly-go/internal/common/consts"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/internal/mongo/domain/entity"
"mayfly-go/internal/mongo/domain/repository"
"mayfly-go/internal/mongo/mgm"
"mayfly-go/pkg/base"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/netx"
"mayfly-go/pkg/utils/structx"
"net"
"regexp"
"time"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type Mongo interface {
@@ -38,7 +24,7 @@ type Mongo interface {
// 获取mongo连接实例
// @param id mongo id
GetMongoInst(id uint64) *MongoInstance
GetMongoConn(id uint64) (*mgm.MongoConn, error)
}
func newMongoAppImpl(mongoRepo repository.Mongo) Mongo {
@@ -61,7 +47,7 @@ func (d *mongoAppImpl) Count(condition *entity.MongoQuery) int64 {
}
func (d *mongoAppImpl) Delete(id uint64) error {
DeleteMongoCache(id)
mgm.CloseConn(id)
return d.GetRepo().DeleteById(id)
}
@@ -71,148 +57,16 @@ func (d *mongoAppImpl) Save(m *entity.Mongo) error {
}
// 先关闭连接
DeleteMongoCache(m.Id)
mgm.CloseConn(m.Id)
return d.GetRepo().UpdateById(m)
}
func (d *mongoAppImpl) GetMongoInst(id uint64) *MongoInstance {
mongoInstance, err := GetMongoInstance(id, func(u uint64) (*entity.Mongo, error) {
mongo, err := d.GetById(new(entity.Mongo), u)
func (d *mongoAppImpl) GetMongoConn(id uint64) (*mgm.MongoConn, error) {
return mgm.GetMongoConn(id, func() (*mgm.MongoInfo, error) {
mongo, err := d.GetById(new(entity.Mongo), id)
if err != nil {
return nil, err
return nil, errorx.NewBiz("mongo信息不存在")
}
return mongo, nil
})
biz.ErrIsNilAppendErr(err, "连接mongo失败: %s")
return mongoInstance
}
// -----------------------------------------------------------
// mongo客户端连接缓存指定时间内没有访问则会被关闭
var mongoCliCache = cache.NewTimedCache(consts.MongoConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Infof("删除mongo连接缓存: id = %v", key)
value.(*MongoInstance).Close()
})
func init() {
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
// 遍历所有mongo连接实例若存在redis实例使用该ssh隧道机器则返回true表示还在使用中...
items := mongoCliCache.Items()
for _, v := range items {
if v.Value.(*MongoInstance).Info.SshTunnelMachineId == machineId {
return true
}
}
return false
return mongo.ToMongoInfo(), nil
})
}
// 获取mongo的连接实例
func GetMongoInstance(mongoId uint64, getMongoEntity func(uint64) (*entity.Mongo, error)) (*MongoInstance, error) {
mi, err := mongoCliCache.ComputeIfAbsent(mongoId, func(_ any) (any, error) {
mongoEntity, err := getMongoEntity(mongoId)
if err != nil {
return nil, err
}
c, err := connect(mongoEntity)
if err != nil {
return nil, err
}
return c, nil
})
if mi != nil {
return mi.(*MongoInstance), err
}
return nil, err
}
func DeleteMongoCache(mongoId uint64) {
mongoCliCache.Delete(mongoId)
}
type MongoInfo struct {
Id uint64 `json:"id"`
Name string `json:"name"`
TagPath string `json:"tagPath"`
SshTunnelMachineId int `json:"-"` // ssh隧道机器id
}
func (m *MongoInfo) GetLogDesc() string {
return fmt.Sprintf("Mongo[id=%d, tag=%s, name=%s]", m.Id, m.TagPath, m.Name)
}
type MongoInstance struct {
Id uint64
Info *MongoInfo
Cli *mongo.Client
}
func (mi *MongoInstance) Close() {
if mi.Cli != nil {
if err := mi.Cli.Disconnect(context.Background()); err != nil {
logx.Errorf("关闭mongo实例[%d]连接失败: %s", mi.Id, err)
}
mi.Cli = nil
}
}
// 连接mongo并返回client
func connect(me *entity.Mongo) (*MongoInstance, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
mongoInstance := &MongoInstance{Id: me.Id, Info: toMongoInfo(me)}
mongoOptions := options.Client().ApplyURI(me.Uri).
SetMaxPoolSize(1)
// 启用ssh隧道则连接隧道机器
if me.SshTunnelMachineId > 0 {
mongoOptions.SetDialer(&MongoSshDialer{machineId: me.SshTunnelMachineId})
}
client, err := mongo.Connect(ctx, mongoOptions)
if err != nil {
mongoInstance.Close()
return nil, err
}
if err = client.Ping(context.TODO(), nil); err != nil {
mongoInstance.Close()
return nil, err
}
logx.Infof("连接mongo: %s", func(str string) string {
reg := regexp.MustCompile(`(^mongodb://.+?:)(.+)(@.+$)`)
return reg.ReplaceAllString(str, `${1}****${3}`)
}(me.Uri))
mongoInstance.Cli = client
return mongoInstance, err
}
func toMongoInfo(me *entity.Mongo) *MongoInfo {
mi := new(MongoInfo)
structx.Copy(mi, me)
return mi
}
type MongoSshDialer struct {
machineId int
}
func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId)
if err != nil {
return nil, err
}
if sshConn, err := stm.GetDialConn(network, address); err == nil {
// 将ssh conn包装否则内部部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
return &netx.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}

View File

@@ -1,6 +1,10 @@
package entity
import "mayfly-go/pkg/model"
import (
"mayfly-go/internal/mongo/mgm"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/structx"
)
type Mongo struct {
model.Model
@@ -11,3 +15,10 @@ type Mongo struct {
TagId uint64 `json:"tagId"`
TagPath string `json:"tagPath"`
}
// 转换为mongoInfo进行连接
func (me *Mongo) ToMongoInfo() *mgm.MongoInfo {
mongoInfo := new(mgm.MongoInfo)
structx.Copy(mongoInfo, me)
return mongoInfo
}

View File

@@ -0,0 +1,24 @@
package mgm
import (
"context"
"mayfly-go/pkg/logx"
"go.mongodb.org/mongo-driver/mongo"
)
type MongoConn struct {
Id string
Info *MongoInfo
Cli *mongo.Client
}
func (mc *MongoConn) Close() {
if mc.Cli != nil {
if err := mc.Cli.Disconnect(context.Background()); err != nil {
logx.Errorf("关闭mongo实例[%s]连接失败: %s", mc.Id, err)
}
mc.Cli = nil
}
}

View File

@@ -0,0 +1,72 @@
package mgm
import (
"mayfly-go/internal/common/consts"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/logx"
"sync"
"time"
)
// mongo客户端连接缓存指定时间内没有访问则会被关闭
var connCache = cache.NewTimedCache(consts.MongoConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Infof("删除mongo连接缓存: id = %v", key)
value.(*MongoConn).Close()
})
func init() {
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
// 遍历所有mongo连接实例若存在redis实例使用该ssh隧道机器则返回true表示还在使用中...
items := connCache.Items()
for _, v := range items {
if v.Value.(*MongoConn).Info.SshTunnelMachineId == machineId {
return true
}
}
return false
})
}
var mutex sync.Mutex
// 从缓存中获取mongo连接信息, 若缓存中不存在则会使用回调函数获取mongoInfo进行连接并缓存
func GetMongoConn(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (*MongoConn, error) {
connId := getConnId(mongoId)
// connId不为空则为需要缓存
needCache := connId != ""
if needCache {
load, ok := connCache.Get(connId)
if ok {
return load.(*MongoConn), nil
}
}
mutex.Lock()
defer mutex.Unlock()
// 若缓存中不存在则从回调函数中获取MongoInfo
mi, err := getMongoInfo()
if err != nil {
return nil, err
}
// 连接mongo
mc, err := mi.Conn()
if err != nil {
return nil, err
}
if needCache {
connCache.Put(connId, mc)
}
return mc, nil
}
// 关闭连接,并移除缓存连接
func CloseConn(mongoId uint64) {
connCache.Delete(mongoId)
}

View File

@@ -0,0 +1,79 @@
package mgm
import (
"context"
"fmt"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/netx"
"net"
"regexp"
"time"
machineapp "mayfly-go/internal/machine/application"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
)
type MongoInfo struct {
Id uint64 `json:"id"`
Name string `json:"name"`
Uri string `json:"-"`
TagPath string `json:"tagPath"`
SshTunnelMachineId int `json:"-"` // ssh隧道机器id
}
func (mi *MongoInfo) Conn() (*MongoConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
mongoOptions := options.Client().ApplyURI(mi.Uri).
SetMaxPoolSize(1)
// 启用ssh隧道则连接隧道机器
if mi.SshTunnelMachineId > 0 {
mongoOptions.SetDialer(&MongoSshDialer{machineId: mi.SshTunnelMachineId})
}
client, err := mongo.Connect(ctx, mongoOptions)
if err != nil {
return nil, err
}
if err = client.Ping(context.TODO(), nil); err != nil {
client.Disconnect(ctx)
return nil, err
}
logx.Infof("连接mongo: %s", func(str string) string {
reg := regexp.MustCompile(`(^mongodb://.+?:)(.+)(@.+$)`)
return reg.ReplaceAllString(str, `${1}****${3}`)
}(mi.Uri))
return &MongoConn{Id: getConnId(mi.Id), Info: mi, Cli: client}, nil
}
type MongoSshDialer struct {
machineId int
}
func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId)
if err != nil {
return nil, err
}
if sshConn, err := stm.GetDialConn(network, address); err == nil {
// 将ssh conn包装否则内部部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
return &netx.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}
// 生成mongo连接id
func getConnId(id uint64) string {
if id == 0 {
return ""
}
return fmt.Sprintf("%d", id)
}

View File

@@ -11,7 +11,7 @@ import (
)
func (r *Redis) Hscan(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
g := rc.GinCtx
count := ginx.QueryInt(g, "count", 10)
match := g.Query("match")
@@ -32,7 +32,7 @@ func (r *Redis) Hscan(rc *req.Ctx) {
}
func (r *Redis) Hdel(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
field := rc.GinCtx.Query("field")
rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key, "field", field)
@@ -42,7 +42,7 @@ func (r *Redis) Hdel(rc *req.Ctx) {
}
func (r *Redis) Hget(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
field := rc.GinCtx.Query("field")
res, err := ri.GetCmdable().HGet(context.TODO(), key, field).Result()
@@ -56,7 +56,7 @@ func (r *Redis) Hset(rc *req.Ctx) {
ginx.BindJsonAndValid(g, hashValue)
hv := hashValue.Value[0]
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
rc.ReqParam = collx.Kvs("redis", ri.Info, "hash", hv)
res, err := ri.GetCmdable().HSet(context.TODO(), hashValue.Key, hv["field"].(string), hv["value"]).Result()
@@ -69,7 +69,7 @@ func (r *Redis) SetHashValue(rc *req.Ctx) {
hashValue := new(form.HashValue)
ginx.BindJsonAndValid(g, hashValue)
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
rc.ReqParam = collx.Kvs("redis", ri.Info, "hash", hashValue)
cmd := ri.GetCmdable()

View File

@@ -4,7 +4,7 @@ import (
"context"
"mayfly-go/internal/redis/api/form"
"mayfly-go/internal/redis/api/vo"
"mayfly-go/internal/redis/domain/entity"
"mayfly-go/internal/redis/rdm"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/ginx"
"mayfly-go/pkg/req"
@@ -18,7 +18,7 @@ import (
// scan获取redis的key列表信息
func (r *Redis) ScanKeys(rc *req.Ctx) {
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
form := &form.RedisScanForm{}
ginx.BindJsonAndValid(rc.GinCtx, form)
@@ -44,7 +44,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) {
// 通配符或全匹配
mode := ri.Info.Mode
if mode == "" || mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel {
if mode == "" || mode == rdm.StandaloneMode || mode == rdm.SentinelMode {
redisAddr := ri.Cli.Options().Addr
cursorRes[redisAddr] = form.Cursor[redisAddr]
for {
@@ -64,7 +64,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) {
break
}
}
} else if mode == entity.RedisModeCluster {
} else if mode == rdm.ClusterMode {
mu := &sync.Mutex{}
// 遍历所有master节点并执行scan命令合并keys
ri.ClusterCli.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error {
@@ -98,7 +98,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) {
}
func (r *Redis) KeyInfo(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
cmd := ri.GetCmdable()
ctx := context.Background()
ttl, err := cmd.TTL(ctx, key).Result()
@@ -120,7 +120,7 @@ func (r *Redis) KeyInfo(rc *req.Ctx) {
}
func (r *Redis) TtlKey(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
ttl, err := ri.GetCmdable().TTL(context.Background(), key).Result()
biz.ErrIsNilAppendErr(err, "ttl失败: %s")
@@ -132,7 +132,7 @@ func (r *Redis) TtlKey(rc *req.Ctx) {
}
func (r *Redis) DeleteKey(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key)
ri.GetCmdable().Del(context.Background(), key)
}
@@ -141,7 +141,7 @@ func (r *Redis) RenameKey(rc *req.Ctx) {
form := &form.Rename{}
ginx.BindJsonAndValid(rc.GinCtx, form)
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
rc.ReqParam = collx.Kvs("redis", ri.Info, "rename", form)
ri.GetCmdable().Rename(context.Background(), form.Key, form.NewKey)
}
@@ -150,21 +150,21 @@ func (r *Redis) ExpireKey(rc *req.Ctx) {
form := &form.Expire{}
ginx.BindJsonAndValid(rc.GinCtx, form)
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
rc.ReqParam = collx.Kvs("redis", ri.Info, "expire", form)
ri.GetCmdable().Expire(context.Background(), form.Key, time.Duration(form.Seconds)*time.Second)
}
// 移除过期时间
func (r *Redis) PersistKey(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key)
ri.GetCmdable().Persist(context.Background(), key)
}
// 清空库
func (r *Redis) FlushDb(rc *req.Ctx) {
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
rc.ReqParam = ri.Info
ri.GetCmdable().FlushDB(context.Background())
}

View File

@@ -10,7 +10,7 @@ import (
)
func (r *Redis) GetListValue(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
ctx := context.TODO()
cmdable := ri.GetCmdable()
@@ -34,7 +34,7 @@ func (r *Redis) Lrem(rc *req.Ctx) {
option := new(form.LRemOption)
ginx.BindJsonAndValid(g, option)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
res, err := cmd.LRem(context.TODO(), option.Key, int64(option.Count), option.Member).Result()
biz.ErrIsNilAppendErr(err, "lrem失败: %s")
rc.ResData = res
@@ -45,7 +45,7 @@ func (r *Redis) SaveListValue(rc *req.Ctx) {
listValue := new(form.ListValue)
ginx.BindJsonAndValid(g, listValue)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
key := listValue.Key
ctx := context.TODO()
@@ -59,7 +59,7 @@ func (r *Redis) SetListValue(rc *req.Ctx) {
listSetValue := new(form.ListSetValue)
ginx.BindJsonAndValid(g, listSetValue)
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
_, err := ri.GetCmdable().LSet(context.TODO(), listSetValue.Key, listSetValue.Index, listSetValue.Value).Result()
biz.ErrIsNilAppendErr(err, "list set失败: %s")

View File

@@ -6,6 +6,7 @@ import (
"mayfly-go/internal/redis/api/vo"
"mayfly-go/internal/redis/application"
"mayfly-go/internal/redis/domain/entity"
"mayfly-go/internal/redis/rdm"
tagapp "mayfly-go/internal/tag/application"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/ginx"
@@ -86,7 +87,7 @@ func (r *Redis) DeleteRedis(rc *req.Ctx) {
func (r *Redis) RedisInfo(rc *req.Ctx) {
g := rc.GinCtx
ri, err := r.RedisApp.GetRedisInstance(uint64(ginx.PathParamInt(g, "id")), 0)
ri, err := r.RedisApp.GetRedisConn(uint64(ginx.PathParamInt(g, "id")), 0)
biz.ErrIsNil(err)
section := rc.GinCtx.Query("section")
@@ -94,9 +95,9 @@ func (r *Redis) RedisInfo(rc *req.Ctx) {
ctx := context.Background()
var redisCli *redis.Client
if mode == "" || mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel {
if mode == "" || mode == rdm.StandaloneMode || mode == rdm.SentinelMode {
redisCli = ri.Cli
} else if mode == entity.RedisModeCluster {
} else if mode == rdm.ClusterMode {
host := rc.GinCtx.Query("host")
biz.NotEmpty(host, "集群模式host信息不能为空")
clusterClient := ri.ClusterCli
@@ -164,9 +165,9 @@ func (r *Redis) RedisInfo(rc *req.Ctx) {
func (r *Redis) ClusterInfo(rc *req.Ctx) {
g := rc.GinCtx
ri, err := r.RedisApp.GetRedisInstance(uint64(ginx.PathParamInt(g, "id")), 0)
ri, err := r.RedisApp.GetRedisConn(uint64(ginx.PathParamInt(g, "id")), 0)
biz.ErrIsNil(err)
biz.IsEquals(ri.Info.Mode, entity.RedisModeCluster, "非集群模式")
biz.IsEquals(ri.Info.Mode, rdm.ClusterMode, "非集群模式")
info, _ := ri.ClusterCli.ClusterInfo(context.Background()).Result()
nodesStr, _ := ri.ClusterCli.ClusterNodes(context.Background()).Result()
@@ -208,14 +209,14 @@ func (r *Redis) ClusterInfo(rc *req.Ctx) {
}
// 校验查询参数中的key为必填项并返回redis实例
func (r *Redis) checkKeyAndGetRedisIns(rc *req.Ctx) (*application.RedisInstance, string) {
func (r *Redis) checkKeyAndGetRedisConn(rc *req.Ctx) (*rdm.RedisConn, string) {
key := rc.GinCtx.Query("key")
biz.NotEmpty(key, "key不能为空")
return r.getRedisIns(rc), key
return r.getRedisConn(rc), key
}
func (r *Redis) getRedisIns(rc *req.Ctx) *application.RedisInstance {
ri, err := r.RedisApp.GetRedisInstance(getIdAndDbNum(rc.GinCtx))
func (r *Redis) getRedisConn(rc *req.Ctx) *rdm.RedisConn {
ri, err := r.RedisApp.GetRedisConn(getIdAndDbNum(rc.GinCtx))
biz.ErrIsNil(err)
biz.ErrIsNilAppendErr(r.TagApp.CanAccess(rc.LoginAccount.Id, ri.Info.TagPath), "%s")
return ri

View File

@@ -11,7 +11,7 @@ import (
)
func (r *Redis) GetSetValue(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
res, err := ri.GetCmdable().SMembers(context.TODO(), key).Result()
biz.ErrIsNilAppendErr(err, "获取set值失败: %s")
rc.ResData = res
@@ -22,7 +22,7 @@ func (r *Redis) SetSetValue(rc *req.Ctx) {
keyvalue := new(form.SetValue)
ginx.BindJsonAndValid(g, keyvalue)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
key := keyvalue.Key
// 简单处理->先删除,后新增
@@ -35,7 +35,7 @@ func (r *Redis) SetSetValue(rc *req.Ctx) {
}
func (r *Redis) Scard(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
total, err := ri.GetCmdable().SCard(context.TODO(), key).Result()
biz.ErrIsNilAppendErr(err, "scard失败: %s")
@@ -47,7 +47,7 @@ func (r *Redis) Sscan(rc *req.Ctx) {
scan := new(form.ScanForm)
ginx.BindJsonAndValid(g, scan)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
keys, cursor, err := cmd.SScan(context.TODO(), scan.Key, scan.Cursor, scan.Match, scan.Count).Result()
biz.ErrIsNilAppendErr(err, "sscan失败: %s")
rc.ResData = collx.M{
@@ -60,7 +60,7 @@ func (r *Redis) Sadd(rc *req.Ctx) {
g := rc.GinCtx
option := new(form.SmemberOption)
ginx.BindJsonAndValid(g, option)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
res, err := cmd.SAdd(context.TODO(), option.Key, option.Member).Result()
biz.ErrIsNilAppendErr(err, "sadd失败: %s")
@@ -72,7 +72,7 @@ func (r *Redis) Srem(rc *req.Ctx) {
option := new(form.SmemberOption)
ginx.BindJsonAndValid(g, option)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
res, err := cmd.SRem(context.TODO(), option.Key, option.Member).Result()
biz.ErrIsNilAppendErr(err, "srem失败: %s")
rc.ResData = res

View File

@@ -11,7 +11,7 @@ import (
)
func (r *Redis) GetStringValue(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
str, err := ri.GetCmdable().Get(context.TODO(), key).Result()
biz.ErrIsNilAppendErr(err, "获取字符串值失败: %s")
rc.ResData = str
@@ -22,7 +22,7 @@ func (r *Redis) SetStringValue(rc *req.Ctx) {
keyValue := new(form.StringValue)
ginx.BindJsonAndValid(g, keyValue)
ri := r.getRedisIns(rc)
ri := r.getRedisConn(rc)
cmd := ri.GetCmdable()
rc.ReqParam = collx.Kvs("redis", ri.Info, "string", keyValue)

View File

@@ -12,7 +12,7 @@ import (
)
func (r *Redis) ZCard(rc *req.Ctx) {
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
total, err := ri.GetCmdable().ZCard(context.TODO(), key).Result()
biz.ErrIsNilAppendErr(err, "zcard失败: %s")
@@ -21,7 +21,7 @@ func (r *Redis) ZCard(rc *req.Ctx) {
func (r *Redis) ZScan(rc *req.Ctx) {
g := rc.GinCtx
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
cursor := uint64(ginx.QueryInt(g, "cursor", 0))
match := ginx.Query(g, "match", "*")
@@ -37,7 +37,7 @@ func (r *Redis) ZScan(rc *req.Ctx) {
func (r *Redis) ZRevRange(rc *req.Ctx) {
g := rc.GinCtx
ri, key := r.checkKeyAndGetRedisIns(rc)
ri, key := r.checkKeyAndGetRedisConn(rc)
start := ginx.QueryInt(g, "start", 0)
stop := ginx.QueryInt(g, "stop", 50)
@@ -51,7 +51,7 @@ func (r *Redis) ZRem(rc *req.Ctx) {
option := new(form.SmemberOption)
ginx.BindJsonAndValid(g, option)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
res, err := cmd.ZRem(context.TODO(), option.Key, option.Member).Result()
biz.ErrIsNilAppendErr(err, "zrem失败: %s")
rc.ResData = res
@@ -62,7 +62,7 @@ func (r *Redis) ZAdd(rc *req.Ctx) {
option := new(form.ZAddOption)
ginx.BindJsonAndValid(g, option)
cmd := r.getRedisIns(rc).GetCmdable()
cmd := r.getRedisConn(rc).GetCmdable()
zm := redis.Z{
Score: option.Score,
Member: option.Member,

View File

@@ -1,26 +1,14 @@
package application
import (
"context"
"fmt"
"mayfly-go/internal/common/consts"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/internal/redis/domain/entity"
"mayfly-go/internal/redis/domain/repository"
"mayfly-go/internal/redis/rdm"
"mayfly-go/pkg/base"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/netx"
"mayfly-go/pkg/utils/structx"
"net"
"strconv"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
type Redis interface {
@@ -31,7 +19,7 @@ type Redis interface {
Count(condition *entity.RedisQuery) int64
Save(entity *entity.Redis) error
Save(re *entity.Redis) error
// 删除数据库信息
Delete(id uint64) error
@@ -39,7 +27,10 @@ type Redis interface {
// 获取数据库连接实例
// id: 数据库实例id
// db: 库号
GetRedisInstance(id uint64, db int) (*RedisInstance, error)
GetRedisConn(id uint64, db int) (*rdm.RedisConn, error)
// 测试连接
TestConn(re *entity.Redis) error
}
func newRedisApp(redisRepo repository.Redis) Redis {
@@ -52,7 +43,7 @@ type redisAppImpl struct {
base.AppImpl[*entity.Redis, repository.Redis]
}
// 分页获取机器脚本信息列表
// 分页获取redis列表
func (r *redisAppImpl) GetPageList(condition *entity.RedisQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) {
return r.GetRepo().GetRedisList(condition, pageParam, toEntity, orderBy...)
}
@@ -64,7 +55,7 @@ func (r *redisAppImpl) Count(condition *entity.RedisQuery) int64 {
func (r *redisAppImpl) Save(re *entity.Redis) error {
// ’修改信息且密码不为空‘ or ‘新增’需要测试是否可连接
if (re.Id != 0 && re.Password != "") || re.Id == 0 {
if err := TestRedisConnection(re); err != nil {
if err := r.TestConn(re); err != nil {
return errorx.NewBiz("Redis连接失败: %s", err.Error())
}
}
@@ -92,7 +83,7 @@ func (r *redisAppImpl) Save(re *entity.Redis) error {
if oldRedis.Db != re.Db || oldRedis.SshTunnelMachineId != re.SshTunnelMachineId {
for _, dbStr := range strings.Split(oldRedis.Db, ",") {
db, _ := strconv.Atoi(dbStr)
CloseRedis(re.Id, db)
rdm.CloseConn(re.Id, db)
}
}
re.PwdEncrypt()
@@ -108,253 +99,35 @@ func (r *redisAppImpl) Delete(id uint64) error {
// 如果存在连接,则关闭所有库连接信息
for _, dbStr := range strings.Split(re.Db, ",") {
db, _ := strconv.Atoi(dbStr)
CloseRedis(re.Id, db)
rdm.CloseConn(re.Id, db)
}
return r.DeleteById(id)
}
// 获取数据库连接实例
func (r *redisAppImpl) GetRedisInstance(id uint64, db int) (*RedisInstance, error) {
// Id不为0则为需要缓存
needCache := id != 0
if needCache {
load, ok := redisCache.Get(getRedisCacheKey(id, db))
if ok {
return load.(*RedisInstance), nil
}
}
// 缓存不存在则回调获取redis信息
re, err := r.GetById(new(entity.Redis), id)
if err != nil {
return nil, errorx.NewBiz("redis信息不存在")
}
re.PwdDecrypt()
redisMode := re.Mode
var ri *RedisInstance
if redisMode == "" || redisMode == entity.RedisModeStandalone {
ri = getRedisCient(re, db)
// 测试连接
_, e := ri.Cli.Ping(context.Background()).Result()
if e != nil {
ri.Close()
return nil, errorx.NewBiz("redis连接失败: %s", e.Error())
}
} else if redisMode == entity.RedisModeCluster {
ri = getRedisClusterClient(re)
// 测试连接
_, e := ri.ClusterCli.Ping(context.Background()).Result()
if e != nil {
ri.Close()
return nil, errorx.NewBiz("redis集群连接失败: %s", e.Error())
}
} else if redisMode == entity.RedisModeSentinel {
ri = getRedisSentinelCient(re, db)
// 测试连接
_, e := ri.Cli.Ping(context.Background()).Result()
if e != nil {
ri.Close()
return nil, errorx.NewBiz("redis sentinel连接失败: %s", e.Error())
}
}
logx.Infof("连接redis: %s/%d", re.Host, db)
if needCache {
redisCache.Put(getRedisCacheKey(id, db), ri)
}
return ri, nil
}
// 生成redis连接缓存key
func getRedisCacheKey(id uint64, db int) string {
return fmt.Sprintf("%d/%d", id, db)
}
func toRedisInfo(re *entity.Redis, db int) *RedisInfo {
redisInfo := new(RedisInfo)
structx.Copy(redisInfo, re)
redisInfo.Db = db
return redisInfo
}
func getRedisCient(re *entity.Redis, db int) *RedisInstance {
ri := &RedisInstance{Id: getRedisCacheKey(re.Id, db), Info: toRedisInfo(re, db)}
redisOptions := &redis.Options{
Addr: re.Host,
Username: re.Username,
Password: re.Password, // no password set
DB: db, // use default DB
DialTimeout: 8 * time.Second,
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
WriteTimeout: -1,
}
if re.SshTunnelMachineId > 0 {
redisOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
}
ri.Cli = redis.NewClient(redisOptions)
return ri
}
func getRedisClusterClient(re *entity.Redis) *RedisInstance {
ri := &RedisInstance{Id: getRedisCacheKey(re.Id, 0), Info: toRedisInfo(re, 0)}
redisClusterOptions := &redis.ClusterOptions{
Addrs: strings.Split(re.Host, ","),
Username: re.Username,
Password: re.Password,
DialTimeout: 8 * time.Second,
}
if re.SshTunnelMachineId > 0 {
redisClusterOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
}
ri.ClusterCli = redis.NewClusterClient(redisClusterOptions)
return ri
}
func getRedisSentinelCient(re *entity.Redis, db int) *RedisInstance {
ri := &RedisInstance{Id: getRedisCacheKey(re.Id, db), Info: toRedisInfo(re, db)}
// sentinel模式host为 masterName=host:port,host:port
masterNameAndHosts := strings.Split(re.Host, "=")
sentinelOptions := &redis.FailoverOptions{
MasterName: masterNameAndHosts[0],
SentinelAddrs: strings.Split(masterNameAndHosts[1], ","),
Username: re.Username,
Password: re.Password, // no password set
SentinelPassword: re.Password, // 哨兵节点密码需与redis节点密码一致
DB: db, // use default DB
DialTimeout: 8 * time.Second,
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
WriteTimeout: -1,
}
if re.SshTunnelMachineId > 0 {
sentinelOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
}
ri.Cli = redis.NewFailoverClient(sentinelOptions)
return ri
}
func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(_ context.Context, network, addr string) (net.Conn, error) {
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId)
func (r *redisAppImpl) GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) {
return rdm.GetRedisConn(id, db, func() (*rdm.RedisInfo, error) {
// 缓存不存在则回调获取redis信息
re, err := r.GetById(new(entity.Redis), id)
if err != nil {
return nil, err
return nil, errorx.NewBiz("redis信息不存在")
}
re.PwdDecrypt()
if sshConn, err := sshTunnel.GetDialConn(network, addr); err == nil {
// 将ssh conn包装否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
return &netx.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}
}
//------------------------------------------------------------------------------
// redis客户端连接缓存指定时间内没有访问则会被关闭
var redisCache = cache.NewTimedCache(consts.RedisConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Info(fmt.Sprintf("删除redis连接缓存 id = %s", key))
value.(*RedisInstance).Close()
})
// 移除redis连接缓存并关闭redis连接
func CloseRedis(id uint64, db int) {
redisCache.Delete(getRedisCacheKey(id, db))
}
func init() {
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
// 遍历所有redis连接实例若存在redis实例使用该ssh隧道机器则返回true表示还在使用中...
items := redisCache.Items()
for _, v := range items {
if v.Value.(*RedisInstance).Info.SshTunnelMachineId == machineId {
return true
}
}
return false
return re.ToRedisInfo(db), nil
})
}
func TestRedisConnection(re *entity.Redis) error {
var cmd redis.Cmdable
// 取第一个库测试连接即可
dbStr := strings.Split(re.Db, ",")[0]
db, _ := strconv.Atoi(dbStr)
if re.Mode == "" || re.Mode == entity.RedisModeStandalone {
rcli := getRedisCient(re, db)
defer rcli.Close()
cmd = rcli.Cli
} else if re.Mode == entity.RedisModeCluster {
ccli := getRedisClusterClient(re)
defer ccli.Close()
cmd = ccli.ClusterCli
} else if re.Mode == entity.RedisModeSentinel {
rcli := getRedisSentinelCient(re, db)
defer rcli.Close()
cmd = rcli.Cli
func (r *redisAppImpl) TestConn(re *entity.Redis) error {
db := 0
if re.Db != "" {
db, _ = strconv.Atoi(strings.Split(re.Db, ",")[0])
}
// 测试连接
_, e := cmd.Ping(context.Background()).Result()
return e
}
type RedisInfo struct {
Id uint64 `json:"id"`
Host string `json:"host"`
Db int `json:"db"` // 库号
TagPath string `json:"tagPath"`
Mode string `json:"-"`
Name string `json:"-"`
SshTunnelMachineId int `json:"-"`
}
// 获取记录日志的描述
func (r *RedisInfo) GetLogDesc() string {
return fmt.Sprintf("Redis[id=%d, tag=%s, host=%s, db=%d]", r.Id, r.TagPath, r.Host, r.Db)
}
// redis实例
type RedisInstance struct {
Id string
Info *RedisInfo
Cli *redis.Client
ClusterCli *redis.ClusterClient
}
// 获取命令执行接口的具体实现
func (r *RedisInstance) GetCmdable() redis.Cmdable {
redisMode := r.Info.Mode
if redisMode == "" || redisMode == entity.RedisModeStandalone || r.Info.Mode == entity.RedisModeSentinel {
return r.Cli
}
if redisMode == entity.RedisModeCluster {
return r.ClusterCli
rc, err := re.ToRedisInfo(db).Conn()
if err != nil {
return err
}
rc.Close()
return nil
}
func (r *RedisInstance) Scan(cursor uint64, match string, count int64) ([]string, uint64, error) {
return r.GetCmdable().Scan(context.Background(), cursor, match, count).Result()
}
func (r *RedisInstance) Close() {
mode := r.Info.Mode
if mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel {
if err := r.Cli.Close(); err != nil {
logx.Errorf("关闭redis单机实例[%s]连接失败: %s", r.Id, err.Error())
}
r.Cli = nil
}
if mode == entity.RedisModeCluster {
if err := r.ClusterCli.Close(); err != nil {
logx.Errorf("关闭redis集群实例[%s]连接失败: %s", r.Id, err.Error())
}
r.ClusterCli = nil
}
}

View File

@@ -2,7 +2,9 @@ package entity
import (
"mayfly-go/internal/common/utils"
"mayfly-go/internal/redis/rdm"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils/structx"
)
type Redis struct {
@@ -20,12 +22,6 @@ type Redis struct {
TagPath string
}
const (
RedisModeStandalone = "standalone"
RedisModeCluster = "cluster"
RedisModeSentinel = "sentinel"
)
func (r *Redis) PwdEncrypt() {
// 密码替换为加密后的密码
r.Password = utils.PwdAesEncrypt(r.Password)
@@ -35,3 +31,11 @@ func (r *Redis) PwdDecrypt() {
// 密码替换为解密后的密码
r.Password = utils.PwdAesDecrypt(r.Password)
}
// 转换为redisInfo进行连接
func (re *Redis) ToRedisInfo(db int) *rdm.RedisInfo {
redisInfo := new(rdm.RedisInfo)
structx.Copy(redisInfo, re)
redisInfo.Db = db
return redisInfo
}

View File

@@ -0,0 +1,51 @@
package rdm
import (
"context"
"mayfly-go/pkg/logx"
"github.com/redis/go-redis/v9"
)
// redis连接信息
type RedisConn struct {
Id string
Info *RedisInfo
Cli *redis.Client
ClusterCli *redis.ClusterClient
}
// 获取命令执行接口的具体实现
func (r *RedisConn) GetCmdable() redis.Cmdable {
redisMode := r.Info.Mode
if redisMode == "" || redisMode == StandaloneMode || r.Info.Mode == SentinelMode {
return r.Cli
}
if redisMode == ClusterMode {
return r.ClusterCli
}
return nil
}
func (r *RedisConn) Scan(cursor uint64, match string, count int64) ([]string, uint64, error) {
return r.GetCmdable().Scan(context.Background(), cursor, match, count).Result()
}
func (r *RedisConn) Close() {
mode := r.Info.Mode
if mode == StandaloneMode || mode == SentinelMode {
if err := r.Cli.Close(); err != nil {
logx.Errorf("关闭redis单机实例[%s]连接失败: %s", r.Id, err.Error())
}
r.Cli = nil
return
}
if mode == ClusterMode {
if err := r.ClusterCli.Close(); err != nil {
logx.Errorf("关闭redis集群实例[%s]连接失败: %s", r.Id, err.Error())
}
r.ClusterCli = nil
}
}

View File

@@ -0,0 +1,73 @@
package rdm
import (
"fmt"
"mayfly-go/internal/common/consts"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/logx"
"sync"
"time"
)
// redis客户端连接缓存指定时间内没有访问则会被关闭
var connCache = cache.NewTimedCache(consts.RedisConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Info(fmt.Sprintf("删除redis连接缓存 id = %s", key))
value.(*RedisConn).Close()
})
func init() {
machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool {
// 遍历所有redis连接实例若存在redis实例使用该ssh隧道机器则返回true表示还在使用中...
items := connCache.Items()
for _, v := range items {
if v.Value.(*RedisConn).Info.SshTunnelMachineId == machineId {
return true
}
}
return false
})
}
var mutex sync.Mutex
// 从缓存中获取redis连接信息, 若缓存中不存在则会使用回调函数获取redisInfo进行连接并缓存
func GetRedisConn(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (*RedisConn, error) {
connId := getConnId(redisId, db)
// connId不为空则为需要缓存
needCache := connId != ""
if needCache {
load, ok := connCache.Get(connId)
if ok {
return load.(*RedisConn), nil
}
}
mutex.Lock()
defer mutex.Unlock()
// 若缓存中不存在则从回调函数中获取RedisInfo
ri, err := getRedisInfo()
if err != nil {
return nil, err
}
// 连接数据库
rc, err := ri.Conn()
if err != nil {
return nil, err
}
if needCache {
connCache.Put(connId, rc)
}
return rc, nil
}
// 移除redis连接缓存并关闭redis连接
func CloseConn(id uint64, db int) {
connCache.Delete(getConnId(id, db))
}

View File

@@ -0,0 +1,161 @@
package rdm
import (
"context"
"fmt"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/netx"
"net"
"strings"
"time"
"github.com/redis/go-redis/v9"
)
type RedisMode string
const (
StandaloneMode RedisMode = "standalone"
ClusterMode RedisMode = "cluster"
SentinelMode RedisMode = "sentinel"
)
type RedisInfo struct {
Id uint64 `json:"id"`
Host string `json:"host"`
Db int `json:"db"` // 库号
Mode RedisMode `json:"-"`
Username string `json:"-"`
Password string `json:"-"`
Name string `json:"-"`
TagPath string `json:"tagPath"`
SshTunnelMachineId int `json:"-"`
}
func (r *RedisInfo) Conn() (*RedisConn, error) {
redisMode := r.Mode
if redisMode == StandaloneMode {
return r.connStandalone()
}
if redisMode == ClusterMode {
return r.connCluster()
}
if redisMode == SentinelMode {
return r.connSentinel()
}
return nil, errorx.NewBiz("redis mode error")
}
func (re *RedisInfo) connStandalone() (*RedisConn, error) {
redisOptions := &redis.Options{
Addr: re.Host,
Username: re.Username,
Password: re.Password, // no password set
DB: re.Db, // use default DB
DialTimeout: 8 * time.Second,
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
WriteTimeout: -1,
}
if re.SshTunnelMachineId > 0 {
redisOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
}
cli := redis.NewClient(redisOptions)
_, e := cli.Ping(context.Background()).Result()
if e != nil {
cli.Close()
return nil, errorx.NewBiz("redis连接失败: %s", e.Error())
}
logx.Infof("连接redis standalone: %s/%d", re.Host, re.Db)
rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re}
rc.Cli = cli
return rc, nil
}
func (re *RedisInfo) connCluster() (*RedisConn, error) {
redisClusterOptions := &redis.ClusterOptions{
Addrs: strings.Split(re.Host, ","),
Username: re.Username,
Password: re.Password,
DialTimeout: 8 * time.Second,
}
if re.SshTunnelMachineId > 0 {
redisClusterOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
}
cli := redis.NewClusterClient(redisClusterOptions)
// 测试连接
_, e := cli.Ping(context.Background()).Result()
if e != nil {
cli.Close()
return nil, errorx.NewBiz("redis集群连接失败: %s", e.Error())
}
logx.Infof("连接redis cluster: %s/%d", re.Host, re.Db)
rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re}
rc.ClusterCli = cli
return rc, nil
}
func (re *RedisInfo) connSentinel() (*RedisConn, error) {
// sentinel模式host为 masterName=host:port,host:port
masterNameAndHosts := strings.Split(re.Host, "=")
sentinelOptions := &redis.FailoverOptions{
MasterName: masterNameAndHosts[0],
SentinelAddrs: strings.Split(masterNameAndHosts[1], ","),
Username: re.Username,
Password: re.Password, // no password set
SentinelPassword: re.Password, // 哨兵节点密码需与redis节点密码一致
DB: re.Db, // use default DB
DialTimeout: 8 * time.Second,
ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines.
WriteTimeout: -1,
}
if re.SshTunnelMachineId > 0 {
sentinelOptions.Dialer = getRedisDialer(re.SshTunnelMachineId)
}
cli := redis.NewFailoverClient(sentinelOptions)
_, e := cli.Ping(context.Background()).Result()
if e != nil {
cli.Close()
return nil, errorx.NewBiz("redis sentinel连接失败: %s", e.Error())
}
logx.Infof("连接redis sentinel: %s/%d", re.Host, re.Db)
rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re}
rc.Cli = cli
return rc, nil
}
func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) {
return func(_ context.Context, network, addr string) (net.Conn, error) {
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId)
if err != nil {
return nil, err
}
if sshConn, err := sshTunnel.GetDialConn(network, addr); err == nil {
// 将ssh conn包装否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
return &netx.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}
}
// 生成redis连接id
func getConnId(id uint64, db int) string {
if id == 0 {
return ""
}
return fmt.Sprintf("%d/%d", id, db)
}

View File

@@ -43,7 +43,7 @@ func T2022() *gormigrate.Migration {
return err
}
if err := tx.AutoMigrate(&entity2.Instance{}); err != nil {
if err := tx.AutoMigrate(&entity2.DbInstance{}); err != nil {
return err
}
if err := tx.AutoMigrate(&entity2.Db{}); err != nil {

View File

@@ -10,9 +10,15 @@ type App[T any] interface {
// 新增一个实体
Insert(e T) error
// 使用指定gorm db执行主要用于事务执行
InsertWithDb(db *gorm.DB, e T) error
// 批量新增实体
BatchInsert(models []T) error
// 使用指定gorm db执行主要用于事务执行
BatchInsertWithDb(db *gorm.DB, es []T) error
// 根据实体id更新实体信息
UpdateById(e T) error
@@ -70,11 +76,21 @@ func (ai *AppImpl[T, R]) Insert(e T) error {
return ai.GetRepo().Insert(e)
}
// 使用指定gorm db执行主要用于事务执行
func (ai *AppImpl[T, R]) InsertWithDb(db *gorm.DB, e T) error {
return ai.GetRepo().InsertWithDb(db, e)
}
// 批量新增实体 (单纯新增,不做其他业务逻辑处理)
func (ai *AppImpl[T, R]) BatchInsert(es []T) error {
return ai.GetRepo().BatchInsert(es)
}
// 使用指定gorm db执行主要用于事务执行
func (ai *AppImpl[T, R]) BatchInsertWithDb(db *gorm.DB, es []T) error {
return ai.GetRepo().BatchInsertWithDb(db, es)
}
// 根据实体id更新实体信息 (单纯更新,不做其他业务逻辑处理)
func (ai *AppImpl[T, R]) UpdateById(e T) error {
return ai.GetRepo().UpdateById(e)

View File

@@ -109,6 +109,7 @@ func ErrorRes(g *gin.Context, err any) {
g.JSON(http.StatusOK, model.ServerError())
default:
logx.Errorf("未知错误: %v", t)
g.JSON(http.StatusOK, model.ServerError())
}
}