diff --git a/mayfly_go_web/package.json b/mayfly_go_web/package.json index ed967e84..5d117391 100644 --- a/mayfly_go_web/package.json +++ b/mayfly_go_web/package.json @@ -11,7 +11,7 @@ "dependencies": { "@element-plus/icons-vue": "^2.1.0", "asciinema-player": "^3.6.2", - "axios": "^1.5.1", + "axios": "^1.6.0", "countup.js": "^2.7.0", "cropperjs": "^1.5.11", "echarts": "^5.4.3", diff --git a/mayfly_go_web/src/views/ops/db/DbEdit.vue b/mayfly_go_web/src/views/ops/db/DbEdit.vue index dbc6404e..60d4850d 100644 --- a/mayfly_go_web/src/views/ops/db/DbEdit.vue +++ b/mayfly_go_web/src/views/ops/db/DbEdit.vue @@ -19,7 +19,7 @@ :disabled="form.id !== undefined" remote :remote-method="getInstances" - @change="getAllDatabase" + @change="changeInstance" v-model="form.instanceId" placeholder="请输入实例名称搜索并选择实例" filterable @@ -163,6 +163,11 @@ watch(props, (newValue: any) => { } }); +const changeInstance = () => { + state.databaseList = []; + getAllDatabase(); +}; + /** * 改变表单中的数据库字段,方便表单错误提示。如全部删光,可提示请添加数据库 */ @@ -171,8 +176,6 @@ const changeDatabase = () => { }; const getAllDatabase = async () => { - // 清空数据库列表,可能已经有选择库了 - state.databaseList = []; if (state.form.instanceId > 0) { state.allDatabases = await dbApi.getAllDatabase.request({ instanceId: state.form.instanceId }); } diff --git a/mayfly_go_web/yarn.lock b/mayfly_go_web/yarn.lock index 0879ef63..1c443571 100644 --- a/mayfly_go_web/yarn.lock +++ b/mayfly_go_web/yarn.lock @@ -628,10 +628,10 @@ asynckit@^0.4.0: resolved "https://registry.npmmirror.com/asynckit/-/asynckit-0.4.0.tgz" integrity sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q== -axios@^1.5.1: - version "1.5.1" - resolved "https://registry.npmmirror.com/axios/-/axios-1.5.1.tgz#11fbaa11fc35f431193a9564109c88c1f27b585f" - integrity sha512-Q28iYCWzNHjAm+yEAot5QaAMxhMghWLFVf7rRdwhUI+c2jix2DUXjAHXVi+s1ibs3mjPO/cCgbA++3BjD0vP/A== +axios@^1.6.0: + version "1.6.0" + resolved "https://registry.npmmirror.com/axios/-/axios-1.6.0.tgz#f1e5292f26b2fd5c2e66876adc5b06cdbd7d2102" + integrity sha512-EZ1DYihju9pwVB+jg67ogm+Tmqc6JmhamRN6I4Zt8DfZu5lbcQGw3ozH9lFejSJgs/ibaef3A9PMXPLeefFGJg== dependencies: follow-redirects "^1.15.0" form-data "^4.0.0" diff --git a/server/internal/common/consts/consts.go b/server/internal/common/consts/consts.go index e70c88ba..a3521dd8 100644 --- a/server/internal/common/consts/consts.go +++ b/server/internal/common/consts/consts.go @@ -6,7 +6,7 @@ const ( AdminId = 1 MachineConnExpireTime = 60 * time.Minute - DbConnExpireTime = 45 * time.Minute + DbConnExpireTime = 120 * time.Minute RedisConnExpireTime = 30 * time.Minute MongoConnExpireTime = 30 * time.Minute diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index b012e49b..126e0d55 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -6,6 +6,7 @@ import ( "mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/application" + "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/domain/entity" msgapp "mayfly-go/internal/msg/application" msgdto "mayfly-go/internal/msg/application/dto" @@ -82,14 +83,14 @@ func (d *Db) DeleteDb(rc *req.Ctx) { } } -func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection { - dc, err := d.DbApp.GetDbConnection(getDbId(g), getDbName(g)) +func (d *Db) getDbConn(g *gin.Context) *dbm.DbConn { + dc, err := d.DbApp.GetDbConn(getDbId(g), getDbName(g)) biz.ErrIsNil(err) return dc } func (d *Db) TableInfos(rc *req.Ctx) { - res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetTableInfos() + res, err := d.getDbConn(rc.GinCtx).GetMeta().GetTableInfos() biz.ErrIsNilAppendErr(err, "获取表信息失败: %s") rc.ResData = res } @@ -97,7 +98,7 @@ func (d *Db) TableInfos(rc *req.Ctx) { func (d *Db) TableIndex(rc *req.Ctx) { tn := rc.GinCtx.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetTableIndex(tn) + res, err := d.getDbConn(rc.GinCtx).GetMeta().GetTableIndex(tn) biz.ErrIsNilAppendErr(err, "获取表索引信息失败: %s") rc.ResData = res } @@ -105,7 +106,7 @@ func (d *Db) TableIndex(rc *req.Ctx) { func (d *Db) GetCreateTableDdl(rc *req.Ctx) { tn := rc.GinCtx.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - res, err := d.getDbConnection(rc.GinCtx).GetMeta().GetCreateTableDdl(tn) + res, err := d.getDbConn(rc.GinCtx).GetMeta().GetCreateTableDdl(tn) biz.ErrIsNilAppendErr(err, "获取表ddl失败: %s") rc.ResData = res } @@ -116,7 +117,7 @@ func (d *Db) ExecSql(rc *req.Ctx) { ginx.BindJsonAndValid(g, form) dbId := getDbId(g) - dbConn, err := d.DbApp.GetDbConnection(dbId, form.Db) + dbConn, err := d.DbApp.GetDbConn(dbId, form.Db) biz.ErrIsNil(err) biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") @@ -188,7 +189,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { dbName := getDbName(g) clientId := g.Query("clientId") - dbConn, err := d.DbApp.GetDbConnection(dbId, dbName) + dbConn, err := d.DbApp.GetDbConn(dbId, dbName) biz.ErrIsNil(err) biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc()) @@ -256,7 +257,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { if !ok { logx.Warnf("sql解析失败: %s", sql) } - dbConn, err = d.DbApp.GetDbConnection(dbId, stmtUse.DBName.String()) + dbConn, err = d.DbApp.GetDbConn(dbId, stmtUse.DBName.String()) biz.ErrIsNil(err) biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") execReq.DbConn = dbConn @@ -334,7 +335,7 @@ func (d *Db) DumpSql(rc *req.Ctx) { } func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []string, needStruct bool, needData bool) { - dbConn, err := d.DbApp.GetDbConnection(dbId, dbName) + dbConn, err := d.DbApp.GetDbConn(dbId, dbName) biz.ErrIsNil(err) writer.WriteString("\n-- ----------------------------") writer.WriteString("\n-- 导出平台: mayfly-go") @@ -397,7 +398,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str // @router /api/db/:dbId/t-metadata [get] func (d *Db) TableMA(rc *req.Ctx) { - dbi := d.getDbConnection(rc.GinCtx) + dbi := d.getDbConn(rc.GinCtx) res, err := dbi.GetMeta().GetTables() biz.ErrIsNilAppendErr(err, "获取表基础信息失败: %s") rc.ResData = res @@ -409,7 +410,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) { tn := g.Query("tableName") biz.NotEmpty(tn, "tableName不能为空") - dbi := d.getDbConnection(rc.GinCtx) + dbi := d.getDbConn(rc.GinCtx) res, err := dbi.GetMeta().GetColumns(tn) biz.ErrIsNilAppendErr(err, "获取数据库列信息失败: %s") rc.ResData = res @@ -417,7 +418,7 @@ func (d *Db) ColumnMA(rc *req.Ctx) { // @router /api/db/:dbId/hint-tables [get] func (d *Db) HintTables(rc *req.Ctx) { - dbi := d.getDbConnection(rc.GinCtx) + dbi := d.getDbConn(rc.GinCtx) dm := dbi.GetMeta() // 获取所有表 diff --git a/server/internal/db/api/instance.go b/server/internal/db/api/instance.go index bc4bcc3f..1b90bc4d 100644 --- a/server/internal/db/api/instance.go +++ b/server/internal/db/api/instance.go @@ -33,7 +33,7 @@ func (d *Instance) Instances(rc *req.Ctx) { // @router /api/instances [post] func (d *Instance) SaveInstance(rc *req.Ctx) { form := &form.InstanceForm{} - instance := ginx.BindJsonAndCopyTo[*entity.Instance](rc.GinCtx, form, new(entity.Instance)) + instance := ginx.BindJsonAndCopyTo[*entity.DbInstance](rc.GinCtx, form, new(entity.DbInstance)) // 密码解密,并使用解密后的赋值 originPwd, err := cryptox.DefaultRsaDecrypt(form.Password, true) @@ -52,7 +52,7 @@ func (d *Instance) SaveInstance(rc *req.Ctx) { // @router /api/instances/:instance [GET] func (d *Instance) GetInstance(rc *req.Ctx) { dbId := getInstanceId(rc.GinCtx) - dbEntity, err := d.InstanceApp.GetById(new(entity.Instance), dbId) + dbEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), dbId) biz.ErrIsNil(err, "获取数据库实例错误") dbEntity.Password = "" rc.ResData = dbEntity @@ -62,7 +62,7 @@ func (d *Instance) GetInstance(rc *req.Ctx) { // @router /api/instances/:instance/pwd [GET] func (d *Instance) GetInstancePwd(rc *req.Ctx) { instanceId := getInstanceId(rc.GinCtx) - instanceEntity, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "Password") + instanceEntity, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password") biz.ErrIsNil(err, "获取数据库实例错误") instanceEntity.PwdDecrypt() rc.ResData = instanceEntity.Password @@ -80,7 +80,7 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) { biz.ErrIsNilAppendErr(err, "string类型转换为int异常: %s") instanceId := uint64(value) if d.DbApp.Count(&entity.DbQuery{InstanceId: instanceId}) != 0 { - instance, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "name") + instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "name") biz.ErrIsNil(err, "获取数据库实例错误,数据库实例ID为: %d", instance.Id) biz.IsTrue(false, "不能删除数据库实例【%s】,请先删除其关联的数据库资源。", instance.Name) } @@ -91,7 +91,7 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) { // 获取数据库实例的所有数据库名 func (d *Instance) GetDatabaseNames(rc *req.Ctx) { instanceId := getInstanceId(rc.GinCtx) - instance, err := d.InstanceApp.GetById(new(entity.Instance), instanceId, "Password") + instance, err := d.InstanceApp.GetById(new(entity.DbInstance), instanceId, "Password") biz.ErrIsNil(err, "获取数据库实例错误") instance.PwdDecrypt() res, err := d.InstanceApp.GetDatabases(instance) diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index d2413c27..0e0075b3 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -1,23 +1,15 @@ package application import ( - "database/sql" - "fmt" - "mayfly-go/internal/common/consts" + "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" - "mayfly-go/internal/machine/infrastructure/machine" "mayfly-go/pkg/base" - "mayfly-go/pkg/cache" "mayfly-go/pkg/errorx" - "mayfly-go/pkg/logx" "mayfly-go/pkg/model" "mayfly-go/pkg/utils/collx" - "reflect" - "strconv" + "mayfly-go/pkg/utils/structx" "strings" - "sync" - "time" ) type Db interface { @@ -34,9 +26,9 @@ type Db interface { Delete(id uint64) error // 获取数据库连接实例 - // @param id 数据库实例id - // @param db 数据库 - GetDbConnection(dbId uint64, dbName string) (*DbConnection, error) + // @param id 数据库id + // @param dbName 数据库 + GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) } func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance) Db { @@ -96,7 +88,7 @@ func (d *dbAppImpl) Save(dbEntity *entity.Db) error { for _, v := range delDb { // 关闭数据库连接 - CloseDb(dbEntity.Id, v) + dbm.CloseDb(dbEntity.Id, v) // 删除该库关联的所有sql记录 d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: dbId, Db: v}) } @@ -112,316 +104,39 @@ func (d *dbAppImpl) Delete(id uint64) error { dbs := strings.Split(db.Database, " ") for _, v := range dbs { // 关闭连接 - CloseDb(id, v) + dbm.CloseDb(id, v) } // 删除该库下用户保存的所有sql信息 d.dbSqlRepo.DeleteByCond(&entity.DbSql{DbId: id}) return d.DeleteById(id) } -var mutex sync.Mutex - -func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) (*DbConnection, error) { - cacheKey := GetDbCacheKey(dbId, dbName) - - // Id不为0,则为需要缓存 - needCache := dbId != 0 - if needCache { - load, ok := dbCache.Get(cacheKey) - if ok { - return load.(*DbConnection), nil +func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) { + return dbm.GetDbConn(dbId, dbName, func() (*dbm.DbInfo, error) { + db, err := d.GetById(new(entity.Db), dbId) + if err != nil { + return nil, errorx.NewBiz("数据库信息不存在") } - } - mutex.Lock() - defer mutex.Unlock() - - db, err := d.GetById(new(entity.Db), dbId) - if err != nil { - return nil, errorx.NewBiz("数据库信息不存在") - } - if !strings.Contains(" "+db.Database+" ", " "+dbName+" ") { - return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName) - } - - instance, err := d.dbInstanceApp.GetById(new(entity.Instance), db.InstanceId) - if err != nil { - return nil, errorx.NewBiz("数据库实例不存在") - } - // 密码解密 - instance.PwdDecrypt() - - dbInfo := NewDbInfo(db, instance) - dbInfo.Database = dbName - dbi := &DbConnection{Id: cacheKey, Info: dbInfo} - - conn, err := getInstanceConn(instance, dbName) - if err != nil { - dbi.Close() - logx.Errorf("连接db失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, dbName) - return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error())) - } - - // 最大连接周期,超过时间的连接就close - // conn.SetConnMaxLifetime(100 * time.Second) - // 设置最大连接数 - conn.SetMaxOpenConns(5) - // 设置闲置连接数 - conn.SetMaxIdleConns(1) - - dbi.db = conn - logx.Infof("连接db: %s:%d/%s", dbInfo.Host, dbInfo.Port, dbName) - if needCache { - dbCache.Put(cacheKey, dbi) - } - return dbi, nil -} - -//---------------------------------------- db instance ------------------------------------ - -type DbInfo struct { - Id uint64 - Name string - Type entity.DbType // 类型,mysql oracle等 - Host string - Port int - Network string - Username string - TagPath string - Database string - SshTunnelMachineId int -} - -func NewDbInfo(db *entity.Db, instance *entity.Instance) *DbInfo { - return &DbInfo{ - Id: db.Id, - Name: db.Name, - Type: instance.Type, - Host: instance.Host, - Port: instance.Port, - Username: instance.Username, - TagPath: db.TagPath, - SshTunnelMachineId: instance.SshTunnelMachineId, - } -} - -// 获取记录日志的描述 -func (d *DbInfo) GetLogDesc() string { - return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database) -} - -// db实例连接信息 -type DbConnection struct { - Id string - Info *DbInfo - - db *sql.DB -} - -// 执行查询语句 -// 依次返回 列名数组,结果map,错误 -func (d *DbConnection) SelectData(execSql string) ([]string, []map[string]any, error) { - return selectDataByDb(d.db, execSql) -} - -// 将查询结果映射至struct,可具体参考sqlx库 -func (d *DbConnection) SelectData2Struct(execSql string, dest any) error { - return select2StructByDb(d.db, execSql, dest) -} - -// WalkTableRecord 遍历表记录 -func (d *DbConnection) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error { - return walkTableRecord(d.db, selectSql, walk) -} - -// 执行 update, insert, delete,建表等sql -// 返回影响条数和错误 -func (d *DbConnection) Exec(sql string) (int64, error) { - res, err := d.db.Exec(sql) - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -// 获取数据库元信息实现接口 -func (d *DbConnection) GetMeta() DbMetadata { - switch d.Info.Type { - case entity.DbTypeMysql: - return &MysqlMetadata{di: d} - case entity.DbTypePostgres: - return &PgsqlMetadata{di: d} - default: - panic(fmt.Sprintf("invalid database type: %s", d.Info.Type)) - } -} - -// 关闭连接 -func (d *DbConnection) Close() { - if d.db != nil { - if err := d.db.Close(); err != nil { - logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error()) + if !strings.Contains(" "+db.Database+" ", " "+dbName+" ") { + return nil, errorx.NewBiz("未配置数据库【%s】的操作权限", dbName) } - d.db = nil - } -} -//------------------------------------------------------------------------------ - -// 客户端连接缓存,指定时间内没有访问则会被关闭, key为数据库实例id:数据库 -var dbCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second). - WithUpdateAccessTime(true). - OnEvicted(func(key any, value any) { - logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key)) - value.(*DbConnection).Close() - }) - -func init() { - machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool { - // 遍历所有db连接实例,若存在db实例使用该ssh隧道机器,则返回true,表示还在使用中... - items := dbCache.Items() - for _, v := range items { - if v.Value.(*DbConnection).Info.SshTunnelMachineId == machineId { - return true - } + instance, err := d.dbInstanceApp.GetById(new(entity.DbInstance), db.InstanceId) + if err != nil { + return nil, errorx.NewBiz("数据库实例不存在") } - return false + // 密码解密 + instance.PwdDecrypt() + return toDbInfo(instance, dbId, dbName, db.TagPath), nil }) } -func GetDbCacheKey(dbId uint64, db string) string { - return fmt.Sprintf("%d:%s", dbId, db) -} - -func selectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) { - // 列名用于前端表头名称按照数据库与查询字段顺序显示 - var colNames []string - result := make([]map[string]any, 0, 16) - err := walkTableRecord(db, selectSql, func(record map[string]any, columns []string) { - result = append(result, record) - if colNames == nil { - colNames = make([]string, len(columns)) - copy(colNames, columns) - } - }) - if err != nil { - return nil, nil, err - } - return colNames, result, nil -} - -func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) error { - rows, err := db.Query(selectSql) - if err != nil { - return err - } - // rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数, - // 后面的链接过来直接报错或拒绝,实际上也没有起效果 - defer func() { - if rows != nil { - rows.Close() - } - }() - - colTypes, err := rows.ColumnTypes() - if err != nil { - return err - } - lenCols := len(colTypes) - // 列名用于前端表头名称按照数据库与查询字段顺序显示 - colNames := make([]string, lenCols) - // 这里表示一行填充数据 - scans := make([]any, lenCols) - // 这里表示一行所有列的值,用[]byte表示 - values := make([][]byte, lenCols) - for k, colType := range colTypes { - colNames[k] = colType.Name() - // 这里scans引用values,把数据填充到[]byte里 - scans[k] = &values[k] - } - - for rows.Next() { - // 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大 - if err := rows.Scan(scans...); err != nil { - return err - } - // 每行数据 - rowData := make(map[string]any, lenCols) - // 把values中的数据复制到row中 - for i, v := range values { - rowData[colTypes[i].Name()] = valueConvert(v, colTypes[i]) - } - walk(rowData, colNames) - } - - return nil -} - -// 将查询的值转为对应列类型的实际值,不全部转为字符串 -func valueConvert(data []byte, colType *sql.ColumnType) any { - if data == nil { - return nil - } - // 列的数据库类型名 - colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName()) - - // 如果类型是bit,则直接返回第一个字节即可 - if strings.Contains(colDatabaseTypeName, "bit") { - return data[0] - } - - // 这里把[]byte数据转成string - stringV := string(data) - if stringV == "" { - return "" - } - colScanType := strings.ToLower(colType.ScanType().Name()) - - if strings.Contains(colScanType, "int") { - // 如果长度超过16位,则返回字符串,因为前端js长度大于16会丢失精度 - if len(stringV) > 16 { - return stringV - } - intV, _ := strconv.Atoi(stringV) - switch colType.ScanType().Kind() { - case reflect.Int8: - return int8(intV) - case reflect.Uint8: - return uint8(intV) - case reflect.Int64: - return int64(intV) - case reflect.Uint64: - return uint64(intV) - case reflect.Uint: - return uint(intV) - default: - return intV - } - } - if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") { - floatV, _ := strconv.ParseFloat(stringV, 64) - return floatV - } - - return stringV -} - -// 查询数据结果映射至struct。可参考sqlx库 -func select2StructByDb(db *sql.DB, selectSql string, dest any) error { - rows, err := db.Query(selectSql) - if err != nil { - return err - } - // rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数, - // 后面的链接过来直接报错或拒绝,实际上也没有起效果 - defer func() { - if rows != nil { - rows.Close() - } - }() - return scanAll(rows, dest, false) -} - -// 删除db缓存并关闭该数据库所有连接 -func CloseDb(dbId uint64, db string) { - dbCache.Delete(GetDbCacheKey(dbId, db)) +func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath string) *dbm.DbInfo { + di := new(dbm.DbInfo) + di.Id = dbId + di.Database = database + di.TagPath = tagPath + + structx.Copy(di, instance) + return di } diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 96d64a90..79f54034 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "mayfly-go/internal/db/config" + "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/errorx" @@ -20,7 +21,7 @@ type DbSqlExecReq struct { Sql string Remark string LoginAccount *model.LoginAccount - DbConn *DbConnection + DbConn *dbm.DbConn } type DbSqlExecRes struct { @@ -269,7 +270,7 @@ func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *ent return doExec(execSqlReq.Sql, execSqlReq.DbConn) } -func doExec(sql string, dbConn *DbConnection) (*DbSqlExecRes, error) { +func doExec(sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) { rowsAffected, err := dbConn.Exec(sql) execRes := "success" if err != nil { diff --git a/server/internal/db/application/instance.go b/server/internal/db/application/instance.go index 99a34f8c..106b449d 100644 --- a/server/internal/db/application/instance.go +++ b/server/internal/db/application/instance.go @@ -1,8 +1,6 @@ package application import ( - "database/sql" - "fmt" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/base" @@ -11,20 +9,20 @@ import ( ) type Instance interface { - base.App[*entity.Instance] + base.App[*entity.DbInstance] // GetPageList 分页获取数据库实例 GetPageList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) Count(condition *entity.InstanceQuery) int64 - Save(instanceEntity *entity.Instance) error + Save(instanceEntity *entity.DbInstance) error // Delete 删除数据库信息 Delete(id uint64) error // GetDatabases 获取数据库实例的所有数据库列表 - GetDatabases(entity *entity.Instance) ([]string, error) + GetDatabases(entity *entity.DbInstance) ([]string, error) } func newInstanceApp(instanceRepo repository.Instance) Instance { @@ -34,7 +32,7 @@ func newInstanceApp(instanceRepo repository.Instance) Instance { } type instanceAppImpl struct { - base.AppImpl[*entity.Instance, repository.Instance] + base.AppImpl[*entity.DbInstance, repository.Instance] } // GetPageList 分页获取数据库实例 @@ -46,19 +44,21 @@ func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 { return app.CountByCond(condition) } -func (app *instanceAppImpl) Save(instanceEntity *entity.Instance) error { +func (app *instanceAppImpl) Save(instanceEntity *entity.DbInstance) error { // 默认tcp连接 instanceEntity.Network = instanceEntity.GetNetwork() // 测试连接 if instanceEntity.Password != "" { - if err := testConnection(instanceEntity); err != nil { - return errorx.NewBiz("数据库连接失败: %s", err.Error()) + dbConn, err := toDbInfo(instanceEntity, 0, "", "").Conn() + if err != nil { + return err } + defer dbConn.Close() } // 查找是否存在该库 - oldInstance := &entity.Instance{Host: instanceEntity.Host, Port: instanceEntity.Port, Username: instanceEntity.Username} + oldInstance := &entity.DbInstance{Host: instanceEntity.Host, Port: instanceEntity.Port, Username: instanceEntity.Username} if instanceEntity.SshTunnelMachineId > 0 { oldInstance.SshTunnelMachineId = instanceEntity.SshTunnelMachineId } @@ -87,55 +87,19 @@ func (app *instanceAppImpl) Delete(id uint64) error { return app.DeleteById(id) } -// getInstanceConn 获取数据库连接数据库实例 -func getInstanceConn(instance *entity.Instance, db string) (*sql.DB, error) { - var conn *sql.DB - var err error - switch instance.Type { - case entity.DbTypeMysql: - conn, err = getMysqlDB(instance, db) - case entity.DbTypePostgres: - conn, err = getPgsqlDB(instance, db) - default: - panic(fmt.Sprintf("invalid database type: %s", instance.Type)) - } - - if err != nil { - return nil, err - } - err = conn.Ping() - if err != nil { - conn.Close() - return nil, err - } - - return conn, nil -} - -func testConnection(d *entity.Instance) error { - // 不指定数据库名称 - conn, err := getInstanceConn(d, "") - if err != nil { - return err - } - defer conn.Close() - return nil -} - -func (app *instanceAppImpl) GetDatabases(ed *entity.Instance) ([]string, error) { +func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) { ed.Network = ed.GetNetwork() databases := make([]string, 0) - var dbConn *sql.DB metaDb := ed.Type.MetaDbName() getDatabasesSql := ed.Type.StmtSelectDbName() - dbConn, err := getInstanceConn(ed, metaDb) + dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn() if err != nil { - return nil, errorx.NewBiz("数据库连接失败: %s", err.Error()) + return nil, err } defer dbConn.Close() - _, res, err := selectDataByDb(dbConn, getDatabasesSql) + _, res, err := dbConn.SelectData(getDatabasesSql) if err != nil { return nil, err } diff --git a/server/internal/db/dbm/conn.go b/server/internal/db/dbm/conn.go new file mode 100644 index 00000000..b04fb6ef --- /dev/null +++ b/server/internal/db/dbm/conn.go @@ -0,0 +1,195 @@ +package dbm + +import ( + "database/sql" + "fmt" + "mayfly-go/pkg/logx" + "reflect" + "strconv" + "strings" +) + +// db实例连接信息 +type DbConn struct { + Id string + Info *DbInfo + + db *sql.DB +} + +// 执行查询语句 +// 依次返回 列名数组,结果map,错误 +func (d *DbConn) SelectData(execSql string) ([]string, []map[string]any, error) { + return selectDataByDb(d.db, execSql) +} + +// 将查询结果映射至struct,可具体参考sqlx库 +func (d *DbConn) SelectData2Struct(execSql string, dest any) error { + return select2StructByDb(d.db, execSql, dest) +} + +// WalkTableRecord 遍历表记录 +func (d *DbConn) WalkTableRecord(selectSql string, walk func(record map[string]any, columns []string)) error { + return walkTableRecord(d.db, selectSql, walk) +} + +// 执行 update, insert, delete,建表等sql +// 返回影响条数和错误 +func (d *DbConn) Exec(sql string) (int64, error) { + res, err := d.db.Exec(sql) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +// 获取数据库元信息实现接口 +func (d *DbConn) GetMeta() DbMetadata { + switch d.Info.Type { + case DbTypeMysql: + return &MysqlMetadata{dc: d} + case DbTypePostgres: + return &PgsqlMetadata{dc: d} + default: + panic(fmt.Sprintf("invalid database type: %s", d.Info.Type)) + } +} + +// 关闭连接 +func (d *DbConn) Close() { + if d.db != nil { + if err := d.db.Close(); err != nil { + logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error()) + } + d.db = nil + } +} + +func selectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) { + // 列名用于前端表头名称按照数据库与查询字段顺序显示 + var colNames []string + result := make([]map[string]any, 0, 16) + err := walkTableRecord(db, selectSql, func(record map[string]any, columns []string) { + result = append(result, record) + if colNames == nil { + colNames = make([]string, len(columns)) + copy(colNames, columns) + } + }) + if err != nil { + return nil, nil, err + } + return colNames, result, nil +} + +func walkTableRecord(db *sql.DB, selectSql string, walk func(record map[string]any, columns []string)) error { + rows, err := db.Query(selectSql) + if err != nil { + return err + } + // rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数, + // 后面的链接过来直接报错或拒绝,实际上也没有起效果 + defer func() { + if rows != nil { + rows.Close() + } + }() + + colTypes, err := rows.ColumnTypes() + if err != nil { + return err + } + lenCols := len(colTypes) + // 列名用于前端表头名称按照数据库与查询字段顺序显示 + colNames := make([]string, lenCols) + // 这里表示一行填充数据 + scans := make([]any, lenCols) + // 这里表示一行所有列的值,用[]byte表示 + values := make([][]byte, lenCols) + for k, colType := range colTypes { + colNames[k] = colType.Name() + // 这里scans引用values,把数据填充到[]byte里 + scans[k] = &values[k] + } + + for rows.Next() { + // 不Scan也会导致等待,该链接实际处于未工作的状态,然后也会导致连接数迅速达到最大 + if err := rows.Scan(scans...); err != nil { + return err + } + // 每行数据 + rowData := make(map[string]any, lenCols) + // 把values中的数据复制到row中 + for i, v := range values { + rowData[colTypes[i].Name()] = valueConvert(v, colTypes[i]) + } + walk(rowData, colNames) + } + + return nil +} + +// 将查询的值转为对应列类型的实际值,不全部转为字符串 +func valueConvert(data []byte, colType *sql.ColumnType) any { + if data == nil { + return nil + } + // 列的数据库类型名 + colDatabaseTypeName := strings.ToLower(colType.DatabaseTypeName()) + + // 如果类型是bit,则直接返回第一个字节即可 + if strings.Contains(colDatabaseTypeName, "bit") { + return data[0] + } + + // 这里把[]byte数据转成string + stringV := string(data) + if stringV == "" { + return "" + } + colScanType := strings.ToLower(colType.ScanType().Name()) + + if strings.Contains(colScanType, "int") { + // 如果长度超过16位,则返回字符串,因为前端js长度大于16会丢失精度 + if len(stringV) > 16 { + return stringV + } + intV, _ := strconv.Atoi(stringV) + switch colType.ScanType().Kind() { + case reflect.Int8: + return int8(intV) + case reflect.Uint8: + return uint8(intV) + case reflect.Int64: + return int64(intV) + case reflect.Uint64: + return uint64(intV) + case reflect.Uint: + return uint(intV) + default: + return intV + } + } + if strings.Contains(colScanType, "float") || strings.Contains(colDatabaseTypeName, "decimal") { + floatV, _ := strconv.ParseFloat(stringV, 64) + return floatV + } + + return stringV +} + +// 查询数据结果映射至struct。可参考sqlx库 +func select2StructByDb(db *sql.DB, selectSql string, dest any) error { + rows, err := db.Query(selectSql) + if err != nil { + return err + } + // rows对象一定要close掉,如果出错,不关掉则会很迅速的达到设置最大连接数, + // 后面的链接过来直接报错或拒绝,实际上也没有起效果 + defer func() { + if rows != nil { + rows.Close() + } + }() + return scanAll(rows, dest, false) +} diff --git a/server/internal/db/dbm/conn_cache.go b/server/internal/db/dbm/conn_cache.go new file mode 100644 index 00000000..322a309d --- /dev/null +++ b/server/internal/db/dbm/conn_cache.go @@ -0,0 +1,73 @@ +package dbm + +import ( + "fmt" + "mayfly-go/internal/common/consts" + "mayfly-go/internal/machine/infrastructure/machine" + "mayfly-go/pkg/cache" + "mayfly-go/pkg/logx" + "sync" + "time" +) + +// 客户端连接缓存,指定时间内没有访问则会被关闭, key为数据库连接id +var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second). + WithUpdateAccessTime(true). + OnEvicted(func(key any, value any) { + logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key)) + value.(*DbConn).Close() + }) + +func init() { + machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool { + // 遍历所有db连接实例,若存在db实例使用该ssh隧道机器,则返回true,表示还在使用中... + items := connCache.Items() + for _, v := range items { + if v.Value.(*DbConn).Info.SshTunnelMachineId == machineId { + return true + } + } + return false + }) +} + +var mutex sync.Mutex + +// 从缓存中获取数据库连接信息,若缓存中不存在则会使用回调函数获取dbInfo进行连接并缓存 +func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) (*DbConn, error) { + connId := GetDbConnId(dbId, database) + + // connId不为空,则为需要缓存 + needCache := connId != "" + if needCache { + load, ok := connCache.Get(connId) + if ok { + return load.(*DbConn), nil + } + } + + mutex.Lock() + defer mutex.Unlock() + + // 若缓存中不存在,则从回调函数中获取DbInfo + dbInfo, err := getDbInfo() + if err != nil { + return nil, err + } + + // 连接数据库 + dbConn, err := dbInfo.Conn() + if err != nil { + return nil, err + } + + if needCache { + connCache.Put(connId, dbConn) + } + return dbConn, nil +} + +// 删除db缓存并关闭该数据库所有连接 +func CloseDb(dbId uint64, db string) { + connCache.Delete(GetDbConnId(dbId, db)) +} diff --git a/server/internal/db/domain/entity/db_type.go b/server/internal/db/dbm/db_type.go similarity index 99% rename from server/internal/db/domain/entity/db_type.go rename to server/internal/db/dbm/db_type.go index 1bbc9811..1e376f8c 100644 --- a/server/internal/db/domain/entity/db_type.go +++ b/server/internal/db/dbm/db_type.go @@ -1,10 +1,11 @@ -package entity +package dbm import ( "fmt" + "strings" + "github.com/kanzihuang/vitess/go/vt/sqlparser" "github.com/lib/pq" - "strings" ) type DbType string diff --git a/server/internal/db/domain/entity/db_type_test.go b/server/internal/db/dbm/db_type_test.go similarity index 98% rename from server/internal/db/domain/entity/db_type_test.go rename to server/internal/db/dbm/db_type_test.go index 11640da7..83cdf6b2 100644 --- a/server/internal/db/domain/entity/db_type_test.go +++ b/server/internal/db/dbm/db_type_test.go @@ -1,8 +1,9 @@ -package entity +package dbm import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func Test_QuoteLiteral(t *testing.T) { diff --git a/server/internal/db/dbm/info.go b/server/internal/db/dbm/info.go new file mode 100644 index 00000000..f6e7f799 --- /dev/null +++ b/server/internal/db/dbm/info.go @@ -0,0 +1,79 @@ +package dbm + +import ( + "database/sql" + "fmt" + "mayfly-go/pkg/errorx" + "mayfly-go/pkg/logx" +) + +type DbInfo struct { + Id uint64 + Name string + + Type DbType // 类型,mysql postgres等 + Host string + Port int + Network string + Username string + Password string + Params string + Database string + + TagPath string + SshTunnelMachineId int +} + +// 获取记录日志的描述 +func (d *DbInfo) GetLogDesc() string { + return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database) +} + +// 连接数据库 +func (dbInfo *DbInfo) Conn() (*DbConn, error) { + var conn *sql.DB + var err error + database := dbInfo.Database + + switch dbInfo.Type { + case DbTypeMysql: + conn, err = getMysqlDB(dbInfo) + case DbTypePostgres: + conn, err = getPgsqlDB(dbInfo) + default: + return nil, errorx.NewBiz("invalid database type: %s", dbInfo.Type) + } + + if err != nil { + logx.Errorf("连接db失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, database) + return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error())) + } + + err = conn.Ping() + if err != nil { + logx.Errorf("db ping失败: %s:%d/%s", dbInfo.Host, dbInfo.Port, database) + return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error())) + } + + dbc := &DbConn{Id: GetDbConnId(dbInfo.Id, database), Info: dbInfo} + + // 最大连接周期,超过时间的连接就close + // conn.SetConnMaxLifetime(100 * time.Second) + // 设置最大连接数 + conn.SetMaxOpenConns(5) + // 设置闲置连接数 + conn.SetMaxIdleConns(1) + dbc.db = conn + logx.Infof("连接db: %s:%d/%s", dbInfo.Host, dbInfo.Port, database) + + return dbc, nil +} + +// 获取连接id +func GetDbConnId(dbId uint64, db string) string { + if dbId == 0 { + return "" + } + + return fmt.Sprintf("%d:%s", dbId, db) +} diff --git a/server/internal/db/application/meta.go b/server/internal/db/dbm/meta.go similarity index 99% rename from server/internal/db/application/meta.go rename to server/internal/db/dbm/meta.go index 943e2b0c..7e53ec20 100644 --- a/server/internal/db/application/meta.go +++ b/server/internal/db/dbm/meta.go @@ -1,4 +1,4 @@ -package application +package dbm import ( "embed" diff --git a/server/internal/db/application/metasql/mysql_meta.sql b/server/internal/db/dbm/metasql/mysql_meta.sql similarity index 100% rename from server/internal/db/application/metasql/mysql_meta.sql rename to server/internal/db/dbm/metasql/mysql_meta.sql diff --git a/server/internal/db/application/metasql/pgsql_meta.sql b/server/internal/db/dbm/metasql/pgsql_meta.sql similarity index 100% rename from server/internal/db/application/metasql/pgsql_meta.sql rename to server/internal/db/dbm/metasql/pgsql_meta.sql diff --git a/server/internal/db/application/mysql_meta.go b/server/internal/db/dbm/mysql_meta.go similarity index 88% rename from server/internal/db/application/mysql_meta.go rename to server/internal/db/dbm/mysql_meta.go index 311d97d1..5ade1dc4 100644 --- a/server/internal/db/application/mysql_meta.go +++ b/server/internal/db/dbm/mysql_meta.go @@ -1,10 +1,9 @@ -package application +package dbm import ( "context" "database/sql" "fmt" - "mayfly-go/internal/db/domain/entity" machineapp "mayfly-go/internal/machine/application" "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" @@ -13,7 +12,7 @@ import ( "github.com/go-sql-driver/mysql" ) -func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) { +func getMysqlDB(d *DbInfo) (*sql.DB, error) { // SSH Conect if d.SshTunnelMachineId > 0 { sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId) @@ -25,7 +24,7 @@ func getMysqlDB(d *entity.Instance, db string) (*sql.DB, error) { }) } // 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name - dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db) + dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database) if d.Params != "" { dsn = fmt.Sprintf("%s&%s", dsn, d.Params) } @@ -42,12 +41,12 @@ const ( ) type MysqlMetadata struct { - di *DbConnection + dc *DbConn } // 获取表基础元信息, 如表名等 func (mm *MysqlMetadata) GetTables() ([]Table, error) { - _, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_MA_KEY)) + _, res, err := mm.dc.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_MA_KEY)) if err != nil { return nil, err } @@ -72,7 +71,7 @@ func (mm *MysqlMetadata) GetColumns(tableNames ...string) ([]Column, error) { tableName = tableName + "'" + tableNames[i] + "'" } - _, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName)) + _, res, err := mm.dc.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName)) if err != nil { return nil, err } @@ -113,7 +112,7 @@ func (mm *MysqlMetadata) GetPrimaryKey(tablename string) (string, error) { // 获取表信息,比GetTableMetedatas获取更详细的表信息 func (mm *MysqlMetadata) GetTableInfos() ([]Table, error) { - _, res, err := mm.di.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY)) + _, res, err := mm.dc.SelectData(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY)) if err != nil { return nil, err } @@ -134,7 +133,7 @@ func (mm *MysqlMetadata) GetTableInfos() ([]Table, error) { // 获取表索引信息 func (mm *MysqlMetadata) GetTableIndex(tableName string) ([]Index, error) { - _, res, err := mm.di.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)) + _, res, err := mm.dc.SelectData(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err } @@ -171,7 +170,7 @@ func (mm *MysqlMetadata) GetTableIndex(tableName string) ([]Index, error) { // 获取建表ddl func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) (string, error) { - _, res, err := mm.di.SelectData(fmt.Sprintf("show create table `%s` ", tableName)) + _, res, err := mm.dc.SelectData(fmt.Sprintf("show create table `%s` ", tableName)) if err != nil { return "", err } @@ -179,9 +178,9 @@ func (mm *MysqlMetadata) GetCreateTableDdl(tableName string) (string, error) { } func (mm *MysqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) { - return mm.di.SelectData(fmt.Sprintf("SELECT * FROM %s LIMIT %d, %d", tableName, (pageNum-1)*pageSize, pageSize)) + return mm.dc.SelectData(fmt.Sprintf("SELECT * FROM %s LIMIT %d, %d", tableName, (pageNum-1)*pageSize, pageSize)) } func (mm *MysqlMetadata) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error { - return mm.di.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk) + return mm.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk) } diff --git a/server/internal/db/application/pgsql_meta.go b/server/internal/db/dbm/pgsql_meta.go similarity index 88% rename from server/internal/db/application/pgsql_meta.go rename to server/internal/db/dbm/pgsql_meta.go index 1710bfc9..7b5366f2 100644 --- a/server/internal/db/application/pgsql_meta.go +++ b/server/internal/db/dbm/pgsql_meta.go @@ -1,10 +1,9 @@ -package application +package dbm import ( "database/sql" "database/sql/driver" "fmt" - "mayfly-go/internal/db/domain/entity" machineapp "mayfly-go/internal/machine/application" "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" @@ -16,7 +15,7 @@ import ( "github.com/lib/pq" ) -func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) { +func getPgsqlDB(d *DbInfo) (*sql.DB, error) { driverName := string(d.Type) // SSH Conect if d.SshTunnelMachineId > 0 { @@ -28,6 +27,7 @@ func getPgsqlDB(d *entity.Instance, db string) (*sql.DB, error) { sql.Drivers() } + db := d.Database var dbParam string if db != "" { dbParam = "dbname=" + db @@ -75,12 +75,12 @@ const ( ) type PgsqlMetadata struct { - di *DbConnection + dc *DbConn } // 获取表基础元信息, 如表名等 func (pm *PgsqlMetadata) GetTables() ([]Table, error) { - _, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_MA_KEY)) + _, res, err := pm.dc.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_MA_KEY)) if err != nil { return nil, err } @@ -105,7 +105,7 @@ func (pm *PgsqlMetadata) GetColumns(tableNames ...string) ([]Column, error) { tableName = tableName + "'" + tableNames[i] + "'" } - _, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) + _, res, err := pm.dc.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) if err != nil { return nil, err } @@ -144,7 +144,7 @@ func (pm *PgsqlMetadata) GetPrimaryKey(tablename string) (string, error) { // 获取表信息,比GetTables获取更详细的表信息 func (pm *PgsqlMetadata) GetTableInfos() ([]Table, error) { - _, res, err := pm.di.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY)) + _, res, err := pm.dc.SelectData(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY)) if err != nil { return nil, err } @@ -165,7 +165,7 @@ func (pm *PgsqlMetadata) GetTableInfos() ([]Table, error) { // 获取表索引信息 func (pm *PgsqlMetadata) GetTableIndex(tableName string) ([]Index, error) { - _, res, err := pm.di.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) + _, res, err := pm.dc.SelectData(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err } @@ -186,16 +186,16 @@ func (pm *PgsqlMetadata) GetTableIndex(tableName string) ([]Index, error) { // 获取建表ddl func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) (string, error) { - _, err := pm.di.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY)) + _, err := pm.dc.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY)) if err != nil { return "", err } - _, schemaRes, _ := pm.di.SelectData("select current_schema() as schema") + _, schemaRes, _ := pm.dc.SelectData("select current_schema() as schema") schemaName := schemaRes[0]["schema"].(string) ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName) - _, res, err := pm.di.SelectData(ddlSql) + _, res, err := pm.dc.SelectData(ddlSql) if err != nil { return "", err } @@ -204,9 +204,9 @@ func (pm *PgsqlMetadata) GetCreateTableDdl(tableName string) (string, error) { } func (pm *PgsqlMetadata) GetTableRecord(tableName string, pageNum, pageSize int) ([]string, []map[string]any, error) { - return pm.di.SelectData(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize)) + return pm.dc.SelectData(fmt.Sprintf("SELECT * FROM %s OFFSET %d LIMIT %d", tableName, (pageNum-1)*pageSize, pageSize)) } func (pm *PgsqlMetadata) WalkTableRecord(tableName string, walk func(record map[string]any, columns []string)) error { - return pm.di.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk) + return pm.dc.WalkTableRecord(fmt.Sprintf("SELECT * FROM %s", tableName), walk) } diff --git a/server/internal/db/application/sqlx.go b/server/internal/db/dbm/sqlx.go similarity index 99% rename from server/internal/db/application/sqlx.go rename to server/internal/db/dbm/sqlx.go index ead78f31..66e2154f 100644 --- a/server/internal/db/application/sqlx.go +++ b/server/internal/db/dbm/sqlx.go @@ -1,4 +1,4 @@ -package application +package dbm import ( "database/sql" diff --git a/server/internal/db/domain/entity/db_instance.go b/server/internal/db/domain/entity/db_instance.go new file mode 100644 index 00000000..2f1e03ec --- /dev/null +++ b/server/internal/db/domain/entity/db_instance.go @@ -0,0 +1,50 @@ +package entity + +import ( + "fmt" + "mayfly-go/internal/common/utils" + "mayfly-go/internal/db/dbm" + "mayfly-go/pkg/model" +) + +type DbInstance struct { + model.Model + + Name string `orm:"column(name)" json:"name"` + Type dbm.DbType `orm:"column(type)" json:"type"` // 类型,mysql oracle等 + Host string `orm:"column(host)" json:"host"` + Port int `orm:"column(port)" json:"port"` + Network string `orm:"column(network)" json:"network"` + Username string `orm:"column(username)" json:"username"` + Password string `orm:"column(password)" json:"-"` + Params string `orm:"column(params)" json:"params"` + Remark string `orm:"column(remark)" json:"remark"` + SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id +} + +func (d *DbInstance) TableName() string { + return "t_db_instance" +} + +// 获取数据库连接网络, 若没有使用ssh隧道,则直接返回。否则返回拼接的网络需要注册至指定dial +func (d *DbInstance) GetNetwork() string { + network := d.Network + if d.SshTunnelMachineId <= 0 { + if network == "" { + return "tcp" + } else { + return network + } + } + return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId) +} + +func (d *DbInstance) PwdEncrypt() { + // 密码替换为加密后的密码 + d.Password = utils.PwdAesEncrypt(d.Password) +} + +func (d *DbInstance) PwdDecrypt() { + // 密码替换为解密后的密码 + d.Password = utils.PwdAesDecrypt(d.Password) +} diff --git a/server/internal/db/domain/entity/instance.go b/server/internal/db/domain/entity/instance.go deleted file mode 100644 index ff7c353c..00000000 --- a/server/internal/db/domain/entity/instance.go +++ /dev/null @@ -1,49 +0,0 @@ -package entity - -import ( - "fmt" - "mayfly-go/internal/common/utils" - "mayfly-go/pkg/model" -) - -type Instance struct { - model.Model - - Name string `orm:"column(name)" json:"name"` - Type DbType `orm:"column(type)" json:"type"` // 类型,mysql oracle等 - Host string `orm:"column(host)" json:"host"` - Port int `orm:"column(port)" json:"port"` - Network string `orm:"column(network)" json:"network"` - Username string `orm:"column(username)" json:"username"` - Password string `orm:"column(password)" json:"-"` - Params string `orm:"column(params)" json:"params"` - Remark string `orm:"column(remark)" json:"remark"` - SshTunnelMachineId int `orm:"column(ssh_tunnel_machine_id)" json:"sshTunnelMachineId"` // ssh隧道机器id -} - -func (d *Instance) TableName() string { - return "t_db_instance" -} - -// 获取数据库连接网络, 若没有使用ssh隧道,则直接返回。否则返回拼接的网络需要注册至指定dial -func (d *Instance) GetNetwork() string { - network := d.Network - if d.SshTunnelMachineId <= 0 { - if network == "" { - return "tcp" - } else { - return network - } - } - return fmt.Sprintf("%s+ssh:%d", d.Type, d.SshTunnelMachineId) -} - -func (d *Instance) PwdEncrypt() { - // 密码替换为加密后的密码 - d.Password = utils.PwdAesEncrypt(d.Password) -} - -func (d *Instance) PwdDecrypt() { - // 密码替换为解密后的密码 - d.Password = utils.PwdAesDecrypt(d.Password) -} diff --git a/server/internal/db/domain/repository/instance.go b/server/internal/db/domain/repository/instance.go index 62698eaa..be9dcbdd 100644 --- a/server/internal/db/domain/repository/instance.go +++ b/server/internal/db/domain/repository/instance.go @@ -7,7 +7,7 @@ import ( ) type Instance interface { - base.Repo[*entity.Instance] + base.Repo[*entity.DbInstance] // 分页获取数据库实例信息列表 GetInstanceList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) diff --git a/server/internal/db/infrastructure/persistence/instance.go b/server/internal/db/infrastructure/persistence/instance.go index f4754bd1..cd9e1d67 100644 --- a/server/internal/db/infrastructure/persistence/instance.go +++ b/server/internal/db/infrastructure/persistence/instance.go @@ -9,16 +9,16 @@ import ( ) type instanceRepoImpl struct { - base.RepoImpl[*entity.Instance] + base.RepoImpl[*entity.DbInstance] } func newInstanceRepo() repository.Instance { - return &instanceRepoImpl{base.RepoImpl[*entity.Instance]{M: new(entity.Instance)}} + return &instanceRepoImpl{base.RepoImpl[*entity.DbInstance]{M: new(entity.DbInstance)}} } // 分页获取数据库信息列表 func (d *instanceRepoImpl) GetInstanceList(condition *entity.InstanceQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { - qd := gormx.NewQuery(new(entity.Instance)). + qd := gormx.NewQuery(new(entity.DbInstance)). Eq("id", condition.Id). Eq("host", condition.Host). Like("name", condition.Name) diff --git a/server/internal/mongo/api/mongo.go b/server/internal/mongo/api/mongo.go index b54669ef..ea6cee0d 100644 --- a/server/internal/mongo/api/mongo.go +++ b/server/internal/mongo/api/mongo.go @@ -74,18 +74,20 @@ func (m *Mongo) DeleteMongo(rc *req.Ctx) { } func (m *Mongo) Databases(rc *req.Ctx) { - cli := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx)).Cli - res, err := cli.ListDatabases(context.TODO(), bson.D{}) + conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx)) + biz.ErrIsNil(err) + res, err := conn.Cli.ListDatabases(context.TODO(), bson.D{}) biz.ErrIsNilAppendErr(err, "获取mongo所有库信息失败: %s") rc.ResData = res } func (m *Mongo) Collections(rc *req.Ctx) { - cli := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx)).Cli + conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx)) + biz.ErrIsNil(err) db := rc.GinCtx.Query("database") biz.NotEmpty(db, "database不能为空") ctx := context.TODO() - res, err := cli.Database(db).ListCollectionNames(ctx, bson.D{}) + res, err := conn.Cli.Database(db).ListCollectionNames(ctx, bson.D{}) biz.ErrIsNilAppendErr(err, "获取库集合信息失败: %s") rc.ResData = res } @@ -94,8 +96,9 @@ func (m *Mongo) RunCommand(rc *req.Ctx) { commandForm := new(form.MongoRunCommand) ginx.BindJsonAndValid(rc.GinCtx, commandForm) - inst := m.MongoApp.GetMongoInst(m.GetMongoId(rc.GinCtx)) - rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm) + conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(rc.GinCtx)) + biz.ErrIsNil(err) + rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) // 顺序执行 commands := bson.D{} @@ -110,7 +113,7 @@ func (m *Mongo) RunCommand(rc *req.Ctx) { ctx := context.TODO() var bm bson.M - err := inst.Cli.Database(commandForm.Database).RunCommand( + err = conn.Cli.Database(commandForm.Database).RunCommand( ctx, commands, ).Decode(&bm) @@ -121,10 +124,13 @@ func (m *Mongo) RunCommand(rc *req.Ctx) { func (m *Mongo) FindCommand(rc *req.Ctx) { g := rc.GinCtx - cli := m.MongoApp.GetMongoInst(m.GetMongoId(g)).Cli commandForm := new(form.MongoFindCommand) ginx.BindJsonAndValid(g, commandForm) + conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g)) + biz.ErrIsNil(err) + cli := conn.Cli + limit := commandForm.Limit if limit != 0 { biz.IsTrue(limit <= 100, "limit不能超过100") @@ -157,8 +163,9 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) { commandForm := new(form.MongoUpdateByIdCommand) ginx.BindJsonAndValid(g, commandForm) - inst := m.MongoApp.GetMongoInst(m.GetMongoId(g)) - rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm) + conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g)) + biz.ErrIsNil(err) + rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) // 解析docId文档id,如果为string类型则使用ObjectId解析,解析失败则为普通字符串 docId := commandForm.DocId @@ -170,7 +177,7 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) { } } - res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).UpdateByID(context.TODO(), docId, commandForm.Update) + res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).UpdateByID(context.TODO(), docId, commandForm.Update) biz.ErrIsNilAppendErr(err, "命令执行失败: %s") rc.ResData = res @@ -181,8 +188,9 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) { commandForm := new(form.MongoUpdateByIdCommand) ginx.BindJsonAndValid(g, commandForm) - inst := m.MongoApp.GetMongoInst(m.GetMongoId(g)) - rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm) + conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g)) + biz.ErrIsNil(err) + rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) // 解析docId文档id,如果为string类型则使用ObjectId解析,解析失败则为普通字符串 docId := commandForm.DocId @@ -194,7 +202,7 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) { } } - res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).DeleteOne(context.TODO(), bson.D{{Key: "_id", Value: docId}}) + res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).DeleteOne(context.TODO(), bson.D{{Key: "_id", Value: docId}}) biz.ErrIsNilAppendErr(err, "命令执行失败: %s") rc.ResData = res } @@ -204,10 +212,11 @@ func (m *Mongo) InsertOneCommand(rc *req.Ctx) { commandForm := new(form.MongoInsertCommand) ginx.BindJsonAndValid(g, commandForm) - inst := m.MongoApp.GetMongoInst(m.GetMongoId(g)) - rc.ReqParam = collx.Kvs("mongo", inst.Info, "cmd", commandForm) + conn, err := m.MongoApp.GetMongoConn(m.GetMongoId(g)) + biz.ErrIsNil(err) + rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) - res, err := inst.Cli.Database(commandForm.Database).Collection(commandForm.Collection).InsertOne(context.TODO(), commandForm.Doc) + res, err := conn.Cli.Database(commandForm.Database).Collection(commandForm.Collection).InsertOne(context.TODO(), commandForm.Doc) biz.ErrIsNilAppendErr(err, "命令执行失败: %s") rc.ResData = res } diff --git a/server/internal/mongo/application/mongo.go b/server/internal/mongo/application/mongo.go index ec0bf896..11dbcc19 100644 --- a/server/internal/mongo/application/mongo.go +++ b/server/internal/mongo/application/mongo.go @@ -1,26 +1,12 @@ package application import ( - "context" - "fmt" - "mayfly-go/internal/common/consts" - machineapp "mayfly-go/internal/machine/application" - "mayfly-go/internal/machine/infrastructure/machine" "mayfly-go/internal/mongo/domain/entity" "mayfly-go/internal/mongo/domain/repository" + "mayfly-go/internal/mongo/mgm" "mayfly-go/pkg/base" - "mayfly-go/pkg/biz" - "mayfly-go/pkg/cache" - "mayfly-go/pkg/logx" + "mayfly-go/pkg/errorx" "mayfly-go/pkg/model" - "mayfly-go/pkg/utils/netx" - "mayfly-go/pkg/utils/structx" - "net" - "regexp" - "time" - - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" ) type Mongo interface { @@ -38,7 +24,7 @@ type Mongo interface { // 获取mongo连接实例 // @param id mongo id - GetMongoInst(id uint64) *MongoInstance + GetMongoConn(id uint64) (*mgm.MongoConn, error) } func newMongoAppImpl(mongoRepo repository.Mongo) Mongo { @@ -61,7 +47,7 @@ func (d *mongoAppImpl) Count(condition *entity.MongoQuery) int64 { } func (d *mongoAppImpl) Delete(id uint64) error { - DeleteMongoCache(id) + mgm.CloseConn(id) return d.GetRepo().DeleteById(id) } @@ -71,148 +57,16 @@ func (d *mongoAppImpl) Save(m *entity.Mongo) error { } // 先关闭连接 - DeleteMongoCache(m.Id) + mgm.CloseConn(m.Id) return d.GetRepo().UpdateById(m) } -func (d *mongoAppImpl) GetMongoInst(id uint64) *MongoInstance { - mongoInstance, err := GetMongoInstance(id, func(u uint64) (*entity.Mongo, error) { - mongo, err := d.GetById(new(entity.Mongo), u) +func (d *mongoAppImpl) GetMongoConn(id uint64) (*mgm.MongoConn, error) { + return mgm.GetMongoConn(id, func() (*mgm.MongoInfo, error) { + mongo, err := d.GetById(new(entity.Mongo), id) if err != nil { - return nil, err + return nil, errorx.NewBiz("mongo信息不存在") } - return mongo, nil - }) - biz.ErrIsNilAppendErr(err, "连接mongo失败: %s") - return mongoInstance -} - -// ----------------------------------------------------------- - -// mongo客户端连接缓存,指定时间内没有访问则会被关闭 -var mongoCliCache = cache.NewTimedCache(consts.MongoConnExpireTime, 5*time.Second). - WithUpdateAccessTime(true). - OnEvicted(func(key any, value any) { - logx.Infof("删除mongo连接缓存: id = %v", key) - value.(*MongoInstance).Close() - }) - -func init() { - machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool { - // 遍历所有mongo连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中... - items := mongoCliCache.Items() - for _, v := range items { - if v.Value.(*MongoInstance).Info.SshTunnelMachineId == machineId { - return true - } - } - return false + return mongo.ToMongoInfo(), nil }) } - -// 获取mongo的连接实例 -func GetMongoInstance(mongoId uint64, getMongoEntity func(uint64) (*entity.Mongo, error)) (*MongoInstance, error) { - mi, err := mongoCliCache.ComputeIfAbsent(mongoId, func(_ any) (any, error) { - mongoEntity, err := getMongoEntity(mongoId) - if err != nil { - return nil, err - } - - c, err := connect(mongoEntity) - if err != nil { - return nil, err - } - return c, nil - }) - - if mi != nil { - return mi.(*MongoInstance), err - } - return nil, err -} - -func DeleteMongoCache(mongoId uint64) { - mongoCliCache.Delete(mongoId) -} - -type MongoInfo struct { - Id uint64 `json:"id"` - Name string `json:"name"` - TagPath string `json:"tagPath"` - SshTunnelMachineId int `json:"-"` // ssh隧道机器id -} - -func (m *MongoInfo) GetLogDesc() string { - return fmt.Sprintf("Mongo[id=%d, tag=%s, name=%s]", m.Id, m.TagPath, m.Name) -} - -type MongoInstance struct { - Id uint64 - Info *MongoInfo - - Cli *mongo.Client -} - -func (mi *MongoInstance) Close() { - if mi.Cli != nil { - if err := mi.Cli.Disconnect(context.Background()); err != nil { - logx.Errorf("关闭mongo实例[%d]连接失败: %s", mi.Id, err) - } - mi.Cli = nil - } -} - -// 连接mongo,并返回client -func connect(me *entity.Mongo) (*MongoInstance, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - mongoInstance := &MongoInstance{Id: me.Id, Info: toMongoInfo(me)} - - mongoOptions := options.Client().ApplyURI(me.Uri). - SetMaxPoolSize(1) - // 启用ssh隧道则连接隧道机器 - if me.SshTunnelMachineId > 0 { - mongoOptions.SetDialer(&MongoSshDialer{machineId: me.SshTunnelMachineId}) - } - - client, err := mongo.Connect(ctx, mongoOptions) - if err != nil { - mongoInstance.Close() - return nil, err - } - if err = client.Ping(context.TODO(), nil); err != nil { - mongoInstance.Close() - return nil, err - } - - logx.Infof("连接mongo: %s", func(str string) string { - reg := regexp.MustCompile(`(^mongodb://.+?:)(.+)(@.+$)`) - return reg.ReplaceAllString(str, `${1}****${3}`) - }(me.Uri)) - mongoInstance.Cli = client - return mongoInstance, err -} - -func toMongoInfo(me *entity.Mongo) *MongoInfo { - mi := new(MongoInfo) - structx.Copy(mi, me) - return mi -} - -type MongoSshDialer struct { - machineId int -} - -func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId) - if err != nil { - return nil, err - } - if sshConn, err := stm.GetDialConn(network, address); err == nil { - // 将ssh conn包装,否则内部部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported - return &netx.WrapSshConn{Conn: sshConn}, nil - } else { - return nil, err - } -} diff --git a/server/internal/mongo/domain/entity/mongo.go b/server/internal/mongo/domain/entity/mongo.go index 9d639ae3..332bd174 100644 --- a/server/internal/mongo/domain/entity/mongo.go +++ b/server/internal/mongo/domain/entity/mongo.go @@ -1,6 +1,10 @@ package entity -import "mayfly-go/pkg/model" +import ( + "mayfly-go/internal/mongo/mgm" + "mayfly-go/pkg/model" + "mayfly-go/pkg/utils/structx" +) type Mongo struct { model.Model @@ -11,3 +15,10 @@ type Mongo struct { TagId uint64 `json:"tagId"` TagPath string `json:"tagPath"` } + +// 转换为mongoInfo进行连接 +func (me *Mongo) ToMongoInfo() *mgm.MongoInfo { + mongoInfo := new(mgm.MongoInfo) + structx.Copy(mongoInfo, me) + return mongoInfo +} diff --git a/server/internal/mongo/mgm/conn.go b/server/internal/mongo/mgm/conn.go new file mode 100644 index 00000000..b3d4619d --- /dev/null +++ b/server/internal/mongo/mgm/conn.go @@ -0,0 +1,24 @@ +package mgm + +import ( + "context" + "mayfly-go/pkg/logx" + + "go.mongodb.org/mongo-driver/mongo" +) + +type MongoConn struct { + Id string + Info *MongoInfo + + Cli *mongo.Client +} + +func (mc *MongoConn) Close() { + if mc.Cli != nil { + if err := mc.Cli.Disconnect(context.Background()); err != nil { + logx.Errorf("关闭mongo实例[%s]连接失败: %s", mc.Id, err) + } + mc.Cli = nil + } +} diff --git a/server/internal/mongo/mgm/conn_cache.go b/server/internal/mongo/mgm/conn_cache.go new file mode 100644 index 00000000..5e1bb682 --- /dev/null +++ b/server/internal/mongo/mgm/conn_cache.go @@ -0,0 +1,72 @@ +package mgm + +import ( + "mayfly-go/internal/common/consts" + "mayfly-go/internal/machine/infrastructure/machine" + "mayfly-go/pkg/cache" + "mayfly-go/pkg/logx" + "sync" + "time" +) + +// mongo客户端连接缓存,指定时间内没有访问则会被关闭 +var connCache = cache.NewTimedCache(consts.MongoConnExpireTime, 5*time.Second). + WithUpdateAccessTime(true). + OnEvicted(func(key any, value any) { + logx.Infof("删除mongo连接缓存: id = %v", key) + value.(*MongoConn).Close() + }) + +func init() { + machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool { + // 遍历所有mongo连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中... + items := connCache.Items() + for _, v := range items { + if v.Value.(*MongoConn).Info.SshTunnelMachineId == machineId { + return true + } + } + return false + }) +} + +var mutex sync.Mutex + +// 从缓存中获取mongo连接信息, 若缓存中不存在则会使用回调函数获取mongoInfo进行连接并缓存 +func GetMongoConn(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (*MongoConn, error) { + connId := getConnId(mongoId) + + // connId不为空,则为需要缓存 + needCache := connId != "" + if needCache { + load, ok := connCache.Get(connId) + if ok { + return load.(*MongoConn), nil + } + } + + mutex.Lock() + defer mutex.Unlock() + + // 若缓存中不存在,则从回调函数中获取MongoInfo + mi, err := getMongoInfo() + if err != nil { + return nil, err + } + + // 连接mongo + mc, err := mi.Conn() + if err != nil { + return nil, err + } + + if needCache { + connCache.Put(connId, mc) + } + return mc, nil +} + +// 关闭连接,并移除缓存连接 +func CloseConn(mongoId uint64) { + connCache.Delete(mongoId) +} diff --git a/server/internal/mongo/mgm/info.go b/server/internal/mongo/mgm/info.go new file mode 100644 index 00000000..31e8b7f4 --- /dev/null +++ b/server/internal/mongo/mgm/info.go @@ -0,0 +1,79 @@ +package mgm + +import ( + "context" + "fmt" + "mayfly-go/pkg/logx" + "mayfly-go/pkg/utils/netx" + "net" + "regexp" + "time" + + machineapp "mayfly-go/internal/machine/application" + + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +type MongoInfo struct { + Id uint64 `json:"id"` + Name string `json:"name"` + + Uri string `json:"-"` + + TagPath string `json:"tagPath"` + SshTunnelMachineId int `json:"-"` // ssh隧道机器id +} + +func (mi *MongoInfo) Conn() (*MongoConn, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + mongoOptions := options.Client().ApplyURI(mi.Uri). + SetMaxPoolSize(1) + // 启用ssh隧道则连接隧道机器 + if mi.SshTunnelMachineId > 0 { + mongoOptions.SetDialer(&MongoSshDialer{machineId: mi.SshTunnelMachineId}) + } + + client, err := mongo.Connect(ctx, mongoOptions) + if err != nil { + return nil, err + } + if err = client.Ping(context.TODO(), nil); err != nil { + client.Disconnect(ctx) + return nil, err + } + + logx.Infof("连接mongo: %s", func(str string) string { + reg := regexp.MustCompile(`(^mongodb://.+?:)(.+)(@.+$)`) + return reg.ReplaceAllString(str, `${1}****${3}`) + }(mi.Uri)) + + return &MongoConn{Id: getConnId(mi.Id), Info: mi, Cli: client}, nil +} + +type MongoSshDialer struct { + machineId int +} + +func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId) + if err != nil { + return nil, err + } + if sshConn, err := stm.GetDialConn(network, address); err == nil { + // 将ssh conn包装,否则内部部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported + return &netx.WrapSshConn{Conn: sshConn}, nil + } else { + return nil, err + } +} + +// 生成mongo连接id +func getConnId(id uint64) string { + if id == 0 { + return "" + } + return fmt.Sprintf("%d", id) +} diff --git a/server/internal/redis/api/hash.go b/server/internal/redis/api/hash.go index e79f9e62..dc155a73 100644 --- a/server/internal/redis/api/hash.go +++ b/server/internal/redis/api/hash.go @@ -11,7 +11,7 @@ import ( ) func (r *Redis) Hscan(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) g := rc.GinCtx count := ginx.QueryInt(g, "count", 10) match := g.Query("match") @@ -32,7 +32,7 @@ func (r *Redis) Hscan(rc *req.Ctx) { } func (r *Redis) Hdel(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) field := rc.GinCtx.Query("field") rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key, "field", field) @@ -42,7 +42,7 @@ func (r *Redis) Hdel(rc *req.Ctx) { } func (r *Redis) Hget(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) field := rc.GinCtx.Query("field") res, err := ri.GetCmdable().HGet(context.TODO(), key, field).Result() @@ -56,7 +56,7 @@ func (r *Redis) Hset(rc *req.Ctx) { ginx.BindJsonAndValid(g, hashValue) hv := hashValue.Value[0] - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) rc.ReqParam = collx.Kvs("redis", ri.Info, "hash", hv) res, err := ri.GetCmdable().HSet(context.TODO(), hashValue.Key, hv["field"].(string), hv["value"]).Result() @@ -69,7 +69,7 @@ func (r *Redis) SetHashValue(rc *req.Ctx) { hashValue := new(form.HashValue) ginx.BindJsonAndValid(g, hashValue) - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) rc.ReqParam = collx.Kvs("redis", ri.Info, "hash", hashValue) cmd := ri.GetCmdable() diff --git a/server/internal/redis/api/key.go b/server/internal/redis/api/key.go index 17022abb..55460df4 100644 --- a/server/internal/redis/api/key.go +++ b/server/internal/redis/api/key.go @@ -4,7 +4,7 @@ import ( "context" "mayfly-go/internal/redis/api/form" "mayfly-go/internal/redis/api/vo" - "mayfly-go/internal/redis/domain/entity" + "mayfly-go/internal/redis/rdm" "mayfly-go/pkg/biz" "mayfly-go/pkg/ginx" "mayfly-go/pkg/req" @@ -18,7 +18,7 @@ import ( // scan获取redis的key列表信息 func (r *Redis) ScanKeys(rc *req.Ctx) { - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) form := &form.RedisScanForm{} ginx.BindJsonAndValid(rc.GinCtx, form) @@ -44,7 +44,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) { // 通配符或全匹配 mode := ri.Info.Mode - if mode == "" || mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel { + if mode == "" || mode == rdm.StandaloneMode || mode == rdm.SentinelMode { redisAddr := ri.Cli.Options().Addr cursorRes[redisAddr] = form.Cursor[redisAddr] for { @@ -64,7 +64,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) { break } } - } else if mode == entity.RedisModeCluster { + } else if mode == rdm.ClusterMode { mu := &sync.Mutex{} // 遍历所有master节点,并执行scan命令,合并keys ri.ClusterCli.ForEachMaster(ctx, func(ctx context.Context, client *redis.Client) error { @@ -98,7 +98,7 @@ func (r *Redis) ScanKeys(rc *req.Ctx) { } func (r *Redis) KeyInfo(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) cmd := ri.GetCmdable() ctx := context.Background() ttl, err := cmd.TTL(ctx, key).Result() @@ -120,7 +120,7 @@ func (r *Redis) KeyInfo(rc *req.Ctx) { } func (r *Redis) TtlKey(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) ttl, err := ri.GetCmdable().TTL(context.Background(), key).Result() biz.ErrIsNilAppendErr(err, "ttl失败: %s") @@ -132,7 +132,7 @@ func (r *Redis) TtlKey(rc *req.Ctx) { } func (r *Redis) DeleteKey(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key) ri.GetCmdable().Del(context.Background(), key) } @@ -141,7 +141,7 @@ func (r *Redis) RenameKey(rc *req.Ctx) { form := &form.Rename{} ginx.BindJsonAndValid(rc.GinCtx, form) - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) rc.ReqParam = collx.Kvs("redis", ri.Info, "rename", form) ri.GetCmdable().Rename(context.Background(), form.Key, form.NewKey) } @@ -150,21 +150,21 @@ func (r *Redis) ExpireKey(rc *req.Ctx) { form := &form.Expire{} ginx.BindJsonAndValid(rc.GinCtx, form) - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) rc.ReqParam = collx.Kvs("redis", ri.Info, "expire", form) ri.GetCmdable().Expire(context.Background(), form.Key, time.Duration(form.Seconds)*time.Second) } // 移除过期时间 func (r *Redis) PersistKey(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) rc.ReqParam = collx.Kvs("redis", ri.Info, "key", key) ri.GetCmdable().Persist(context.Background(), key) } // 清空库 func (r *Redis) FlushDb(rc *req.Ctx) { - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) rc.ReqParam = ri.Info ri.GetCmdable().FlushDB(context.Background()) } diff --git a/server/internal/redis/api/list.go b/server/internal/redis/api/list.go index 5e5f2064..626a3726 100644 --- a/server/internal/redis/api/list.go +++ b/server/internal/redis/api/list.go @@ -10,7 +10,7 @@ import ( ) func (r *Redis) GetListValue(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) ctx := context.TODO() cmdable := ri.GetCmdable() @@ -34,7 +34,7 @@ func (r *Redis) Lrem(rc *req.Ctx) { option := new(form.LRemOption) ginx.BindJsonAndValid(g, option) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() res, err := cmd.LRem(context.TODO(), option.Key, int64(option.Count), option.Member).Result() biz.ErrIsNilAppendErr(err, "lrem失败: %s") rc.ResData = res @@ -45,7 +45,7 @@ func (r *Redis) SaveListValue(rc *req.Ctx) { listValue := new(form.ListValue) ginx.BindJsonAndValid(g, listValue) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() key := listValue.Key ctx := context.TODO() @@ -59,7 +59,7 @@ func (r *Redis) SetListValue(rc *req.Ctx) { listSetValue := new(form.ListSetValue) ginx.BindJsonAndValid(g, listSetValue) - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) _, err := ri.GetCmdable().LSet(context.TODO(), listSetValue.Key, listSetValue.Index, listSetValue.Value).Result() biz.ErrIsNilAppendErr(err, "list set失败: %s") diff --git a/server/internal/redis/api/redis.go b/server/internal/redis/api/redis.go index 598593b1..abf10e2e 100644 --- a/server/internal/redis/api/redis.go +++ b/server/internal/redis/api/redis.go @@ -6,6 +6,7 @@ import ( "mayfly-go/internal/redis/api/vo" "mayfly-go/internal/redis/application" "mayfly-go/internal/redis/domain/entity" + "mayfly-go/internal/redis/rdm" tagapp "mayfly-go/internal/tag/application" "mayfly-go/pkg/biz" "mayfly-go/pkg/ginx" @@ -86,7 +87,7 @@ func (r *Redis) DeleteRedis(rc *req.Ctx) { func (r *Redis) RedisInfo(rc *req.Ctx) { g := rc.GinCtx - ri, err := r.RedisApp.GetRedisInstance(uint64(ginx.PathParamInt(g, "id")), 0) + ri, err := r.RedisApp.GetRedisConn(uint64(ginx.PathParamInt(g, "id")), 0) biz.ErrIsNil(err) section := rc.GinCtx.Query("section") @@ -94,9 +95,9 @@ func (r *Redis) RedisInfo(rc *req.Ctx) { ctx := context.Background() var redisCli *redis.Client - if mode == "" || mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel { + if mode == "" || mode == rdm.StandaloneMode || mode == rdm.SentinelMode { redisCli = ri.Cli - } else if mode == entity.RedisModeCluster { + } else if mode == rdm.ClusterMode { host := rc.GinCtx.Query("host") biz.NotEmpty(host, "集群模式host信息不能为空") clusterClient := ri.ClusterCli @@ -164,9 +165,9 @@ func (r *Redis) RedisInfo(rc *req.Ctx) { func (r *Redis) ClusterInfo(rc *req.Ctx) { g := rc.GinCtx - ri, err := r.RedisApp.GetRedisInstance(uint64(ginx.PathParamInt(g, "id")), 0) + ri, err := r.RedisApp.GetRedisConn(uint64(ginx.PathParamInt(g, "id")), 0) biz.ErrIsNil(err) - biz.IsEquals(ri.Info.Mode, entity.RedisModeCluster, "非集群模式") + biz.IsEquals(ri.Info.Mode, rdm.ClusterMode, "非集群模式") info, _ := ri.ClusterCli.ClusterInfo(context.Background()).Result() nodesStr, _ := ri.ClusterCli.ClusterNodes(context.Background()).Result() @@ -208,14 +209,14 @@ func (r *Redis) ClusterInfo(rc *req.Ctx) { } // 校验查询参数中的key为必填项,并返回redis实例 -func (r *Redis) checkKeyAndGetRedisIns(rc *req.Ctx) (*application.RedisInstance, string) { +func (r *Redis) checkKeyAndGetRedisConn(rc *req.Ctx) (*rdm.RedisConn, string) { key := rc.GinCtx.Query("key") biz.NotEmpty(key, "key不能为空") - return r.getRedisIns(rc), key + return r.getRedisConn(rc), key } -func (r *Redis) getRedisIns(rc *req.Ctx) *application.RedisInstance { - ri, err := r.RedisApp.GetRedisInstance(getIdAndDbNum(rc.GinCtx)) +func (r *Redis) getRedisConn(rc *req.Ctx) *rdm.RedisConn { + ri, err := r.RedisApp.GetRedisConn(getIdAndDbNum(rc.GinCtx)) biz.ErrIsNil(err) biz.ErrIsNilAppendErr(r.TagApp.CanAccess(rc.LoginAccount.Id, ri.Info.TagPath), "%s") return ri diff --git a/server/internal/redis/api/set.go b/server/internal/redis/api/set.go index 0824a5e4..483e225a 100644 --- a/server/internal/redis/api/set.go +++ b/server/internal/redis/api/set.go @@ -11,7 +11,7 @@ import ( ) func (r *Redis) GetSetValue(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) res, err := ri.GetCmdable().SMembers(context.TODO(), key).Result() biz.ErrIsNilAppendErr(err, "获取set值失败: %s") rc.ResData = res @@ -22,7 +22,7 @@ func (r *Redis) SetSetValue(rc *req.Ctx) { keyvalue := new(form.SetValue) ginx.BindJsonAndValid(g, keyvalue) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() key := keyvalue.Key // 简单处理->先删除,后新增 @@ -35,7 +35,7 @@ func (r *Redis) SetSetValue(rc *req.Ctx) { } func (r *Redis) Scard(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) total, err := ri.GetCmdable().SCard(context.TODO(), key).Result() biz.ErrIsNilAppendErr(err, "scard失败: %s") @@ -47,7 +47,7 @@ func (r *Redis) Sscan(rc *req.Ctx) { scan := new(form.ScanForm) ginx.BindJsonAndValid(g, scan) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() keys, cursor, err := cmd.SScan(context.TODO(), scan.Key, scan.Cursor, scan.Match, scan.Count).Result() biz.ErrIsNilAppendErr(err, "sscan失败: %s") rc.ResData = collx.M{ @@ -60,7 +60,7 @@ func (r *Redis) Sadd(rc *req.Ctx) { g := rc.GinCtx option := new(form.SmemberOption) ginx.BindJsonAndValid(g, option) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() res, err := cmd.SAdd(context.TODO(), option.Key, option.Member).Result() biz.ErrIsNilAppendErr(err, "sadd失败: %s") @@ -72,7 +72,7 @@ func (r *Redis) Srem(rc *req.Ctx) { option := new(form.SmemberOption) ginx.BindJsonAndValid(g, option) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() res, err := cmd.SRem(context.TODO(), option.Key, option.Member).Result() biz.ErrIsNilAppendErr(err, "srem失败: %s") rc.ResData = res diff --git a/server/internal/redis/api/string.go b/server/internal/redis/api/string.go index 5445a722..23a85970 100644 --- a/server/internal/redis/api/string.go +++ b/server/internal/redis/api/string.go @@ -11,7 +11,7 @@ import ( ) func (r *Redis) GetStringValue(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) str, err := ri.GetCmdable().Get(context.TODO(), key).Result() biz.ErrIsNilAppendErr(err, "获取字符串值失败: %s") rc.ResData = str @@ -22,7 +22,7 @@ func (r *Redis) SetStringValue(rc *req.Ctx) { keyValue := new(form.StringValue) ginx.BindJsonAndValid(g, keyValue) - ri := r.getRedisIns(rc) + ri := r.getRedisConn(rc) cmd := ri.GetCmdable() rc.ReqParam = collx.Kvs("redis", ri.Info, "string", keyValue) diff --git a/server/internal/redis/api/zset.go b/server/internal/redis/api/zset.go index 3183c5e5..b9d95afe 100644 --- a/server/internal/redis/api/zset.go +++ b/server/internal/redis/api/zset.go @@ -12,7 +12,7 @@ import ( ) func (r *Redis) ZCard(rc *req.Ctx) { - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) total, err := ri.GetCmdable().ZCard(context.TODO(), key).Result() biz.ErrIsNilAppendErr(err, "zcard失败: %s") @@ -21,7 +21,7 @@ func (r *Redis) ZCard(rc *req.Ctx) { func (r *Redis) ZScan(rc *req.Ctx) { g := rc.GinCtx - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) cursor := uint64(ginx.QueryInt(g, "cursor", 0)) match := ginx.Query(g, "match", "*") @@ -37,7 +37,7 @@ func (r *Redis) ZScan(rc *req.Ctx) { func (r *Redis) ZRevRange(rc *req.Ctx) { g := rc.GinCtx - ri, key := r.checkKeyAndGetRedisIns(rc) + ri, key := r.checkKeyAndGetRedisConn(rc) start := ginx.QueryInt(g, "start", 0) stop := ginx.QueryInt(g, "stop", 50) @@ -51,7 +51,7 @@ func (r *Redis) ZRem(rc *req.Ctx) { option := new(form.SmemberOption) ginx.BindJsonAndValid(g, option) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() res, err := cmd.ZRem(context.TODO(), option.Key, option.Member).Result() biz.ErrIsNilAppendErr(err, "zrem失败: %s") rc.ResData = res @@ -62,7 +62,7 @@ func (r *Redis) ZAdd(rc *req.Ctx) { option := new(form.ZAddOption) ginx.BindJsonAndValid(g, option) - cmd := r.getRedisIns(rc).GetCmdable() + cmd := r.getRedisConn(rc).GetCmdable() zm := redis.Z{ Score: option.Score, Member: option.Member, diff --git a/server/internal/redis/application/redis.go b/server/internal/redis/application/redis.go index 3cefaff9..77db9b9d 100644 --- a/server/internal/redis/application/redis.go +++ b/server/internal/redis/application/redis.go @@ -1,26 +1,14 @@ package application import ( - "context" - "fmt" - "mayfly-go/internal/common/consts" - machineapp "mayfly-go/internal/machine/application" - "mayfly-go/internal/machine/infrastructure/machine" "mayfly-go/internal/redis/domain/entity" "mayfly-go/internal/redis/domain/repository" + "mayfly-go/internal/redis/rdm" "mayfly-go/pkg/base" - "mayfly-go/pkg/cache" "mayfly-go/pkg/errorx" - "mayfly-go/pkg/logx" "mayfly-go/pkg/model" - "mayfly-go/pkg/utils/netx" - "mayfly-go/pkg/utils/structx" - "net" "strconv" "strings" - "time" - - "github.com/redis/go-redis/v9" ) type Redis interface { @@ -31,7 +19,7 @@ type Redis interface { Count(condition *entity.RedisQuery) int64 - Save(entity *entity.Redis) error + Save(re *entity.Redis) error // 删除数据库信息 Delete(id uint64) error @@ -39,7 +27,10 @@ type Redis interface { // 获取数据库连接实例 // id: 数据库实例id // db: 库号 - GetRedisInstance(id uint64, db int) (*RedisInstance, error) + GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) + + // 测试连接 + TestConn(re *entity.Redis) error } func newRedisApp(redisRepo repository.Redis) Redis { @@ -52,7 +43,7 @@ type redisAppImpl struct { base.AppImpl[*entity.Redis, repository.Redis] } -// 分页获取机器脚本信息列表 +// 分页获取redis列表 func (r *redisAppImpl) GetPageList(condition *entity.RedisQuery, pageParam *model.PageParam, toEntity any, orderBy ...string) (*model.PageResult[any], error) { return r.GetRepo().GetRedisList(condition, pageParam, toEntity, orderBy...) } @@ -64,7 +55,7 @@ func (r *redisAppImpl) Count(condition *entity.RedisQuery) int64 { func (r *redisAppImpl) Save(re *entity.Redis) error { // ’修改信息且密码不为空‘ or ‘新增’需要测试是否可连接 if (re.Id != 0 && re.Password != "") || re.Id == 0 { - if err := TestRedisConnection(re); err != nil { + if err := r.TestConn(re); err != nil { return errorx.NewBiz("Redis连接失败: %s", err.Error()) } } @@ -92,7 +83,7 @@ func (r *redisAppImpl) Save(re *entity.Redis) error { if oldRedis.Db != re.Db || oldRedis.SshTunnelMachineId != re.SshTunnelMachineId { for _, dbStr := range strings.Split(oldRedis.Db, ",") { db, _ := strconv.Atoi(dbStr) - CloseRedis(re.Id, db) + rdm.CloseConn(re.Id, db) } } re.PwdEncrypt() @@ -108,253 +99,35 @@ func (r *redisAppImpl) Delete(id uint64) error { // 如果存在连接,则关闭所有库连接信息 for _, dbStr := range strings.Split(re.Db, ",") { db, _ := strconv.Atoi(dbStr) - CloseRedis(re.Id, db) + rdm.CloseConn(re.Id, db) } return r.DeleteById(id) } // 获取数据库连接实例 -func (r *redisAppImpl) GetRedisInstance(id uint64, db int) (*RedisInstance, error) { - // Id不为0,则为需要缓存 - needCache := id != 0 - if needCache { - load, ok := redisCache.Get(getRedisCacheKey(id, db)) - if ok { - return load.(*RedisInstance), nil - } - } - // 缓存不存在,则回调获取redis信息 - re, err := r.GetById(new(entity.Redis), id) - if err != nil { - return nil, errorx.NewBiz("redis信息不存在") - } - re.PwdDecrypt() - - redisMode := re.Mode - var ri *RedisInstance - if redisMode == "" || redisMode == entity.RedisModeStandalone { - ri = getRedisCient(re, db) - // 测试连接 - _, e := ri.Cli.Ping(context.Background()).Result() - if e != nil { - ri.Close() - return nil, errorx.NewBiz("redis连接失败: %s", e.Error()) - } - } else if redisMode == entity.RedisModeCluster { - ri = getRedisClusterClient(re) - // 测试连接 - _, e := ri.ClusterCli.Ping(context.Background()).Result() - if e != nil { - ri.Close() - return nil, errorx.NewBiz("redis集群连接失败: %s", e.Error()) - } - } else if redisMode == entity.RedisModeSentinel { - ri = getRedisSentinelCient(re, db) - // 测试连接 - _, e := ri.Cli.Ping(context.Background()).Result() - if e != nil { - ri.Close() - return nil, errorx.NewBiz("redis sentinel连接失败: %s", e.Error()) - } - } - - logx.Infof("连接redis: %s/%d", re.Host, db) - if needCache { - redisCache.Put(getRedisCacheKey(id, db), ri) - } - return ri, nil -} - -// 生成redis连接缓存key -func getRedisCacheKey(id uint64, db int) string { - return fmt.Sprintf("%d/%d", id, db) -} - -func toRedisInfo(re *entity.Redis, db int) *RedisInfo { - redisInfo := new(RedisInfo) - structx.Copy(redisInfo, re) - redisInfo.Db = db - return redisInfo -} - -func getRedisCient(re *entity.Redis, db int) *RedisInstance { - ri := &RedisInstance{Id: getRedisCacheKey(re.Id, db), Info: toRedisInfo(re, db)} - - redisOptions := &redis.Options{ - Addr: re.Host, - Username: re.Username, - Password: re.Password, // no password set - DB: db, // use default DB - DialTimeout: 8 * time.Second, - ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines. - WriteTimeout: -1, - } - if re.SshTunnelMachineId > 0 { - redisOptions.Dialer = getRedisDialer(re.SshTunnelMachineId) - } - ri.Cli = redis.NewClient(redisOptions) - return ri -} - -func getRedisClusterClient(re *entity.Redis) *RedisInstance { - ri := &RedisInstance{Id: getRedisCacheKey(re.Id, 0), Info: toRedisInfo(re, 0)} - - redisClusterOptions := &redis.ClusterOptions{ - Addrs: strings.Split(re.Host, ","), - Username: re.Username, - Password: re.Password, - DialTimeout: 8 * time.Second, - } - if re.SshTunnelMachineId > 0 { - redisClusterOptions.Dialer = getRedisDialer(re.SshTunnelMachineId) - } - ri.ClusterCli = redis.NewClusterClient(redisClusterOptions) - return ri -} - -func getRedisSentinelCient(re *entity.Redis, db int) *RedisInstance { - ri := &RedisInstance{Id: getRedisCacheKey(re.Id, db), Info: toRedisInfo(re, db)} - // sentinel模式host为 masterName=host:port,host:port - masterNameAndHosts := strings.Split(re.Host, "=") - sentinelOptions := &redis.FailoverOptions{ - MasterName: masterNameAndHosts[0], - SentinelAddrs: strings.Split(masterNameAndHosts[1], ","), - Username: re.Username, - Password: re.Password, // no password set - SentinelPassword: re.Password, // 哨兵节点密码需与redis节点密码一致 - DB: db, // use default DB - DialTimeout: 8 * time.Second, - ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines. - WriteTimeout: -1, - } - if re.SshTunnelMachineId > 0 { - sentinelOptions.Dialer = getRedisDialer(re.SshTunnelMachineId) - } - ri.Cli = redis.NewFailoverClient(sentinelOptions) - return ri -} - -func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) { - return func(_ context.Context, network, addr string) (net.Conn, error) { - sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId) +func (r *redisAppImpl) GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) { + return rdm.GetRedisConn(id, db, func() (*rdm.RedisInfo, error) { + // 缓存不存在,则回调获取redis信息 + re, err := r.GetById(new(entity.Redis), id) if err != nil { - return nil, err + return nil, errorx.NewBiz("redis信息不存在") } + re.PwdDecrypt() - if sshConn, err := sshTunnel.GetDialConn(network, addr); err == nil { - // 将ssh conn包装,否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported - return &netx.WrapSshConn{Conn: sshConn}, nil - } else { - return nil, err - } - } -} - -//------------------------------------------------------------------------------ - -// redis客户端连接缓存,指定时间内没有访问则会被关闭 -var redisCache = cache.NewTimedCache(consts.RedisConnExpireTime, 5*time.Second). - WithUpdateAccessTime(true). - OnEvicted(func(key any, value any) { - logx.Info(fmt.Sprintf("删除redis连接缓存 id = %s", key)) - value.(*RedisInstance).Close() - }) - -// 移除redis连接缓存并关闭redis连接 -func CloseRedis(id uint64, db int) { - redisCache.Delete(getRedisCacheKey(id, db)) -} - -func init() { - machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool { - // 遍历所有redis连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中... - items := redisCache.Items() - for _, v := range items { - if v.Value.(*RedisInstance).Info.SshTunnelMachineId == machineId { - return true - } - } - return false + return re.ToRedisInfo(db), nil }) } -func TestRedisConnection(re *entity.Redis) error { - var cmd redis.Cmdable - // 取第一个库测试连接即可 - dbStr := strings.Split(re.Db, ",")[0] - db, _ := strconv.Atoi(dbStr) - if re.Mode == "" || re.Mode == entity.RedisModeStandalone { - rcli := getRedisCient(re, db) - defer rcli.Close() - cmd = rcli.Cli - } else if re.Mode == entity.RedisModeCluster { - ccli := getRedisClusterClient(re) - defer ccli.Close() - cmd = ccli.ClusterCli - } else if re.Mode == entity.RedisModeSentinel { - rcli := getRedisSentinelCient(re, db) - defer rcli.Close() - cmd = rcli.Cli +func (r *redisAppImpl) TestConn(re *entity.Redis) error { + db := 0 + if re.Db != "" { + db, _ = strconv.Atoi(strings.Split(re.Db, ",")[0]) } - // 测试连接 - _, e := cmd.Ping(context.Background()).Result() - return e -} - -type RedisInfo struct { - Id uint64 `json:"id"` - Host string `json:"host"` - Db int `json:"db"` // 库号 - TagPath string `json:"tagPath"` - Mode string `json:"-"` - Name string `json:"-"` - - SshTunnelMachineId int `json:"-"` -} - -// 获取记录日志的描述 -func (r *RedisInfo) GetLogDesc() string { - return fmt.Sprintf("Redis[id=%d, tag=%s, host=%s, db=%d]", r.Id, r.TagPath, r.Host, r.Db) -} - -// redis实例 -type RedisInstance struct { - Id string - Info *RedisInfo - - Cli *redis.Client - ClusterCli *redis.ClusterClient -} - -// 获取命令执行接口的具体实现 -func (r *RedisInstance) GetCmdable() redis.Cmdable { - redisMode := r.Info.Mode - if redisMode == "" || redisMode == entity.RedisModeStandalone || r.Info.Mode == entity.RedisModeSentinel { - return r.Cli - } - if redisMode == entity.RedisModeCluster { - return r.ClusterCli + rc, err := re.ToRedisInfo(db).Conn() + if err != nil { + return err } + rc.Close() return nil } - -func (r *RedisInstance) Scan(cursor uint64, match string, count int64) ([]string, uint64, error) { - return r.GetCmdable().Scan(context.Background(), cursor, match, count).Result() -} - -func (r *RedisInstance) Close() { - mode := r.Info.Mode - if mode == entity.RedisModeStandalone || mode == entity.RedisModeSentinel { - if err := r.Cli.Close(); err != nil { - logx.Errorf("关闭redis单机实例[%s]连接失败: %s", r.Id, err.Error()) - } - r.Cli = nil - } - if mode == entity.RedisModeCluster { - if err := r.ClusterCli.Close(); err != nil { - logx.Errorf("关闭redis集群实例[%s]连接失败: %s", r.Id, err.Error()) - } - r.ClusterCli = nil - } -} diff --git a/server/internal/redis/domain/entity/redis.go b/server/internal/redis/domain/entity/redis.go index 4a569527..a88cadc8 100644 --- a/server/internal/redis/domain/entity/redis.go +++ b/server/internal/redis/domain/entity/redis.go @@ -2,7 +2,9 @@ package entity import ( "mayfly-go/internal/common/utils" + "mayfly-go/internal/redis/rdm" "mayfly-go/pkg/model" + "mayfly-go/pkg/utils/structx" ) type Redis struct { @@ -20,12 +22,6 @@ type Redis struct { TagPath string } -const ( - RedisModeStandalone = "standalone" - RedisModeCluster = "cluster" - RedisModeSentinel = "sentinel" -) - func (r *Redis) PwdEncrypt() { // 密码替换为加密后的密码 r.Password = utils.PwdAesEncrypt(r.Password) @@ -35,3 +31,11 @@ func (r *Redis) PwdDecrypt() { // 密码替换为解密后的密码 r.Password = utils.PwdAesDecrypt(r.Password) } + +// 转换为redisInfo进行连接 +func (re *Redis) ToRedisInfo(db int) *rdm.RedisInfo { + redisInfo := new(rdm.RedisInfo) + structx.Copy(redisInfo, re) + redisInfo.Db = db + return redisInfo +} diff --git a/server/internal/redis/rdm/conn.go b/server/internal/redis/rdm/conn.go new file mode 100644 index 00000000..cf6a99ea --- /dev/null +++ b/server/internal/redis/rdm/conn.go @@ -0,0 +1,51 @@ +package rdm + +import ( + "context" + "mayfly-go/pkg/logx" + + "github.com/redis/go-redis/v9" +) + +// redis连接信息 +type RedisConn struct { + Id string + Info *RedisInfo + + Cli *redis.Client + ClusterCli *redis.ClusterClient +} + +// 获取命令执行接口的具体实现 +func (r *RedisConn) GetCmdable() redis.Cmdable { + redisMode := r.Info.Mode + if redisMode == "" || redisMode == StandaloneMode || r.Info.Mode == SentinelMode { + return r.Cli + } + if redisMode == ClusterMode { + return r.ClusterCli + } + return nil +} + +func (r *RedisConn) Scan(cursor uint64, match string, count int64) ([]string, uint64, error) { + return r.GetCmdable().Scan(context.Background(), cursor, match, count).Result() +} + +func (r *RedisConn) Close() { + mode := r.Info.Mode + if mode == StandaloneMode || mode == SentinelMode { + if err := r.Cli.Close(); err != nil { + logx.Errorf("关闭redis单机实例[%s]连接失败: %s", r.Id, err.Error()) + } + r.Cli = nil + return + } + + if mode == ClusterMode { + if err := r.ClusterCli.Close(); err != nil { + logx.Errorf("关闭redis集群实例[%s]连接失败: %s", r.Id, err.Error()) + } + r.ClusterCli = nil + } +} diff --git a/server/internal/redis/rdm/conn_cache.go b/server/internal/redis/rdm/conn_cache.go new file mode 100644 index 00000000..33163b26 --- /dev/null +++ b/server/internal/redis/rdm/conn_cache.go @@ -0,0 +1,73 @@ +package rdm + +import ( + "fmt" + "mayfly-go/internal/common/consts" + "mayfly-go/internal/machine/infrastructure/machine" + "mayfly-go/pkg/cache" + "mayfly-go/pkg/logx" + "sync" + "time" +) + +// redis客户端连接缓存,指定时间内没有访问则会被关闭 +var connCache = cache.NewTimedCache(consts.RedisConnExpireTime, 5*time.Second). + WithUpdateAccessTime(true). + OnEvicted(func(key any, value any) { + logx.Info(fmt.Sprintf("删除redis连接缓存 id = %s", key)) + value.(*RedisConn).Close() + }) + +func init() { + machine.AddCheckSshTunnelMachineUseFunc(func(machineId int) bool { + // 遍历所有redis连接实例,若存在redis实例使用该ssh隧道机器,则返回true,表示还在使用中... + items := connCache.Items() + for _, v := range items { + if v.Value.(*RedisConn).Info.SshTunnelMachineId == machineId { + return true + } + } + return false + }) +} + +var mutex sync.Mutex + +// 从缓存中获取redis连接信息, 若缓存中不存在则会使用回调函数获取redisInfo进行连接并缓存 +func GetRedisConn(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (*RedisConn, error) { + connId := getConnId(redisId, db) + + // connId不为空,则为需要缓存 + needCache := connId != "" + if needCache { + load, ok := connCache.Get(connId) + if ok { + return load.(*RedisConn), nil + } + } + + mutex.Lock() + defer mutex.Unlock() + + // 若缓存中不存在,则从回调函数中获取RedisInfo + ri, err := getRedisInfo() + if err != nil { + return nil, err + } + + // 连接数据库 + rc, err := ri.Conn() + if err != nil { + return nil, err + } + + if needCache { + connCache.Put(connId, rc) + } + return rc, nil +} + +// 移除redis连接缓存并关闭redis连接 +func CloseConn(id uint64, db int) { + connCache.Delete(getConnId(id, db)) +} diff --git a/server/internal/redis/rdm/info.go b/server/internal/redis/rdm/info.go new file mode 100644 index 00000000..113f70f0 --- /dev/null +++ b/server/internal/redis/rdm/info.go @@ -0,0 +1,161 @@ +package rdm + +import ( + "context" + "fmt" + machineapp "mayfly-go/internal/machine/application" + "mayfly-go/pkg/errorx" + "mayfly-go/pkg/logx" + "mayfly-go/pkg/utils/netx" + "net" + "strings" + "time" + + "github.com/redis/go-redis/v9" +) + +type RedisMode string + +const ( + StandaloneMode RedisMode = "standalone" + ClusterMode RedisMode = "cluster" + SentinelMode RedisMode = "sentinel" +) + +type RedisInfo struct { + Id uint64 `json:"id"` + + Host string `json:"host"` + Db int `json:"db"` // 库号 + Mode RedisMode `json:"-"` + Username string `json:"-"` + Password string `json:"-"` + + Name string `json:"-"` + TagPath string `json:"tagPath"` + SshTunnelMachineId int `json:"-"` +} + +func (r *RedisInfo) Conn() (*RedisConn, error) { + redisMode := r.Mode + if redisMode == StandaloneMode { + return r.connStandalone() + } + if redisMode == ClusterMode { + return r.connCluster() + } + if redisMode == SentinelMode { + return r.connSentinel() + } + + return nil, errorx.NewBiz("redis mode error") +} + +func (re *RedisInfo) connStandalone() (*RedisConn, error) { + redisOptions := &redis.Options{ + Addr: re.Host, + Username: re.Username, + Password: re.Password, // no password set + DB: re.Db, // use default DB + DialTimeout: 8 * time.Second, + ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines. + WriteTimeout: -1, + } + if re.SshTunnelMachineId > 0 { + redisOptions.Dialer = getRedisDialer(re.SshTunnelMachineId) + } + + cli := redis.NewClient(redisOptions) + _, e := cli.Ping(context.Background()).Result() + if e != nil { + cli.Close() + return nil, errorx.NewBiz("redis连接失败: %s", e.Error()) + } + + logx.Infof("连接redis standalone: %s/%d", re.Host, re.Db) + + rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re} + rc.Cli = cli + return rc, nil +} + +func (re *RedisInfo) connCluster() (*RedisConn, error) { + redisClusterOptions := &redis.ClusterOptions{ + Addrs: strings.Split(re.Host, ","), + Username: re.Username, + Password: re.Password, + DialTimeout: 8 * time.Second, + } + if re.SshTunnelMachineId > 0 { + redisClusterOptions.Dialer = getRedisDialer(re.SshTunnelMachineId) + } + cli := redis.NewClusterClient(redisClusterOptions) + // 测试连接 + _, e := cli.Ping(context.Background()).Result() + if e != nil { + cli.Close() + return nil, errorx.NewBiz("redis集群连接失败: %s", e.Error()) + } + + logx.Infof("连接redis cluster: %s/%d", re.Host, re.Db) + + rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re} + rc.ClusterCli = cli + return rc, nil +} + +func (re *RedisInfo) connSentinel() (*RedisConn, error) { + // sentinel模式host为 masterName=host:port,host:port + masterNameAndHosts := strings.Split(re.Host, "=") + sentinelOptions := &redis.FailoverOptions{ + MasterName: masterNameAndHosts[0], + SentinelAddrs: strings.Split(masterNameAndHosts[1], ","), + Username: re.Username, + Password: re.Password, // no password set + SentinelPassword: re.Password, // 哨兵节点密码需与redis节点密码一致 + DB: re.Db, // use default DB + DialTimeout: 8 * time.Second, + ReadTimeout: -1, // Disable timeouts, because SSH does not support deadlines. + WriteTimeout: -1, + } + if re.SshTunnelMachineId > 0 { + sentinelOptions.Dialer = getRedisDialer(re.SshTunnelMachineId) + } + cli := redis.NewFailoverClient(sentinelOptions) + + _, e := cli.Ping(context.Background()).Result() + if e != nil { + cli.Close() + return nil, errorx.NewBiz("redis sentinel连接失败: %s", e.Error()) + } + + logx.Infof("连接redis sentinel: %s/%d", re.Host, re.Db) + + rc := &RedisConn{Id: getConnId(re.Id, re.Db), Info: re} + rc.Cli = cli + return rc, nil +} + +func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) { + return func(_ context.Context, network, addr string) (net.Conn, error) { + sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId) + if err != nil { + return nil, err + } + + if sshConn, err := sshTunnel.GetDialConn(network, addr); err == nil { + // 将ssh conn包装,否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported + return &netx.WrapSshConn{Conn: sshConn}, nil + } else { + return nil, err + } + } +} + +// 生成redis连接id +func getConnId(id uint64, db int) string { + if id == 0 { + return "" + } + return fmt.Sprintf("%d/%d", id, db) +} diff --git a/server/migrations/2022.go b/server/migrations/2022.go index 94698acf..2af279e5 100644 --- a/server/migrations/2022.go +++ b/server/migrations/2022.go @@ -43,7 +43,7 @@ func T2022() *gormigrate.Migration { return err } - if err := tx.AutoMigrate(&entity2.Instance{}); err != nil { + if err := tx.AutoMigrate(&entity2.DbInstance{}); err != nil { return err } if err := tx.AutoMigrate(&entity2.Db{}); err != nil { diff --git a/server/pkg/base/app.go b/server/pkg/base/app.go index 2c0a0557..74205ed0 100644 --- a/server/pkg/base/app.go +++ b/server/pkg/base/app.go @@ -10,9 +10,15 @@ type App[T any] interface { // 新增一个实体 Insert(e T) error + // 使用指定gorm db执行,主要用于事务执行 + InsertWithDb(db *gorm.DB, e T) error + // 批量新增实体 BatchInsert(models []T) error + // 使用指定gorm db执行,主要用于事务执行 + BatchInsertWithDb(db *gorm.DB, es []T) error + // 根据实体id更新实体信息 UpdateById(e T) error @@ -70,11 +76,21 @@ func (ai *AppImpl[T, R]) Insert(e T) error { return ai.GetRepo().Insert(e) } +// 使用指定gorm db执行,主要用于事务执行 +func (ai *AppImpl[T, R]) InsertWithDb(db *gorm.DB, e T) error { + return ai.GetRepo().InsertWithDb(db, e) +} + // 批量新增实体 (单纯新增,不做其他业务逻辑处理) func (ai *AppImpl[T, R]) BatchInsert(es []T) error { return ai.GetRepo().BatchInsert(es) } +// 使用指定gorm db执行,主要用于事务执行 +func (ai *AppImpl[T, R]) BatchInsertWithDb(db *gorm.DB, es []T) error { + return ai.GetRepo().BatchInsertWithDb(db, es) +} + // 根据实体id更新实体信息 (单纯更新,不做其他业务逻辑处理) func (ai *AppImpl[T, R]) UpdateById(e T) error { return ai.GetRepo().UpdateById(e) diff --git a/server/pkg/ginx/ginx.go b/server/pkg/ginx/ginx.go index 26d75fbd..04233458 100644 --- a/server/pkg/ginx/ginx.go +++ b/server/pkg/ginx/ginx.go @@ -109,6 +109,7 @@ func ErrorRes(g *gin.Context, err any) { g.JSON(http.StatusOK, model.ServerError()) default: logx.Errorf("未知错误: %v", t) + g.JSON(http.StatusOK, model.ServerError()) } }