diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 8d6bac7b..4af65bf4 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -63,18 +63,6 @@ func (d *Db) Save(rc *req.Ctx) { d.DbApp.Save(db) } -// 获取数据库实例的所有数据库名 -func (d *Db) GetDatabaseNames(rc *req.Ctx) { - form := &form.DbForm{} - ginx.BindJsonAndValid(rc.GinCtx, form) - - instance := d.InstanceApp.GetById(form.InstanceId, "Password") - biz.NotNil(instance, "获取数据库实例错误") - instance.PwdDecrypt() - rc.ResData = d.InstanceApp.GetDatabases(instance) - rc.ResData = d.InstanceApp.GetDatabases(instance) -} - func (d *Db) DeleteDb(rc *req.Ctx) { idsStr := ginx.PathParam(rc.GinCtx, "dbId") rc.ReqParam = idsStr @@ -94,9 +82,7 @@ func (d *Db) getDbConnection(g *gin.Context) *application.DbConnection { dbName := g.Query("db") biz.NotEmpty(dbName, "db不能为空") dbId := getDbId(g) - db := d.DbApp.GetById(dbId) - instance := d.InstanceApp.GetById(db.InstanceId) - return d.DbApp.GetDbConnection(db, instance, dbName) + return d.DbApp.GetDbConnection(dbId, dbName) } func (d *Db) TableInfos(rc *req.Ctx) { @@ -121,12 +107,10 @@ func (d *Db) ExecSql(rc *req.Ctx) { ginx.BindJsonAndValid(g, form) dbId := getDbId(g) - db := d.DbApp.GetById(dbId) - instance := d.InstanceApp.GetById(db.InstanceId) - dbInstance := d.DbApp.GetDbConnection(db, instance, form.Db) - biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbInstance.Info.TagPath), "%s") + dbConn := d.DbApp.GetDbConnection(dbId, form.Db) + biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") - rc.ReqParam = fmt.Sprintf("%s\n-> %s", dbInstance.Info.GetLogDesc(), form.Sql) + rc.ReqParam = fmt.Sprintf("%s\n-> %s", dbConn.Info.GetLogDesc(), form.Sql) biz.NotEmpty(form.Sql, "sql不能为空") // 去除前后空格及换行符 @@ -136,7 +120,7 @@ func (d *Db) ExecSql(rc *req.Ctx) { DbId: dbId, Db: form.Db, Remark: form.Remark, - DbInstance: dbInstance, + DbConn: dbConn, LoginAccount: rc.LoginAccount, } @@ -180,9 +164,9 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { dbId := getDbId(g) dbName := getDbName(g) - dbInstance := d.getDbConnection(rc.GinCtx) - biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbInstance.Info.TagPath), "%s") - rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbInstance.Info.GetLogDesc(), filename) + dbConn := d.getDbConnection(rc.GinCtx) + biz.ErrIsNilAppendErr(d.TagApp.CanAccess(rc.LoginAccount.Id, dbConn.Info.TagPath), "%s") + rc.ReqParam = fmt.Sprintf("%s -> filename: %s", dbConn.Info.GetLogDesc(), filename) logExecRecord := true // 如果执行sql文件大于该值则不记录sql执行记录 @@ -195,7 +179,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { if err := recover(); err != nil { switch t := err.(type) { case *biz.BizError: - d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbInstance.Info.GetLogDesc(), t.Error()))) + d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s]%s执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), t.Error()))) } } }() @@ -204,7 +188,7 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { DbId: dbId, Db: dbName, Remark: fileheader.Filename, - DbInstance: dbInstance, + DbConn: dbConn, LoginAccount: rc.LoginAccount, } @@ -220,15 +204,15 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { if logExecRecord { _, err = d.DbSqlExecApp.Exec(execReq) } else { - _, err = dbInstance.Exec(sql) + _, err = dbConn.Exec(sql) } if err != nil { - d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbInstance.Info.GetLogDesc(), sql, err.Error()))) + d.MsgApp.CreateAndSend(rc.LoginAccount, ws.ErrMsg("sql脚本执行失败", fmt.Sprintf("[%s][%s] -> sql=[%s] 执行失败: [%s]", filename, dbConn.Info.GetLogDesc(), sql, err.Error()))) return } } - d.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbInstance.Info.GetLogDesc()))) + d.MsgApp.CreateAndSend(rc.LoginAccount, ws.SuccessMsg("sql脚本执行成功", fmt.Sprintf("[%s]执行完成 -> %s", filename, dbConn.Info.GetLogDesc()))) }() } diff --git a/server/internal/db/api/instance.go b/server/internal/db/api/instance.go index 3772986b..ad48ec17 100644 --- a/server/internal/db/api/instance.go +++ b/server/internal/db/api/instance.go @@ -86,12 +86,6 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) { } } -func getInstanceId(g *gin.Context) uint64 { - instanceId, _ := strconv.Atoi(g.Param("instanceId")) - biz.IsTrue(instanceId > 0, "instanceId 错误") - return uint64(instanceId) -} - // 获取数据库实例的所有数据库名 func (d *Instance) GetDatabaseNames(rc *req.Ctx) { instanceId := getInstanceId(rc.GinCtx) @@ -100,3 +94,9 @@ func (d *Instance) GetDatabaseNames(rc *req.Ctx) { instance.PwdDecrypt() rc.ResData = d.InstanceApp.GetDatabases(instance) } + +func getInstanceId(g *gin.Context) uint64 { + instanceId, _ := strconv.Atoi(g.Param("instanceId")) + biz.IsTrue(instanceId > 0, "instanceId 错误") + return uint64(instanceId) +} diff --git a/server/internal/db/application/application.go b/server/internal/db/application/application.go index 3dd32c4e..910a0e09 100644 --- a/server/internal/db/application/application.go +++ b/server/internal/db/application/application.go @@ -6,7 +6,7 @@ import ( var ( instanceApp Instance = newInstanceApp(persistence.GetInstanceRepo()) - dbApp Db = newDbApp(persistence.GetDbRepo(), persistence.GetDbSqlRepo()) + dbApp Db = newDbApp(persistence.GetDbRepo(), persistence.GetDbSqlRepo(), instanceApp) dbSqlExecApp DbSqlExec = newDbSqlExecApp(persistence.GetDbSqlExecRepo()) ) diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index 77fcbd5e..4dfcd067 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -39,19 +39,21 @@ type Db interface { // 获取数据库连接实例 // @param id 数据库实例id // @param db 数据库 - GetDbConnection(db *entity.Db, instance *entity.Instance, dbName string) *DbConnection + GetDbConnection(dbId uint64, dbName string) *DbConnection } -func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql) Db { +func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance) Db { return &dbAppImpl{ - dbRepo: dbRepo, - dbSqlRepo: dbSqlRepo, + dbRepo: dbRepo, + dbSqlRepo: dbSqlRepo, + dbInstanceApp: dbInstanceApp, } } type dbAppImpl struct { - dbRepo repository.Db - dbSqlRepo repository.DbSql + dbRepo repository.Db + dbSqlRepo repository.DbSql + dbInstanceApp Instance } // 分页获取数据库信息列表 @@ -129,11 +131,11 @@ func (d *dbAppImpl) Delete(id uint64) { var mutex sync.Mutex -func (d *dbAppImpl) GetDbConnection(db *entity.Db, instance *entity.Instance, dbName string) *DbConnection { - cacheKey := GetDbCacheKey(db.Id, dbName) +func (d *dbAppImpl) GetDbConnection(dbId uint64, dbName string) *DbConnection { + cacheKey := GetDbCacheKey(dbId, dbName) // Id不为0,则为需要缓存 - needCache := db.Id != 0 + needCache := dbId != 0 if needCache { load, ok := dbCache.Get(cacheKey) if ok { @@ -143,9 +145,11 @@ func (d *dbAppImpl) GetDbConnection(db *entity.Db, instance *entity.Instance, db mutex.Lock() defer mutex.Unlock() + db := d.GetById(dbId) biz.NotNil(db, "数据库信息不存在") biz.IsTrue(strings.Contains(" "+db.Database+" ", " "+dbName+" "), "未配置该库的操作权限") + instance := d.dbInstanceApp.GetById(db.InstanceId) // 密码解密 instance.PwdDecrypt() @@ -208,7 +212,7 @@ func (d *DbInfo) GetLogDesc() string { return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database) } -// db实例 +// db实例连接信息 type DbConnection struct { Id string Info *DbInfo @@ -286,13 +290,6 @@ func GetDbCacheKey(dbId uint64, db string) string { return fmt.Sprintf("%d:%s", dbId, db) } -func GetDbInstanceByCache(id string) *DbConnection { - if load, ok := dbCache.Get(id); ok { - return load.(*DbConnection) - } - return nil -} - func SelectDataByDb(db *sql.DB, selectSql string) ([]string, []map[string]any, error) { rows, err := db.Query(selectSql) if err != nil { diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 591411ef..3c5ab4d6 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -20,7 +20,7 @@ type DbSqlExecReq struct { Sql string Remark string LoginAccount *model.LoginAccount - DbInstance *DbConnection + DbConn *DbConnection } type DbSqlExecRes struct { @@ -98,7 +98,7 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) if isSelect || strings.HasPrefix(lowerSql, "show") { execRes, execErr = doRead(execSqlReq) } else { - execRes, execErr = doExec(execSqlReq.Sql, execSqlReq.DbInstance) + execRes, execErr = doExec(execSqlReq.Sql, execSqlReq.DbConn) } if execErr != nil { return nil, execErr @@ -124,7 +124,7 @@ func (d *dbSqlExecAppImpl) Exec(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) case *sqlparser.Insert: execRes, err = doInsert(stmt, execSqlReq, dbSqlExecRecord) default: - execRes, err = doExec(execSqlReq.Sql, execSqlReq.DbInstance) + execRes, err = doExec(execSqlReq.Sql, execSqlReq.DbConn) } if err != nil { return nil, err @@ -174,9 +174,9 @@ func doSelect(selectStmt *sqlparser.Select, execSqlReq *DbSqlExecReq) (*DbSqlExe } func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) { - dbInstance := execSqlReq.DbInstance + dbConn := execSqlReq.DbConn sql := execSqlReq.Sql - colNames, res, err := dbInstance.SelectData(sql) + colNames, res, err := dbConn.SelectData(sql) if err != nil { return nil, err } @@ -187,7 +187,7 @@ func doRead(execSqlReq *DbSqlExecReq) (*DbSqlExecRes, error) { } func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) { - dbInstance := execSqlReq.DbInstance + dbConn := execSqlReq.DbConn tableStr := sqlparser.String(update.TableExprs) // 可能使用别名,故空格切割 @@ -202,12 +202,12 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent } // 获取表主键列名,排除使用别名 - primaryKey := dbInstance.GetMeta().GetPrimaryKey(tableName) + primaryKey := dbConn.GetMeta().GetPrimaryKey(tableName) updateColumnsAndPrimaryKey := strings.Join(updateColumns, ",") + "," + primaryKey // 查询要更新字段数据的旧值,以及主键值 selectSql := fmt.Sprintf("SELECT %s FROM %s %s LIMIT 200", updateColumnsAndPrimaryKey, tableStr, where) - _, res, err := dbInstance.SelectData(selectSql) + _, res, err := dbConn.SelectData(selectSql) if err == nil { bytes, _ := json.Marshal(res) dbSqlExec.OldValue = string(bytes) @@ -218,11 +218,11 @@ func doUpdate(update *sqlparser.Update, execSqlReq *DbSqlExecReq, dbSqlExec *ent dbSqlExec.Table = tableName dbSqlExec.Type = entity.DbSqlExecTypeUpdate - return doExec(execSqlReq.Sql, dbInstance) + return doExec(execSqlReq.Sql, dbConn) } func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) { - dbInstance := execSqlReq.DbInstance + dbConn := execSqlReq.DbConn tableStr := sqlparser.String(delete.TableExprs) // 可能使用别名,故空格切割 @@ -232,14 +232,14 @@ func doDelete(delete *sqlparser.Delete, execSqlReq *DbSqlExecReq, dbSqlExec *ent // 查询删除数据 selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where) - _, res, _ := dbInstance.SelectData(selectSql) + _, res, _ := dbConn.SelectData(selectSql) bytes, _ := json.Marshal(res) dbSqlExec.OldValue = string(bytes) dbSqlExec.Table = table dbSqlExec.Type = entity.DbSqlExecTypeDelete - return doExec(execSqlReq.Sql, dbInstance) + return doExec(execSqlReq.Sql, dbConn) } func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *entity.DbSqlExec) (*DbSqlExecRes, error) { @@ -249,11 +249,11 @@ func doInsert(insert *sqlparser.Insert, execSqlReq *DbSqlExecReq, dbSqlExec *ent dbSqlExec.Table = table dbSqlExec.Type = entity.DbSqlExecTypeInsert - return doExec(execSqlReq.Sql, execSqlReq.DbInstance) + return doExec(execSqlReq.Sql, execSqlReq.DbConn) } -func doExec(sql string, dbInstance *DbConnection) (*DbSqlExecRes, error) { - rowsAffected, err := dbInstance.Exec(sql) +func doExec(sql string, dbConn *DbConnection) (*DbSqlExecRes, error) { + rowsAffected, err := dbConn.Exec(sql) execRes := "success" if err != nil { execRes = err.Error() diff --git a/server/internal/db/router/db.go b/server/internal/db/router/db.go index 25800869..1b5136bb 100644 --- a/server/internal/db/router/db.go +++ b/server/internal/db/router/db.go @@ -29,9 +29,6 @@ func InitDbRouter(router *gin.RouterGroup) { req.NewPost("", d.Save).Log(req.NewLogSave("db-保存数据库信息")), - // 获取数据库实例的所有数据库名 - req.NewPost("/databases", d.GetDatabaseNames), - req.NewDelete(":dbId", d.DeleteDb).Log(req.NewLogSave("db-删除数据库信息")), req.NewGet(":dbId/t-infos", d.TableInfos),