diff --git a/README.md b/README.md index 501a8491..9901a466 100644 --- a/README.md +++ b/README.md @@ -15,11 +15,14 @@ github star github fork + + github star + docker pulls - golang + golang vue @@ -106,7 +109,7 @@ http://go.mayfly.run ## 💌 支持作者 -如果觉得项目不错,或者已经在使用了,希望你可以去 Github 或者 Gitee 帮我点个 ⭐ Star,这将是对我极大的鼓励与支持。 +如果觉得项目不错,或者已经在使用了,希望你可以去 GithubGiteeGitcode 帮我点个 ⭐ Star,这将是对我极大的鼓励与支持。 > 喝杯咖啡 ☕️ 或者来杯奶茶 🧋,让作者更有精神,写出更棒的代码! diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index ec1db401..f826f145 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -8,7 +8,6 @@ import ( "mayfly-go/internal/db/application" "mayfly-go/internal/db/application/dto" "mayfly-go/internal/db/config" - "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/imsg" @@ -140,10 +139,12 @@ func (d *Db) DeleteDb(rc *req.Ctx) { func (d *Db) ExecSql(rc *req.Ctx) { form := req.BindJsonAndValid(rc, new(form.DbSqlExecForm)) + ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second) + defer cancel() + dbId := getDbId(rc) - dbConn, err := d.dbApp.GetDbConn(dbId, form.Db) + dbConn, err := d.dbApp.GetDbConn(ctx, dbId, form.Db) biz.ErrIsNil(err) - defer dbm.PutDbConn(dbConn) biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s") @@ -163,9 +164,6 @@ func (d *Db) ExecSql(rc *req.Ctx) { CheckFlow: true, } - ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second) - defer cancel() - execRes, err := d.dbSqlExecApp.Exec(ctx, execReq) biz.ErrIsNil(err) rc.ResData = execRes @@ -194,9 +192,8 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) { dbName := getDbName(rc) clientId := rc.Query("clientId") - dbConn, err := d.dbApp.GetDbConn(dbId, dbName) + dbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, dbId, dbName) biz.ErrIsNil(err) - defer dbm.PutDbConn(dbConn) biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s") rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc()) @@ -228,9 +225,8 @@ func (d *Db) DumpSql(rc *req.Ctx) { needData := dumpType == "2" || dumpType == "3" la := rc.GetLoginAccount() - dbConn, err := d.dbApp.GetDbConn(dbId, dbName) + dbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, dbId, dbName) biz.ErrIsNil(err) - defer dbm.PutDbConn(dbConn) biz.ErrIsNilAppendErr(d.tagApp.CanAccess(la.Id, dbConn.Info.CodePath...), "%s") @@ -358,9 +354,8 @@ func (d *Db) CopyTable(rc *req.Ctx) { form := &form.DbCopyTableForm{} copy := req.BindJsonAndCopyTo[*dbi.DbCopyTable](rc, form, new(dbi.DbCopyTable)) - conn, err := d.dbApp.GetDbConn(form.Id, form.Db) + conn, err := d.dbApp.GetDbConn(rc.MetaCtx, form.Id, form.Db) biz.ErrIsNilAppendErr(err, "copy table error: %s") - defer dbm.PutDbConn(conn) err = conn.GetDialect().CopyTable(copy) if err != nil { @@ -382,8 +377,7 @@ func getDbName(rc *req.Ctx) string { } func (d *Db) getDbConn(rc *req.Ctx) *dbi.DbConn { - dc, err := d.dbApp.GetDbConn(getDbId(rc), getDbName(rc)) + dc, err := d.dbApp.GetDbConn(rc.MetaCtx, getDbId(rc), getDbName(rc)) biz.ErrIsNil(err) - defer dbm.PutDbConn(dc) return dc } diff --git a/server/internal/db/api/db_instance.go b/server/internal/db/api/db_instance.go index 1105c9ca..3eb90289 100644 --- a/server/internal/db/api/db_instance.go +++ b/server/internal/db/api/db_instance.go @@ -93,7 +93,7 @@ func (d *Instance) TestConn(rc *req.Ctx) { form := &form.InstanceForm{} instance := req.BindJsonAndCopyTo[*entity.DbInstance](rc, form, new(entity.DbInstance)) - biz.ErrIsNil(d.instanceApp.TestConn(instance, form.AuthCerts[0])) + biz.ErrIsNil(d.instanceApp.TestConn(rc.MetaCtx, instance, form.AuthCerts[0])) } // SaveInstance 保存数据库实例信息 @@ -137,13 +137,13 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) { func (d *Instance) GetDatabaseNames(rc *req.Ctx) { form := &form.InstanceDbNamesForm{} instance := req.BindJsonAndCopyTo[*entity.DbInstance](rc, form, new(entity.DbInstance)) - res, err := d.instanceApp.GetDatabases(instance, form.AuthCert) + res, err := d.instanceApp.GetDatabases(rc.MetaCtx, instance, form.AuthCert) biz.ErrIsNil(err) rc.ResData = res } func (d *Instance) GetDatabaseNamesByAc(rc *req.Ctx) { - res, err := d.instanceApp.GetDatabasesByAc(rc.PathParam("ac")) + res, err := d.instanceApp.GetDatabasesByAc(rc.MetaCtx, rc.PathParam("ac")) biz.ErrIsNil(err) rc.ResData = res } @@ -151,7 +151,7 @@ func (d *Instance) GetDatabaseNamesByAc(rc *req.Ctx) { // 获取数据库实例server信息 func (d *Instance) GetDbServer(rc *req.Ctx) { instanceId := getInstanceId(rc) - conn, err := d.dbApp.GetDbConnByInstanceId(instanceId) + conn, err := d.dbApp.GetDbConnByInstanceId(rc.MetaCtx, instanceId) biz.ErrIsNil(err) res, err := conn.GetMetadata().GetDbServer() biz.ErrIsNil(err) diff --git a/server/internal/db/api/db_transfer.go b/server/internal/db/api/db_transfer.go index 517c86a8..d4281916 100644 --- a/server/internal/db/api/db_transfer.go +++ b/server/internal/db/api/db_transfer.go @@ -6,7 +6,6 @@ import ( "mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/application" "mayfly-go/internal/db/application/dto" - "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/imsg" fileapp "mayfly-go/internal/file/application" @@ -150,9 +149,8 @@ func (d *DbTransferTask) FileRun(rc *req.Ctx) { tFile, err := d.dbTransferFile.GetById(fm.Id) biz.IsTrue(tFile != nil && err == nil, "file not found") - targetDbConn, err := d.dbApp.GetDbConn(fm.TargetDbId, fm.TargetDbName) + targetDbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, fm.TargetDbId, fm.TargetDbName) biz.ErrIsNilAppendErr(err, "failed to connect to the target database: %s") - defer dbm.PutDbConn(targetDbConn) biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, targetDbConn.Info.CodePath...), "%s") diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index b4181f5c..8df76dd4 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -40,10 +40,10 @@ type Db interface { // @param id 数据库id // // @param dbName 数据库名 - GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) + GetDbConn(ctx context.Context, dbId uint64, dbName string) (*dbi.DbConn, error) // 根据数据库实例id获取连接,随机返回该instanceId下已连接的conn,若不存在则是使用该instanceId关联的db进行连接并返回。 - GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) + GetDbConnByInstanceId(ctx context.Context, instanceId uint64) (*dbi.DbConn, error) // DumpDb dumpDb DumpDb(ctx context.Context, reqParam *dto.DumpDb) error @@ -170,8 +170,8 @@ func (d *dbAppImpl) Delete(ctx context.Context, id uint64) error { }) } -func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) { - return dbm.GetDbConn(dbId, dbName, func() (*dbi.DbInfo, error) { +func (d *dbAppImpl) GetDbConn(ctx context.Context, dbId uint64, dbName string) (*dbi.DbConn, error) { + return dbm.GetDbConn(ctx, dbId, dbName, func() (*dbi.DbInfo, error) { db, err := d.GetById(dbId) if err != nil { return nil, errorx.NewBiz("db not found") @@ -198,8 +198,8 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) { }) } -func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) { - conn := dbm.GetDbConnByInstanceId(instanceId) +func (d *dbAppImpl) GetDbConnByInstanceId(ctx context.Context, instanceId uint64) (*dbi.DbConn, error) { + conn := dbm.GetDbConnByInstanceId(ctx, instanceId) if conn != nil { return conn, nil } @@ -214,7 +214,7 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error // 使用该实例关联的已配置数据库中的第一个库进行连接并返回 firstDb := dbs[0] - return d.GetDbConn(firstDb.Id, strings.Split(firstDb.Database, " ")[0]) + return d.GetDbConn(ctx, firstDb.Id, strings.Split(firstDb.Database, " ")[0]) } func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { @@ -233,7 +233,7 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { dbName := reqParam.DbName tables := reqParam.Tables - dbConn, err := d.GetDbConn(dbId, dbName) + dbConn, err := d.GetDbConn(ctx, dbId, dbName) if err != nil { return err } diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index 0907e7df..26d6ca2b 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" @@ -110,7 +109,9 @@ func (app *dataSyncAppImpl) AddCronJob(ctx context.Context, taskEntity *entity.D taskId := taskEntity.Id if err := scheduler.AddFunByKey(key, taskEntity.TaskCron, func() { logx.Infof("start the data synchronization task: %d", taskId) - if err := app.RunCronJob(ctx, taskId); err != nil { + cancelCtx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + if err := app.RunCronJob(cancelCtx, taskId); err != nil { logx.Errorf("the data synchronization task failed to execute at a scheduled time: %s", err.Error()) } }); err != nil { @@ -150,8 +151,7 @@ func (app *dataSyncAppImpl) RunCronJob(ctx context.Context, id uint64) error { logx.ErrorfContext(ctx, "data source connection unavailable: %s", err.Error()) return } - srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName) - defer dbm.PutDbConn(srcConn) + srcConn, err := app.dbApp.GetDbConn(ctx, uint64(task.SrcDbId), task.SrcDbName) if err != nil { logx.ErrorfContext(ctx, "failed to connect to the source database: %s", err.Error()) return @@ -205,16 +205,14 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en } // 获取源数据库连接 - srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName) - defer dbm.PutDbConn(srcConn) + srcConn, err := app.dbApp.GetDbConn(ctx, uint64(task.SrcDbId), task.SrcDbName) if err != nil { return syncLog, errorx.NewBiz("failed to connect to the source database: %s", err.Error()) } // 获取目标数据库连接 - targetConn, err := app.dbApp.GetDbConn(uint64(task.TargetDbId), task.TargetDbName) - defer dbm.PutDbConn(targetConn) + targetConn, err := app.dbApp.GetDbConn(ctx, uint64(task.TargetDbId), task.TargetDbName) if err != nil { return syncLog, errorx.NewBiz("failed to connect to the target database: %s", err.Error()) } diff --git a/server/internal/db/application/db_instance.go b/server/internal/db/application/db_instance.go index 811ab49a..7d832130 100644 --- a/server/internal/db/application/db_instance.go +++ b/server/internal/db/application/db_instance.go @@ -27,7 +27,7 @@ type Instance interface { // GetPageList 分页获取数据库实例 GetPageList(condition *entity.InstanceQuery, orderBy ...string) (*model.PageResult[*entity.DbInstance], error) - TestConn(instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error + TestConn(ctx context.Context, instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error SaveDbInstance(ctx context.Context, instance *dto.SaveDbInstance) (uint64, error) @@ -35,10 +35,10 @@ type Instance interface { Delete(ctx context.Context, id uint64) error // GetDatabases 获取数据库实例的所有数据库列表 - GetDatabases(entity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error) + GetDatabases(ctx context.Context, entity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error) // GetDatabasesByAc 根据授权凭证名获取所有数据库名称列表 - GetDatabasesByAc(acName string) ([]string, error) + GetDatabasesByAc(ctx context.Context, acName string) ([]string, error) // ToDbInfo 根据实例与授权凭证返回对应的DbInfo ToDbInfo(instance *entity.DbInstance, authCertName string, database string) (*dbi.DbInfo, error) @@ -59,7 +59,7 @@ func (app *instanceAppImpl) GetPageList(condition *entity.InstanceQuery, orderBy return app.GetRepo().GetInstanceList(condition, orderBy...) } -func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error { +func (app *instanceAppImpl) TestConn(ctx context.Context, instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error { instanceEntity.Network = instanceEntity.GetNetwork() authCert, err := app.resourceAuthCertApp.GetRealAuthCert(authCert) @@ -67,7 +67,7 @@ func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance, authCert return err } - dbConn, err := dbm.Conn(app.toDbInfoByAc(instanceEntity, authCert, "")) + dbConn, err := dbm.Conn(ctx, app.toDbInfoByAc(instanceEntity, authCert, "")) if err != nil { return err } @@ -185,7 +185,7 @@ func (app *instanceAppImpl) Delete(ctx context.Context, instanceId uint64) error }) } -func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error) { +func (app *instanceAppImpl) GetDatabases(ctx context.Context, ed *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error) { if authCert.Id != 0 { // 密文可能被清除,故需要重新获取 authCert, _ = app.resourceAuthCertApp.GetAuthCert(authCert.Name) @@ -199,10 +199,10 @@ func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance, authCert *tagent } } - return app.getDatabases(ed, authCert) + return app.getDatabases(ctx, ed, authCert) } -func (app *instanceAppImpl) GetDatabasesByAc(acName string) ([]string, error) { +func (app *instanceAppImpl) GetDatabasesByAc(ctx context.Context, acName string) ([]string, error) { ac, err := app.resourceAuthCertApp.GetAuthCert(acName) if err != nil { return nil, errorx.NewBiz("db ac not found") @@ -214,7 +214,7 @@ func (app *instanceAppImpl) GetDatabasesByAc(acName string) ([]string, error) { return nil, errorx.NewBiz("the db instance information for this ac does not exist") } - return app.getDatabases(instance, ac) + return app.getDatabases(ctx, instance, ac) } func (app *instanceAppImpl) ToDbInfo(instance *entity.DbInstance, authCertName string, database string) (*dbi.DbInfo, error) { @@ -226,11 +226,11 @@ func (app *instanceAppImpl) ToDbInfo(instance *entity.DbInstance, authCertName s return app.toDbInfoByAc(instance, ac, database), nil } -func (app *instanceAppImpl) getDatabases(instance *entity.DbInstance, ac *tagentity.ResourceAuthCert) ([]string, error) { +func (app *instanceAppImpl) getDatabases(ctx context.Context, instance *entity.DbInstance, ac *tagentity.ResourceAuthCert) ([]string, error) { instance.Network = instance.GetNetwork() dbi := app.toDbInfoByAc(instance, ac, "") - dbConn, err := dbm.Conn(dbi) + dbConn, err := dbm.Conn(ctx, dbi) if err != nil { return nil, err } diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 38da80bc..89add493 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -5,7 +5,6 @@ import ( "fmt" "mayfly-go/internal/db/application/dto" "mayfly-go/internal/db/config" - "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/dbm/sqlparser" "mayfly-go/internal/db/dbm/sqlparser/sqlstmt" @@ -283,8 +282,7 @@ func (d *dbSqlExecAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *fl return nil, errorx.NewBiz("failed to parse the business form information: %s", err.Error()) } - dbConn, err := d.dbApp.GetDbConn(execSqlBizForm.DbId, execSqlBizForm.DbName) - defer dbm.PutDbConn(dbConn) + dbConn, err := d.dbApp.GetDbConn(ctx, execSqlBizForm.DbId, execSqlBizForm.DbName) if err != nil { return nil, err } diff --git a/server/internal/db/application/db_transfer.go b/server/internal/db/application/db_transfer.go index 99e32cc2..e6c1d54f 100644 --- a/server/internal/db/application/db_transfer.go +++ b/server/internal/db/application/db_transfer.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "mayfly-go/internal/db/application/dto" - "mayfly-go/internal/db/dbm" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/dbm/sqlparser" "mayfly-go/internal/db/domain/entity" @@ -209,8 +208,7 @@ func (app *dbTransferAppImpl) Run(ctx context.Context, taskId uint64, logId uint // 获取源库连接、目标库连接,判断连接可用性,否则记录日志:xx连接不可用 // 获取源库表信息 - srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName) - defer dbm.PutDbConn(srcConn) + srcConn, err := app.dbApp.GetDbConn(ctx, uint64(task.SrcDbId), task.SrcDbName) if err != nil { app.EndTransfer(ctx, logId, taskId, "failed to obtain source db connection", err, nil) return @@ -248,8 +246,7 @@ func (app *dbTransferAppImpl) Run(ctx context.Context, taskId uint64, logId uint func (app *dbTransferAppImpl) transfer2Db(ctx context.Context, taskId uint64, logId uint64, task *entity.DbTransferTask, srcConn *dbi.DbConn, start time.Time, tables []dbi.Table) { // 获取目标库表信息 - targetConn, err := app.dbApp.GetDbConn(uint64(task.TargetDbId), task.TargetDbName) - defer dbm.PutDbConn(targetConn) + targetConn, err := app.dbApp.GetDbConn(ctx, uint64(task.TargetDbId), task.TargetDbName) if err != nil { app.EndTransfer(ctx, logId, taskId, "failed to get target db connection", err, nil) return diff --git a/server/internal/db/dbm/dbi/conn.go b/server/internal/db/dbm/dbi/conn.go index d5caab7a..04cbe675 100644 --- a/server/internal/db/dbm/dbi/conn.go +++ b/server/internal/db/dbm/dbi/conn.go @@ -21,6 +21,29 @@ type DbConn struct { db *sql.DB } +/******************* pool.Conn impl *******************/ + +// 关闭连接 +func (d *DbConn) Close() error { + if d.db != nil { + logx.Debugf("dbm - conn close, connId: %s", d.Id) + if err := d.db.Close(); err != nil { + logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error()) + } + // TODO 关闭实例隧道会影响其他正在使用的连接,所以暂时不关闭 + //if d.Info.useSshTunnel { + // mcm.CloseSshTunnelMachine(d.Info.SshTunnelMachineId, fmt.Sprintf("db:%d", d.Info.Id)) + //} + d.db = nil + } + + return nil +} + +func (d *DbConn) Ping() error { + return d.db.Ping() +} + // 执行数据库查询返回的列信息 type QueryColumn struct { Name string `json:"name"` // 列名 @@ -167,24 +190,6 @@ func (d *DbConn) Stats(ctx context.Context, execSql string, args ...any) sql.DBS return d.db.Stats() } -// 关闭连接 -func (d *DbConn) Close() { - if d.db != nil { - if err := d.db.Close(); err != nil { - logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error()) - } - // TODO 关闭实例隧道会影响其他正在使用的连接,所以暂时不关闭 - //if d.Info.useSshTunnel { - // mcm.CloseSshTunnelMachine(d.Info.SshTunnelMachineId, fmt.Sprintf("db:%d", d.Info.Id)) - //} - d.db = nil - } -} - -func (d *DbConn) Ping() error { - return d.db.Ping() -} - // 游标方式遍历查询rows, walkFn error不为nil, 则跳出遍历 func (d *DbConn) walkQueryRows(ctx context.Context, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) { cancelCtx, cancelFunc := context.WithCancel(ctx) diff --git a/server/internal/db/dbm/dbi/db_info.go b/server/internal/db/dbm/dbi/db_info.go index 50e77639..8e01551d 100644 --- a/server/internal/db/dbm/dbi/db_info.go +++ b/server/internal/db/dbm/dbi/db_info.go @@ -1,6 +1,7 @@ package dbi import ( + "context" "fmt" machineapp "mayfly-go/internal/machine/application" "mayfly-go/internal/machine/mcm" @@ -52,7 +53,7 @@ func (di *DbInfo) GetLogDesc() string { } // 连接数据库 -func (di *DbInfo) Conn(meta Meta) (*DbConn, error) { +func (di *DbInfo) Conn(ctx context.Context, meta Meta) (*DbConn, error) { if meta == nil { return nil, errorx.NewBiz("the database meta information interface cannot be empty") } @@ -66,7 +67,7 @@ func (di *DbInfo) Conn(meta Meta) (*DbConn, error) { di.Database = database } - conn, err := meta.GetSqlDb(di) + conn, err := meta.GetSqlDb(ctx, di) if err != nil { logx.Errorf("db connection failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error()) return nil, errorx.NewBiz("db connection failed: %s", err.Error()) @@ -83,7 +84,7 @@ func (di *DbInfo) Conn(meta Meta) (*DbConn, error) { // 最大连接周期,超过时间的连接就close // conn.SetConnMaxLifetime(100 * time.Second) // 设置最大连接数 - conn.SetMaxOpenConns(5) + conn.SetMaxOpenConns(6) // 设置闲置连接数 conn.SetMaxIdleConns(1) dbc.db = conn @@ -93,10 +94,10 @@ func (di *DbInfo) Conn(meta Meta) (*DbConn, error) { } // 如果使用了ssh隧道,将其host port改变其本地映射host port -func (di *DbInfo) IfUseSshTunnelChangeIpPort() error { +func (di *DbInfo) IfUseSshTunnelChangeIpPort(ctx context.Context) error { // 开启ssh隧道 if di.SshTunnelMachineId > 0 { - sshTunnelMachine, err := GetSshTunnel(di.SshTunnelMachineId) + sshTunnelMachine, err := GetSshTunnel(ctx, di.SshTunnelMachineId) if err != nil { return err } @@ -133,8 +134,8 @@ func (di *DbInfo) GetDatabase() string { } // 根据ssh tunnel机器id返回ssh tunnel -func GetSshTunnel(sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) { - return machineapp.GetMachineApp().GetSshTunnelMachine(sshTunnelMachineId) +func GetSshTunnel(ctx context.Context, sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) { + return machineapp.GetMachineApp().GetSshTunnelMachine(ctx, sshTunnelMachineId) } // 获取连接id @@ -143,5 +144,5 @@ func GetDbConnId(dbId uint64, db string) string { return "" } - return fmt.Sprintf("%d:%s", dbId, db) + return fmt.Sprintf("db-%d:%s", dbId, db) } diff --git a/server/internal/db/dbm/dbi/meta.go b/server/internal/db/dbm/dbi/meta.go index 5e25983b..79bc0ffe 100644 --- a/server/internal/db/dbm/dbi/meta.go +++ b/server/internal/db/dbm/dbi/meta.go @@ -1,6 +1,7 @@ package dbi import ( + "context" "database/sql" ) @@ -14,7 +15,7 @@ type DbVersion string // Meta 数据库元信息,如获取sql.DB、Dialect等 type Meta interface { // GetSqlDb 根据数据库信息获取sql.DB - GetSqlDb(*DbInfo) (*sql.DB, error) + GetSqlDb(context.Context, *DbInfo) (*sql.DB, error) // GetDialect 获取数据库方言, 若一些接口不需要DbConn,则可以传nil GetDialect(*DbConn) Dialect diff --git a/server/internal/db/dbm/dbm.go b/server/internal/db/dbm/dbm.go index fa67e4fe..37dfab37 100644 --- a/server/internal/db/dbm/dbm.go +++ b/server/internal/db/dbm/dbm.go @@ -1,7 +1,7 @@ package dbm import ( - "fmt" + "context" "mayfly-go/internal/db/dbm/dbi" _ "mayfly-go/internal/db/dbm/dm" _ "mayfly-go/internal/db/dbm/mssql" @@ -11,103 +11,54 @@ import ( _ "mayfly-go/internal/db/dbm/sqlite" "mayfly-go/pkg/logx" "mayfly-go/pkg/pool" - "time" ) -var connPool = make(map[string]pool.Pool) -var instPool = make(map[uint64]pool.Pool) +var ( + poolGroup = pool.NewPoolGroup[*dbi.DbConn]() +) -func init() { -} - -// PutDbConn 释放连接 -func PutDbConn(c *dbi.DbConn) { - if nil == c { - return - } - connId := dbi.GetDbConnId(c.Info.Id, c.Info.Database) - if p, ok := connPool[connId]; ok { - p.Put(c) - } -} - -func getPool(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (pool.Pool, error) { +// GetDbConn 从连接池中获取连接信息,记的用完连接后必须调用 PutDbConn 还回池 +func GetDbConn(ctx context.Context, dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) { connId := dbi.GetDbConnId(dbId, database) - // 获取连接池,如果没有,则创建一个 - if p, ok := connPool[connId]; !ok { - var err error - p, err = pool.NewChannelPool(&pool.Config{ - InitialCap: 1, //资源池初始连接数 - MaxCap: 10, //最大空闲连接数 - MaxIdle: 10, //最大并发连接数 - IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效 - Factory: func() (interface{}, error) { - // 若缓存中不存在,则从回调函数中获取DbInfo - dbInfo, err := getDbInfo() - if err != nil { - return nil, err - } - - // 连接数据库 - return Conn(dbInfo) - }, - Close: func(v interface{}) error { - v.(*dbi.DbConn).Close() - return nil - }, - Ping: func(v interface{}) error { - return v.(*dbi.DbConn).Ping() - }, - }) + pool, err := poolGroup.GetCachePool(connId, func() (*dbi.DbConn, error) { + // 若缓存中不存在,则从回调函数中获取DbInfo + dbInfo, err := getDbInfo() if err != nil { return nil, err } - connPool[connId] = p - instPool[dbId] = p - return p, nil - } else { - return p, nil - } -} + logx.Debugf("dbm - conn create, connId: %s, dbInfo: %v", connId, dbInfo) + // 连接数据库 + return Conn(ctx, dbInfo) + }) -// GetDbConn 从连接池中获取连接信息,记的用完连接后必须调用 PutDbConn 还回池 -func GetDbConn(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) { - - p, err := getPool(dbId, database, getDbInfo) if err != nil { return nil, err } // 从连接池中获取一个可用的连接 - c, err := p.Get() - if err != nil { - return nil, err - } - ec := c.(*dbi.DbConn) - return ec, nil - + return pool.Get(ctx) } // 使用指定dbInfo信息进行连接 -func Conn(di *dbi.DbInfo) (*dbi.DbConn, error) { - return di.Conn(dbi.GetMeta(di.Type)) +func Conn(ctx context.Context, di *dbi.DbInfo) (*dbi.DbConn, error) { + return di.Conn(ctx, dbi.GetMeta(di.Type)) } // 根据实例id获取连接 -func GetDbConnByInstanceId(instanceId uint64) *dbi.DbConn { - if p, ok := instPool[instanceId]; ok { - c, err := p.Get() +func GetDbConnByInstanceId(ctx context.Context, instanceId uint64) *dbi.DbConn { + for _, pool := range poolGroup.AllPool() { + conn, err := pool.Get(ctx) if err != nil { - logx.Error(fmt.Sprintf("实例id[%d]连接获取失败:%s", instanceId, err)) - return nil + continue + } + if conn.Info.InstanceId == instanceId { + return conn } - return c.(*dbi.DbConn) } return nil } // 删除db缓存并关闭该数据库所有连接 func CloseDb(dbId uint64, db string) { - delete(connPool, dbi.GetDbConnId(dbId, db)) - delete(instPool, dbId) + poolGroup.Close(dbi.GetDbConnId(dbId, db)) } diff --git a/server/internal/db/dbm/dm/meta.go b/server/internal/db/dbm/dm/meta.go index 0442715f..8f05f05a 100644 --- a/server/internal/db/dbm/dm/meta.go +++ b/server/internal/db/dbm/dm/meta.go @@ -1,6 +1,7 @@ package dm import ( + "context" "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" @@ -20,7 +21,7 @@ const ( type Meta struct { } -func (dm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { +func (dm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) { driverName := "dm" db := d.Database var dbParam string @@ -36,7 +37,7 @@ func (dm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { dbParam = "?escapeProcess=true" } - err := d.IfUseSshTunnelChangeIpPort() + err := d.IfUseSshTunnelChangeIpPort(ctx) if err != nil { return nil, err } diff --git a/server/internal/db/dbm/mssql/meta.go b/server/internal/db/dbm/mssql/meta.go index 451b34bd..db671e22 100644 --- a/server/internal/db/dbm/mssql/meta.go +++ b/server/internal/db/dbm/mssql/meta.go @@ -1,6 +1,7 @@ package mssql import ( + "context" "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" @@ -23,8 +24,8 @@ const ( type Meta struct { } -func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { - err := d.IfUseSshTunnelChangeIpPort() +func (mm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) { + err := d.IfUseSshTunnelChangeIpPort(ctx) if err != nil { return nil, err } diff --git a/server/internal/db/dbm/mysql/meta.go b/server/internal/db/dbm/mysql/meta.go index dea04e1e..eff76a16 100644 --- a/server/internal/db/dbm/mysql/meta.go +++ b/server/internal/db/dbm/mysql/meta.go @@ -25,10 +25,10 @@ const ( type Meta struct { } -func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { +func (mm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) { // SSH Conect if d.SshTunnelMachineId > 0 { - sshTunnelMachine, err := dbi.GetSshTunnel(d.SshTunnelMachineId) + sshTunnelMachine, err := dbi.GetSshTunnel(ctx, d.SshTunnelMachineId) if err != nil { return nil, err } diff --git a/server/internal/db/dbm/oracle/meta.go b/server/internal/db/dbm/oracle/meta.go index 46238a4d..82517b62 100644 --- a/server/internal/db/dbm/oracle/meta.go +++ b/server/internal/db/dbm/oracle/meta.go @@ -1,6 +1,7 @@ package oracle import ( + "context" "database/sql" "fmt" "mayfly-go/internal/db/dbm/dbi" @@ -23,8 +24,8 @@ const ( type Meta struct { } -func (om *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { - err := d.IfUseSshTunnelChangeIpPort() +func (om *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) { + err := d.IfUseSshTunnelChangeIpPort(ctx) if err != nil { return nil, err } diff --git a/server/internal/db/dbm/postgres/meta.go b/server/internal/db/dbm/postgres/meta.go index 9c31d3cf..29058cac 100644 --- a/server/internal/db/dbm/postgres/meta.go +++ b/server/internal/db/dbm/postgres/meta.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "database/sql" "database/sql/driver" "fmt" @@ -37,7 +38,7 @@ type Meta struct { Param string } -func (pm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { +func (pm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) { driverName := "postgres" // SSH Conect if d.SshTunnelMachineId > 0 { @@ -120,7 +121,8 @@ func (pd *PqSqlDialer) Open(name string) (driver.Conn, error) { } func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) { - sshTunnel, err := dbi.GetSshTunnel(pd.sshTunnelMachineId) + // todo context.Background可能存在问题 + sshTunnel, err := dbi.GetSshTunnel(context.Background(), pd.sshTunnelMachineId) if err != nil { return nil, err } diff --git a/server/internal/db/dbm/sqlite/meta.go b/server/internal/db/dbm/sqlite/meta.go index 3b822631..21b5b3b2 100644 --- a/server/internal/db/dbm/sqlite/meta.go +++ b/server/internal/db/dbm/sqlite/meta.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "errors" "mayfly-go/internal/db/dbm/dbi" @@ -19,7 +20,7 @@ const ( type Meta struct { } -func (md *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { +func (md *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) { // 用host字段来存sqlite的文件路径 // 检查文件是否存在,否则报错,基于sqlite会自动创建文件,为了服务器文件安全,所以先确定文件存在再连接,不自动创建 if _, err := os.Stat(d.Host); err != nil { diff --git a/server/internal/es/api/es_instance.go b/server/internal/es/api/es_instance.go index 18b94b87..c507328b 100644 --- a/server/internal/es/api/es_instance.go +++ b/server/internal/es/api/es_instance.go @@ -1,7 +1,6 @@ package api import ( - "github.com/may-fly/cast" "mayfly-go/internal/es/api/form" "mayfly-go/internal/es/api/vo" "mayfly-go/internal/es/application" @@ -19,6 +18,8 @@ import ( "net/http" "net/url" "strings" + + "github.com/may-fly/cast" ) type Instance struct { @@ -99,7 +100,7 @@ func (d *Instance) TestConn(rc *req.Ctx) { ac = fm.AuthCerts[0] } - res, err := d.inst.TestConn(instance, ac) + res, err := d.inst.TestConn(rc.MetaCtx, instance, ac) biz.ErrIsNil(err) rc.ResData = res } @@ -133,7 +134,7 @@ func (d *Instance) Proxy(rc *req.Ctx) { r := rc.GetRequest() _ = RemoveQueryParam(r, "id", "path") - err := d.inst.DoConn(instanceId, func(conn *esi.EsConn) error { + err := d.inst.DoConn(rc.MetaCtx, instanceId, func(conn *esi.EsConn) error { conn.Proxy(rc.GetWriter(), r, path) return nil }) diff --git a/server/internal/es/application/es_instance.go b/server/internal/es/application/es_instance.go index 06a44270..19177365 100644 --- a/server/internal/es/application/es_instance.go +++ b/server/internal/es/application/es_instance.go @@ -19,7 +19,6 @@ import ( "mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/stringx" "mayfly-go/pkg/utils/structx" - "time" ) type Instance interface { @@ -28,9 +27,9 @@ type Instance interface { GetPageList(condition *entity.InstanceQuery, orderBy ...string) (*model.PageResult[*entity.EsInstance], error) // DoConn 获取连接并执行函数 - DoConn(instanceId uint64, fn func(*esi.EsConn) error) error + DoConn(ctx context.Context, instanceId uint64, fn func(*esi.EsConn) error) error - TestConn(instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error) + TestConn(ctx context.Context, instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error) SaveInst(ctx context.Context, d *dto.SaveEsInstance) (uint64, error) @@ -39,7 +38,7 @@ type Instance interface { var _ Instance = &instanceAppImpl{} -var connPool = make(map[uint64]pool.Pool) +var poolGroup = pool.NewPoolGroup[*esi.EsConn]() type instanceAppImpl struct { base.AppImpl[*entity.EsInstance, repository.EsInstance] @@ -53,54 +52,24 @@ func (app *instanceAppImpl) GetPageList(condition *entity.InstanceQuery, orderBy return app.GetRepo().GetInstanceList(condition, orderBy...) } -func (app *instanceAppImpl) DoConn(instanceId uint64, fn func(*esi.EsConn) error) error { - // 通过实例id获取实例连接信息 - p, err := app.getPool(instanceId) +func (app *instanceAppImpl) DoConn(ctx context.Context, instanceId uint64, fn func(*esi.EsConn) error) error { + p, err := poolGroup.GetChanPool(fmt.Sprintf("es-%d", instanceId), func() (*esi.EsConn, error) { + return app.createConn(ctx, instanceId) + }) + if err != nil { return err } // 从连接池中获取一个可用的连接 - c, err := p.Get() + c, err := p.Get(ctx) if err != nil { return err } - ec := c.(*esi.EsConn) - // 用完后放回连接池 - defer p.Put(c) - - return fn(ec) + return fn(c) } -func (app *instanceAppImpl) getPool(instanceId uint64) (pool.Pool, error) { - // 获取连接池,如果没有,则创建一个 - if p, ok := connPool[instanceId]; !ok { - var err error - p, err = pool.NewChannelPool(&pool.Config{ - InitialCap: 1, //资源池初始连接数 - MaxCap: 10, //最大空闲连接数 - MaxIdle: 10, //最大并发连接数 - IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效 - Factory: func() (interface{}, error) { - return app.createConn(instanceId) - }, - Close: func(v interface{}) error { - return v.(*esi.EsConn).Close() - }, - Ping: func(v interface{}) error { - return v.(*esi.EsConn).Ping() - }, - }) - if err != nil { - return nil, err - } - connPool[instanceId] = p - return p, nil - } else { - return p, nil - } -} -func (app *instanceAppImpl) createConn(instanceId uint64) (*esi.EsConn, error) { +func (app *instanceAppImpl) createConn(ctx context.Context, instanceId uint64) (*esi.EsConn, error) { // 缓存不存在,则重新连接 instance, err := app.GetById(instanceId) if err != nil { @@ -113,7 +82,7 @@ func (app *instanceAppImpl) createConn(instanceId uint64) (*esi.EsConn, error) { } ei.CodePath = app.tagApp.ListTagPathByTypeAndCode(int8(tagentity.TagTypeEsInstance), instance.Code) - conn, _, err := ei.Conn() + conn, _, err := ei.Conn(ctx) if err != nil { return nil, err } @@ -151,7 +120,7 @@ func (app *instanceAppImpl) ToEsInfo(instance *entity.EsInstance, ac *tagentity. return ei, nil } -func (app *instanceAppImpl) TestConn(instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error) { +func (app *instanceAppImpl) TestConn(ctx context.Context, instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error) { instance.Network = instance.GetNetwork() ei, err := app.ToEsInfo(instance, ac) @@ -159,7 +128,7 @@ func (app *instanceAppImpl) TestConn(instance *entity.EsInstance, ac *tagentity. return nil, err } - _, res, err := ei.Conn() + _, res, err := ei.Conn(ctx) if err != nil { return nil, err } diff --git a/server/internal/es/esm/esi/conn.go b/server/internal/es/esm/esi/conn.go index 2369860a..2b811684 100644 --- a/server/internal/es/esm/esi/conn.go +++ b/server/internal/es/esm/esi/conn.go @@ -16,6 +16,21 @@ type EsConn struct { proxy *httputil.ReverseProxy } +/******************* pool.Conn impl *******************/ + +func (d *EsConn) Close() error { + // 如果是使用了ssh隧道转发,则需要手动将其关闭 + if d.Info.useSshTunnel { + mcm.CloseSshTunnelMachine(uint64(d.Info.SshTunnelMachineId), fmt.Sprintf("es:%d", d.Id)) + } + return nil +} + +func (d *EsConn) Ping() error { + _, err := d.Info.Ping() + return err +} + // StartProxy 开始代理 func (d *EsConn) StartProxy() error { // 目标 URL @@ -40,16 +55,3 @@ func (d *EsConn) Proxy(w http.ResponseWriter, r *http.Request, path string) { r.Header.Set("Accept", "application/json") d.proxy.ServeHTTP(w, r) } - -func (d *EsConn) Close() error { - // 如果是使用了ssh隧道转发,则需要手动将其关闭 - if d.Info.useSshTunnel { - mcm.CloseSshTunnelMachine(uint64(d.Info.SshTunnelMachineId), fmt.Sprintf("es:%d", d.Id)) - } - return nil -} - -func (d *EsConn) Ping() error { - _, err := d.Info.Ping() - return err -} diff --git a/server/internal/es/esm/esi/es_info.go b/server/internal/es/esm/esi/es_info.go index 9cdee3df..04ab21d5 100644 --- a/server/internal/es/esm/esi/es_info.go +++ b/server/internal/es/esm/esi/es_info.go @@ -1,6 +1,7 @@ package esi import ( + "context" "encoding/base64" "fmt" machineapp "mayfly-go/internal/machine/application" @@ -46,7 +47,7 @@ func (di *EsInfo) GetLogDesc() string { } // 连接数据库 -func (di *EsInfo) Conn() (*EsConn, map[string]any, error) { +func (di *EsInfo) Conn(ctx context.Context) (*EsConn, map[string]any, error) { // 使用basic加密用户名和密码 if di.Username != "" && di.Password != "" { encodeString := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", di.Username, di.Password))) @@ -54,7 +55,7 @@ func (di *EsInfo) Conn() (*EsConn, map[string]any, error) { } // 使用ssh隧道 - err := di.IfUseSshTunnelChangeIpPort() + err := di.IfUseSshTunnelChangeIpPort(ctx) if err != nil { logx.Errorf("es ssh failed: %s, err:%s", di.baseUrl, err.Error()) return nil, nil, errorx.NewBiz("es ssh failed: %s", err.Error()) @@ -115,10 +116,10 @@ func (di *EsInfo) ExecApi(method, path string, data any, timeoutSecond ...int) ( } // 如果使用了ssh隧道,将其host port改变其本地映射host port -func (di *EsInfo) IfUseSshTunnelChangeIpPort() error { +func (di *EsInfo) IfUseSshTunnelChangeIpPort(ctx context.Context) error { // 开启ssh隧道 if di.SshTunnelMachineId > 0 { - stm, err := GetSshTunnel(di.SshTunnelMachineId) + stm, err := GetSshTunnel(ctx, di.SshTunnelMachineId) if err != nil { return err } @@ -137,6 +138,6 @@ func (di *EsInfo) IfUseSshTunnelChangeIpPort() error { } // 根据ssh tunnel机器id返回ssh tunnel -func GetSshTunnel(sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) { - return machineapp.GetMachineApp().GetSshTunnelMachine(sshTunnelMachineId) +func GetSshTunnel(ctx context.Context, sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) { + return machineapp.GetMachineApp().GetSshTunnelMachine(ctx, sshTunnelMachineId) } diff --git a/server/internal/machine/api/machine.go b/server/internal/machine/api/machine.go index 631b2109..59acf925 100644 --- a/server/internal/machine/api/machine.go +++ b/server/internal/machine/api/machine.go @@ -135,9 +135,8 @@ func (m *Machine) SimpleMachieInfo(rc *req.Ctx) { } func (m *Machine) MachineStats(rc *req.Ctx) { - cli, err := m.machineApp.GetCli(GetMachineId(rc)) + cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc)) biz.ErrIsNilAppendErr(err, "connection error: %s") - defer mcm.PutMachineCli(cli) rc.ResData = cli.GetAllStats() } @@ -160,7 +159,7 @@ func (m *Machine) TestConn(rc *req.Ctx) { machineForm := new(form.MachineForm) me := req.BindJsonAndCopyTo(rc, machineForm, new(entity.Machine)) // 测试连接 - biz.ErrIsNilAppendErr(m.machineApp.TestConn(me, machineForm.AuthCerts[0]), "connection error: %s") + biz.ErrIsNilAppendErr(m.machineApp.TestConn(rc.MetaCtx, me, machineForm.AuthCerts[0]), "connection error: %s") } func (m *Machine) ChangeStatus(rc *req.Ctx) { @@ -198,9 +197,8 @@ func (m *Machine) GetProcess(rc *req.Ctx) { count := rc.QueryIntDefault("count", 10) cmd += "| head -n " + fmt.Sprintf("%d", count) - cli, err := m.machineApp.GetCli(GetMachineId(rc)) + cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc)) biz.ErrIsNilAppendErr(err, "connection error: %s") - defer mcm.PutMachineCli(cli) biz.ErrIsNilAppendErr(m.tagTreeApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), "%s") @@ -214,9 +212,8 @@ func (m *Machine) KillProcess(rc *req.Ctx) { pid := rc.Query("pid") biz.NotEmpty(pid, "pid cannot be empty") - cli, err := m.machineApp.GetCli(GetMachineId(rc)) + cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc)) biz.ErrIsNilAppendErr(err, "connection error: %s") - defer mcm.PutMachineCli(cli) biz.ErrIsNilAppendErr(m.tagTreeApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), "%s") @@ -225,9 +222,8 @@ func (m *Machine) KillProcess(rc *req.Ctx) { } func (m *Machine) GetUsers(rc *req.Ctx) { - cli, err := m.machineApp.GetCli(GetMachineId(rc)) + cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc)) biz.ErrIsNilAppendErr(err, "connection error: %s") - defer mcm.PutMachineCli(cli) res, err := cli.GetUsers() biz.ErrIsNil(err) @@ -235,9 +231,8 @@ func (m *Machine) GetUsers(rc *req.Ctx) { } func (m *Machine) GetGroups(rc *req.Ctx) { - cli, err := m.machineApp.GetCli(GetMachineId(rc)) + cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc)) biz.ErrIsNilAppendErr(err, "connection error: %s") - defer mcm.PutMachineCli(cli) res, err := cli.GetGroups() biz.ErrIsNil(err) @@ -262,12 +257,9 @@ func (m *Machine) WsSSH(rc *req.Ctx) { err = req.PermissionHandler(rc) biz.ErrIsNil(err, mcm.GetErrorContentRn("You do not have permission to operate the machine terminal, please log in again and try again ~")) - cli, err := m.machineApp.GetCliByAc(GetMachineAc(rc)) + cli, err := m.machineApp.NewCli(rc.MetaCtx, GetMachineAc(rc)) biz.ErrIsNilAppendErr(err, mcm.GetErrorContentRn("connection error: %s")) - defer func() { - cli.Close() - mcm.PutMachineCli(cli) - }() + defer cli.Close() biz.ErrIsNilAppendErr(m.tagTreeApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), mcm.GetErrorContentRn("%s")) global.EventBus.Publish(rc.MetaCtx, event.EventTopicResourceOp, cli.Info.CodePath[0]) @@ -327,7 +319,7 @@ func (m *Machine) WsGuacamole(rc *req.Ctx) { return } - err = mi.IfUseSshTunnelChangeIpPort(true) + err = mi.IfUseSshTunnelChangeIpPort(rc.MetaCtx, true) if err != nil { return } diff --git a/server/internal/machine/api/machine_file.go b/server/internal/machine/api/machine_file.go index 9649e540..7a8e5ae5 100644 --- a/server/internal/machine/api/machine_file.go +++ b/server/internal/machine/api/machine_file.go @@ -326,9 +326,8 @@ func (m *MachineFile) UploadFolder(rc *req.Ctx) { } folderName := filepath.Dir(paths[0]) - mcli, err := m.machineFileApp.GetMachineCli(authCertName) + mcli, err := m.machineFileApp.GetMachineCli(rc.MetaCtx, authCertName) biz.ErrIsNil(err) - defer mcm.PutMachineCli(mcli) mi := mcli.Info diff --git a/server/internal/machine/api/machine_script.go b/server/internal/machine/api/machine_script.go index af1a1c71..83e6e812 100644 --- a/server/internal/machine/api/machine_script.go +++ b/server/internal/machine/api/machine_script.go @@ -5,7 +5,6 @@ import ( "mayfly-go/internal/machine/api/vo" "mayfly-go/internal/machine/application" "mayfly-go/internal/machine/domain/entity" - "mayfly-go/internal/machine/mcm" tagapp "mayfly-go/internal/tag/application" "mayfly-go/pkg/biz" "mayfly-go/pkg/model" @@ -78,9 +77,8 @@ func (m *MachineScript) RunMachineScript(rc *req.Ctx) { script, err = stringx.TemplateParse(ms.Script, p) biz.ErrIsNilAppendErr(err, "failed to parse the script template parameter: %s") } - cli, err := m.machineApp.GetCliByAc(ac) + cli, err := m.machineApp.GetCliByAc(rc.MetaCtx, ac) biz.ErrIsNilAppendErr(err, "connection error: %s") - defer mcm.PutMachineCli(cli) biz.ErrIsNilAppendErr(m.tagApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), "%s") diff --git a/server/internal/machine/application/machine.go b/server/internal/machine/application/machine.go index db92023b..575788b5 100644 --- a/server/internal/machine/application/machine.go +++ b/server/internal/machine/application/machine.go @@ -27,7 +27,7 @@ type Machine interface { SaveMachine(ctx context.Context, param *dto.SaveMachine) error // 测试机器连接 - TestConn(me *entity.Machine, authCert *tagentity.ResourceAuthCert) error + TestConn(ctx context.Context, me *entity.Machine, authCert *tagentity.ResourceAuthCert) error // 调整机器状态 ChangeStatus(ctx context.Context, id uint64, status int8) error @@ -38,16 +38,16 @@ type Machine interface { GetMachineList(condition *entity.MachineQuery, orderBy ...string) (*model.PageResult[*entity.Machine], error) // 新建机器客户端连接(需手动调用Close) - NewCli(authCertName string) (*mcm.Cli, error) + NewCli(ctx context.Context, authCertName string) (*mcm.Cli, error) // 获取已缓存的机器连接,若不存在则新建客户端连接并缓存,主要用于定时获取状态等(避免频繁创建连接) - GetCli(id uint64) (*mcm.Cli, error) + GetCli(ctx context.Context, id uint64) (*mcm.Cli, error) // 根据授权凭证获取客户端连接 - GetCliByAc(authCertName string) (*mcm.Cli, error) + GetCliByAc(ctx context.Context, authCertName string) (*mcm.Cli, error) // 获取ssh隧道机器连接 - GetSshTunnelMachine(id int) (*mcm.SshTunnelMachine, error) + GetSshTunnelMachine(ctx context.Context, id int) (*mcm.SshTunnelMachine, error) // 定时更新机器状态信息 TimerUpdateStats() @@ -159,7 +159,7 @@ func (m *machineAppImpl) SaveMachine(ctx context.Context, param *dto.SaveMachine }) } -func (m *machineAppImpl) TestConn(me *entity.Machine, authCert *tagentity.ResourceAuthCert) error { +func (m *machineAppImpl) TestConn(ctx context.Context, me *entity.Machine, authCert *tagentity.ResourceAuthCert) error { me.Id = 0 authCert, err := m.resourceAuthCertApp.GetRealAuthCert(authCert) @@ -171,7 +171,7 @@ func (m *machineAppImpl) TestConn(me *entity.Machine, authCert *tagentity.Resour if err != nil { return err } - cli, err := mi.Conn() + cli, err := mi.Conn(ctx) if err != nil { return err } @@ -224,30 +224,30 @@ func (m *machineAppImpl) Delete(ctx context.Context, id uint64) error { }) } -func (m *machineAppImpl) NewCli(authCertName string) (*mcm.Cli, error) { +func (m *machineAppImpl) NewCli(ctx context.Context, authCertName string) (*mcm.Cli, error) { if mi, err := m.ToMachineInfoByAc(authCertName); err != nil { return nil, err } else { - return mi.Conn() + return mi.Conn(ctx) } } -func (m *machineAppImpl) GetCliByAc(authCertName string) (*mcm.Cli, error) { - return mcm.GetMachineCli(authCertName, func(ac string) (*mcm.MachineInfo, error) { +func (m *machineAppImpl) GetCliByAc(ctx context.Context, authCertName string) (*mcm.Cli, error) { + return mcm.GetMachineCli(ctx, authCertName, func(ac string) (*mcm.MachineInfo, error) { return m.ToMachineInfoByAc(ac) }) } -func (m *machineAppImpl) GetCli(machineId uint64) (*mcm.Cli, error) { +func (m *machineAppImpl) GetCli(ctx context.Context, machineId uint64) (*mcm.Cli, error) { _, authCert, err := m.getMachineAndAuthCert(machineId) if err != nil { return nil, err } - return m.GetCliByAc(authCert.Name) + return m.GetCliByAc(ctx, authCert.Name) } -func (m *machineAppImpl) GetSshTunnelMachine(machineId int) (*mcm.SshTunnelMachine, error) { - return mcm.GetSshTunnelMachine(machineId, func(mid uint64) (*mcm.MachineInfo, error) { +func (m *machineAppImpl) GetSshTunnelMachine(ctx context.Context, machineId int) (*mcm.SshTunnelMachine, error) { + return mcm.GetSshTunnelMachine(ctx, machineId, func(mid uint64) (*mcm.MachineInfo, error) { return m.ToMachineInfoById(mid) }) } @@ -264,7 +264,9 @@ func (m *machineAppImpl) TimerUpdateStats() { } }() logx.Debugf("time to get machine [id=%d] status information start", mid) - cli, err := m.GetCli(mid) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + cli, err := m.GetCli(ctx, mid) if err != nil { logx.Errorf("failed to get machine [id=%d] status information periodically, failed to get machine cli: %s", mid, err.Error()) return diff --git a/server/internal/machine/application/machine_cronjob.go b/server/internal/machine/application/machine_cronjob.go index cd553ae2..b0f58baf 100644 --- a/server/internal/machine/application/machine_cronjob.go +++ b/server/internal/machine/application/machine_cronjob.go @@ -5,7 +5,6 @@ import ( "mayfly-go/internal/machine/application/dto" "mayfly-go/internal/machine/domain/entity" "mayfly-go/internal/machine/domain/repository" - "mayfly-go/internal/machine/mcm" tagapp "mayfly-go/internal/tag/application" tagentity "mayfly-go/internal/tag/domain/entity" "mayfly-go/pkg/base" @@ -179,14 +178,14 @@ func (m *machineCronJobAppImpl) runCronJob0(mid uint64, cronJob *entity.MachineC ExecTime: time.Now(), } - machineCli, err := m.machineApp.GetCli(mid) + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + machineCli, err := m.machineApp.GetCli(ctx, mid) res := "" if err != nil { machine, _ := m.machineApp.GetById(mid) execRes.MachineCode = machine.Code } else { - defer mcm.PutMachineCli(machineCli) - execRes.MachineCode = machineCli.Info.Code res, err = machineCli.Run(cronJob.Script) if err != nil { diff --git a/server/internal/machine/application/machine_file.go b/server/internal/machine/application/machine_file.go index 207a8228..4bba3666 100644 --- a/server/internal/machine/application/machine_file.go +++ b/server/internal/machine/application/machine_file.go @@ -38,7 +38,7 @@ type MachineFile interface { Save(ctx context.Context, entity *entity.MachineFile) error // 获取机器cli - GetMachineCli(authCertName string) (*mcm.Cli, error) + GetMachineCli(ctx context.Context, authCertName string) (*mcm.Cli, error) GetRdpFilePath(ua *model.LoginAccount, path string) string @@ -86,6 +86,8 @@ type machineFileAppImpl struct { machineApp Machine `inject:"T"` } +var _ MachineFile = (*machineFileAppImpl)(nil) + // 注入MachineFileRepo func (m *machineFileAppImpl) InjectMachineFileRepo(repo repository.MachineFile) { m.Repo = repo @@ -134,7 +136,7 @@ func (m *machineFileAppImpl) ReadDir(ctx context.Context, opParam *dto.MachineFi }), nil } - _, sftpCli, err := m.GetMachineSftpCli(opParam) + _, sftpCli, err := m.GetMachineSftpCli(ctx, opParam) if err != nil { return nil, err } @@ -166,11 +168,10 @@ func (m *machineFileAppImpl) GetDirSize(ctx context.Context, opParam *dto.Machin return bytex.FormatSize(totalSize), nil } - mcli, err := m.GetMachineCli(opParam.AuthCertName) + mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName) if err != nil { return "", err } - defer mcm.PutMachineCli(mcli) res, err := mcli.Run(fmt.Sprintf("du -sh %s", path)) if err != nil { @@ -200,11 +201,10 @@ func (m *machineFileAppImpl) FileStat(ctx context.Context, opParam *dto.MachineF return fmt.Sprintf("%v", stat), err } - mcli, err := m.GetMachineCli(opParam.AuthCertName) + mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName) if err != nil { return "", err } - defer mcm.PutMachineCli(mcli) return mcli.Run(fmt.Sprintf("stat -L %s", path)) } @@ -221,7 +221,7 @@ func (m *machineFileAppImpl) MkDir(ctx context.Context, opParam *dto.MachineFile return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, nil } - mi, sftpCli, err := m.GetMachineSftpCli(opParam) + mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam) if err != nil { return nil, err } @@ -241,7 +241,7 @@ func (m *machineFileAppImpl) CreateFile(ctx context.Context, opParam *dto.Machin return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, err } - mi, sftpCli, err := m.GetMachineSftpCli(opParam) + mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam) if err != nil { return nil, err } @@ -254,7 +254,7 @@ func (m *machineFileAppImpl) CreateFile(ctx context.Context, opParam *dto.Machin } func (m *machineFileAppImpl) ReadFile(ctx context.Context, opParam *dto.MachineFileOp) (*sftp.File, *mcm.MachineInfo, error) { - mi, sftpCli, err := m.GetMachineSftpCli(opParam) + mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam) if err != nil { return nil, nil, err } @@ -278,7 +278,7 @@ func (m *machineFileAppImpl) WriteFileContent(ctx context.Context, opParam *dto. return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, err } - mi, sftpCli, err := m.GetMachineSftpCli(opParam) + mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam) if err != nil { return nil, err } @@ -313,7 +313,7 @@ func (m *machineFileAppImpl) UploadFile(ctx context.Context, opParam *dto.Machin return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, nil } - mi, sftpCli, err := m.GetMachineSftpCli(opParam) + mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam) if err != nil { return nil, err } @@ -379,11 +379,10 @@ func (m *machineFileAppImpl) RemoveFile(ctx context.Context, opParam *dto.Machin return nil, nil } - mcli, err := m.GetMachineCli(opParam.AuthCertName) + mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName) if err != nil { return nil, err } - defer mcm.PutMachineCli(mcli) minfo := mcli.Info @@ -431,11 +430,10 @@ func (m *machineFileAppImpl) Copy(ctx context.Context, opParam *dto.MachineFileO return nil, nil } - mcli, err := m.GetMachineCli(opParam.AuthCertName) + mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName) if err != nil { return nil, err } - defer mcm.PutMachineCli(mcli) mi := mcli.Info res, err := mcli.Run(fmt.Sprintf("cp -r %s %s", strings.Join(path, " "), toPath)) @@ -461,11 +459,10 @@ func (m *machineFileAppImpl) Mv(ctx context.Context, opParam *dto.MachineFileOp, return nil, nil } - mcli, err := m.GetMachineCli(opParam.AuthCertName) + mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName) if err != nil { return nil, err } - defer mcm.PutMachineCli(mcli) mi := mcli.Info res, err := mcli.Run(fmt.Sprintf("mv %s %s", strings.Join(path, " "), toPath)) @@ -483,7 +480,7 @@ func (m *machineFileAppImpl) Rename(ctx context.Context, opParam *dto.MachineFil return nil, os.Rename(oldname, newname) } - mi, sftpCli, err := m.GetMachineSftpCli(opParam) + mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam) if err != nil { return nil, err } @@ -491,17 +488,16 @@ func (m *machineFileAppImpl) Rename(ctx context.Context, opParam *dto.MachineFil } // 获取文件机器cli -func (m *machineFileAppImpl) GetMachineCli(authCertName string) (*mcm.Cli, error) { - return m.machineApp.GetCliByAc(authCertName) +func (m *machineFileAppImpl) GetMachineCli(ctx context.Context, authCertName string) (*mcm.Cli, error) { + return m.machineApp.GetCliByAc(ctx, authCertName) } // 获取文件机器 sftp cli -func (m *machineFileAppImpl) GetMachineSftpCli(opParam *dto.MachineFileOp) (*mcm.MachineInfo, *sftp.Client, error) { - mcli, err := m.GetMachineCli(opParam.AuthCertName) +func (m *machineFileAppImpl) GetMachineSftpCli(ctx context.Context, opParam *dto.MachineFileOp) (*mcm.MachineInfo, *sftp.Client, error) { + mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName) if err != nil { return nil, nil, err } - defer mcm.PutMachineCli(mcli) sftpCli, err := mcli.GetSftpCli() if err != nil { diff --git a/server/internal/machine/mcm/client.go b/server/internal/machine/mcm/client.go index e61564e1..0e6f2def 100644 --- a/server/internal/machine/mcm/client.go +++ b/server/internal/machine/mcm/client.go @@ -18,11 +18,41 @@ type Cli struct { sftpClient *sftp.Client // sftp客户端 } +/******************* pool.Conn impl *******************/ + func (c *Cli) Ping() error { - _, _, err := c.sshClient.Conn.SendRequest("ping", true, nil) + _, _, err := c.sshClient.SendRequest("ping", true, nil) return err } +// Close 关闭client并从缓存中移除,如果使用隧道则也关闭 +func (c *Cli) Close() error { + m := c.Info + logx.Debugf("close machine cli -> id=%d, name=%s, ip=%s", m.Id, m.Name, m.Ip) + if c.sshClient != nil { + c.sshClient.Close() + c.sshClient = nil + } + if c.sftpClient != nil { + c.sftpClient.Close() + c.sftpClient = nil + } + + var sshTunnelMachineId uint64 + if m.SshTunnelMachine != nil { + sshTunnelMachineId = m.SshTunnelMachine.Id + } + if m.TempSshMachineId != 0 { + sshTunnelMachineId = m.TempSshMachineId + } + if sshTunnelMachineId != 0 { + logx.Debugf("close machine ssh tunnel -> machineId=%d, sshTunnelMachineId=%d", m.Id, sshTunnelMachineId) + CloseSshTunnelMachine(sshTunnelMachineId, m.GetTunnelId()) + } + + return nil +} + // GetSftpCli 获取sftp client func (c *Cli) GetSftpCli() (*sftp.Client, error) { if c.sshClient == nil { @@ -72,32 +102,6 @@ func (c *Cli) Run(shell string) (string, error) { return string(buf), nil } -// Close 关闭client并从缓存中移除,如果使用隧道则也关闭 -func (c *Cli) Close() { - m := c.Info - logx.Debugf("close machine cli -> id=%d, name=%s, ip=%s", m.Id, m.Name, m.Ip) - if c.sshClient != nil { - c.sshClient.Close() - c.sshClient = nil - } - if c.sftpClient != nil { - c.sftpClient.Close() - c.sftpClient = nil - } - - var sshTunnelMachineId uint64 - if m.SshTunnelMachine != nil { - sshTunnelMachineId = m.SshTunnelMachine.Id - } - if m.TempSshMachineId != 0 { - sshTunnelMachineId = m.TempSshMachineId - } - if sshTunnelMachineId != 0 { - logx.Debugf("close machine ssh tunnel -> machineId=%d, sshTunnelMachineId=%d", m.Id, sshTunnelMachineId) - CloseSshTunnelMachine(sshTunnelMachineId, m.GetTunnelId()) - } -} - // GetAllStats 获取机器的所有状态信息 func (c *Cli) GetAllStats() *Stats { stats := new(Stats) diff --git a/server/internal/machine/mcm/client_cache.go b/server/internal/machine/mcm/client_cache.go index e2924e5f..32186e3b 100644 --- a/server/internal/machine/mcm/client_cache.go +++ b/server/internal/machine/mcm/client_cache.go @@ -1,77 +1,44 @@ package mcm import ( + "context" "mayfly-go/pkg/pool" - "time" ) -var mcConnPool = make(map[string]pool.Pool) -var mcIdPool = make(map[uint64]pool.Pool) - -func init() { -} -func getMcPool(authCertName string, getMachine func(string) (*MachineInfo, error)) (pool.Pool, error) { - // 获取连接池,如果没有,则创建一个 - if p, ok := mcConnPool[authCertName]; !ok { - var err error - p, err = pool.NewChannelPool(&pool.Config{ - InitialCap: 1, //资源池初始连接数 - MaxCap: 10, //最大空闲连接数 - MaxIdle: 10, //最大并发连接数 - IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效 - Factory: func() (interface{}, error) { - mi, err := getMachine(authCertName) - if err != nil { - return nil, err - } - mi.Key = authCertName - return mi.Conn() - }, - Close: func(v interface{}) error { - v.(*Cli).Close() - return nil - }, - Ping: func(v interface{}) error { - return v.(*Cli).Ping() - }, - }) - if err != nil { - return nil, err - } - mcConnPool[authCertName] = p - return p, nil - } else { - return p, nil - } -} - -func PutMachineCli(c *Cli) { - if nil == c { - return - } - if p, ok := mcConnPool[c.Info.AuthCertName]; ok { - p.Put(c) - } - -} +var ( + poolGroup = pool.NewPoolGroup[*Cli]() +) // 从缓存中获取客户端信息,不存在则回调获取机器信息函数,并新建。 // @param 机器的授权凭证名 -func GetMachineCli(authCertName string, getMachine func(string) (*MachineInfo, error)) (*Cli, error) { - p, err := getMcPool(authCertName, getMachine) +func GetMachineCli(ctx context.Context, authCertName string, getMachine func(string) (*MachineInfo, error)) (*Cli, error) { + pool, err := poolGroup.GetCachePool(authCertName, func() (*Cli, error) { + mi, err := getMachine(authCertName) + if err != nil { + return nil, err + } + mi.Key = authCertName + return mi.Conn(ctx) + }) + if err != nil { return nil, err } // 从连接池中获取一个可用的连接 - c, err := p.Get() - if err != nil { - return nil, err - } - return c.(*Cli), nil + return pool.Get(ctx) } // 删除指定机器缓存客户端,并关闭客户端连接 func DeleteCli(id uint64) { - // 遍历所有机器连接实例,删除指定机器id关联的连接... - delete(mcIdPool, id) + for _, pool := range poolGroup.AllPool() { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() + conn, err := pool.Get(ctx) + if err != nil { + continue + } + if conn.Info.Id == id { + pool.Close() + } + } } diff --git a/server/internal/machine/mcm/machine.go b/server/internal/machine/mcm/machine.go index 6a11b992..63d97194 100644 --- a/server/internal/machine/mcm/machine.go +++ b/server/internal/machine/mcm/machine.go @@ -1,6 +1,7 @@ package mcm import ( + "context" "fmt" tagentity "mayfly-go/internal/tag/domain/entity" "mayfly-go/pkg/errorx" @@ -49,11 +50,11 @@ func (mi *MachineInfo) GetTunnelId() string { } // 连接 -func (mi *MachineInfo) Conn() (*Cli, error) { +func (mi *MachineInfo) Conn(ctx context.Context) (*Cli, error) { logx.Infof("the machine[%s] is connecting: %s:%d", mi.Name, mi.Ip, mi.Port) // 如果使用了ssh隧道,则修改机器ip port为暴露的ip port - err := mi.IfUseSshTunnelChangeIpPort(false) + err := mi.IfUseSshTunnelChangeIpPort(ctx, false) if err != nil { return nil, errorx.NewBiz("ssh tunnel connection failed: %s", err.Error()) } @@ -71,7 +72,7 @@ func (mi *MachineInfo) Conn() (*Cli, error) { } // 如果使用了ssh隧道,则修改机器ip port为暴露的ip port -func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(out bool) error { +func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(ctx context.Context, out bool) error { if !mi.UseSshTunnel() { return nil } @@ -82,8 +83,9 @@ func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(out bool) error { mi.Id = uint64(time.Now().Nanosecond()) } - sshTunnelMachine, err := GetSshTunnelMachine(int(mi.SshTunnelMachine.Id), func(u uint64) (*MachineInfo, error) { - return mi.SshTunnelMachine, nil + stm := mi.SshTunnelMachine + sshTunnelMachine, err := GetSshTunnelMachine(ctx, int(stm.Id), func(u uint64) (*MachineInfo, error) { + return stm, nil }) if err != nil { return err @@ -102,7 +104,7 @@ func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(out bool) error { mi.Ip = exposeIp mi.Port = exposePort // 代理之后置空跳板机信息,防止重复跳 - mi.TempSshMachineId = mi.SshTunnelMachine.Id + mi.TempSshMachineId = stm.Id mi.SshTunnelMachine = nil return nil } diff --git a/server/internal/machine/mcm/sshtunnel.go b/server/internal/machine/mcm/sshtunnel.go index c7a5f09e..ded0c395 100644 --- a/server/internal/machine/mcm/sshtunnel.go +++ b/server/internal/machine/mcm/sshtunnel.go @@ -1,6 +1,7 @@ package mcm import ( + "context" "errors" "fmt" "io" @@ -9,7 +10,6 @@ import ( "mayfly-go/pkg/utils/netx" "net" "sync" - "time" "golang.org/x/crypto/ssh" ) @@ -20,7 +20,7 @@ var ( // 所有检测ssh隧道机器是否被使用的函数 checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc - tunnelPool = make(map[int]pool.Pool) + tunnelPool = make(map[int]pool.Pool[*SshTunnelMachine]) ) // 检查ssh隧道机器是否有被使用 @@ -43,11 +43,36 @@ type SshTunnelMachine struct { tunnels map[string]*Tunnel // 隧道id -> 隧道 } +/******************* pool.Conn impl *******************/ + func (stm *SshTunnelMachine) Ping() error { _, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil) return err } +func (stm *SshTunnelMachine) Close() error { + stm.mutex.Lock() + defer stm.mutex.Unlock() + + for id, tunnel := range stm.tunnels { + if tunnel != nil { + tunnel.Close() + delete(stm.tunnels, id) + } + } + + if stm.SshClient != nil { + logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId) + err := stm.SshClient.Close() + if err != nil { + logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error()) + } + } + delete(tunnelPool, stm.machineId) + + return nil +} + func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) { stm.mutex.Lock() defer stm.mutex.Unlock() @@ -92,84 +117,44 @@ func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, return stm.SshClient.Dial(network, addr) } -func (stm *SshTunnelMachine) Close() { - stm.mutex.Lock() - defer stm.mutex.Unlock() - - for id, tunnel := range stm.tunnels { - if tunnel != nil { - tunnel.Close() - delete(stm.tunnels, id) - } - } - - if stm.SshClient != nil { - logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId) - err := stm.SshClient.Close() - if err != nil { - logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error()) - } - } - delete(tunnelPool, stm.machineId) -} - -func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool, error) { +func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool[*SshTunnelMachine], error) { // 获取连接池,如果没有,则创建一个 - if p, ok := tunnelPool[machineId]; !ok { - var err error - p, err = pool.NewChannelPool(&pool.Config{ - InitialCap: 1, //资源池初始连接数 - MaxCap: 10, //最大空闲连接数 - MaxIdle: 10, //最大并发连接数 - IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效 - Factory: func() (interface{}, error) { - mi, err := getMachine(uint64(machineId)) - if err != nil { - return nil, err - } - if mi == nil { - return nil, errors.New("error get machine info") - } - sshClient, err := GetSshClient(mi, nil) - if err != nil { - return nil, err - } - stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi} - logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port) + if p, ok := tunnelPool[machineId]; ok { + return p, nil + } - return stm, err - }, - Close: func(v interface{}) error { - v.(*SshTunnelMachine).Close() - return nil - }, - Ping: func(v interface{}) error { - return v.(*SshTunnelMachine).Ping() - }, - }) + p := pool.NewChannelPool(func() (*SshTunnelMachine, error) { + mi, err := getMachine(uint64(machineId)) if err != nil { return nil, err } - tunnelPool[machineId] = p - return p, nil - } else { - return p, nil - } + if mi == nil { + return nil, errors.New("error get machine info") + } + sshClient, err := GetSshClient(mi, nil) + if err != nil { + return nil, err + } + stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi} + logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port) + + return stm, err + }, pool.WithOnPoolClose(func() error { + delete(tunnelPool, machineId) + return nil + })) + tunnelPool[machineId] = p + return p, nil } // 获取ssh隧道机器,方便统一管理充当ssh隧道的机器,避免创建多个ssh client -func GetSshTunnelMachine(machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) { +func GetSshTunnelMachine(ctx context.Context, machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) { p, err := getTunnelPool(machineId, getMachine) if err != nil { return nil, err } // 从连接池中获取一个可用的连接 - c, err := p.Get() - if err != nil { - return nil, err - } - - return c.(*SshTunnelMachine), nil + return p.Get(ctx) } // 关闭ssh隧道机器的指定隧道 diff --git a/server/internal/mongo/api/mongo.go b/server/internal/mongo/api/mongo.go index 362ed7b7..48e023c5 100644 --- a/server/internal/mongo/api/mongo.go +++ b/server/internal/mongo/api/mongo.go @@ -8,7 +8,6 @@ import ( "mayfly-go/internal/mongo/application" "mayfly-go/internal/mongo/domain/entity" "mayfly-go/internal/mongo/imsg" - "mayfly-go/internal/mongo/mgm" "mayfly-go/internal/pkg/consts" tagapp "mayfly-go/internal/tag/application" tagentity "mayfly-go/internal/tag/domain/entity" @@ -126,9 +125,8 @@ func (m *Mongo) DeleteMongo(rc *req.Ctx) { } func (m *Mongo) Databases(rc *req.Ctx) { - conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc)) + conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) - defer mgm.PutMongoConn(conn) res, err := conn.Cli.ListDatabases(context.TODO(), bson.D{}) biz.ErrIsNilAppendErr(err, "get mongo dbs error: %s") @@ -136,9 +134,8 @@ func (m *Mongo) Databases(rc *req.Ctx) { } func (m *Mongo) Collections(rc *req.Ctx) { - conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc)) + conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) - defer mgm.PutMongoConn(conn) global.EventBus.Publish(rc.MetaCtx, event.EventTopicResourceOp, conn.Info.CodePath[0]) @@ -154,9 +151,8 @@ func (m *Mongo) RunCommand(rc *req.Ctx) { commandForm := new(form.MongoRunCommand) req.BindJsonAndValid(rc, commandForm) - conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc)) + conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) - defer mgm.PutMongoConn(conn) rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) @@ -185,9 +181,8 @@ func (m *Mongo) RunCommand(rc *req.Ctx) { func (m *Mongo) FindCommand(rc *req.Ctx) { commandForm := req.BindJsonAndValid(rc, new(form.MongoFindCommand)) - conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc)) + conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) - defer mgm.PutMongoConn(conn) cli := conn.Cli @@ -221,9 +216,8 @@ func (m *Mongo) FindCommand(rc *req.Ctx) { func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) { commandForm := req.BindJsonAndValid(rc, new(form.MongoUpdateByIdCommand)) - conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc)) + conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) - defer mgm.PutMongoConn(conn) rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) @@ -246,9 +240,8 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) { func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) { commandForm := req.BindJsonAndValid(rc, new(form.MongoUpdateByIdCommand)) - conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc)) + conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) - defer mgm.PutMongoConn(conn) rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) @@ -270,9 +263,8 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) { func (m *Mongo) InsertOneCommand(rc *req.Ctx) { commandForm := req.BindJsonAndValid(rc, new(form.MongoInsertCommand)) - conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc)) + conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc)) biz.ErrIsNil(err) - defer mgm.PutMongoConn(conn) rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm) diff --git a/server/internal/mongo/application/mongo.go b/server/internal/mongo/application/mongo.go index 0df1e7d8..f2afc3fe 100644 --- a/server/internal/mongo/application/mongo.go +++ b/server/internal/mongo/application/mongo.go @@ -31,7 +31,7 @@ type Mongo interface { // 获取mongo连接实例 // @param id mongo id - GetMongoConn(id uint64) (*mgm.MongoConn, error) + GetMongoConn(ctx context.Context, id uint64) (*mgm.MongoConn, error) } type mongoAppImpl struct { @@ -131,8 +131,8 @@ func (d *mongoAppImpl) SaveMongo(ctx context.Context, m *entity.Mongo, tagCodePa }) } -func (d *mongoAppImpl) GetMongoConn(id uint64) (*mgm.MongoConn, error) { - return mgm.GetMongoConn(id, func() (*mgm.MongoInfo, error) { +func (d *mongoAppImpl) GetMongoConn(ctx context.Context, id uint64) (*mgm.MongoConn, error) { + return mgm.GetMongoConn(ctx, id, func() (*mgm.MongoInfo, error) { me, err := d.GetById(id) if err != nil { return nil, errorx.NewBiz("mongo not found") diff --git a/server/internal/mongo/mgm/conn.go b/server/internal/mongo/mgm/conn.go index b3d4619d..a5d82dae 100644 --- a/server/internal/mongo/mgm/conn.go +++ b/server/internal/mongo/mgm/conn.go @@ -14,11 +14,19 @@ type MongoConn struct { Cli *mongo.Client } -func (mc *MongoConn) Close() { +/******************* pool.Conn impl *******************/ + +func (mc *MongoConn) Close() error { if mc.Cli != nil { if err := mc.Cli.Disconnect(context.Background()); err != nil { logx.Errorf("关闭mongo实例[%s]连接失败: %s", mc.Id, err) + return err } mc.Cli = nil } + return nil +} + +func (mc *MongoConn) Ping() error { + return mc.Cli.Ping(context.Background(), nil) } diff --git a/server/internal/mongo/mgm/conn_cache.go b/server/internal/mongo/mgm/conn_cache.go index 57823fea..ad5dd239 100644 --- a/server/internal/mongo/mgm/conn_cache.go +++ b/server/internal/mongo/mgm/conn_cache.go @@ -3,78 +3,33 @@ package mgm import ( "context" "mayfly-go/pkg/pool" - "time" ) -var connPool = make(map[string]pool.Pool) +var ( + poolGroup = pool.NewPoolGroup[*MongoConn]() +) -func init() { -} - -func getPool(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (pool.Pool, error) { - connId := getConnId(mongoId) - - // 获取连接池,如果没有,则创建一个 - if p, ok := connPool[connId]; !ok { - var err error - p, err = pool.NewChannelPool(&pool.Config{ - InitialCap: 1, //资源池初始连接数 - MaxCap: 10, //最大空闲连接数 - MaxIdle: 10, //最大并发连接数 - IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效 - Factory: func() (interface{}, error) { - // 若缓存中不存在,则从回调函数中获取MongoInfo - mi, err := getMongoInfo() - if err != nil { - return nil, err - } - - // 连接mongo - return mi.Conn() - }, - Close: func(v interface{}) error { - v.(*MongoConn).Close() - return nil - }, - Ping: func(v interface{}) error { - return v.(*MongoConn).Cli.Ping(context.Background(), nil) - }, - }) +// 从缓存中获取mongo连接信息, 若缓存中不存在则会使用回调函数获取mongoInfo进行连接并缓存 +func GetMongoConn(ctx context.Context, mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (*MongoConn, error) { + pool, err := poolGroup.GetCachePool(getConnId(mongoId), func() (*MongoConn, error) { + // 若缓存中不存在,则从回调函数中获取MongoInfo + mi, err := getMongoInfo() if err != nil { return nil, err } - connPool[connId] = p - return p, nil - } else { - return p, nil - } -} -func PutMongoConn(c *MongoConn) { - if nil == c { - return - } - if p, ok := connPool[getConnId(c.Info.Id)]; ok { - p.Put(c) - } -} + // 连接mongo + return mi.Conn() + }) -// 从缓存中获取mongo连接信息, 若缓存中不存在则会使用回调函数获取mongoInfo进行连接并缓存 -func GetMongoConn(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (*MongoConn, error) { - p, err := getPool(mongoId, getMongoInfo) if err != nil { return nil, err } // 从连接池中获取一个可用的连接 - c, err := p.Get() - if err != nil { - return nil, err - } - return c.(*MongoConn), nil + return pool.Get(ctx) } // 关闭连接,并移除缓存连接 func CloseConn(mongoId uint64) { - connId := getConnId(mongoId) - delete(connPool, connId) + poolGroup.Close(getConnId(mongoId)) } diff --git a/server/internal/mongo/mgm/info.go b/server/internal/mongo/mgm/info.go index adcf0d05..46a79540 100644 --- a/server/internal/mongo/mgm/info.go +++ b/server/internal/mongo/mgm/info.go @@ -58,7 +58,7 @@ type MongoSshDialer struct { } func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId) + stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(ctx, sd.machineId) if err != nil { return nil, err } diff --git a/server/internal/redis/api/redis.go b/server/internal/redis/api/redis.go index 6e7c0edd..065ab718 100644 --- a/server/internal/redis/api/redis.go +++ b/server/internal/redis/api/redis.go @@ -150,9 +150,8 @@ func (r *Redis) DeleteRedis(rc *req.Ctx) { } func (r *Redis) RedisInfo(rc *req.Ctx) { - ri, err := r.redisApp.GetRedisConn(uint64(rc.PathParamInt("id")), 0) + ri, err := r.redisApp.GetRedisConn(rc.MetaCtx, uint64(rc.PathParamInt("id")), 0) biz.ErrIsNil(err) - defer rdm.PutRedisConn(ri) section := rc.Query("section") mode := ri.Info.Mode @@ -228,9 +227,8 @@ func (r *Redis) RedisInfo(rc *req.Ctx) { } func (r *Redis) ClusterInfo(rc *req.Ctx) { - ri, err := r.redisApp.GetRedisConn(uint64(rc.PathParamInt("id")), 0) + ri, err := r.redisApp.GetRedisConn(rc.MetaCtx, uint64(rc.PathParamInt("id")), 0) biz.ErrIsNil(err) - defer rdm.PutRedisConn(ri) biz.IsEquals(ri.Info.Mode, rdm.ClusterMode, "non-cluster mode") info, _ := ri.ClusterCli.ClusterInfo(context.Background()).Result() @@ -281,9 +279,9 @@ func (r *Redis) checkKeyAndGetRedisConn(rc *req.Ctx) (*rdm.RedisConn, string) { } func (r *Redis) getRedisConn(rc *req.Ctx) *rdm.RedisConn { - ri, err := r.redisApp.GetRedisConn(getIdAndDbNum(rc)) + id, db := getIdAndDbNum(rc) + ri, err := r.redisApp.GetRedisConn(rc.MetaCtx, id, db) biz.ErrIsNil(err) - defer rdm.PutRedisConn(ri) biz.ErrIsNilAppendErr(r.tagApp.CanAccess(rc.GetLoginAccount().Id, ri.Info.CodePath...), "%s") return ri diff --git a/server/internal/redis/application/redis.go b/server/internal/redis/application/redis.go index 701a180f..e381dfad 100644 --- a/server/internal/redis/application/redis.go +++ b/server/internal/redis/application/redis.go @@ -46,7 +46,7 @@ type Redis interface { // 获取数据库连接实例 // id: 数据库实例id // db: 库号 - GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) + GetRedisConn(ctx context.Context, id uint64, db int) (*rdm.RedisConn, error) // 执行redis命令 RunCmd(ctx context.Context, redisConn *rdm.RedisConn, cmdParam *dto.RunCmd) (any, error) @@ -196,8 +196,8 @@ func (r *redisAppImpl) Delete(ctx context.Context, id uint64) error { } // 获取数据库连接实例 -func (r *redisAppImpl) GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) { - return rdm.GetRedisConn(id, db, func() (*rdm.RedisInfo, error) { +func (r *redisAppImpl) GetRedisConn(ctx context.Context, id uint64, db int) (*rdm.RedisConn, error) { + return rdm.GetRedisConn(ctx, id, db, func() (*rdm.RedisInfo, error) { // 缓存不存在,则回调获取redis信息 re, err := r.GetById(id) if err != nil { @@ -258,7 +258,7 @@ func (r *redisAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *flowap return nil, errorx.NewBiz("failed to parse the business form information: %s", err.Error()) } - redisConn, err := r.GetRedisConn(runCmdParam.Id, runCmdParam.Db) + redisConn, err := r.GetRedisConn(ctx, runCmdParam.Id, runCmdParam.Db) if err != nil { return nil, err } diff --git a/server/internal/redis/rdm/conn.go b/server/internal/redis/rdm/conn.go index 9e967053..20b08b4e 100644 --- a/server/internal/redis/rdm/conn.go +++ b/server/internal/redis/rdm/conn.go @@ -17,6 +17,34 @@ type RedisConn struct { ClusterCli *redis.ClusterClient } +/******************* pool.Conn impl *******************/ + +func (r *RedisConn) Close() error { + mode := r.Info.Mode + if mode == StandaloneMode || mode == SentinelMode { + if err := r.Cli.Close(); err != nil { + logx.Errorf("close redis standalone instance [%s] connection failed: %s", r.Id, err.Error()) + return err + } + r.Cli = nil + return nil + } + + if mode == ClusterMode { + if err := r.ClusterCli.Close(); err != nil { + logx.Errorf("close redis cluster instance [%s] connection failed: %s", r.Id, err.Error()) + return err + } + r.ClusterCli = nil + } + return nil +} + +func (r *RedisConn) Ping() error { + _, err := r.Cli.Ping(context.Background()).Result() + return err +} + // 获取命令执行接口的具体实现 func (r *RedisConn) GetCmdable() redis.Cmdable { redisMode := r.Info.Mode @@ -45,21 +73,3 @@ func (r *RedisConn) RunCmd(ctx context.Context, args ...any) (any, error) { } return nil, errorx.NewBiz("redis mode error") } - -func (r *RedisConn) Close() { - mode := r.Info.Mode - if mode == StandaloneMode || mode == SentinelMode { - if err := r.Cli.Close(); err != nil { - logx.Errorf("close redis standalone instance [%s] connection failed: %s", r.Id, err.Error()) - } - r.Cli = nil - return - } - - if mode == ClusterMode { - if err := r.ClusterCli.Close(); err != nil { - logx.Errorf("close redis cluster instance [%s] connection failed: %s", r.Id, err.Error()) - } - r.ClusterCli = nil - } -} diff --git a/server/internal/redis/rdm/conn_cache.go b/server/internal/redis/rdm/conn_cache.go index 3be5cdef..4f5323b3 100644 --- a/server/internal/redis/rdm/conn_cache.go +++ b/server/internal/redis/rdm/conn_cache.go @@ -3,79 +3,36 @@ package rdm import ( "context" "mayfly-go/pkg/pool" - "time" ) func init() { } -var connPool = make(map[string]pool.Pool) +var ( + poolGroup = pool.NewPoolGroup[*RedisConn]() +) -func getPool(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (pool.Pool, error) { - connId := getConnId(redisId, db) - // 获取连接池,如果没有,则创建一个 - if p, ok := connPool[connId]; !ok { - var err error - p, err = pool.NewChannelPool(&pool.Config{ - InitialCap: 1, //资源池初始连接数 - MaxCap: 10, //最大空闲连接数 - MaxIdle: 10, //最大并发连接数 - IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效 - Factory: func() (interface{}, error) { - // 若缓存中不存在,则从回调函数中获取RedisInfo - ri, err := getRedisInfo() - if err != nil { - return nil, err - } - // 连接数据库 - return ri.Conn() - }, - Close: func(v interface{}) error { - v.(*RedisConn).Close() - return nil - }, - Ping: func(v interface{}) error { - _, err := v.(*RedisConn).Cli.Ping(context.Background()).Result() - return err - }, - }) +// 从缓存中获取redis连接信息, 若缓存中不存在则会使用回调函数获取redisInfo进行连接并缓存 +func GetRedisConn(ctx context.Context, redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (*RedisConn, error) { + p, err := poolGroup.GetCachePool(getConnId(redisId, db), func() (*RedisConn, error) { + // 若缓存中不存在,则从回调函数中获取RedisInfo + ri, err := getRedisInfo() if err != nil { return nil, err } - connPool[connId] = p - return p, nil - } else { - return p, nil - } -} + // 连接数据库 + return ri.Conn() + }) -func PutRedisConn(c *RedisConn) { - if nil == c { - return - } - - if p, ok := connPool[getConnId(c.Info.Id, c.Info.Db)]; ok { - p.Put(c) - } -} - -// 从缓存中获取redis连接信息, 若缓存中不存在则会使用回调函数获取redisInfo进行连接并缓存 -func GetRedisConn(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (*RedisConn, error) { - p, err := getPool(redisId, db, getRedisInfo) if err != nil { return nil, err } // 从连接池中获取一个可用的连接 - c, err := p.Get() - if err != nil { - return nil, err - } - // 用完后记的放回连接池 - return c.(*RedisConn), nil + return p.Get(ctx) } // 移除redis连接缓存并关闭redis连接 func CloseConn(id uint64, db int) { - delete(connPool, getConnId(id, db)) + poolGroup.Close(getConnId(id, db)) } diff --git a/server/internal/redis/rdm/info.go b/server/internal/redis/rdm/info.go index daf40443..109a5ae9 100644 --- a/server/internal/redis/rdm/info.go +++ b/server/internal/redis/rdm/info.go @@ -139,8 +139,8 @@ func (re *RedisInfo) connSentinel() (*RedisConn, error) { } func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) { - return func(_ context.Context, network, addr string) (net.Conn, error) { - sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId) + return func(ctx context.Context, network, addr string) (net.Conn, error) { + sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(ctx, machineId) if err != nil { return nil, err } diff --git a/server/migration/migrations/v1_10.go b/server/migration/migrations/v1_10.go index 8a06c8e2..dea26e19 100644 --- a/server/migration/migrations/v1_10.go +++ b/server/migration/migrations/v1_10.go @@ -1,26 +1,26 @@ package migrations import ( - "github.com/go-gormigrate/gormigrate/v2" - "gorm.io/gorm" esentity "mayfly-go/internal/es/domain/entity" flowentity "mayfly-go/internal/flow/domain/entity" sysentity "mayfly-go/internal/sys/domain/entity" "mayfly-go/pkg/model" "time" + + "github.com/go-gormigrate/gormigrate/v2" + "gorm.io/gorm" ) func V1_10() []*gormigrate.Migration { var migrations []*gormigrate.Migration migrations = append(migrations, V1_10_0()...) - migrations = append(migrations, V1_10_1()...) return migrations } func V1_10_0() []*gormigrate.Migration { return []*gormigrate.Migration{ { - ID: "20250213-v1.10.0-flow-recode", + ID: "20250520-v1.10.0-flow-recode", Migrate: func(tx *gorm.DB) error { err := tx.AutoMigrate(&flowentity.Procdef{}, &flowentity.Procinst{}, @@ -32,19 +32,6 @@ func V1_10_0() []*gormigrate.Migration { return err } - return nil - }, - Rollback: func(tx *gorm.DB) error { - return nil - }, - }, - } -} -func V1_10_1() []*gormigrate.Migration { - return []*gormigrate.Migration{ - { - ID: "20250422-v1.10.1-es", - Migrate: func(tx *gorm.DB) error { // 添加实例表 entities := [...]any{ new(esentity.EsInstance), @@ -55,7 +42,7 @@ func V1_10_1() []*gormigrate.Migration { } } - // 添加菜单资源 + // 添加ES相关菜单资源 resources := []*sysentity.Resource{ { Model: model.Model{CreateModel: model.CreateModel{DeletedModel: model.DeletedModel{IdModel: model.IdModel{Id: 1745292787}}}}, @@ -65,7 +52,7 @@ func V1_10_1() []*gormigrate.Migration { Code: "/es", Type: 1, Meta: `{"icon":"icon es/es-color","isKeepAlive":true,"routeName":"ES"}`, - Weight: 7, + Weight: 50000001, }, { Model: model.Model{CreateModel: model.CreateModel{DeletedModel: model.DeletedModel{IdModel: model.IdModel{Id: 1745319348}}}}, diff --git a/server/pkg/pool/cache_pool.go b/server/pkg/pool/cache_pool.go new file mode 100644 index 00000000..b8e19e63 --- /dev/null +++ b/server/pkg/pool/cache_pool.go @@ -0,0 +1,209 @@ +package pool + +import ( + "context" + "mayfly-go/pkg/logx" + "mayfly-go/pkg/utils/stringx" + "sync" + "time" +) + +var CachePoolDefaultConfig = PoolConfig{ + MaxConns: 1, + IdleTimeout: 60 * time.Minute, + WaitTimeout: 10 * time.Second, + HealthCheckInterval: 10 * time.Minute, +} + +type cacheEntry[T Conn] struct { + conn T + lastActive time.Time +} + +type CachePool[T Conn] struct { + factory func() (T, error) + mu sync.RWMutex + cache map[string]*cacheEntry[T] // 使用字符串键的缓存 + config PoolConfig + closeCh chan struct{} + closed bool +} + +func NewCachePool[T Conn](factory func() (T, error), opts ...Option) *CachePool[T] { + config := CachePoolDefaultConfig + for _, opt := range opts { + opt(&config) + } + + p := &CachePool[T]{ + factory: factory, + cache: make(map[string]*cacheEntry[T]), + config: config, + closeCh: make(chan struct{}), + } + + go p.backgroundMaintenance() + return p +} + +// Get 获取连接(自动创建或复用缓存连接) +func (p *CachePool[T]) Get(ctx context.Context) (T, error) { + p.mu.Lock() + defer p.mu.Unlock() + var zero T + + if p.closed { + return zero, ErrPoolClosed + } + + // 1. 尝试从缓存中获取可用连接 + for key, entry := range p.cache { + if time.Since(entry.lastActive) <= p.config.IdleTimeout { + entry.lastActive = time.Now() // 更新活跃时间 + return entry.conn, nil + } + // 自动清理闲置超时的连接 + entry.conn.Close() + delete(p.cache, key) + } + + // 2. 创建新连接并缓存 + conn, err := p.factory() + if err != nil { + return zero, err + } + + p.cache[generateCacheKey()] = &cacheEntry[T]{ + conn: conn, + lastActive: time.Now(), + } + + return conn, nil +} + +// Put 将连接放回缓存 +func (p *CachePool[T]) Put(conn T) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return conn.Close() + } + + p.cache[generateCacheKey()] = &cacheEntry[T]{ + conn: conn, + lastActive: time.Now(), + } + + // 如果超出最大连接数,清理最久未使用的 + if len(p.cache) > p.config.MaxConns { + p.removeOldest() + } + + return nil +} + +// 移除最久未使用的连接 +func (p *CachePool[T]) removeOldest() { + var oldestKey string + var oldestTime time.Time + + for key, entry := range p.cache { + if oldestKey == "" || entry.lastActive.Before(oldestTime) { + oldestKey = key + oldestTime = entry.lastActive + } + } + + if oldestKey != "" { + p.cache[oldestKey].conn.Close() + delete(p.cache, oldestKey) + } +} + +// Close 关闭连接池 +func (p *CachePool[T]) Close() { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return + } + + p.closed = true + close(p.closeCh) + + for _, entry := range p.cache { + if err := entry.conn.Close(); err != nil { + logx.Errorf("cache pool - error closing connection: %v", err) + } + } + + // 触发关闭回调 + if p.config.OnPoolClose != nil { + p.config.OnPoolClose() + } + + p.cache = make(map[string]*cacheEntry[T]) +} + +// Resize 动态调整大小 +func (p *CachePool[T]) Resize(newSize int) { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed || newSize == p.config.MaxConns { + return + } + + p.config.MaxConns = newSize + + // 如果新大小小于当前缓存数量,清理多余的连接 + for len(p.cache) > newSize { + p.removeOldest() + } +} + +// Stats 获取统计信息 +func (p *CachePool[T]) Stats() PoolStats { + p.mu.RLock() + defer p.mu.RUnlock() + + return PoolStats{ + TotalConns: int32(len(p.cache)), + } +} + +// 后台维护协程 +func (p *CachePool[T]) backgroundMaintenance() { + ticker := time.NewTicker(p.config.HealthCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.cleanupIdle() + case <-p.closeCh: + return + } + } +} + +// 清理闲置超时的连接 +func (p *CachePool[T]) cleanupIdle() { + p.mu.Lock() + defer p.mu.Unlock() + + cutoff := time.Now().Add(-p.config.IdleTimeout) + for key, entry := range p.cache { + if entry.lastActive.Before(cutoff) { + entry.conn.Close() + delete(p.cache, key) + } + } +} + +// 生成缓存键 +func generateCacheKey() string { + return stringx.RandUUID() +} diff --git a/server/pkg/pool/chan_pool.go b/server/pkg/pool/chan_pool.go new file mode 100644 index 00000000..577cd3a5 --- /dev/null +++ b/server/pkg/pool/chan_pool.go @@ -0,0 +1,366 @@ +package pool + +import ( + "context" + "errors" + "mayfly-go/pkg/logx" + "mayfly-go/pkg/utils/anyx" + "sync" + "sync/atomic" + "time" +) + +var ChanPoolDefaultConfig = PoolConfig{ + MaxConns: 5, + IdleTimeout: 60 * time.Minute, + WaitTimeout: 10 * time.Second, + HealthCheckInterval: 10 * time.Minute, +} + +// ConnWrapper 封装连接及其元数据 +type ConnWrapper[T Conn] struct { + conn T + lastActive time.Time // 最后活跃时间 + isValid bool // 连接是否有效 +} + +func (w *ConnWrapper[T]) Ping() error { + if !w.isValid { + return errors.New("connection marked invalid") + } + return w.conn.Ping() +} + +func (w *ConnWrapper[T]) Close() error { + w.isValid = false + return w.conn.Close() +} + +// ChanPool 连接池结构 +type ChanPool[T Conn] struct { + mu sync.RWMutex + factory func() (T, error) + idleConns chan *ConnWrapper[T] + config PoolConfig + currentConns int32 + stats PoolStats + closeChan chan struct{} // 用于关闭健康检查 goroutine + closed bool // 关闭状态标识 +} + +// PoolStats 统计信息 +type PoolStats struct { + TotalConns int32 // 总连接数 + IdleConns int32 // 空闲连接数 + ActiveConns int32 // 活跃连接数 + WaitCount int64 // 等待连接次数 +} + +func NewChannelPool[T Conn](factory func() (T, error), opts ...Option) *ChanPool[T] { + // 1. 初始化配置(使用默认值 + Option 覆盖) + config := ChanPoolDefaultConfig + for _, opt := range opts { + opt(&config) + } + + // 2. 创建连接池 + p := &ChanPool[T]{ + factory: factory, + idleConns: make(chan *ConnWrapper[T], config.MaxConns), + config: config, + closeChan: make(chan struct{}), + } + + // 3. 启动健康检查 + go p.healthCheck() + return p +} + +func (p *ChanPool[T]) Get(ctx context.Context) (T, error) { + connChan := make(chan T, 1) + errChan := make(chan error, 1) + + go func() { + conn, err := p.get() + if err != nil { + errChan <- err + } else { + connChan <- conn + } + }() + + var zero T + select { + case <-ctx.Done(): + return zero, ctx.Err() + case err := <-errChan: + return zero, err + case conn := <-connChan: + // 启动监控协程 + go func() { + <-ctx.Done() + // 上下文被取消后,将连接放回连接池 + if err := p.Put(conn); err != nil { + logx.Errorf("Failed to return leaked connection: %v", err) + conn.Close() + atomic.AddInt32(&p.currentConns, -1) + } + }() + return conn, nil + } +} + +func (p *ChanPool[T]) get() (T, error) { + // 优先从 channel 获取空闲连接(无锁) + select { + case wrapper := <-p.idleConns: + atomic.AddInt32(&p.stats.IdleConns, -1) + atomic.AddInt32(&p.stats.ActiveConns, 1) + wrapper.lastActive = time.Now() + return wrapper.conn, nil + default: + return p.createConn() + } +} + +func (p *ChanPool[T]) createConn() (T, error) { + var zero T + + // 使用CAS保证原子性 + for { + current := atomic.LoadInt32(&p.currentConns) + if current >= int32(p.config.MaxConns) { + if p.config.WaitTimeout > 0 { + return p.waitForConn() + } + return zero, errors.New("connection pool exhausted") + } + + if atomic.CompareAndSwapInt32(&p.currentConns, current, current+1) { + break + } + } + + // 直接创建新连接 + conn, err := p.factory() + if err != nil { + atomic.AddInt32(&p.currentConns, -1) + return zero, err + } + + // 更新状态 + atomic.AddInt32(&p.stats.ActiveConns, 1) + return conn, nil +} + +// 新增等待连接方法 +func (p *ChanPool[T]) waitForConn() (T, error) { + var zero T + timeout := time.NewTimer(p.config.WaitTimeout) + defer timeout.Stop() + + for { + select { + case wrapper := <-p.idleConns: + if wrapper.isValid && wrapper.Ping() == nil { + atomic.AddInt32(&p.stats.IdleConns, -1) + atomic.AddInt32(&p.stats.ActiveConns, 1) + wrapper.lastActive = time.Now() + return wrapper.conn, nil + } + wrapper.Close() + atomic.AddInt32(&p.currentConns, -1) + case <-timeout.C: + atomic.AddInt64(&p.stats.WaitCount, 1) + return zero, errors.New("connection pool wait timeout") + default: + // 非阻塞检查后短暂休眠避免CPU空转 + time.Sleep(10 * time.Millisecond) + } + } +} + +func (p *ChanPool[T]) Put(conn T) error { + if anyx.IsBlank(conn) { + return nil + } + + // 快速路径 + select { + case p.idleConns <- &ConnWrapper[T]{conn: conn, lastActive: time.Now(), isValid: true}: + atomic.AddInt32(&p.stats.IdleConns, 1) + atomic.AddInt32(&p.stats.ActiveConns, -1) + return nil + default: + } + + // 慢速路径 + p.mu.Lock() + defer p.mu.Unlock() + + // 检查是否超过最大连接数 + if atomic.LoadInt32(&p.currentConns) > int32(p.config.MaxConns) { + conn.Close() + atomic.AddInt32(&p.currentConns, -1) + } else { + // 直接放入空闲队列 + select { + case p.idleConns <- &ConnWrapper[T]{conn: conn, lastActive: time.Now(), isValid: true}: + default: + conn.Close() + atomic.AddInt32(&p.currentConns, -1) + } + } + atomic.AddInt32(&p.stats.ActiveConns, -1) + + return nil +} + +func (p *ChanPool[T]) Close() { + p.mu.Lock() + if p.closed { + p.mu.Unlock() + return + } + p.closed = true + + // 1. 停止健康检查 + close(p.closeChan) + + // 2. 临时转移空闲连接 + idle := make([]*ConnWrapper[T], 0, len(p.idleConns)) + for len(p.idleConns) > 0 { + idle = append(idle, <-p.idleConns) + } + close(p.idleConns) // 安全关闭通道 + + p.mu.Unlock() // 提前释放锁,避免阻塞其他操作 + + // 3. 关闭所有连接(无需持有锁) + for _, wrapper := range idle { + wrapper.Close() + } + + // 4. 触发关闭回调 + if p.config.OnPoolClose != nil { + p.config.OnPoolClose() + } +} + +func (p *ChanPool[T]) healthCheck() { + ticker := time.NewTicker(p.config.HealthCheckInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + p.checkIdleConns() + case <-p.closeChan: + return + } + } +} + +func (p *ChanPool[T]) checkIdleConns() { + p.mu.Lock() + defer p.mu.Unlock() + + if p.closed { + return + } + + idle := make([]*ConnWrapper[T], 0, len(p.idleConns)) + for len(p.idleConns) > 0 { + idle = append(idle, <-p.idleConns) + } + + now := time.Now() + for _, wrapper := range idle { + if now.Sub(wrapper.lastActive) > p.config.IdleTimeout || wrapper.Ping() != nil { + wrapper.Close() + atomic.AddInt32(&p.currentConns, -1) + } else { + select { + case p.idleConns <- wrapper: + default: + wrapper.Close() + atomic.AddInt32(&p.currentConns, -1) + } + } + } +} + +func (p *ChanPool[T]) Resize(newMaxConns int) { + p.mu.Lock() + defer p.mu.Unlock() + + oldMax := p.config.MaxConns + p.config.MaxConns = newMaxConns + + // 缩小连接池:关闭多余的空闲连接 + if newMaxConns < oldMax { + toClose := oldMax - newMaxConns + closed := 0 + + // 非阻塞取出待关闭的连接 + var wrappers []*ConnWrapper[T] + for len(p.idleConns) > 0 && closed < toClose { + wrappers = append(wrappers, <-p.idleConns) + closed++ + } + + // 关闭连接并更新计数 + for _, wrapper := range wrappers { + wrapper.Close() + atomic.AddInt32(&p.currentConns, -1) + atomic.AddInt32(&p.stats.IdleConns, -1) + } + } + + // 重建空闲连接通道(无需迁移连接,因 channel 本身无状态) + p.idleConns = make(chan *ConnWrapper[T], newMaxConns) +} + +func (p *ChanPool[T]) CheckLeaks() []T { + p.mu.Lock() + defer p.mu.Unlock() + + var leaks []T + now := time.Now() + + // 检查所有空闲连接 + idle := make([]*ConnWrapper[T], 0, len(p.idleConns)) + for len(p.idleConns) > 0 { + idle = append(idle, <-p.idleConns) + } + + for _, wrapper := range idle { + // 判定泄漏条件:长期未使用且未被标记为活跃 + if now.Sub(wrapper.lastActive) > 10*p.config.IdleTimeout { + leaks = append(leaks, wrapper.conn) + wrapper.Close() + atomic.AddInt32(&p.currentConns, -1) + atomic.AddInt32(&p.stats.IdleConns, -1) + } else { + // 放回空闲池 + select { + case p.idleConns <- wrapper: + default: + wrapper.Close() + atomic.AddInt32(&p.currentConns, -1) + } + } + } + return leaks +} + +func (p *ChanPool[T]) Stats() PoolStats { + p.mu.RLock() + defer p.mu.RUnlock() + return PoolStats{ + TotalConns: atomic.LoadInt32(&p.currentConns), + IdleConns: int32(len(p.idleConns)), // 直接读取通道长度 + ActiveConns: atomic.LoadInt32(&p.stats.ActiveConns), + WaitCount: atomic.LoadInt64(&p.stats.WaitCount), + } +} diff --git a/server/pkg/pool/channel.go b/server/pkg/pool/channel.go deleted file mode 100644 index 59f703bc..00000000 --- a/server/pkg/pool/channel.go +++ /dev/null @@ -1,216 +0,0 @@ -package pool - -import ( - "errors" - "fmt" - "mayfly-go/pkg/logx" - "sync" - "time" - //"reflect" -) - -var ( - //ErrMaxActiveConnReached 连接池超限 - ErrMaxActiveConnReached = errors.New("MaxActiveConnReached") -) - -// Config 连接池相关配置 -type Config struct { - //连接池中拥有的最小连接数 - InitialCap int - //最大并发存活连接数 - MaxCap int - //最大空闲连接 - MaxIdle int - //生成连接的方法 - Factory func() (interface{}, error) - //关闭连接的方法 - Close func(interface{}) error - //检查连接是否有效的方法 - Ping func(interface{}) error - //连接最大空闲时间,超过该事件则将失效 - IdleTimeout time.Duration -} - -// channelPool 存放连接信息 -type channelPool struct { - mu sync.RWMutex - conns chan *idleConn - factory func() (interface{}, error) - close func(interface{}) error - ping func(interface{}) error - idleTimeout, waitTimeOut time.Duration - maxActive int - openingConns int -} - -type idleConn struct { - conn interface{} - t time.Time -} - -// NewChannelPool 初始化连接 -func NewChannelPool(poolConfig *Config) (Pool, error) { - if !(poolConfig.InitialCap <= poolConfig.MaxIdle && poolConfig.MaxCap >= poolConfig.MaxIdle && poolConfig.InitialCap >= 0) { - return nil, errors.New("invalid capacity settings") - } - if poolConfig.Factory == nil { - return nil, errors.New("invalid factory func settings") - } - if poolConfig.Close == nil { - return nil, errors.New("invalid close func settings") - } - - c := &channelPool{ - conns: make(chan *idleConn, poolConfig.MaxIdle), - factory: poolConfig.Factory, - close: poolConfig.Close, - idleTimeout: poolConfig.IdleTimeout, - maxActive: poolConfig.MaxCap, - openingConns: poolConfig.InitialCap, - } - - if poolConfig.Ping != nil { - c.ping = poolConfig.Ping - } - - for i := 0; i < poolConfig.InitialCap; i++ { - conn, err := c.factory() - if err != nil { - c.Release() - return nil, fmt.Errorf("factory is not able to fill the pool: %s", err) - } - c.conns <- &idleConn{conn: conn, t: time.Now()} - } - - return c, nil -} - -// getConns 获取所有连接 -func (c *channelPool) getConns() chan *idleConn { - c.mu.Lock() - conns := c.conns - c.mu.Unlock() - return conns -} - -// Get 从pool中取一个连接 -func (c *channelPool) Get() (interface{}, error) { - conns := c.getConns() - if conns == nil { - return nil, ErrClosed - } - for { - select { - case wrapConn := <-conns: - if wrapConn == nil { - return nil, ErrClosed - } - //判断是否超时,超时则丢弃 - if timeout := c.idleTimeout; timeout > 0 { - if wrapConn.t.Add(timeout).Before(time.Now()) { - //丢弃并关闭该连接 - c.Close(wrapConn.conn) - continue - } - } - //判断是否失效,失效则丢弃,如果用户没有设定 ping 方法,就不检查 - if c.ping != nil { - if err := c.Ping(wrapConn.conn); err != nil { - c.Close(wrapConn.conn) - continue - } - } - return wrapConn.conn, nil - default: - c.mu.Lock() - logx.Debugf("openConn %v %v", c.openingConns, c.maxActive) - defer c.mu.Unlock() - if c.openingConns >= c.maxActive { - return nil, ErrMaxActiveConnReached - } - if c.factory == nil { - return nil, ErrClosed - } - conn, err := c.factory() - if err != nil { - return nil, err - } - c.openingConns++ - return conn, nil - } - } -} - -// Put 将连接放回pool中 -func (c *channelPool) Put(conn interface{}) error { - if conn == nil { - return errors.New("connection is nil. rejecting") - } - - c.mu.Lock() - - if c.conns == nil { - c.mu.Unlock() - return c.Close(conn) - } - - select { - case c.conns <- &idleConn{conn: conn, t: time.Now()}: - c.mu.Unlock() - return nil - default: - c.mu.Unlock() - //连接池已满,直接关闭该连接 - return c.Close(conn) - } -} - -// Close 关闭单条连接 -func (c *channelPool) Close(conn interface{}) error { - if conn == nil { - return errors.New("connection is nil. rejecting") - } - c.mu.Lock() - defer c.mu.Unlock() - if c.close == nil { - return nil - } - c.openingConns-- - return c.close(conn) -} - -// Ping 检查单条连接是否有效 -func (c *channelPool) Ping(conn interface{}) error { - if conn == nil { - return errors.New("connection is nil. rejecting") - } - return c.ping(conn) -} - -// Release 释放连接池中所有连接 -func (c *channelPool) Release() { - c.mu.Lock() - conns := c.conns - c.conns = nil - c.factory = nil - c.ping = nil - closeFun := c.close - c.close = nil - c.mu.Unlock() - - if conns == nil { - return - } - - close(conns) - for wrapConn := range conns { - //log.Printf("Type %v\n",reflect.TypeOf(wrapConn.conn)) - _ = closeFun(wrapConn.conn) - } -} - -// Len 连接池中已有的连接 -func (c *channelPool) Len() int { - return len(c.getConns()) -} diff --git a/server/pkg/pool/config.go b/server/pkg/pool/config.go new file mode 100644 index 00000000..88550982 --- /dev/null +++ b/server/pkg/pool/config.go @@ -0,0 +1,55 @@ +package pool + +import ( + "errors" + "time" +) + +var ErrPoolClosed = errors.New("pool is closed") + +// PoolConfig 连接池配置 +type PoolConfig struct { + MaxConns int // 最大连接数 + IdleTimeout time.Duration // 空闲连接超时时间 + WaitTimeout time.Duration // 获取连接超时时间 + HealthCheckInterval time.Duration // 健康检查间隔 + OnPoolClose func() error // 连接池关闭时的回调 +} + +// Option 函数类型,用于配置 Pool +type Option func(*PoolConfig) + +// WithMaxConns 设置最大连接数 +func WithMaxConns(maxConns int) Option { + return func(c *PoolConfig) { + c.MaxConns = maxConns + } +} + +// WithIdleTimeout 设置空闲超时 +func WithIdleTimeout(timeout time.Duration) Option { + return func(c *PoolConfig) { + c.IdleTimeout = timeout + } +} + +// WithWaitTimeout 设置等待超时 +func WithWaitTimeout(timeout time.Duration) Option { + return func(c *PoolConfig) { + c.WaitTimeout = timeout + } +} + +// WithHealthCheckInterval 设置健康检查间隔 +func WithHealthCheckInterval(interval time.Duration) Option { + return func(c *PoolConfig) { + c.HealthCheckInterval = interval + } +} + +// WithOnPoolClose 设置连接池关闭回调 +func WithOnPoolClose(fn func() error) Option { + return func(c *PoolConfig) { + c.OnPoolClose = fn + } +} diff --git a/server/pkg/pool/group.go b/server/pkg/pool/group.go new file mode 100644 index 00000000..38515707 --- /dev/null +++ b/server/pkg/pool/group.go @@ -0,0 +1,76 @@ +package pool + +import ( + "mayfly-go/pkg/logx" + + "golang.org/x/sync/singleflight" +) + +type PoolGroup[T Conn] struct { + poolGroup map[string]Pool[T] + createGroup singleflight.Group +} + +func NewPoolGroup[T Conn]() *PoolGroup[T] { + return &PoolGroup[T]{ + poolGroup: make(map[string]Pool[T]), + createGroup: singleflight.Group{}, + } +} + +func (pg *PoolGroup[T]) GetOrCreate( + key string, + poolFactory func() Pool[T], + opts ...Option, +) (Pool[T], error) { + if p, ok := pg.poolGroup[key]; ok { + return p, nil + } + + v, err, _ := pg.createGroup.Do(key, func() (any, error) { + logx.Infof("pool group - create pool, key: %s", key) + p := poolFactory() + pg.poolGroup[key] = p + return p, nil + }) + + if err != nil { + return nil, err + } + + return v.(Pool[T]), nil +} + +// GetChanPool 获取或创建 ChannelPool 类型连接池 +func (pg *PoolGroup[T]) GetChanPool(key string, factory func() (T, error), opts ...Option) (Pool[T], error) { + return pg.GetOrCreate(key, func() Pool[T] { + return NewChannelPool(factory, opts...) + }, opts...) +} + +// GetCachePool 获取或创建 CachePool 类型连接池 +func (pg *PoolGroup[T]) GetCachePool(key string, factory func() (T, error), opts ...Option) (Pool[T], error) { + return pg.GetOrCreate(key, func() Pool[T] { + return NewCachePool(factory, opts...) + }, opts...) +} + +func (pg *PoolGroup[T]) Close(key string) error { + if p, ok := pg.poolGroup[key]; ok { + logx.Infof("pool group - close pool, key: %s", key) + p.Close() + pg.createGroup.Forget(key) + delete(pg.poolGroup, key) + } + return nil +} + +func (pg *PoolGroup[T]) CloseAll() { + for key := range pg.poolGroup { + pg.Close(key) + } +} + +func (pg *PoolGroup[T]) AllPool() map[string]Pool[T] { + return pg.poolGroup +} diff --git a/server/pkg/pool/pool.go b/server/pkg/pool/pool.go index 16030c9e..cab6ea03 100644 --- a/server/pkg/pool/pool.go +++ b/server/pkg/pool/pool.go @@ -1,21 +1,27 @@ package pool -import "errors" - -var ( - //ErrClosed 连接池已经关闭Error - ErrClosed = errors.New("pool is closed") +import ( + "context" ) -// Pool 基本方法 -type Pool interface { - Get() (interface{}, error) +// Conn 连接接口 +// 连接池的连接必须实现 Conn 接口 +type Conn interface { + // Close 关闭连接 + Close() error - Put(interface{}) error - - Close(interface{}) error - - Release() - - Len() int + // Ping 检查连接是否有效 + Ping() error +} + +// Pool 连接池接口 +type Pool[T Conn] interface { + // 核心方法 + Get(ctx context.Context) (T, error) + Put(T) error + Close() + + // 管理方法 + Resize(int) // 动态调整大小 + Stats() PoolStats // 获取统计信息 }