refactor: 获取数据库连接调整

This commit is contained in:
meilin.huang
2023-09-05 14:41:12 +08:00
parent d0ac7de4cb
commit 618d782af3
6 changed files with 49 additions and 71 deletions

View File

@@ -63,18 +63,6 @@ func (d *Db) Save(rc *req.Ctx) {
d.DbApp.Save(db) d.DbApp.Save(db)
} }
// 获取数据库实例的所有数据库名
func (d *Db) GetDatabaseNames(rc *req.Ctx) {
form := &form.DbForm{}
ginx.BindJsonAndValid(rc.GinCtx, form)
instance := d.InstanceApp.GetById(form.InstanceId, "Password")
biz.NotNil(instance, "获取数据库实例错误")
instance.PwdDecrypt()
rc.ResData = d.InstanceApp.GetDatabases(instance)
rc.ResData = d.InstanceApp.GetDatabases(instance)
}
func (d *Db) DeleteDb(rc *req.Ctx) { func (d *Db) DeleteDb(rc *req.Ctx) {
idsStr := ginx.PathParam(rc.GinCtx, "dbId") idsStr := ginx.PathParam(rc.GinCtx, "dbId")
rc.ReqParam = idsStr rc.ReqParam = idsStr
@@ -94,9 +82,7 @@ func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection {
dbName := g.Query("db") dbName := g.Query("db")
biz.NotEmpty(dbName, "db不能为空") biz.NotEmpty(dbName, "db不能为空")
dbId := getDbId(g) dbId := getDbId(g)
db := d.DbApp.GetById(dbId) return d.DbApp.GetDbConnection(dbId, dbName)
instance := d.InstanceApp.GetById(db.InstanceId)
return d.DbApp.GetDbConnection(db, instance, dbName)
} }
func (d *Db) TableInfos(rc *req.Ctx) { func (d *Db) TableInfos(rc *req.Ctx) {
@@ -121,12 +107,10 @@ func (d *Db) ExecSql(rc *req.Ctx) {
ginx.BindJsonAndValid(g, form) ginx.BindJsonAndValid(g, form)
dbId := getDbId(g) dbId := getDbId(g)
db := d.DbApp.GetById(dbId) dbConn := d.DbApp.GetDbConnection(dbId, form.Db)
instance := d.InstanceApp.GetById(db.InstanceId) biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
dbInstance := d.DbApp.GetDbConnection(db, instance, form.Db)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbInstance.Info.TagPath), "%s")
rc.ReqParam = fmt.Sprintf("%s\n-> %s", dbInstance.Info.GetLogDesc(), form.Sql) rc.ReqParam = fmt.Sprintf("%s\n-> %s", dbConn.Info.GetLogDesc(), form.Sql)
biz.NotEmpty(form.Sql, "sql不能为空") biz.NotEmpty(form.Sql, "sql不能为空")
// 去除前后空格及换行符 // 去除前后空格及换行符
@@ -136,7 +120,7 @@ func (d *Db) ExecSql(rc *req.Ctx) {
DbId: dbId, DbId: dbId,
Db: form.Db, Db: form.Db,
Remark: form.Remark, Remark: form.Remark,
DbInstance: dbInstance, DbConn: dbConn,
LoginAccount: rc.LoginAccount, LoginAccount: rc.LoginAccount,
} }
@@ -180,9 +164,9 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
dbId := getDbId(g) dbId := getDbId(g)
dbName := getDbName(g) dbName := getDbName(g)
dbInstance := d.getDbConnection(rc.GinCtx) dbConn := d.getDbConnection(rc.GinCtx)
biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbInstance.Info.TagPath), "%s") biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s")
rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbInstance.Info.GetLogDesc(), filename) rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename)
logExecRecord := true logExecRecord := true
// 如果执行sql文件大于该值则不记录sql执行记录 // 如果执行sql文件大于该值则不记录sql执行记录
@@ -195,7 +179,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
if err := recover(); err != nil { if err := recover(); err != nil {
switch t := err.(type) { switch t := err.(type) {
case *biz.BizError: case *biz.BizError:
d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbInstance.Info.GetLogDesc(), t.Error()))) d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), t.Error())))
} }
} }
}() }()
@@ -204,7 +188,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
DbId: dbId, DbId: dbId,
Db: dbName, Db: dbName,
Remark: fileheader.Filename, Remark: fileheader.Filename,
DbInstance: dbInstance, DbConn: dbConn,
LoginAccount: rc.LoginAccount, LoginAccount: rc.LoginAccount,
} }
@@ -220,15 +204,15 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
if logExecRecord { if logExecRecord {
_, err = d.DbSqlExecApp.Exec(execReq) _, err = d.DbSqlExecApp.Exec(execReq)
} else { } else {
_, err = dbInstance.Exec(sql) _, err = dbConn.Exec(sql)
} }
if err != nil { if err != nil {
d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbInstance.Info.GetLogDesc(), sql, err.Error()))) d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error())))
return return
} }
} }
d.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbInstance.Info.GetLogDesc()))) d.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc())))
}() }()
} }

View File

@@ -86,12 +86,6 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) {
} }
} }
func getInstanceId(g *gin.Context) uint64 {
instanceId, _ := strconv.Atoi(g.Param("instanceId"))
biz.IsTrue(instanceId > 0, "instanceId 错误")
return uint64(instanceId)
}
// 获取数据库实例的所有数据库名 // 获取数据库实例的所有数据库名
func (d *Instance) GetDatabaseNames(rc *req.Ctx) { func (d *Instance) GetDatabaseNames(rc *req.Ctx) {
instanceId := getInstanceId(rc.GinCtx) instanceId := getInstanceId(rc.GinCtx)
@@ -100,3 +94,9 @@ func (d *Instance) GetDatabaseNames(rc *req.Ctx) {
instance.PwdDecrypt() instance.PwdDecrypt()
rc.ResData = d.InstanceApp.GetDatabases(instance) rc.ResData = d.InstanceApp.GetDatabases(instance)
} }
func getInstanceId(g *gin.Context) uint64 {
instanceId, _ := strconv.Atoi(g.Param("instanceId"))
biz.IsTrue(instanceId > 0, "instanceId 错误")
return uint64(instanceId)
}

View File

@@ -6,7 +6,7 @@ import (
var ( var (
instanceApp Instance = newInstanceApp(persistence.GetInstanceRepo()) instanceApp Instance = newInstanceApp(persistence.GetInstanceRepo())
dbApp Db = newDbApp(persistence.GetDbRepo(), persistence.GetDbSqlRepo()) dbApp Db = newDbApp(persistence.GetDbRepo(), persistence.GetDbSqlRepo(), instanceApp)
dbSqlExecApp DbSqlExec = newDbSqlExecApp(persistence.GetDbSqlExecRepo()) dbSqlExecApp DbSqlExec = newDbSqlExecApp(persistence.GetDbSqlExecRepo())
) )

View File

@@ -39,19 +39,21 @@ type Db interface {
// 获取数据库连接实例 // 获取数据库连接实例
// @param id 数据库实例id // @param id 数据库实例id
// @param db 数据库 // @param db 数据库
GetDbConnection(db *entity.Db, instance *entity.Instance, dbName string) *DbConnection GetDbConnection(dbId uint64, dbName string) *DbConnection
} }
func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql) Db { func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance) Db {
return &dbAppImpl{ return &dbAppImpl{
dbRepo: dbRepo, dbRepo: dbRepo,
dbSqlRepo: dbSqlRepo, dbSqlRepo: dbSqlRepo,
dbInstanceApp: dbInstanceApp,
} }
} }
type dbAppImpl struct { type dbAppImpl struct {
dbRepo repository.Db dbRepo repository.Db
dbSqlRepo repository.DbSql dbSqlRepo repository.DbSql
dbInstanceApp Instance
} }
// 分页获取数据库信息列表 // 分页获取数据库信息列表
@@ -129,11 +131,11 @@ func (d *dbAppImpl) Delete(id uint64) {
var mutex sync.Mutex var mutex sync.Mutex
func (d *dbAppImpl) GetDbConnection(db *entity.Db, instance *entity.Instance, dbName string) *DbConnection { func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection {
cacheKey := GetDbCacheKey(db.Id, dbName) cacheKey := GetDbCacheKey(dbId, dbName)
// Id不为0则为需要缓存 // Id不为0则为需要缓存
needCache := db.Id != 0 needCache := dbId != 0
if needCache { if needCache {
load, ok := dbCache.Get(cacheKey) load, ok := dbCache.Get(cacheKey)
if ok { if ok {
@@ -143,9 +145,11 @@ func (d *dbAppImpl) GetDbConnection(db *entity.Db, instance *entity.Instance, db
mutex.Lock() mutex.Lock()
defer mutex.Unlock() defer mutex.Unlock()
db := d.GetById(dbId)
biz.NotNil(db, "数据库信息不存在") biz.NotNil(db, "数据库信息不存在")
biz.IsTrue(strings.Contains(" "+db.Database+" ", " "+dbName+" "), "未配置该库的操作权限") biz.IsTrue(strings.Contains(" "+db.Database+" ", " "+dbName+" "), "未配置该库的操作权限")
instance := d.dbInstanceApp.GetById(db.InstanceId)
// 密码解密 // 密码解密
instance.PwdDecrypt() instance.PwdDecrypt()
@@ -208,7 +212,7 @@ func (d *DbInfo) GetLogDesc() string {
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database) return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database)
} }
// db实例 // db实例连接信息
type DbConnection struct { type DbConnection struct {
Id string Id string
Info *DbInfo Info *DbInfo
@@ -286,13 +290,6 @@ func GetDbCacheKey(dbId uint64, db string) string {
return fmt.Sprintf("%d:%s", dbId, db) return fmt.Sprintf("%d:%s", dbId, db)
} }
func GetDbInstanceByCache(id string) *DbConnection {
if load, ok := dbCache.Get(id); ok {
return load.(*DbConnection)
}
return nil
}
func SelectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) { func SelectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) {
rows, err := db.Query(selectSql) rows, err := db.Query(selectSql)
if err != nil { if err != nil {

View File

@@ -20,7 +20,7 @@ type DbSqlExecReq struct {
Sql string Sql string
Remark string Remark string
LoginAccount *model.LoginAccount LoginAccount *model.LoginAccount
DbInstance *DbConnection DbConn *DbConnection
} }
type DbSqlExecRes struct { type DbSqlExecRes struct {
@@ -98,7 +98,7 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
if isSelect || strings.HasPrefix(lowerSql, "show") { if isSelect || strings.HasPrefix(lowerSql, "show") {
execRes, execErr = doRead(execSqlReq) execRes, execErr = doRead(execSqlReq)
} else { } else {
execRes, execErr = doExec(execSqlReq.Sql, execSqlReq.DbInstance) execRes, execErr = doExec(execSqlReq.Sql, execSqlReq.DbConn)
} }
if execErr != nil { if execErr != nil {
return nil, execErr return nil, execErr
@@ -124,7 +124,7 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error)
case *sqlparser.Insert: case *sqlparser.Insert:
execRes, err = doInsert(stmt, execSqlReq, dbSqlExecRecord) execRes, err = doInsert(stmt, execSqlReq, dbSqlExecRecord)
default: default:
execRes, err = doExec(execSqlReq.Sql, execSqlReq.DbInstance) execRes, err = doExec(execSqlReq.Sql, execSqlReq.DbConn)
} }
if err != nil { if err != nil {
return nil, err return nil, err
@@ -174,9 +174,9 @@ func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExe
} }
func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) { func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
dbInstance := execSqlReq.DbInstance dbConn := execSqlReq.DbConn
sql := execSqlReq.Sql sql := execSqlReq.Sql
colNames, res, err := dbInstance.SelectData(sql) colNames, res, err := dbConn.SelectData(sql)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -187,7 +187,7 @@ func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) {
} }
func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) { func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
dbInstance := execSqlReq.DbInstance dbConn := execSqlReq.DbConn
tableStr := sqlparser.String(update.TableExprs) tableStr := sqlparser.String(update.TableExprs)
// 可能使用别名,故空格切割 // 可能使用别名,故空格切割
@@ -202,12 +202,12 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
} }
// 获取表主键列名,排除使用别名 // 获取表主键列名,排除使用别名
primaryKey := dbInstance.GetMeta().GetPrimaryKey(tableName) primaryKey := dbConn.GetMeta().GetPrimaryKey(tableName)
updateColumnsAndPrimaryKey := strings.Join(updateColumns, ",") + "," + primaryKey updateColumnsAndPrimaryKey := strings.Join(updateColumns, ",") + "," + primaryKey
// 查询要更新字段数据的旧值,以及主键值 // 查询要更新字段数据的旧值,以及主键值
selectSql := fmt.Sprintf("SELECT %s FROM %s %s LIMIT 200", updateColumnsAndPrimaryKey, tableStr, where) selectSql := fmt.Sprintf("SELECT %s FROM %s %s LIMIT 200", updateColumnsAndPrimaryKey, tableStr, where)
_, res, err := dbInstance.SelectData(selectSql) _, res, err := dbConn.SelectData(selectSql)
if err == nil { if err == nil {
bytes, _ := json.Marshal(res) bytes, _ := json.Marshal(res)
dbSqlExec.OldValue = string(bytes) dbSqlExec.OldValue = string(bytes)
@@ -218,11 +218,11 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent
dbSqlExec.Table = tableName dbSqlExec.Table = tableName
dbSqlExec.Type = entity.DbSqlExecTypeUpdate dbSqlExec.Type = entity.DbSqlExecTypeUpdate
return doExec(execSqlReq.Sql, dbInstance) return doExec(execSqlReq.Sql, dbConn)
} }
func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) { func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
dbInstance := execSqlReq.DbInstance dbConn := execSqlReq.DbConn
tableStr := sqlparser.String(delete.TableExprs) tableStr := sqlparser.String(delete.TableExprs)
// 可能使用别名,故空格切割 // 可能使用别名,故空格切割
@@ -232,14 +232,14 @@ func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *ent
// 查询删除数据 // 查询删除数据
selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where) selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where)
_, res, _ := dbInstance.SelectData(selectSql) _, res, _ := dbConn.SelectData(selectSql)
bytes, _ := json.Marshal(res) bytes, _ := json.Marshal(res)
dbSqlExec.OldValue = string(bytes) dbSqlExec.OldValue = string(bytes)
dbSqlExec.Table = table dbSqlExec.Table = table
dbSqlExec.Type = entity.DbSqlExecTypeDelete dbSqlExec.Type = entity.DbSqlExecTypeDelete
return doExec(execSqlReq.Sql, dbInstance) return doExec(execSqlReq.Sql, dbConn)
} }
func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) { func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) {
@@ -249,11 +249,11 @@ func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *ent
dbSqlExec.Table = table dbSqlExec.Table = table
dbSqlExec.Type = entity.DbSqlExecTypeInsert dbSqlExec.Type = entity.DbSqlExecTypeInsert
return doExec(execSqlReq.Sql, execSqlReq.DbInstance) return doExec(execSqlReq.Sql, execSqlReq.DbConn)
} }
func doExec(sql string, dbInstance *DbConnection) (*DbSqlExecRes, error) { func doExec(sql string, dbConn *DbConnection) (*DbSqlExecRes, error) {
rowsAffected, err := dbInstance.Exec(sql) rowsAffected, err := dbConn.Exec(sql)
execRes := "success" execRes := "success"
if err != nil { if err != nil {
execRes = err.Error() execRes = err.Error()

View File

@@ -29,9 +29,6 @@ func InitDbRouter(router *gin.RouterGroup) {
req.NewPost("", d.Save).Log(req.NewLogSave("db-保存数据库信息")), req.NewPost("", d.Save).Log(req.NewLogSave("db-保存数据库信息")),
// 获取数据库实例的所有数据库名
req.NewPost("/databases", d.GetDatabaseNames),
req.NewDelete(":dbId", d.DeleteDb).Log(req.NewLogSave("db-删除数据库信息")), req.NewDelete(":dbId", d.DeleteDb).Log(req.NewLogSave("db-删除数据库信息")),
req.NewGet(":dbId/t-infos", d.TableInfos), req.NewGet(":dbId/t-infos", d.TableInfos),