mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 00:10:25 +08:00 
			
		
		
		
	reafctor: pool
This commit is contained in:
		@@ -15,11 +15,14 @@
 | 
			
		||||
    <img src="https://img.shields.io/github/stars/dromara/mayfly-go.svg?style=social" alt="github star"/>
 | 
			
		||||
    <img src="https://img.shields.io/github/forks/dromara/mayfly-go.svg?style=social" alt="github fork"/>
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://github.com/dromara/mayfly-go" target="_blank">
 | 
			
		||||
    <img src="https://gitcode.com/dromara/mayfly-go/star/badge.svg" alt="github star"/>
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://hub.docker.com/r/mayflygo/mayfly-go/tags" target="_blank">
 | 
			
		||||
    <img src="https://img.shields.io/docker/pulls/mayflygo/mayfly-go.svg?label=docker%20pulls&color=fac858" alt="docker pulls"/>
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://github.com/golang/go" target="_blank">
 | 
			
		||||
    <img src="https://img.shields.io/badge/Golang-1.22%2B-yellow.svg" alt="golang"/>
 | 
			
		||||
    <img src="https://img.shields.io/badge/Golang-1.24%2B-yellow.svg" alt="golang"/>
 | 
			
		||||
  </a>
 | 
			
		||||
  <a href="https://cn.vuejs.org" target="_blank">
 | 
			
		||||
    <img src="https://img.shields.io/badge/Vue-3.x-green.svg" alt="vue">
 | 
			
		||||
@@ -106,7 +109,7 @@ http://go.mayfly.run
 | 
			
		||||
 | 
			
		||||
## 💌 支持作者
 | 
			
		||||
 | 
			
		||||
如果觉得项目不错,或者已经在使用了,希望你可以去 <a target="_blank" href="https://github.com/dromara/mayfly-go">Github</a> 或者 <a target="_blank" href="https://gitee.com/dromara/mayfly-go">Gitee</a> 帮我点个 ⭐ Star,这将是对我极大的鼓励与支持。
 | 
			
		||||
如果觉得项目不错,或者已经在使用了,希望你可以去 <a target="_blank" href="https://github.com/dromara/mayfly-go">Github</a> 或 <a target="_blank" href="https://gitee.com/dromara/mayfly-go">Gitee</a> 或 <a target="_blank" href="https://gitcode.com/dromara/mayfly-go">Gitcode</a> 帮我点个 ⭐ Star,这将是对我极大的鼓励与支持。
 | 
			
		||||
 | 
			
		||||
> 喝杯咖啡 ☕️ 或者来杯奶茶 🧋,让作者更有精神,写出更棒的代码!
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,6 @@ import (
 | 
			
		||||
	"mayfly-go/internal/db/application"
 | 
			
		||||
	"mayfly-go/internal/db/application/dto"
 | 
			
		||||
	"mayfly-go/internal/db/config"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/imsg"
 | 
			
		||||
@@ -140,10 +139,12 @@ func (d *Db) DeleteDb(rc *req.Ctx) {
 | 
			
		||||
func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
	form := req.BindJsonAndValid(rc, new(form.DbSqlExecForm))
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	dbId := getDbId(rc)
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(dbId, form.Db)
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(ctx, dbId, form.Db)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer dbm.PutDbConn(dbConn)
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
 | 
			
		||||
 | 
			
		||||
@@ -163,9 +164,6 @@ func (d *Db) ExecSql(rc *req.Ctx) {
 | 
			
		||||
		CheckFlow: true,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	ctx, cancel := context.WithTimeout(rc.MetaCtx, time.Duration(config.GetDbms().SqlExecTl)*time.Second)
 | 
			
		||||
	defer cancel()
 | 
			
		||||
 | 
			
		||||
	execRes, err := d.dbSqlExecApp.Exec(ctx, execReq)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	rc.ResData = execRes
 | 
			
		||||
@@ -194,9 +192,8 @@ func (d *Db) ExecSqlFile(rc *req.Ctx) {
 | 
			
		||||
	dbName := getDbName(rc)
 | 
			
		||||
	clientId := rc.Query("clientId")
 | 
			
		||||
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(dbId, dbName)
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, dbId, dbName)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer dbm.PutDbConn(dbConn)
 | 
			
		||||
	biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, dbConn.Info.CodePath...), "%s")
 | 
			
		||||
	rc.ReqParam = fmt.Sprintf("filename: %s -> %s", filename, dbConn.Info.GetLogDesc())
 | 
			
		||||
 | 
			
		||||
@@ -228,9 +225,8 @@ func (d *Db) DumpSql(rc *req.Ctx) {
 | 
			
		||||
	needData := dumpType == "2" || dumpType == "3"
 | 
			
		||||
 | 
			
		||||
	la := rc.GetLoginAccount()
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(dbId, dbName)
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, dbId, dbName)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer dbm.PutDbConn(dbConn)
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNilAppendErr(d.tagApp.CanAccess(la.Id, dbConn.Info.CodePath...), "%s")
 | 
			
		||||
 | 
			
		||||
@@ -358,9 +354,8 @@ func (d *Db) CopyTable(rc *req.Ctx) {
 | 
			
		||||
	form := &form.DbCopyTableForm{}
 | 
			
		||||
	copy := req.BindJsonAndCopyTo[*dbi.DbCopyTable](rc, form, new(dbi.DbCopyTable))
 | 
			
		||||
 | 
			
		||||
	conn, err := d.dbApp.GetDbConn(form.Id, form.Db)
 | 
			
		||||
	conn, err := d.dbApp.GetDbConn(rc.MetaCtx, form.Id, form.Db)
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "copy table error: %s")
 | 
			
		||||
	defer dbm.PutDbConn(conn)
 | 
			
		||||
 | 
			
		||||
	err = conn.GetDialect().CopyTable(copy)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -382,8 +377,7 @@ func getDbName(rc *req.Ctx) string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) getDbConn(rc *req.Ctx) *dbi.DbConn {
 | 
			
		||||
	dc, err := d.dbApp.GetDbConn(getDbId(rc), getDbName(rc))
 | 
			
		||||
	dc, err := d.dbApp.GetDbConn(rc.MetaCtx, getDbId(rc), getDbName(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer dbm.PutDbConn(dc)
 | 
			
		||||
	return dc
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -93,7 +93,7 @@ func (d *Instance) TestConn(rc *req.Ctx) {
 | 
			
		||||
	form := &form.InstanceForm{}
 | 
			
		||||
	instance := req.BindJsonAndCopyTo[*entity.DbInstance](rc, form, new(entity.DbInstance))
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNil(d.instanceApp.TestConn(instance, form.AuthCerts[0]))
 | 
			
		||||
	biz.ErrIsNil(d.instanceApp.TestConn(rc.MetaCtx, instance, form.AuthCerts[0]))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SaveInstance 保存数据库实例信息
 | 
			
		||||
@@ -137,13 +137,13 @@ func (d *Instance) DeleteInstance(rc *req.Ctx) {
 | 
			
		||||
func (d *Instance) GetDatabaseNames(rc *req.Ctx) {
 | 
			
		||||
	form := &form.InstanceDbNamesForm{}
 | 
			
		||||
	instance := req.BindJsonAndCopyTo[*entity.DbInstance](rc, form, new(entity.DbInstance))
 | 
			
		||||
	res, err := d.instanceApp.GetDatabases(instance, form.AuthCert)
 | 
			
		||||
	res, err := d.instanceApp.GetDatabases(rc.MetaCtx, instance, form.AuthCert)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	rc.ResData = res
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Instance) GetDatabaseNamesByAc(rc *req.Ctx) {
 | 
			
		||||
	res, err := d.instanceApp.GetDatabasesByAc(rc.PathParam("ac"))
 | 
			
		||||
	res, err := d.instanceApp.GetDatabasesByAc(rc.MetaCtx, rc.PathParam("ac"))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	rc.ResData = res
 | 
			
		||||
}
 | 
			
		||||
@@ -151,7 +151,7 @@ func (d *Instance) GetDatabaseNamesByAc(rc *req.Ctx) {
 | 
			
		||||
// 获取数据库实例server信息
 | 
			
		||||
func (d *Instance) GetDbServer(rc *req.Ctx) {
 | 
			
		||||
	instanceId := getInstanceId(rc)
 | 
			
		||||
	conn, err := d.dbApp.GetDbConnByInstanceId(instanceId)
 | 
			
		||||
	conn, err := d.dbApp.GetDbConnByInstanceId(rc.MetaCtx, instanceId)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	res, err := conn.GetMetadata().GetDbServer()
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,6 @@ import (
 | 
			
		||||
	"mayfly-go/internal/db/api/vo"
 | 
			
		||||
	"mayfly-go/internal/db/application"
 | 
			
		||||
	"mayfly-go/internal/db/application/dto"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/imsg"
 | 
			
		||||
	fileapp "mayfly-go/internal/file/application"
 | 
			
		||||
@@ -150,9 +149,8 @@ func (d *DbTransferTask) FileRun(rc *req.Ctx) {
 | 
			
		||||
	tFile, err := d.dbTransferFile.GetById(fm.Id)
 | 
			
		||||
	biz.IsTrue(tFile != nil && err == nil, "file not found")
 | 
			
		||||
 | 
			
		||||
	targetDbConn, err := d.dbApp.GetDbConn(fm.TargetDbId, fm.TargetDbName)
 | 
			
		||||
	targetDbConn, err := d.dbApp.GetDbConn(rc.MetaCtx, fm.TargetDbId, fm.TargetDbName)
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "failed to connect to the target database: %s")
 | 
			
		||||
	defer dbm.PutDbConn(targetDbConn)
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNilAppendErr(d.tagApp.CanAccess(rc.GetLoginAccount().Id, targetDbConn.Info.CodePath...), "%s")
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -40,10 +40,10 @@ type Db interface {
 | 
			
		||||
	// @param id 数据库id
 | 
			
		||||
	//
 | 
			
		||||
	// @param dbName 数据库名
 | 
			
		||||
	GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error)
 | 
			
		||||
	GetDbConn(ctx context.Context, dbId uint64, dbName string) (*dbi.DbConn, error)
 | 
			
		||||
 | 
			
		||||
	// 根据数据库实例id获取连接,随机返回该instanceId下已连接的conn,若不存在则是使用该instanceId关联的db进行连接并返回。
 | 
			
		||||
	GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error)
 | 
			
		||||
	GetDbConnByInstanceId(ctx context.Context, instanceId uint64) (*dbi.DbConn, error)
 | 
			
		||||
 | 
			
		||||
	// DumpDb dumpDb
 | 
			
		||||
	DumpDb(ctx context.Context, reqParam *dto.DumpDb) error
 | 
			
		||||
@@ -170,8 +170,8 @@ func (d *dbAppImpl) Delete(ctx context.Context, id uint64) error {
 | 
			
		||||
		})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) {
 | 
			
		||||
	return dbm.GetDbConn(dbId, dbName, func() (*dbi.DbInfo, error) {
 | 
			
		||||
func (d *dbAppImpl) GetDbConn(ctx context.Context, dbId uint64, dbName string) (*dbi.DbConn, error) {
 | 
			
		||||
	return dbm.GetDbConn(ctx, dbId, dbName, func() (*dbi.DbInfo, error) {
 | 
			
		||||
		db, err := d.GetById(dbId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errorx.NewBiz("db not found")
 | 
			
		||||
@@ -198,8 +198,8 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) {
 | 
			
		||||
	conn := dbm.GetDbConnByInstanceId(instanceId)
 | 
			
		||||
func (d *dbAppImpl) GetDbConnByInstanceId(ctx context.Context, instanceId uint64) (*dbi.DbConn, error) {
 | 
			
		||||
	conn := dbm.GetDbConnByInstanceId(ctx, instanceId)
 | 
			
		||||
	if conn != nil {
 | 
			
		||||
		return conn, nil
 | 
			
		||||
	}
 | 
			
		||||
@@ -214,7 +214,7 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error
 | 
			
		||||
 | 
			
		||||
	// 使用该实例关联的已配置数据库中的第一个库进行连接并返回
 | 
			
		||||
	firstDb := dbs[0]
 | 
			
		||||
	return d.GetDbConn(firstDb.Id, strings.Split(firstDb.Database, " ")[0])
 | 
			
		||||
	return d.GetDbConn(ctx, firstDb.Id, strings.Split(firstDb.Database, " ")[0])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
 | 
			
		||||
@@ -233,7 +233,7 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
 | 
			
		||||
	dbName := reqParam.DbName
 | 
			
		||||
	tables := reqParam.Tables
 | 
			
		||||
 | 
			
		||||
	dbConn, err := d.GetDbConn(dbId, dbName)
 | 
			
		||||
	dbConn, err := d.GetDbConn(ctx, dbId, dbName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
@@ -110,7 +109,9 @@ func (app *dataSyncAppImpl) AddCronJob(ctx context.Context, taskEntity *entity.D
 | 
			
		||||
		taskId := taskEntity.Id
 | 
			
		||||
		if err := scheduler.AddFunByKey(key, taskEntity.TaskCron, func() {
 | 
			
		||||
			logx.Infof("start the data synchronization task: %d", taskId)
 | 
			
		||||
			if err := app.RunCronJob(ctx, taskId); err != nil {
 | 
			
		||||
			cancelCtx, cancelFunc := context.WithCancel(ctx)
 | 
			
		||||
			defer cancelFunc()
 | 
			
		||||
			if err := app.RunCronJob(cancelCtx, taskId); err != nil {
 | 
			
		||||
				logx.Errorf("the data synchronization task failed to execute at a scheduled time: %s", err.Error())
 | 
			
		||||
			}
 | 
			
		||||
		}); err != nil {
 | 
			
		||||
@@ -150,8 +151,7 @@ func (app *dataSyncAppImpl) RunCronJob(ctx context.Context, id uint64) error {
 | 
			
		||||
				logx.ErrorfContext(ctx, "data source connection unavailable: %s", err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
 | 
			
		||||
			defer dbm.PutDbConn(srcConn)
 | 
			
		||||
			srcConn, err := app.dbApp.GetDbConn(ctx, uint64(task.SrcDbId), task.SrcDbName)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logx.ErrorfContext(ctx, "failed to connect to the source database: %s", err.Error())
 | 
			
		||||
				return
 | 
			
		||||
@@ -205,16 +205,14 @@ func (app *dataSyncAppImpl) doDataSync(ctx context.Context, sql string, task *en
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取源数据库连接
 | 
			
		||||
	srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
 | 
			
		||||
	defer dbm.PutDbConn(srcConn)
 | 
			
		||||
	srcConn, err := app.dbApp.GetDbConn(ctx, uint64(task.SrcDbId), task.SrcDbName)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return syncLog, errorx.NewBiz("failed to connect to the source database: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取目标数据库连接
 | 
			
		||||
	targetConn, err := app.dbApp.GetDbConn(uint64(task.TargetDbId), task.TargetDbName)
 | 
			
		||||
	defer dbm.PutDbConn(targetConn)
 | 
			
		||||
	targetConn, err := app.dbApp.GetDbConn(ctx, uint64(task.TargetDbId), task.TargetDbName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return syncLog, errorx.NewBiz("failed to connect to the target database: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -27,7 +27,7 @@ type Instance interface {
 | 
			
		||||
	// GetPageList 分页获取数据库实例
 | 
			
		||||
	GetPageList(condition *entity.InstanceQuery, orderBy ...string) (*model.PageResult[*entity.DbInstance], error)
 | 
			
		||||
 | 
			
		||||
	TestConn(instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error
 | 
			
		||||
	TestConn(ctx context.Context, instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error
 | 
			
		||||
 | 
			
		||||
	SaveDbInstance(ctx context.Context, instance *dto.SaveDbInstance) (uint64, error)
 | 
			
		||||
 | 
			
		||||
@@ -35,10 +35,10 @@ type Instance interface {
 | 
			
		||||
	Delete(ctx context.Context, id uint64) error
 | 
			
		||||
 | 
			
		||||
	// GetDatabases 获取数据库实例的所有数据库列表
 | 
			
		||||
	GetDatabases(entity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error)
 | 
			
		||||
	GetDatabases(ctx context.Context, entity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error)
 | 
			
		||||
 | 
			
		||||
	// GetDatabasesByAc 根据授权凭证名获取所有数据库名称列表
 | 
			
		||||
	GetDatabasesByAc(acName string) ([]string, error)
 | 
			
		||||
	GetDatabasesByAc(ctx context.Context, acName string) ([]string, error)
 | 
			
		||||
 | 
			
		||||
	// ToDbInfo 根据实例与授权凭证返回对应的DbInfo
 | 
			
		||||
	ToDbInfo(instance *entity.DbInstance, authCertName string, database string) (*dbi.DbInfo, error)
 | 
			
		||||
@@ -59,7 +59,7 @@ func (app *instanceAppImpl) GetPageList(condition *entity.InstanceQuery, orderBy
 | 
			
		||||
	return app.GetRepo().GetInstanceList(condition, orderBy...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error {
 | 
			
		||||
func (app *instanceAppImpl) TestConn(ctx context.Context, instanceEntity *entity.DbInstance, authCert *tagentity.ResourceAuthCert) error {
 | 
			
		||||
	instanceEntity.Network = instanceEntity.GetNetwork()
 | 
			
		||||
 | 
			
		||||
	authCert, err := app.resourceAuthCertApp.GetRealAuthCert(authCert)
 | 
			
		||||
@@ -67,7 +67,7 @@ func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance, authCert
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbConn, err := dbm.Conn(app.toDbInfoByAc(instanceEntity, authCert, ""))
 | 
			
		||||
	dbConn, err := dbm.Conn(ctx, app.toDbInfoByAc(instanceEntity, authCert, ""))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
@@ -185,7 +185,7 @@ func (app *instanceAppImpl) Delete(ctx context.Context, instanceId uint64) error
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error) {
 | 
			
		||||
func (app *instanceAppImpl) GetDatabases(ctx context.Context, ed *entity.DbInstance, authCert *tagentity.ResourceAuthCert) ([]string, error) {
 | 
			
		||||
	if authCert.Id != 0 {
 | 
			
		||||
		// 密文可能被清除,故需要重新获取
 | 
			
		||||
		authCert, _ = app.resourceAuthCertApp.GetAuthCert(authCert.Name)
 | 
			
		||||
@@ -199,10 +199,10 @@ func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance, authCert *tagent
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return app.getDatabases(ed, authCert)
 | 
			
		||||
	return app.getDatabases(ctx, ed, authCert)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) GetDatabasesByAc(acName string) ([]string, error) {
 | 
			
		||||
func (app *instanceAppImpl) GetDatabasesByAc(ctx context.Context, acName string) ([]string, error) {
 | 
			
		||||
	ac, err := app.resourceAuthCertApp.GetAuthCert(acName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errorx.NewBiz("db ac not found")
 | 
			
		||||
@@ -214,7 +214,7 @@ func (app *instanceAppImpl) GetDatabasesByAc(acName string) ([]string, error) {
 | 
			
		||||
		return nil, errorx.NewBiz("the db instance information for this ac does not exist")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return app.getDatabases(instance, ac)
 | 
			
		||||
	return app.getDatabases(ctx, instance, ac)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) ToDbInfo(instance *entity.DbInstance, authCertName string, database string) (*dbi.DbInfo, error) {
 | 
			
		||||
@@ -226,11 +226,11 @@ func (app *instanceAppImpl) ToDbInfo(instance *entity.DbInstance, authCertName s
 | 
			
		||||
	return app.toDbInfoByAc(instance, ac, database), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) getDatabases(instance *entity.DbInstance, ac *tagentity.ResourceAuthCert) ([]string, error) {
 | 
			
		||||
func (app *instanceAppImpl) getDatabases(ctx context.Context, instance *entity.DbInstance, ac *tagentity.ResourceAuthCert) ([]string, error) {
 | 
			
		||||
	instance.Network = instance.GetNetwork()
 | 
			
		||||
	dbi := app.toDbInfoByAc(instance, ac, "")
 | 
			
		||||
 | 
			
		||||
	dbConn, err := dbm.Conn(dbi)
 | 
			
		||||
	dbConn, err := dbm.Conn(ctx, dbi)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/application/dto"
 | 
			
		||||
	"mayfly-go/internal/db/config"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/sqlparser"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/sqlparser/sqlstmt"
 | 
			
		||||
@@ -283,8 +282,7 @@ func (d *dbSqlExecAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *fl
 | 
			
		||||
		return nil, errorx.NewBiz("failed to parse the business form information: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(execSqlBizForm.DbId, execSqlBizForm.DbName)
 | 
			
		||||
	defer dbm.PutDbConn(dbConn)
 | 
			
		||||
	dbConn, err := d.dbApp.GetDbConn(ctx, execSqlBizForm.DbId, execSqlBizForm.DbName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -6,7 +6,6 @@ import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	"mayfly-go/internal/db/application/dto"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/sqlparser"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
@@ -209,8 +208,7 @@ func (app *dbTransferAppImpl) Run(ctx context.Context, taskId uint64, logId uint
 | 
			
		||||
 | 
			
		||||
	// 获取源库连接、目标库连接,判断连接可用性,否则记录日志:xx连接不可用
 | 
			
		||||
	// 获取源库表信息
 | 
			
		||||
	srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
 | 
			
		||||
	defer dbm.PutDbConn(srcConn)
 | 
			
		||||
	srcConn, err := app.dbApp.GetDbConn(ctx, uint64(task.SrcDbId), task.SrcDbName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		app.EndTransfer(ctx, logId, taskId, "failed to obtain source db connection", err, nil)
 | 
			
		||||
		return
 | 
			
		||||
@@ -248,8 +246,7 @@ func (app *dbTransferAppImpl) Run(ctx context.Context, taskId uint64, logId uint
 | 
			
		||||
 | 
			
		||||
func (app *dbTransferAppImpl) transfer2Db(ctx context.Context, taskId uint64, logId uint64, task *entity.DbTransferTask, srcConn *dbi.DbConn, start time.Time, tables []dbi.Table) {
 | 
			
		||||
	// 获取目标库表信息
 | 
			
		||||
	targetConn, err := app.dbApp.GetDbConn(uint64(task.TargetDbId), task.TargetDbName)
 | 
			
		||||
	defer dbm.PutDbConn(targetConn)
 | 
			
		||||
	targetConn, err := app.dbApp.GetDbConn(ctx, uint64(task.TargetDbId), task.TargetDbName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		app.EndTransfer(ctx, logId, taskId, "failed to get target db connection", err, nil)
 | 
			
		||||
		return
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,29 @@ type DbConn struct {
 | 
			
		||||
	db *sql.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************* pool.Conn impl *******************/
 | 
			
		||||
 | 
			
		||||
// 关闭连接
 | 
			
		||||
func (d *DbConn) Close() error {
 | 
			
		||||
	if d.db != nil {
 | 
			
		||||
		logx.Debugf("dbm - conn close, connId: %s", d.Id)
 | 
			
		||||
		if err := d.db.Close(); err != nil {
 | 
			
		||||
			logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		// TODO 关闭实例隧道会影响其他正在使用的连接,所以暂时不关闭
 | 
			
		||||
		//if d.Info.useSshTunnel {
 | 
			
		||||
		//	mcm.CloseSshTunnelMachine(d.Info.SshTunnelMachineId, fmt.Sprintf("db:%d", d.Info.Id))
 | 
			
		||||
		//}
 | 
			
		||||
		d.db = nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *DbConn) Ping() error {
 | 
			
		||||
	return d.db.Ping()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 执行数据库查询返回的列信息
 | 
			
		||||
type QueryColumn struct {
 | 
			
		||||
	Name string `json:"name"` // 列名
 | 
			
		||||
@@ -167,24 +190,6 @@ func (d *DbConn) Stats(ctx context.Context, execSql string, args ...any) sql.DBS
 | 
			
		||||
	return d.db.Stats()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 关闭连接
 | 
			
		||||
func (d *DbConn) Close() {
 | 
			
		||||
	if d.db != nil {
 | 
			
		||||
		if err := d.db.Close(); err != nil {
 | 
			
		||||
			logx.Errorf("关闭数据库实例[%s]连接失败: %s", d.Id, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		// TODO 关闭实例隧道会影响其他正在使用的连接,所以暂时不关闭
 | 
			
		||||
		//if d.Info.useSshTunnel {
 | 
			
		||||
		//	mcm.CloseSshTunnelMachine(d.Info.SshTunnelMachineId, fmt.Sprintf("db:%d", d.Info.Id))
 | 
			
		||||
		//}
 | 
			
		||||
		d.db = nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *DbConn) Ping() error {
 | 
			
		||||
	return d.db.Ping()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 游标方式遍历查询rows, walkFn error不为nil, 则跳出遍历
 | 
			
		||||
func (d *DbConn) walkQueryRows(ctx context.Context, selectSql string, walkFn WalkQueryRowsFunc, args ...any) ([]*QueryColumn, error) {
 | 
			
		||||
	cancelCtx, cancelFunc := context.WithCancel(ctx)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	machineapp "mayfly-go/internal/machine/application"
 | 
			
		||||
	"mayfly-go/internal/machine/mcm"
 | 
			
		||||
@@ -52,7 +53,7 @@ func (di *DbInfo) GetLogDesc() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 连接数据库
 | 
			
		||||
func (di *DbInfo) Conn(meta Meta) (*DbConn, error) {
 | 
			
		||||
func (di *DbInfo) Conn(ctx context.Context, meta Meta) (*DbConn, error) {
 | 
			
		||||
	if meta == nil {
 | 
			
		||||
		return nil, errorx.NewBiz("the database meta information interface cannot be empty")
 | 
			
		||||
	}
 | 
			
		||||
@@ -66,7 +67,7 @@ func (di *DbInfo) Conn(meta Meta) (*DbConn, error) {
 | 
			
		||||
		di.Database = database
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	conn, err := meta.GetSqlDb(di)
 | 
			
		||||
	conn, err := meta.GetSqlDb(ctx, di)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logx.Errorf("db connection failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error())
 | 
			
		||||
		return nil, errorx.NewBiz("db connection failed: %s", err.Error())
 | 
			
		||||
@@ -83,7 +84,7 @@ func (di *DbInfo) Conn(meta Meta) (*DbConn, error) {
 | 
			
		||||
	// 最大连接周期,超过时间的连接就close
 | 
			
		||||
	// conn.SetConnMaxLifetime(100 * time.Second)
 | 
			
		||||
	// 设置最大连接数
 | 
			
		||||
	conn.SetMaxOpenConns(5)
 | 
			
		||||
	conn.SetMaxOpenConns(6)
 | 
			
		||||
	// 设置闲置连接数
 | 
			
		||||
	conn.SetMaxIdleConns(1)
 | 
			
		||||
	dbc.db = conn
 | 
			
		||||
@@ -93,10 +94,10 @@ func (di *DbInfo) Conn(meta Meta) (*DbConn, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 如果使用了ssh隧道,将其host port改变其本地映射host port
 | 
			
		||||
func (di *DbInfo) IfUseSshTunnelChangeIpPort() error {
 | 
			
		||||
func (di *DbInfo) IfUseSshTunnelChangeIpPort(ctx context.Context) error {
 | 
			
		||||
	// 开启ssh隧道
 | 
			
		||||
	if di.SshTunnelMachineId > 0 {
 | 
			
		||||
		sshTunnelMachine, err := GetSshTunnel(di.SshTunnelMachineId)
 | 
			
		||||
		sshTunnelMachine, err := GetSshTunnel(ctx, di.SshTunnelMachineId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@@ -133,8 +134,8 @@ func (di *DbInfo) GetDatabase() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据ssh tunnel机器id返回ssh tunnel
 | 
			
		||||
func GetSshTunnel(sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) {
 | 
			
		||||
	return machineapp.GetMachineApp().GetSshTunnelMachine(sshTunnelMachineId)
 | 
			
		||||
func GetSshTunnel(ctx context.Context, sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) {
 | 
			
		||||
	return machineapp.GetMachineApp().GetSshTunnelMachine(ctx, sshTunnelMachineId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取连接id
 | 
			
		||||
@@ -143,5 +144,5 @@ func GetDbConnId(dbId uint64, db string) string {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("%d:%s", dbId, db)
 | 
			
		||||
	return fmt.Sprintf("db-%d:%s", dbId, db)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -14,7 +15,7 @@ type DbVersion string
 | 
			
		||||
// Meta 数据库元信息,如获取sql.DB、Dialect等
 | 
			
		||||
type Meta interface {
 | 
			
		||||
	// GetSqlDb 根据数据库信息获取sql.DB
 | 
			
		||||
	GetSqlDb(*DbInfo) (*sql.DB, error)
 | 
			
		||||
	GetSqlDb(context.Context, *DbInfo) (*sql.DB, error)
 | 
			
		||||
 | 
			
		||||
	// GetDialect 获取数据库方言, 若一些接口不需要DbConn,则可以传nil
 | 
			
		||||
	GetDialect(*DbConn) Dialect
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,7 @@
 | 
			
		||||
package dbm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"context"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	_ "mayfly-go/internal/db/dbm/dm"
 | 
			
		||||
	_ "mayfly-go/internal/db/dbm/mssql"
 | 
			
		||||
@@ -11,103 +11,54 @@ import (
 | 
			
		||||
	_ "mayfly-go/internal/db/dbm/sqlite"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
	"mayfly-go/pkg/pool"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var connPool = make(map[string]pool.Pool)
 | 
			
		||||
var instPool = make(map[uint64]pool.Pool)
 | 
			
		||||
var (
 | 
			
		||||
	poolGroup = pool.NewPoolGroup[*dbi.DbConn]()
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PutDbConn 释放连接
 | 
			
		||||
func PutDbConn(c *dbi.DbConn) {
 | 
			
		||||
	if nil == c {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	connId := dbi.GetDbConnId(c.Info.Id, c.Info.Database)
 | 
			
		||||
	if p, ok := connPool[connId]; ok {
 | 
			
		||||
		p.Put(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPool(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (pool.Pool, error) {
 | 
			
		||||
// GetDbConn 从连接池中获取连接信息,记的用完连接后必须调用 PutDbConn 还回池
 | 
			
		||||
func GetDbConn(ctx context.Context, dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) {
 | 
			
		||||
	connId := dbi.GetDbConnId(dbId, database)
 | 
			
		||||
 | 
			
		||||
	// 获取连接池,如果没有,则创建一个
 | 
			
		||||
	if p, ok := connPool[connId]; !ok {
 | 
			
		||||
		var err error
 | 
			
		||||
		p, err = pool.NewChannelPool(&pool.Config{
 | 
			
		||||
			InitialCap:  1,                //资源池初始连接数
 | 
			
		||||
			MaxCap:      10,               //最大空闲连接数
 | 
			
		||||
			MaxIdle:     10,               //最大并发连接数
 | 
			
		||||
			IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
 | 
			
		||||
			Factory: func() (interface{}, error) {
 | 
			
		||||
	pool, err := poolGroup.GetCachePool(connId, func() (*dbi.DbConn, error) {
 | 
			
		||||
		// 若缓存中不存在,则从回调函数中获取DbInfo
 | 
			
		||||
		dbInfo, err := getDbInfo()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logx.Debugf("dbm - conn create, connId: %s, dbInfo: %v", connId, dbInfo)
 | 
			
		||||
		// 连接数据库
 | 
			
		||||
				return Conn(dbInfo)
 | 
			
		||||
			},
 | 
			
		||||
			Close: func(v interface{}) error {
 | 
			
		||||
				v.(*dbi.DbConn).Close()
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
			Ping: func(v interface{}) error {
 | 
			
		||||
				return v.(*dbi.DbConn).Ping()
 | 
			
		||||
			},
 | 
			
		||||
		return Conn(ctx, dbInfo)
 | 
			
		||||
	})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		connPool[connId] = p
 | 
			
		||||
		instPool[dbId] = p
 | 
			
		||||
		return p, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetDbConn 从连接池中获取连接信息,记的用完连接后必须调用 PutDbConn 还回池
 | 
			
		||||
func GetDbConn(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) {
 | 
			
		||||
 | 
			
		||||
	p, err := getPool(dbId, database, getDbInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	// 从连接池中获取一个可用的连接
 | 
			
		||||
	c, err := p.Get()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ec := c.(*dbi.DbConn)
 | 
			
		||||
	return ec, nil
 | 
			
		||||
 | 
			
		||||
	return pool.Get(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 使用指定dbInfo信息进行连接
 | 
			
		||||
func Conn(di *dbi.DbInfo) (*dbi.DbConn, error) {
 | 
			
		||||
	return di.Conn(dbi.GetMeta(di.Type))
 | 
			
		||||
func Conn(ctx context.Context, di *dbi.DbInfo) (*dbi.DbConn, error) {
 | 
			
		||||
	return di.Conn(ctx, dbi.GetMeta(di.Type))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据实例id获取连接
 | 
			
		||||
func GetDbConnByInstanceId(instanceId uint64) *dbi.DbConn {
 | 
			
		||||
	if p, ok := instPool[instanceId]; ok {
 | 
			
		||||
		c, err := p.Get()
 | 
			
		||||
func GetDbConnByInstanceId(ctx context.Context, instanceId uint64) *dbi.DbConn {
 | 
			
		||||
	for _, pool := range poolGroup.AllPool() {
 | 
			
		||||
		conn, err := pool.Get(ctx)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logx.Error(fmt.Sprintf("实例id[%d]连接获取失败:%s", instanceId, err))
 | 
			
		||||
			return nil
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if conn.Info.InstanceId == instanceId {
 | 
			
		||||
			return conn
 | 
			
		||||
		}
 | 
			
		||||
		return c.(*dbi.DbConn)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 删除db缓存并关闭该数据库所有连接
 | 
			
		||||
func CloseDb(dbId uint64, db string) {
 | 
			
		||||
	delete(connPool, dbi.GetDbConnId(dbId, db))
 | 
			
		||||
	delete(instPool, dbId)
 | 
			
		||||
	poolGroup.Close(dbi.GetDbConnId(dbId, db))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package dm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
@@ -20,7 +21,7 @@ const (
 | 
			
		||||
type Meta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
func (dm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	driverName := "dm"
 | 
			
		||||
	db := d.Database
 | 
			
		||||
	var dbParam string
 | 
			
		||||
@@ -36,7 +37,7 @@ func (dm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
		dbParam = "?escapeProcess=true"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort()
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package mssql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
@@ -23,8 +24,8 @@ const (
 | 
			
		||||
type Meta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort()
 | 
			
		||||
func (mm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,10 +25,10 @@ const (
 | 
			
		||||
type Meta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
func (mm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	// SSH Conect
 | 
			
		||||
	if d.SshTunnelMachineId > 0 {
 | 
			
		||||
		sshTunnelMachine, err := dbi.GetSshTunnel(d.SshTunnelMachineId)
 | 
			
		||||
		sshTunnelMachine, err := dbi.GetSshTunnel(ctx, d.SshTunnelMachineId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package oracle
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
@@ -23,8 +24,8 @@ const (
 | 
			
		||||
type Meta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (om *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort()
 | 
			
		||||
func (om *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package postgres
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -37,7 +38,7 @@ type Meta struct {
 | 
			
		||||
	Param string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
func (pm *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	driverName := "postgres"
 | 
			
		||||
	// SSH Conect
 | 
			
		||||
	if d.SshTunnelMachineId > 0 {
 | 
			
		||||
@@ -120,7 +121,8 @@ func (pd *PqSqlDialer) Open(name string) (driver.Conn, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
 | 
			
		||||
	sshTunnel, err := dbi.GetSshTunnel(pd.sshTunnelMachineId)
 | 
			
		||||
	// todo context.Background可能存在问题
 | 
			
		||||
	sshTunnel, err := dbi.GetSshTunnel(context.Background(), pd.sshTunnelMachineId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package sqlite
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
@@ -19,7 +20,7 @@ const (
 | 
			
		||||
type Meta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
func (md *Meta) GetSqlDb(ctx context.Context, d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	// 用host字段来存sqlite的文件路径
 | 
			
		||||
	// 检查文件是否存在,否则报错,基于sqlite会自动创建文件,为了服务器文件安全,所以先确定文件存在再连接,不自动创建
 | 
			
		||||
	if _, err := os.Stat(d.Host); err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/may-fly/cast"
 | 
			
		||||
	"mayfly-go/internal/es/api/form"
 | 
			
		||||
	"mayfly-go/internal/es/api/vo"
 | 
			
		||||
	"mayfly-go/internal/es/application"
 | 
			
		||||
@@ -19,6 +18,8 @@ import (
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/may-fly/cast"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Instance struct {
 | 
			
		||||
@@ -99,7 +100,7 @@ func (d *Instance) TestConn(rc *req.Ctx) {
 | 
			
		||||
		ac = fm.AuthCerts[0]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res, err := d.inst.TestConn(instance, ac)
 | 
			
		||||
	res, err := d.inst.TestConn(rc.MetaCtx, instance, ac)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	rc.ResData = res
 | 
			
		||||
}
 | 
			
		||||
@@ -133,7 +134,7 @@ func (d *Instance) Proxy(rc *req.Ctx) {
 | 
			
		||||
	r := rc.GetRequest()
 | 
			
		||||
	_ = RemoveQueryParam(r, "id", "path")
 | 
			
		||||
 | 
			
		||||
	err := d.inst.DoConn(instanceId, func(conn *esi.EsConn) error {
 | 
			
		||||
	err := d.inst.DoConn(rc.MetaCtx, instanceId, func(conn *esi.EsConn) error {
 | 
			
		||||
		conn.Proxy(rc.GetWriter(), r, path)
 | 
			
		||||
		return nil
 | 
			
		||||
	})
 | 
			
		||||
 
 | 
			
		||||
@@ -19,7 +19,6 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/utils/collx"
 | 
			
		||||
	"mayfly-go/pkg/utils/stringx"
 | 
			
		||||
	"mayfly-go/pkg/utils/structx"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Instance interface {
 | 
			
		||||
@@ -28,9 +27,9 @@ type Instance interface {
 | 
			
		||||
	GetPageList(condition *entity.InstanceQuery, orderBy ...string) (*model.PageResult[*entity.EsInstance], error)
 | 
			
		||||
 | 
			
		||||
	// DoConn 获取连接并执行函数
 | 
			
		||||
	DoConn(instanceId uint64, fn func(*esi.EsConn) error) error
 | 
			
		||||
	DoConn(ctx context.Context, instanceId uint64, fn func(*esi.EsConn) error) error
 | 
			
		||||
 | 
			
		||||
	TestConn(instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error)
 | 
			
		||||
	TestConn(ctx context.Context, instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error)
 | 
			
		||||
 | 
			
		||||
	SaveInst(ctx context.Context, d *dto.SaveEsInstance) (uint64, error)
 | 
			
		||||
 | 
			
		||||
@@ -39,7 +38,7 @@ type Instance interface {
 | 
			
		||||
 | 
			
		||||
var _ Instance = &instanceAppImpl{}
 | 
			
		||||
 | 
			
		||||
var connPool = make(map[uint64]pool.Pool)
 | 
			
		||||
var poolGroup = pool.NewPoolGroup[*esi.EsConn]()
 | 
			
		||||
 | 
			
		||||
type instanceAppImpl struct {
 | 
			
		||||
	base.AppImpl[*entity.EsInstance, repository.EsInstance]
 | 
			
		||||
@@ -53,54 +52,24 @@ func (app *instanceAppImpl) GetPageList(condition *entity.InstanceQuery, orderBy
 | 
			
		||||
	return app.GetRepo().GetInstanceList(condition, orderBy...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) DoConn(instanceId uint64, fn func(*esi.EsConn) error) error {
 | 
			
		||||
	// 通过实例id获取实例连接信息
 | 
			
		||||
	p, err := app.getPool(instanceId)
 | 
			
		||||
func (app *instanceAppImpl) DoConn(ctx context.Context, instanceId uint64, fn func(*esi.EsConn) error) error {
 | 
			
		||||
	p, err := poolGroup.GetChanPool(fmt.Sprintf("es-%d", instanceId), func() (*esi.EsConn, error) {
 | 
			
		||||
		return app.createConn(ctx, instanceId)
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// 从连接池中获取一个可用的连接
 | 
			
		||||
	c, err := p.Get()
 | 
			
		||||
	c, err := p.Get(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	ec := c.(*esi.EsConn)
 | 
			
		||||
 | 
			
		||||
	// 用完后放回连接池
 | 
			
		||||
	defer p.Put(c)
 | 
			
		||||
 | 
			
		||||
	return fn(ec)
 | 
			
		||||
	return fn(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) getPool(instanceId uint64) (pool.Pool, error) {
 | 
			
		||||
	// 获取连接池,如果没有,则创建一个
 | 
			
		||||
	if p, ok := connPool[instanceId]; !ok {
 | 
			
		||||
		var err error
 | 
			
		||||
		p, err = pool.NewChannelPool(&pool.Config{
 | 
			
		||||
			InitialCap:  1,                //资源池初始连接数
 | 
			
		||||
			MaxCap:      10,               //最大空闲连接数
 | 
			
		||||
			MaxIdle:     10,               //最大并发连接数
 | 
			
		||||
			IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
 | 
			
		||||
			Factory: func() (interface{}, error) {
 | 
			
		||||
				return app.createConn(instanceId)
 | 
			
		||||
			},
 | 
			
		||||
			Close: func(v interface{}) error {
 | 
			
		||||
				return v.(*esi.EsConn).Close()
 | 
			
		||||
			},
 | 
			
		||||
			Ping: func(v interface{}) error {
 | 
			
		||||
				return v.(*esi.EsConn).Ping()
 | 
			
		||||
			},
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		connPool[instanceId] = p
 | 
			
		||||
		return p, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
func (app *instanceAppImpl) createConn(instanceId uint64) (*esi.EsConn, error) {
 | 
			
		||||
func (app *instanceAppImpl) createConn(ctx context.Context, instanceId uint64) (*esi.EsConn, error) {
 | 
			
		||||
	// 缓存不存在,则重新连接
 | 
			
		||||
	instance, err := app.GetById(instanceId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -113,7 +82,7 @@ func (app *instanceAppImpl) createConn(instanceId uint64) (*esi.EsConn, error) {
 | 
			
		||||
	}
 | 
			
		||||
	ei.CodePath = app.tagApp.ListTagPathByTypeAndCode(int8(tagentity.TagTypeEsInstance), instance.Code)
 | 
			
		||||
 | 
			
		||||
	conn, _, err := ei.Conn()
 | 
			
		||||
	conn, _, err := ei.Conn(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -151,7 +120,7 @@ func (app *instanceAppImpl) ToEsInfo(instance *entity.EsInstance, ac *tagentity.
 | 
			
		||||
	return ei, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) TestConn(instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error) {
 | 
			
		||||
func (app *instanceAppImpl) TestConn(ctx context.Context, instance *entity.EsInstance, ac *tagentity.ResourceAuthCert) (map[string]any, error) {
 | 
			
		||||
	instance.Network = instance.GetNetwork()
 | 
			
		||||
 | 
			
		||||
	ei, err := app.ToEsInfo(instance, ac)
 | 
			
		||||
@@ -159,7 +128,7 @@ func (app *instanceAppImpl) TestConn(instance *entity.EsInstance, ac *tagentity.
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, res, err := ei.Conn()
 | 
			
		||||
	_, res, err := ei.Conn(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,6 +16,21 @@ type EsConn struct {
 | 
			
		||||
	proxy *httputil.ReverseProxy
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************* pool.Conn impl *******************/
 | 
			
		||||
 | 
			
		||||
func (d *EsConn) Close() error {
 | 
			
		||||
	// 如果是使用了ssh隧道转发,则需要手动将其关闭
 | 
			
		||||
	if d.Info.useSshTunnel {
 | 
			
		||||
		mcm.CloseSshTunnelMachine(uint64(d.Info.SshTunnelMachineId), fmt.Sprintf("es:%d", d.Id))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *EsConn) Ping() error {
 | 
			
		||||
	_, err := d.Info.Ping()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// StartProxy 开始代理
 | 
			
		||||
func (d *EsConn) StartProxy() error {
 | 
			
		||||
	// 目标 URL
 | 
			
		||||
@@ -40,16 +55,3 @@ func (d *EsConn) Proxy(w http.ResponseWriter, r *http.Request, path string) {
 | 
			
		||||
	r.Header.Set("Accept", "application/json")
 | 
			
		||||
	d.proxy.ServeHTTP(w, r)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *EsConn) Close() error {
 | 
			
		||||
	// 如果是使用了ssh隧道转发,则需要手动将其关闭
 | 
			
		||||
	if d.Info.useSshTunnel {
 | 
			
		||||
		mcm.CloseSshTunnelMachine(uint64(d.Info.SshTunnelMachineId), fmt.Sprintf("es:%d", d.Id))
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *EsConn) Ping() error {
 | 
			
		||||
	_, err := d.Info.Ping()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package esi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	machineapp "mayfly-go/internal/machine/application"
 | 
			
		||||
@@ -46,7 +47,7 @@ func (di *EsInfo) GetLogDesc() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 连接数据库
 | 
			
		||||
func (di *EsInfo) Conn() (*EsConn, map[string]any, error) {
 | 
			
		||||
func (di *EsInfo) Conn(ctx context.Context) (*EsConn, map[string]any, error) {
 | 
			
		||||
	// 使用basic加密用户名和密码
 | 
			
		||||
	if di.Username != "" && di.Password != "" {
 | 
			
		||||
		encodeString := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%s:%s", di.Username, di.Password)))
 | 
			
		||||
@@ -54,7 +55,7 @@ func (di *EsInfo) Conn() (*EsConn, map[string]any, error) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 使用ssh隧道
 | 
			
		||||
	err := di.IfUseSshTunnelChangeIpPort()
 | 
			
		||||
	err := di.IfUseSshTunnelChangeIpPort(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logx.Errorf("es ssh failed: %s, err:%s", di.baseUrl, err.Error())
 | 
			
		||||
		return nil, nil, errorx.NewBiz("es ssh failed: %s", err.Error())
 | 
			
		||||
@@ -115,10 +116,10 @@ func (di *EsInfo) ExecApi(method, path string, data any, timeoutSecond ...int) (
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 如果使用了ssh隧道,将其host port改变其本地映射host port
 | 
			
		||||
func (di *EsInfo) IfUseSshTunnelChangeIpPort() error {
 | 
			
		||||
func (di *EsInfo) IfUseSshTunnelChangeIpPort(ctx context.Context) error {
 | 
			
		||||
	// 开启ssh隧道
 | 
			
		||||
	if di.SshTunnelMachineId > 0 {
 | 
			
		||||
		stm, err := GetSshTunnel(di.SshTunnelMachineId)
 | 
			
		||||
		stm, err := GetSshTunnel(ctx, di.SshTunnelMachineId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
@@ -137,6 +138,6 @@ func (di *EsInfo) IfUseSshTunnelChangeIpPort() error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据ssh tunnel机器id返回ssh tunnel
 | 
			
		||||
func GetSshTunnel(sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) {
 | 
			
		||||
	return machineapp.GetMachineApp().GetSshTunnelMachine(sshTunnelMachineId)
 | 
			
		||||
func GetSshTunnel(ctx context.Context, sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) {
 | 
			
		||||
	return machineapp.GetMachineApp().GetSshTunnelMachine(ctx, sshTunnelMachineId)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -135,9 +135,8 @@ func (m *Machine) SimpleMachieInfo(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Machine) MachineStats(rc *req.Ctx) {
 | 
			
		||||
	cli, err := m.machineApp.GetCli(GetMachineId(rc))
 | 
			
		||||
	cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc))
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "connection error: %s")
 | 
			
		||||
	defer mcm.PutMachineCli(cli)
 | 
			
		||||
 | 
			
		||||
	rc.ResData = cli.GetAllStats()
 | 
			
		||||
}
 | 
			
		||||
@@ -160,7 +159,7 @@ func (m *Machine) TestConn(rc *req.Ctx) {
 | 
			
		||||
	machineForm := new(form.MachineForm)
 | 
			
		||||
	me := req.BindJsonAndCopyTo(rc, machineForm, new(entity.Machine))
 | 
			
		||||
	// 测试连接
 | 
			
		||||
	biz.ErrIsNilAppendErr(m.machineApp.TestConn(me, machineForm.AuthCerts[0]), "connection error: %s")
 | 
			
		||||
	biz.ErrIsNilAppendErr(m.machineApp.TestConn(rc.MetaCtx, me, machineForm.AuthCerts[0]), "connection error: %s")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Machine) ChangeStatus(rc *req.Ctx) {
 | 
			
		||||
@@ -198,9 +197,8 @@ func (m *Machine) GetProcess(rc *req.Ctx) {
 | 
			
		||||
	count := rc.QueryIntDefault("count", 10)
 | 
			
		||||
	cmd += "| head -n " + fmt.Sprintf("%d", count)
 | 
			
		||||
 | 
			
		||||
	cli, err := m.machineApp.GetCli(GetMachineId(rc))
 | 
			
		||||
	cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc))
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "connection error: %s")
 | 
			
		||||
	defer mcm.PutMachineCli(cli)
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNilAppendErr(m.tagTreeApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), "%s")
 | 
			
		||||
 | 
			
		||||
@@ -214,9 +212,8 @@ func (m *Machine) KillProcess(rc *req.Ctx) {
 | 
			
		||||
	pid := rc.Query("pid")
 | 
			
		||||
	biz.NotEmpty(pid, "pid cannot be empty")
 | 
			
		||||
 | 
			
		||||
	cli, err := m.machineApp.GetCli(GetMachineId(rc))
 | 
			
		||||
	cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc))
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "connection error: %s")
 | 
			
		||||
	defer mcm.PutMachineCli(cli)
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNilAppendErr(m.tagTreeApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), "%s")
 | 
			
		||||
 | 
			
		||||
@@ -225,9 +222,8 @@ func (m *Machine) KillProcess(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Machine) GetUsers(rc *req.Ctx) {
 | 
			
		||||
	cli, err := m.machineApp.GetCli(GetMachineId(rc))
 | 
			
		||||
	cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc))
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "connection error: %s")
 | 
			
		||||
	defer mcm.PutMachineCli(cli)
 | 
			
		||||
 | 
			
		||||
	res, err := cli.GetUsers()
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
@@ -235,9 +231,8 @@ func (m *Machine) GetUsers(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Machine) GetGroups(rc *req.Ctx) {
 | 
			
		||||
	cli, err := m.machineApp.GetCli(GetMachineId(rc))
 | 
			
		||||
	cli, err := m.machineApp.GetCli(rc.MetaCtx, GetMachineId(rc))
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "connection error: %s")
 | 
			
		||||
	defer mcm.PutMachineCli(cli)
 | 
			
		||||
 | 
			
		||||
	res, err := cli.GetGroups()
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
@@ -262,12 +257,9 @@ func (m *Machine) WsSSH(rc *req.Ctx) {
 | 
			
		||||
	err = req.PermissionHandler(rc)
 | 
			
		||||
	biz.ErrIsNil(err, mcm.GetErrorContentRn("You do not have permission to operate the machine terminal, please log in again and try again ~"))
 | 
			
		||||
 | 
			
		||||
	cli, err := m.machineApp.GetCliByAc(GetMachineAc(rc))
 | 
			
		||||
	cli, err := m.machineApp.NewCli(rc.MetaCtx, GetMachineAc(rc))
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, mcm.GetErrorContentRn("connection error: %s"))
 | 
			
		||||
	defer func() {
 | 
			
		||||
		cli.Close()
 | 
			
		||||
		mcm.PutMachineCli(cli)
 | 
			
		||||
	}()
 | 
			
		||||
	defer cli.Close()
 | 
			
		||||
	biz.ErrIsNilAppendErr(m.tagTreeApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), mcm.GetErrorContentRn("%s"))
 | 
			
		||||
 | 
			
		||||
	global.EventBus.Publish(rc.MetaCtx, event.EventTopicResourceOp, cli.Info.CodePath[0])
 | 
			
		||||
@@ -327,7 +319,7 @@ func (m *Machine) WsGuacamole(rc *req.Ctx) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = mi.IfUseSshTunnelChangeIpPort(true)
 | 
			
		||||
	err = mi.IfUseSshTunnelChangeIpPort(rc.MetaCtx, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -326,9 +326,8 @@ func (m *MachineFile) UploadFolder(rc *req.Ctx) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	folderName := filepath.Dir(paths[0])
 | 
			
		||||
	mcli, err := m.machineFileApp.GetMachineCli(authCertName)
 | 
			
		||||
	mcli, err := m.machineFileApp.GetMachineCli(rc.MetaCtx, authCertName)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mcm.PutMachineCli(mcli)
 | 
			
		||||
 | 
			
		||||
	mi := mcli.Info
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"mayfly-go/internal/machine/api/vo"
 | 
			
		||||
	"mayfly-go/internal/machine/application"
 | 
			
		||||
	"mayfly-go/internal/machine/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/machine/mcm"
 | 
			
		||||
	tagapp "mayfly-go/internal/tag/application"
 | 
			
		||||
	"mayfly-go/pkg/biz"
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
@@ -78,9 +77,8 @@ func (m *MachineScript) RunMachineScript(rc *req.Ctx) {
 | 
			
		||||
		script, err = stringx.TemplateParse(ms.Script, p)
 | 
			
		||||
		biz.ErrIsNilAppendErr(err, "failed to parse the script template parameter: %s")
 | 
			
		||||
	}
 | 
			
		||||
	cli, err := m.machineApp.GetCliByAc(ac)
 | 
			
		||||
	cli, err := m.machineApp.GetCliByAc(rc.MetaCtx, ac)
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "connection error: %s")
 | 
			
		||||
	defer mcm.PutMachineCli(cli)
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNilAppendErr(m.tagApp.CanAccess(rc.GetLoginAccount().Id, cli.Info.CodePath...), "%s")
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -27,7 +27,7 @@ type Machine interface {
 | 
			
		||||
	SaveMachine(ctx context.Context, param *dto.SaveMachine) error
 | 
			
		||||
 | 
			
		||||
	// 测试机器连接
 | 
			
		||||
	TestConn(me *entity.Machine, authCert *tagentity.ResourceAuthCert) error
 | 
			
		||||
	TestConn(ctx context.Context, me *entity.Machine, authCert *tagentity.ResourceAuthCert) error
 | 
			
		||||
 | 
			
		||||
	// 调整机器状态
 | 
			
		||||
	ChangeStatus(ctx context.Context, id uint64, status int8) error
 | 
			
		||||
@@ -38,16 +38,16 @@ type Machine interface {
 | 
			
		||||
	GetMachineList(condition *entity.MachineQuery, orderBy ...string) (*model.PageResult[*entity.Machine], error)
 | 
			
		||||
 | 
			
		||||
	// 新建机器客户端连接(需手动调用Close)
 | 
			
		||||
	NewCli(authCertName string) (*mcm.Cli, error)
 | 
			
		||||
	NewCli(ctx context.Context, authCertName string) (*mcm.Cli, error)
 | 
			
		||||
 | 
			
		||||
	// 获取已缓存的机器连接,若不存在则新建客户端连接并缓存,主要用于定时获取状态等(避免频繁创建连接)
 | 
			
		||||
	GetCli(id uint64) (*mcm.Cli, error)
 | 
			
		||||
	GetCli(ctx context.Context, id uint64) (*mcm.Cli, error)
 | 
			
		||||
 | 
			
		||||
	// 根据授权凭证获取客户端连接
 | 
			
		||||
	GetCliByAc(authCertName string) (*mcm.Cli, error)
 | 
			
		||||
	GetCliByAc(ctx context.Context, authCertName string) (*mcm.Cli, error)
 | 
			
		||||
 | 
			
		||||
	// 获取ssh隧道机器连接
 | 
			
		||||
	GetSshTunnelMachine(id int) (*mcm.SshTunnelMachine, error)
 | 
			
		||||
	GetSshTunnelMachine(ctx context.Context, id int) (*mcm.SshTunnelMachine, error)
 | 
			
		||||
 | 
			
		||||
	// 定时更新机器状态信息
 | 
			
		||||
	TimerUpdateStats()
 | 
			
		||||
@@ -159,7 +159,7 @@ func (m *machineAppImpl) SaveMachine(ctx context.Context, param *dto.SaveMachine
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *machineAppImpl) TestConn(me *entity.Machine, authCert *tagentity.ResourceAuthCert) error {
 | 
			
		||||
func (m *machineAppImpl) TestConn(ctx context.Context, me *entity.Machine, authCert *tagentity.ResourceAuthCert) error {
 | 
			
		||||
	me.Id = 0
 | 
			
		||||
 | 
			
		||||
	authCert, err := m.resourceAuthCertApp.GetRealAuthCert(authCert)
 | 
			
		||||
@@ -171,7 +171,7 @@ func (m *machineAppImpl) TestConn(me *entity.Machine, authCert *tagentity.Resour
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	cli, err := mi.Conn()
 | 
			
		||||
	cli, err := mi.Conn(ctx)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
@@ -224,30 +224,30 @@ func (m *machineAppImpl) Delete(ctx context.Context, id uint64) error {
 | 
			
		||||
		})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *machineAppImpl) NewCli(authCertName string) (*mcm.Cli, error) {
 | 
			
		||||
func (m *machineAppImpl) NewCli(ctx context.Context, authCertName string) (*mcm.Cli, error) {
 | 
			
		||||
	if mi, err := m.ToMachineInfoByAc(authCertName); err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	} else {
 | 
			
		||||
		return mi.Conn()
 | 
			
		||||
		return mi.Conn(ctx)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *machineAppImpl) GetCliByAc(authCertName string) (*mcm.Cli, error) {
 | 
			
		||||
	return mcm.GetMachineCli(authCertName, func(ac string) (*mcm.MachineInfo, error) {
 | 
			
		||||
func (m *machineAppImpl) GetCliByAc(ctx context.Context, authCertName string) (*mcm.Cli, error) {
 | 
			
		||||
	return mcm.GetMachineCli(ctx, authCertName, func(ac string) (*mcm.MachineInfo, error) {
 | 
			
		||||
		return m.ToMachineInfoByAc(ac)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *machineAppImpl) GetCli(machineId uint64) (*mcm.Cli, error) {
 | 
			
		||||
func (m *machineAppImpl) GetCli(ctx context.Context, machineId uint64) (*mcm.Cli, error) {
 | 
			
		||||
	_, authCert, err := m.getMachineAndAuthCert(machineId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return m.GetCliByAc(authCert.Name)
 | 
			
		||||
	return m.GetCliByAc(ctx, authCert.Name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *machineAppImpl) GetSshTunnelMachine(machineId int) (*mcm.SshTunnelMachine, error) {
 | 
			
		||||
	return mcm.GetSshTunnelMachine(machineId, func(mid uint64) (*mcm.MachineInfo, error) {
 | 
			
		||||
func (m *machineAppImpl) GetSshTunnelMachine(ctx context.Context, machineId int) (*mcm.SshTunnelMachine, error) {
 | 
			
		||||
	return mcm.GetSshTunnelMachine(ctx, machineId, func(mid uint64) (*mcm.MachineInfo, error) {
 | 
			
		||||
		return m.ToMachineInfoById(mid)
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
@@ -264,7 +264,9 @@ func (m *machineAppImpl) TimerUpdateStats() {
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
				logx.Debugf("time to get machine [id=%d] status information start", mid)
 | 
			
		||||
				cli, err := m.GetCli(mid)
 | 
			
		||||
				ctx, cancelFunc := context.WithCancel(context.Background())
 | 
			
		||||
				defer cancelFunc()
 | 
			
		||||
				cli, err := m.GetCli(ctx, mid)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logx.Errorf("failed to get machine [id=%d] status information periodically, failed to get machine cli: %s", mid, err.Error())
 | 
			
		||||
					return
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,6 @@ import (
 | 
			
		||||
	"mayfly-go/internal/machine/application/dto"
 | 
			
		||||
	"mayfly-go/internal/machine/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/machine/domain/repository"
 | 
			
		||||
	"mayfly-go/internal/machine/mcm"
 | 
			
		||||
	tagapp "mayfly-go/internal/tag/application"
 | 
			
		||||
	tagentity "mayfly-go/internal/tag/domain/entity"
 | 
			
		||||
	"mayfly-go/pkg/base"
 | 
			
		||||
@@ -179,14 +178,14 @@ func (m *machineCronJobAppImpl) runCronJob0(mid uint64, cronJob *entity.MachineC
 | 
			
		||||
		ExecTime:  time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	machineCli, err := m.machineApp.GetCli(mid)
 | 
			
		||||
	ctx, cancelFunc := context.WithCancel(context.Background())
 | 
			
		||||
	defer cancelFunc()
 | 
			
		||||
	machineCli, err := m.machineApp.GetCli(ctx, mid)
 | 
			
		||||
	res := ""
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		machine, _ := m.machineApp.GetById(mid)
 | 
			
		||||
		execRes.MachineCode = machine.Code
 | 
			
		||||
	} else {
 | 
			
		||||
		defer mcm.PutMachineCli(machineCli)
 | 
			
		||||
 | 
			
		||||
		execRes.MachineCode = machineCli.Info.Code
 | 
			
		||||
		res, err = machineCli.Run(cronJob.Script)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -38,7 +38,7 @@ type MachineFile interface {
 | 
			
		||||
	Save(ctx context.Context, entity *entity.MachineFile) error
 | 
			
		||||
 | 
			
		||||
	// 获取机器cli
 | 
			
		||||
	GetMachineCli(authCertName string) (*mcm.Cli, error)
 | 
			
		||||
	GetMachineCli(ctx context.Context, authCertName string) (*mcm.Cli, error)
 | 
			
		||||
 | 
			
		||||
	GetRdpFilePath(ua *model.LoginAccount, path string) string
 | 
			
		||||
 | 
			
		||||
@@ -86,6 +86,8 @@ type machineFileAppImpl struct {
 | 
			
		||||
	machineApp Machine `inject:"T"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ MachineFile = (*machineFileAppImpl)(nil)
 | 
			
		||||
 | 
			
		||||
// 注入MachineFileRepo
 | 
			
		||||
func (m *machineFileAppImpl) InjectMachineFileRepo(repo repository.MachineFile) {
 | 
			
		||||
	m.Repo = repo
 | 
			
		||||
@@ -134,7 +136,7 @@ func (m *machineFileAppImpl) ReadDir(ctx context.Context, opParam *dto.MachineFi
 | 
			
		||||
		}), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, sftpCli, err := m.GetMachineSftpCli(opParam)
 | 
			
		||||
	_, sftpCli, err := m.GetMachineSftpCli(ctx, opParam)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -166,11 +168,10 @@ func (m *machineFileAppImpl) GetDirSize(ctx context.Context, opParam *dto.Machin
 | 
			
		||||
		return bytex.FormatSize(totalSize), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mcli, err := m.GetMachineCli(opParam.AuthCertName)
 | 
			
		||||
	mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	defer mcm.PutMachineCli(mcli)
 | 
			
		||||
 | 
			
		||||
	res, err := mcli.Run(fmt.Sprintf("du -sh %s", path))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -200,11 +201,10 @@ func (m *machineFileAppImpl) FileStat(ctx context.Context, opParam *dto.MachineF
 | 
			
		||||
		return fmt.Sprintf("%v", stat), err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mcli, err := m.GetMachineCli(opParam.AuthCertName)
 | 
			
		||||
	mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	defer mcm.PutMachineCli(mcli)
 | 
			
		||||
 | 
			
		||||
	return mcli.Run(fmt.Sprintf("stat -L %s", path))
 | 
			
		||||
}
 | 
			
		||||
@@ -221,7 +221,7 @@ func (m *machineFileAppImpl) MkDir(ctx context.Context, opParam *dto.MachineFile
 | 
			
		||||
		return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(opParam)
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -241,7 +241,7 @@ func (m *machineFileAppImpl) CreateFile(ctx context.Context, opParam *dto.Machin
 | 
			
		||||
		return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(opParam)
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -254,7 +254,7 @@ func (m *machineFileAppImpl) CreateFile(ctx context.Context, opParam *dto.Machin
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *machineFileAppImpl) ReadFile(ctx context.Context, opParam *dto.MachineFileOp) (*sftp.File, *mcm.MachineInfo, error) {
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(opParam)
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -278,7 +278,7 @@ func (m *machineFileAppImpl) WriteFileContent(ctx context.Context, opParam *dto.
 | 
			
		||||
		return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(opParam)
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -313,7 +313,7 @@ func (m *machineFileAppImpl) UploadFile(ctx context.Context, opParam *dto.Machin
 | 
			
		||||
		return &mcm.MachineInfo{Name: opParam.AuthCertName, Ip: opParam.AuthCertName}, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(opParam)
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -379,11 +379,10 @@ func (m *machineFileAppImpl) RemoveFile(ctx context.Context, opParam *dto.Machin
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mcli, err := m.GetMachineCli(opParam.AuthCertName)
 | 
			
		||||
	mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer mcm.PutMachineCli(mcli)
 | 
			
		||||
 | 
			
		||||
	minfo := mcli.Info
 | 
			
		||||
 | 
			
		||||
@@ -431,11 +430,10 @@ func (m *machineFileAppImpl) Copy(ctx context.Context, opParam *dto.MachineFileO
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mcli, err := m.GetMachineCli(opParam.AuthCertName)
 | 
			
		||||
	mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer mcm.PutMachineCli(mcli)
 | 
			
		||||
 | 
			
		||||
	mi := mcli.Info
 | 
			
		||||
	res, err := mcli.Run(fmt.Sprintf("cp -r %s %s", strings.Join(path, " "), toPath))
 | 
			
		||||
@@ -461,11 +459,10 @@ func (m *machineFileAppImpl) Mv(ctx context.Context, opParam *dto.MachineFileOp,
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mcli, err := m.GetMachineCli(opParam.AuthCertName)
 | 
			
		||||
	mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer mcm.PutMachineCli(mcli)
 | 
			
		||||
 | 
			
		||||
	mi := mcli.Info
 | 
			
		||||
	res, err := mcli.Run(fmt.Sprintf("mv %s %s", strings.Join(path, " "), toPath))
 | 
			
		||||
@@ -483,7 +480,7 @@ func (m *machineFileAppImpl) Rename(ctx context.Context, opParam *dto.MachineFil
 | 
			
		||||
		return nil, os.Rename(oldname, newname)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(opParam)
 | 
			
		||||
	mi, sftpCli, err := m.GetMachineSftpCli(ctx, opParam)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -491,17 +488,16 @@ func (m *machineFileAppImpl) Rename(ctx context.Context, opParam *dto.MachineFil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取文件机器cli
 | 
			
		||||
func (m *machineFileAppImpl) GetMachineCli(authCertName string) (*mcm.Cli, error) {
 | 
			
		||||
	return m.machineApp.GetCliByAc(authCertName)
 | 
			
		||||
func (m *machineFileAppImpl) GetMachineCli(ctx context.Context, authCertName string) (*mcm.Cli, error) {
 | 
			
		||||
	return m.machineApp.GetCliByAc(ctx, authCertName)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取文件机器 sftp cli
 | 
			
		||||
func (m *machineFileAppImpl) GetMachineSftpCli(opParam *dto.MachineFileOp) (*mcm.MachineInfo, *sftp.Client, error) {
 | 
			
		||||
	mcli, err := m.GetMachineCli(opParam.AuthCertName)
 | 
			
		||||
func (m *machineFileAppImpl) GetMachineSftpCli(ctx context.Context, opParam *dto.MachineFileOp) (*mcm.MachineInfo, *sftp.Client, error) {
 | 
			
		||||
	mcli, err := m.GetMachineCli(ctx, opParam.AuthCertName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, nil, err
 | 
			
		||||
	}
 | 
			
		||||
	defer mcm.PutMachineCli(mcli)
 | 
			
		||||
 | 
			
		||||
	sftpCli, err := mcli.GetSftpCli()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
 
 | 
			
		||||
@@ -18,11 +18,41 @@ type Cli struct {
 | 
			
		||||
	sftpClient *sftp.Client // sftp客户端
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************* pool.Conn impl *******************/
 | 
			
		||||
 | 
			
		||||
func (c *Cli) Ping() error {
 | 
			
		||||
	_, _, err := c.sshClient.Conn.SendRequest("ping", true, nil)
 | 
			
		||||
	_, _, err := c.sshClient.SendRequest("ping", true, nil)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close 关闭client并从缓存中移除,如果使用隧道则也关闭
 | 
			
		||||
func (c *Cli) Close() error {
 | 
			
		||||
	m := c.Info
 | 
			
		||||
	logx.Debugf("close machine cli -> id=%d, name=%s, ip=%s", m.Id, m.Name, m.Ip)
 | 
			
		||||
	if c.sshClient != nil {
 | 
			
		||||
		c.sshClient.Close()
 | 
			
		||||
		c.sshClient = nil
 | 
			
		||||
	}
 | 
			
		||||
	if c.sftpClient != nil {
 | 
			
		||||
		c.sftpClient.Close()
 | 
			
		||||
		c.sftpClient = nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sshTunnelMachineId uint64
 | 
			
		||||
	if m.SshTunnelMachine != nil {
 | 
			
		||||
		sshTunnelMachineId = m.SshTunnelMachine.Id
 | 
			
		||||
	}
 | 
			
		||||
	if m.TempSshMachineId != 0 {
 | 
			
		||||
		sshTunnelMachineId = m.TempSshMachineId
 | 
			
		||||
	}
 | 
			
		||||
	if sshTunnelMachineId != 0 {
 | 
			
		||||
		logx.Debugf("close machine ssh tunnel -> machineId=%d, sshTunnelMachineId=%d", m.Id, sshTunnelMachineId)
 | 
			
		||||
		CloseSshTunnelMachine(sshTunnelMachineId, m.GetTunnelId())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetSftpCli 获取sftp client
 | 
			
		||||
func (c *Cli) GetSftpCli() (*sftp.Client, error) {
 | 
			
		||||
	if c.sshClient == nil {
 | 
			
		||||
@@ -72,32 +102,6 @@ func (c *Cli) Run(shell string) (string, error) {
 | 
			
		||||
	return string(buf), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close 关闭client并从缓存中移除,如果使用隧道则也关闭
 | 
			
		||||
func (c *Cli) Close() {
 | 
			
		||||
	m := c.Info
 | 
			
		||||
	logx.Debugf("close machine cli -> id=%d, name=%s, ip=%s", m.Id, m.Name, m.Ip)
 | 
			
		||||
	if c.sshClient != nil {
 | 
			
		||||
		c.sshClient.Close()
 | 
			
		||||
		c.sshClient = nil
 | 
			
		||||
	}
 | 
			
		||||
	if c.sftpClient != nil {
 | 
			
		||||
		c.sftpClient.Close()
 | 
			
		||||
		c.sftpClient = nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var sshTunnelMachineId uint64
 | 
			
		||||
	if m.SshTunnelMachine != nil {
 | 
			
		||||
		sshTunnelMachineId = m.SshTunnelMachine.Id
 | 
			
		||||
	}
 | 
			
		||||
	if m.TempSshMachineId != 0 {
 | 
			
		||||
		sshTunnelMachineId = m.TempSshMachineId
 | 
			
		||||
	}
 | 
			
		||||
	if sshTunnelMachineId != 0 {
 | 
			
		||||
		logx.Debugf("close machine ssh tunnel -> machineId=%d, sshTunnelMachineId=%d", m.Id, sshTunnelMachineId)
 | 
			
		||||
		CloseSshTunnelMachine(sshTunnelMachineId, m.GetTunnelId())
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetAllStats 获取机器的所有状态信息
 | 
			
		||||
func (c *Cli) GetAllStats() *Stats {
 | 
			
		||||
	stats := new(Stats)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,77 +1,44 @@
 | 
			
		||||
package mcm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"mayfly-go/pkg/pool"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var mcConnPool = make(map[string]pool.Pool)
 | 
			
		||||
var mcIdPool = make(map[uint64]pool.Pool)
 | 
			
		||||
var (
 | 
			
		||||
	poolGroup = pool.NewPoolGroup[*Cli]()
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
}
 | 
			
		||||
func getMcPool(authCertName string, getMachine func(string) (*MachineInfo, error)) (pool.Pool, error) {
 | 
			
		||||
	// 获取连接池,如果没有,则创建一个
 | 
			
		||||
	if p, ok := mcConnPool[authCertName]; !ok {
 | 
			
		||||
		var err error
 | 
			
		||||
		p, err = pool.NewChannelPool(&pool.Config{
 | 
			
		||||
			InitialCap:  1,                //资源池初始连接数
 | 
			
		||||
			MaxCap:      10,               //最大空闲连接数
 | 
			
		||||
			MaxIdle:     10,               //最大并发连接数
 | 
			
		||||
			IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
 | 
			
		||||
			Factory: func() (interface{}, error) {
 | 
			
		||||
// 从缓存中获取客户端信息,不存在则回调获取机器信息函数,并新建。
 | 
			
		||||
// @param 机器的授权凭证名
 | 
			
		||||
func GetMachineCli(ctx context.Context, authCertName string, getMachine func(string) (*MachineInfo, error)) (*Cli, error) {
 | 
			
		||||
	pool, err := poolGroup.GetCachePool(authCertName, func() (*Cli, error) {
 | 
			
		||||
		mi, err := getMachine(authCertName)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		mi.Key = authCertName
 | 
			
		||||
				return mi.Conn()
 | 
			
		||||
			},
 | 
			
		||||
			Close: func(v interface{}) error {
 | 
			
		||||
				v.(*Cli).Close()
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
			Ping: func(v interface{}) error {
 | 
			
		||||
				return v.(*Cli).Ping()
 | 
			
		||||
			},
 | 
			
		||||
		return mi.Conn(ctx)
 | 
			
		||||
	})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		mcConnPool[authCertName] = p
 | 
			
		||||
		return p, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PutMachineCli(c *Cli) {
 | 
			
		||||
	if nil == c {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if p, ok := mcConnPool[c.Info.AuthCertName]; ok {
 | 
			
		||||
		p.Put(c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从缓存中获取客户端信息,不存在则回调获取机器信息函数,并新建。
 | 
			
		||||
// @param 机器的授权凭证名
 | 
			
		||||
func GetMachineCli(authCertName string, getMachine func(string) (*MachineInfo, error)) (*Cli, error) {
 | 
			
		||||
	p, err := getMcPool(authCertName, getMachine)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	// 从连接池中获取一个可用的连接
 | 
			
		||||
	c, err := p.Get()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return c.(*Cli), nil
 | 
			
		||||
	return pool.Get(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 删除指定机器缓存客户端,并关闭客户端连接
 | 
			
		||||
func DeleteCli(id uint64) {
 | 
			
		||||
	// 遍历所有机器连接实例,删除指定机器id关联的连接...
 | 
			
		||||
	delete(mcIdPool, id)
 | 
			
		||||
	for _, pool := range poolGroup.AllPool() {
 | 
			
		||||
		ctx, cancelFunc := context.WithCancel(context.Background())
 | 
			
		||||
		defer cancelFunc()
 | 
			
		||||
		conn, err := pool.Get(ctx)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if conn.Info.Id == id {
 | 
			
		||||
			pool.Close()
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package mcm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	tagentity "mayfly-go/internal/tag/domain/entity"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
@@ -49,11 +50,11 @@ func (mi *MachineInfo) GetTunnelId() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 连接
 | 
			
		||||
func (mi *MachineInfo) Conn() (*Cli, error) {
 | 
			
		||||
func (mi *MachineInfo) Conn(ctx context.Context) (*Cli, error) {
 | 
			
		||||
	logx.Infof("the machine[%s] is connecting: %s:%d", mi.Name, mi.Ip, mi.Port)
 | 
			
		||||
 | 
			
		||||
	// 如果使用了ssh隧道,则修改机器ip port为暴露的ip port
 | 
			
		||||
	err := mi.IfUseSshTunnelChangeIpPort(false)
 | 
			
		||||
	err := mi.IfUseSshTunnelChangeIpPort(ctx, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, errorx.NewBiz("ssh tunnel connection failed: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
@@ -71,7 +72,7 @@ func (mi *MachineInfo) Conn() (*Cli, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 如果使用了ssh隧道,则修改机器ip port为暴露的ip port
 | 
			
		||||
func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(out bool) error {
 | 
			
		||||
func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(ctx context.Context, out bool) error {
 | 
			
		||||
	if !mi.UseSshTunnel() {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
@@ -82,8 +83,9 @@ func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(out bool) error {
 | 
			
		||||
		mi.Id = uint64(time.Now().Nanosecond())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sshTunnelMachine, err := GetSshTunnelMachine(int(mi.SshTunnelMachine.Id), func(u uint64) (*MachineInfo, error) {
 | 
			
		||||
		return mi.SshTunnelMachine, nil
 | 
			
		||||
	stm := mi.SshTunnelMachine
 | 
			
		||||
	sshTunnelMachine, err := GetSshTunnelMachine(ctx, int(stm.Id), func(u uint64) (*MachineInfo, error) {
 | 
			
		||||
		return stm, nil
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -102,7 +104,7 @@ func (mi *MachineInfo) IfUseSshTunnelChangeIpPort(out bool) error {
 | 
			
		||||
	mi.Ip = exposeIp
 | 
			
		||||
	mi.Port = exposePort
 | 
			
		||||
	// 代理之后置空跳板机信息,防止重复跳
 | 
			
		||||
	mi.TempSshMachineId = mi.SshTunnelMachine.Id
 | 
			
		||||
	mi.TempSshMachineId = stm.Id
 | 
			
		||||
	mi.SshTunnelMachine = nil
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package mcm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
@@ -9,7 +10,6 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/utils/netx"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/crypto/ssh"
 | 
			
		||||
)
 | 
			
		||||
@@ -20,7 +20,7 @@ var (
 | 
			
		||||
	// 所有检测ssh隧道机器是否被使用的函数
 | 
			
		||||
	checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
 | 
			
		||||
 | 
			
		||||
	tunnelPool = make(map[int]pool.Pool)
 | 
			
		||||
	tunnelPool = make(map[int]pool.Pool[*SshTunnelMachine])
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 检查ssh隧道机器是否有被使用
 | 
			
		||||
@@ -43,11 +43,36 @@ type SshTunnelMachine struct {
 | 
			
		||||
	tunnels   map[string]*Tunnel // 隧道id -> 隧道
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************* pool.Conn impl *******************/
 | 
			
		||||
 | 
			
		||||
func (stm *SshTunnelMachine) Ping() error {
 | 
			
		||||
	_, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil)
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stm *SshTunnelMachine) Close() error {
 | 
			
		||||
	stm.mutex.Lock()
 | 
			
		||||
	defer stm.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	for id, tunnel := range stm.tunnels {
 | 
			
		||||
		if tunnel != nil {
 | 
			
		||||
			tunnel.Close()
 | 
			
		||||
			delete(stm.tunnels, id)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if stm.SshClient != nil {
 | 
			
		||||
		logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
 | 
			
		||||
		err := stm.SshClient.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	delete(tunnelPool, stm.machineId)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) {
 | 
			
		||||
	stm.mutex.Lock()
 | 
			
		||||
	defer stm.mutex.Unlock()
 | 
			
		||||
@@ -92,37 +117,13 @@ func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn,
 | 
			
		||||
	return stm.SshClient.Dial(network, addr)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (stm *SshTunnelMachine) Close() {
 | 
			
		||||
	stm.mutex.Lock()
 | 
			
		||||
	defer stm.mutex.Unlock()
 | 
			
		||||
 | 
			
		||||
	for id, tunnel := range stm.tunnels {
 | 
			
		||||
		if tunnel != nil {
 | 
			
		||||
			tunnel.Close()
 | 
			
		||||
			delete(stm.tunnels, id)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if stm.SshClient != nil {
 | 
			
		||||
		logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
 | 
			
		||||
		err := stm.SshClient.Close()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	delete(tunnelPool, stm.machineId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool, error) {
 | 
			
		||||
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool[*SshTunnelMachine], error) {
 | 
			
		||||
	// 获取连接池,如果没有,则创建一个
 | 
			
		||||
	if p, ok := tunnelPool[machineId]; !ok {
 | 
			
		||||
		var err error
 | 
			
		||||
		p, err = pool.NewChannelPool(&pool.Config{
 | 
			
		||||
			InitialCap:  1,                //资源池初始连接数
 | 
			
		||||
			MaxCap:      10,               //最大空闲连接数
 | 
			
		||||
			MaxIdle:     10,               //最大并发连接数
 | 
			
		||||
			IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
 | 
			
		||||
			Factory: func() (interface{}, error) {
 | 
			
		||||
	if p, ok := tunnelPool[machineId]; ok {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p := pool.NewChannelPool(func() (*SshTunnelMachine, error) {
 | 
			
		||||
		mi, err := getMachine(uint64(machineId))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
@@ -138,38 +139,22 @@ func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error))
 | 
			
		||||
		logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
 | 
			
		||||
 | 
			
		||||
		return stm, err
 | 
			
		||||
			},
 | 
			
		||||
			Close: func(v interface{}) error {
 | 
			
		||||
				v.(*SshTunnelMachine).Close()
 | 
			
		||||
	}, pool.WithOnPoolClose(func() error {
 | 
			
		||||
		delete(tunnelPool, machineId)
 | 
			
		||||
		return nil
 | 
			
		||||
			},
 | 
			
		||||
			Ping: func(v interface{}) error {
 | 
			
		||||
				return v.(*SshTunnelMachine).Ping()
 | 
			
		||||
			},
 | 
			
		||||
		})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
	}))
 | 
			
		||||
	tunnelPool[machineId] = p
 | 
			
		||||
	return p, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取ssh隧道机器,方便统一管理充当ssh隧道的机器,避免创建多个ssh client
 | 
			
		||||
func GetSshTunnelMachine(machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
 | 
			
		||||
func GetSshTunnelMachine(ctx context.Context, machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
 | 
			
		||||
	p, err := getTunnelPool(machineId, getMachine)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	// 从连接池中获取一个可用的连接
 | 
			
		||||
	c, err := p.Get()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c.(*SshTunnelMachine), nil
 | 
			
		||||
	return p.Get(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 关闭ssh隧道机器的指定隧道
 | 
			
		||||
 
 | 
			
		||||
@@ -8,7 +8,6 @@ import (
 | 
			
		||||
	"mayfly-go/internal/mongo/application"
 | 
			
		||||
	"mayfly-go/internal/mongo/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/mongo/imsg"
 | 
			
		||||
	"mayfly-go/internal/mongo/mgm"
 | 
			
		||||
	"mayfly-go/internal/pkg/consts"
 | 
			
		||||
	tagapp "mayfly-go/internal/tag/application"
 | 
			
		||||
	tagentity "mayfly-go/internal/tag/domain/entity"
 | 
			
		||||
@@ -126,9 +125,8 @@ func (m *Mongo) DeleteMongo(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Mongo) Databases(rc *req.Ctx) {
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc))
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mgm.PutMongoConn(conn)
 | 
			
		||||
 | 
			
		||||
	res, err := conn.Cli.ListDatabases(context.TODO(), bson.D{})
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "get mongo dbs error: %s")
 | 
			
		||||
@@ -136,9 +134,8 @@ func (m *Mongo) Databases(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *Mongo) Collections(rc *req.Ctx) {
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc))
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mgm.PutMongoConn(conn)
 | 
			
		||||
 | 
			
		||||
	global.EventBus.Publish(rc.MetaCtx, event.EventTopicResourceOp, conn.Info.CodePath[0])
 | 
			
		||||
 | 
			
		||||
@@ -154,9 +151,8 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
 | 
			
		||||
	commandForm := new(form.MongoRunCommand)
 | 
			
		||||
	req.BindJsonAndValid(rc, commandForm)
 | 
			
		||||
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc))
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mgm.PutMongoConn(conn)
 | 
			
		||||
 | 
			
		||||
	rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
 | 
			
		||||
 | 
			
		||||
@@ -185,9 +181,8 @@ func (m *Mongo) RunCommand(rc *req.Ctx) {
 | 
			
		||||
func (m *Mongo) FindCommand(rc *req.Ctx) {
 | 
			
		||||
	commandForm := req.BindJsonAndValid(rc, new(form.MongoFindCommand))
 | 
			
		||||
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc))
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mgm.PutMongoConn(conn)
 | 
			
		||||
 | 
			
		||||
	cli := conn.Cli
 | 
			
		||||
 | 
			
		||||
@@ -221,9 +216,8 @@ func (m *Mongo) FindCommand(rc *req.Ctx) {
 | 
			
		||||
func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) {
 | 
			
		||||
	commandForm := req.BindJsonAndValid(rc, new(form.MongoUpdateByIdCommand))
 | 
			
		||||
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc))
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mgm.PutMongoConn(conn)
 | 
			
		||||
 | 
			
		||||
	rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
 | 
			
		||||
 | 
			
		||||
@@ -246,9 +240,8 @@ func (m *Mongo) UpdateByIdCommand(rc *req.Ctx) {
 | 
			
		||||
func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) {
 | 
			
		||||
	commandForm := req.BindJsonAndValid(rc, new(form.MongoUpdateByIdCommand))
 | 
			
		||||
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc))
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mgm.PutMongoConn(conn)
 | 
			
		||||
 | 
			
		||||
	rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
 | 
			
		||||
 | 
			
		||||
@@ -270,9 +263,8 @@ func (m *Mongo) DeleteByIdCommand(rc *req.Ctx) {
 | 
			
		||||
func (m *Mongo) InsertOneCommand(rc *req.Ctx) {
 | 
			
		||||
	commandForm := req.BindJsonAndValid(rc, new(form.MongoInsertCommand))
 | 
			
		||||
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(m.GetMongoId(rc))
 | 
			
		||||
	conn, err := m.mongoApp.GetMongoConn(rc.MetaCtx, m.GetMongoId(rc))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer mgm.PutMongoConn(conn)
 | 
			
		||||
 | 
			
		||||
	rc.ReqParam = collx.Kvs("mongo", conn.Info, "cmd", commandForm)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -31,7 +31,7 @@ type Mongo interface {
 | 
			
		||||
 | 
			
		||||
	// 获取mongo连接实例
 | 
			
		||||
	// @param id mongo id
 | 
			
		||||
	GetMongoConn(id uint64) (*mgm.MongoConn, error)
 | 
			
		||||
	GetMongoConn(ctx context.Context, id uint64) (*mgm.MongoConn, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type mongoAppImpl struct {
 | 
			
		||||
@@ -131,8 +131,8 @@ func (d *mongoAppImpl) SaveMongo(ctx context.Context, m *entity.Mongo, tagCodePa
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *mongoAppImpl) GetMongoConn(id uint64) (*mgm.MongoConn, error) {
 | 
			
		||||
	return mgm.GetMongoConn(id, func() (*mgm.MongoInfo, error) {
 | 
			
		||||
func (d *mongoAppImpl) GetMongoConn(ctx context.Context, id uint64) (*mgm.MongoConn, error) {
 | 
			
		||||
	return mgm.GetMongoConn(ctx, id, func() (*mgm.MongoInfo, error) {
 | 
			
		||||
		me, err := d.GetById(id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errorx.NewBiz("mongo not found")
 | 
			
		||||
 
 | 
			
		||||
@@ -14,11 +14,19 @@ type MongoConn struct {
 | 
			
		||||
	Cli *mongo.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mc *MongoConn) Close() {
 | 
			
		||||
/******************* pool.Conn impl *******************/
 | 
			
		||||
 | 
			
		||||
func (mc *MongoConn) Close() error {
 | 
			
		||||
	if mc.Cli != nil {
 | 
			
		||||
		if err := mc.Cli.Disconnect(context.Background()); err != nil {
 | 
			
		||||
			logx.Errorf("关闭mongo实例[%s]连接失败: %s", mc.Id, err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		mc.Cli = nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (mc *MongoConn) Ping() error {
 | 
			
		||||
	return mc.Cli.Ping(context.Background(), nil)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,26 +3,15 @@ package mgm
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"mayfly-go/pkg/pool"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var connPool = make(map[string]pool.Pool)
 | 
			
		||||
var (
 | 
			
		||||
	poolGroup = pool.NewPoolGroup[*MongoConn]()
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getPool(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (pool.Pool, error) {
 | 
			
		||||
	connId := getConnId(mongoId)
 | 
			
		||||
 | 
			
		||||
	// 获取连接池,如果没有,则创建一个
 | 
			
		||||
	if p, ok := connPool[connId]; !ok {
 | 
			
		||||
		var err error
 | 
			
		||||
		p, err = pool.NewChannelPool(&pool.Config{
 | 
			
		||||
			InitialCap:  1,                //资源池初始连接数
 | 
			
		||||
			MaxCap:      10,               //最大空闲连接数
 | 
			
		||||
			MaxIdle:     10,               //最大并发连接数
 | 
			
		||||
			IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
 | 
			
		||||
			Factory: func() (interface{}, error) {
 | 
			
		||||
// 从缓存中获取mongo连接信息, 若缓存中不存在则会使用回调函数获取mongoInfo进行连接并缓存
 | 
			
		||||
func GetMongoConn(ctx context.Context, mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (*MongoConn, error) {
 | 
			
		||||
	pool, err := poolGroup.GetCachePool(getConnId(mongoId), func() (*MongoConn, error) {
 | 
			
		||||
		// 若缓存中不存在,则从回调函数中获取MongoInfo
 | 
			
		||||
		mi, err := getMongoInfo()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -31,50 +20,16 @@ func getPool(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (pool.Pool
 | 
			
		||||
 | 
			
		||||
		// 连接mongo
 | 
			
		||||
		return mi.Conn()
 | 
			
		||||
			},
 | 
			
		||||
			Close: func(v interface{}) error {
 | 
			
		||||
				v.(*MongoConn).Close()
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
			Ping: func(v interface{}) error {
 | 
			
		||||
				return v.(*MongoConn).Cli.Ping(context.Background(), nil)
 | 
			
		||||
			},
 | 
			
		||||
	})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		connPool[connId] = p
 | 
			
		||||
		return p, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PutMongoConn(c *MongoConn) {
 | 
			
		||||
	if nil == c {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if p, ok := connPool[getConnId(c.Info.Id)]; ok {
 | 
			
		||||
		p.Put(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从缓存中获取mongo连接信息, 若缓存中不存在则会使用回调函数获取mongoInfo进行连接并缓存
 | 
			
		||||
func GetMongoConn(mongoId uint64, getMongoInfo func() (*MongoInfo, error)) (*MongoConn, error) {
 | 
			
		||||
	p, err := getPool(mongoId, getMongoInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	// 从连接池中获取一个可用的连接
 | 
			
		||||
	c, err := p.Get()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	return c.(*MongoConn), nil
 | 
			
		||||
	return pool.Get(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 关闭连接,并移除缓存连接
 | 
			
		||||
func CloseConn(mongoId uint64) {
 | 
			
		||||
	connId := getConnId(mongoId)
 | 
			
		||||
	delete(connPool, connId)
 | 
			
		||||
	poolGroup.Close(getConnId(mongoId))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -58,7 +58,7 @@ type MongoSshDialer struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (sd *MongoSshDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
 | 
			
		||||
	stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(sd.machineId)
 | 
			
		||||
	stm, err := machineapp.GetMachineApp().GetSshTunnelMachine(ctx, sd.machineId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -150,9 +150,8 @@ func (r *Redis) DeleteRedis(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Redis) RedisInfo(rc *req.Ctx) {
 | 
			
		||||
	ri, err := r.redisApp.GetRedisConn(uint64(rc.PathParamInt("id")), 0)
 | 
			
		||||
	ri, err := r.redisApp.GetRedisConn(rc.MetaCtx, uint64(rc.PathParamInt("id")), 0)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer rdm.PutRedisConn(ri)
 | 
			
		||||
 | 
			
		||||
	section := rc.Query("section")
 | 
			
		||||
	mode := ri.Info.Mode
 | 
			
		||||
@@ -228,9 +227,8 @@ func (r *Redis) RedisInfo(rc *req.Ctx) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Redis) ClusterInfo(rc *req.Ctx) {
 | 
			
		||||
	ri, err := r.redisApp.GetRedisConn(uint64(rc.PathParamInt("id")), 0)
 | 
			
		||||
	ri, err := r.redisApp.GetRedisConn(rc.MetaCtx, uint64(rc.PathParamInt("id")), 0)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer rdm.PutRedisConn(ri)
 | 
			
		||||
 | 
			
		||||
	biz.IsEquals(ri.Info.Mode, rdm.ClusterMode, "non-cluster mode")
 | 
			
		||||
	info, _ := ri.ClusterCli.ClusterInfo(context.Background()).Result()
 | 
			
		||||
@@ -281,9 +279,9 @@ func (r *Redis) checkKeyAndGetRedisConn(rc *req.Ctx) (*rdm.RedisConn, string) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Redis) getRedisConn(rc *req.Ctx) *rdm.RedisConn {
 | 
			
		||||
	ri, err := r.redisApp.GetRedisConn(getIdAndDbNum(rc))
 | 
			
		||||
	id, db := getIdAndDbNum(rc)
 | 
			
		||||
	ri, err := r.redisApp.GetRedisConn(rc.MetaCtx, id, db)
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	defer rdm.PutRedisConn(ri)
 | 
			
		||||
 | 
			
		||||
	biz.ErrIsNilAppendErr(r.tagApp.CanAccess(rc.GetLoginAccount().Id, ri.Info.CodePath...), "%s")
 | 
			
		||||
	return ri
 | 
			
		||||
 
 | 
			
		||||
@@ -46,7 +46,7 @@ type Redis interface {
 | 
			
		||||
	// 获取数据库连接实例
 | 
			
		||||
	// id: 数据库实例id
 | 
			
		||||
	// db: 库号
 | 
			
		||||
	GetRedisConn(id uint64, db int) (*rdm.RedisConn, error)
 | 
			
		||||
	GetRedisConn(ctx context.Context, id uint64, db int) (*rdm.RedisConn, error)
 | 
			
		||||
 | 
			
		||||
	// 执行redis命令
 | 
			
		||||
	RunCmd(ctx context.Context, redisConn *rdm.RedisConn, cmdParam *dto.RunCmd) (any, error)
 | 
			
		||||
@@ -196,8 +196,8 @@ func (r *redisAppImpl) Delete(ctx context.Context, id uint64) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取数据库连接实例
 | 
			
		||||
func (r *redisAppImpl) GetRedisConn(id uint64, db int) (*rdm.RedisConn, error) {
 | 
			
		||||
	return rdm.GetRedisConn(id, db, func() (*rdm.RedisInfo, error) {
 | 
			
		||||
func (r *redisAppImpl) GetRedisConn(ctx context.Context, id uint64, db int) (*rdm.RedisConn, error) {
 | 
			
		||||
	return rdm.GetRedisConn(ctx, id, db, func() (*rdm.RedisInfo, error) {
 | 
			
		||||
		// 缓存不存在,则回调获取redis信息
 | 
			
		||||
		re, err := r.GetById(id)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -258,7 +258,7 @@ func (r *redisAppImpl) FlowBizHandle(ctx context.Context, bizHandleParam *flowap
 | 
			
		||||
		return nil, errorx.NewBiz("failed to parse the business form information: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	redisConn, err := r.GetRedisConn(runCmdParam.Id, runCmdParam.Db)
 | 
			
		||||
	redisConn, err := r.GetRedisConn(ctx, runCmdParam.Id, runCmdParam.Db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -17,6 +17,34 @@ type RedisConn struct {
 | 
			
		||||
	ClusterCli *redis.ClusterClient
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/******************* pool.Conn impl *******************/
 | 
			
		||||
 | 
			
		||||
func (r *RedisConn) Close() error {
 | 
			
		||||
	mode := r.Info.Mode
 | 
			
		||||
	if mode == StandaloneMode || mode == SentinelMode {
 | 
			
		||||
		if err := r.Cli.Close(); err != nil {
 | 
			
		||||
			logx.Errorf("close redis standalone instance [%s] connection failed: %s", r.Id, err.Error())
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		r.Cli = nil
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if mode == ClusterMode {
 | 
			
		||||
		if err := r.ClusterCli.Close(); err != nil {
 | 
			
		||||
			logx.Errorf("close redis cluster instance [%s] connection failed: %s", r.Id, err.Error())
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		r.ClusterCli = nil
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisConn) Ping() error {
 | 
			
		||||
	_, err := r.Cli.Ping(context.Background()).Result()
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取命令执行接口的具体实现
 | 
			
		||||
func (r *RedisConn) GetCmdable() redis.Cmdable {
 | 
			
		||||
	redisMode := r.Info.Mode
 | 
			
		||||
@@ -45,21 +73,3 @@ func (r *RedisConn) RunCmd(ctx context.Context, args ...any) (any, error) {
 | 
			
		||||
	}
 | 
			
		||||
	return nil, errorx.NewBiz("redis mode error")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *RedisConn) Close() {
 | 
			
		||||
	mode := r.Info.Mode
 | 
			
		||||
	if mode == StandaloneMode || mode == SentinelMode {
 | 
			
		||||
		if err := r.Cli.Close(); err != nil {
 | 
			
		||||
			logx.Errorf("close redis standalone instance [%s] connection failed: %s", r.Id, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		r.Cli = nil
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if mode == ClusterMode {
 | 
			
		||||
		if err := r.ClusterCli.Close(); err != nil {
 | 
			
		||||
			logx.Errorf("close redis cluster instance [%s] connection failed: %s", r.Id, err.Error())
 | 
			
		||||
		}
 | 
			
		||||
		r.ClusterCli = nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -3,25 +3,18 @@ package rdm
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"mayfly-go/pkg/pool"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var connPool = make(map[string]pool.Pool)
 | 
			
		||||
var (
 | 
			
		||||
	poolGroup = pool.NewPoolGroup[*RedisConn]()
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getPool(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (pool.Pool, error) {
 | 
			
		||||
	connId := getConnId(redisId, db)
 | 
			
		||||
	// 获取连接池,如果没有,则创建一个
 | 
			
		||||
	if p, ok := connPool[connId]; !ok {
 | 
			
		||||
		var err error
 | 
			
		||||
		p, err = pool.NewChannelPool(&pool.Config{
 | 
			
		||||
			InitialCap:  1,                //资源池初始连接数
 | 
			
		||||
			MaxCap:      10,               //最大空闲连接数
 | 
			
		||||
			MaxIdle:     10,               //最大并发连接数
 | 
			
		||||
			IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
 | 
			
		||||
			Factory: func() (interface{}, error) {
 | 
			
		||||
// 从缓存中获取redis连接信息, 若缓存中不存在则会使用回调函数获取redisInfo进行连接并缓存
 | 
			
		||||
func GetRedisConn(ctx context.Context, redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (*RedisConn, error) {
 | 
			
		||||
	p, err := poolGroup.GetCachePool(getConnId(redisId, db), func() (*RedisConn, error) {
 | 
			
		||||
		// 若缓存中不存在,则从回调函数中获取RedisInfo
 | 
			
		||||
		ri, err := getRedisInfo()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
@@ -29,53 +22,17 @@ func getPool(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (p
 | 
			
		||||
		}
 | 
			
		||||
		// 连接数据库
 | 
			
		||||
		return ri.Conn()
 | 
			
		||||
			},
 | 
			
		||||
			Close: func(v interface{}) error {
 | 
			
		||||
				v.(*RedisConn).Close()
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
			Ping: func(v interface{}) error {
 | 
			
		||||
				_, err := v.(*RedisConn).Cli.Ping(context.Background()).Result()
 | 
			
		||||
				return err
 | 
			
		||||
			},
 | 
			
		||||
	})
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		connPool[connId] = p
 | 
			
		||||
		return p, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PutRedisConn(c *RedisConn) {
 | 
			
		||||
	if nil == c {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if p, ok := connPool[getConnId(c.Info.Id, c.Info.Db)]; ok {
 | 
			
		||||
		p.Put(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从缓存中获取redis连接信息, 若缓存中不存在则会使用回调函数获取redisInfo进行连接并缓存
 | 
			
		||||
func GetRedisConn(redisId uint64, db int, getRedisInfo func() (*RedisInfo, error)) (*RedisConn, error) {
 | 
			
		||||
	p, err := getPool(redisId, db, getRedisInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 从连接池中获取一个可用的连接
 | 
			
		||||
	c, err := p.Get()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	// 用完后记的放回连接池
 | 
			
		||||
	return c.(*RedisConn), nil
 | 
			
		||||
	return p.Get(ctx)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 移除redis连接缓存并关闭redis连接
 | 
			
		||||
func CloseConn(id uint64, db int) {
 | 
			
		||||
	delete(connPool, getConnId(id, db))
 | 
			
		||||
	poolGroup.Close(getConnId(id, db))
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -139,8 +139,8 @@ func (re *RedisInfo) connSentinel() (*RedisConn, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getRedisDialer(machineId int) func(ctx context.Context, network, addr string) (net.Conn, error) {
 | 
			
		||||
	return func(_ context.Context, network, addr string) (net.Conn, error) {
 | 
			
		||||
		sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(machineId)
 | 
			
		||||
	return func(ctx context.Context, network, addr string) (net.Conn, error) {
 | 
			
		||||
		sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(ctx, machineId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,26 +1,26 @@
 | 
			
		||||
package migrations
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/go-gormigrate/gormigrate/v2"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	esentity "mayfly-go/internal/es/domain/entity"
 | 
			
		||||
	flowentity "mayfly-go/internal/flow/domain/entity"
 | 
			
		||||
	sysentity "mayfly-go/internal/sys/domain/entity"
 | 
			
		||||
	"mayfly-go/pkg/model"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-gormigrate/gormigrate/v2"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func V1_10() []*gormigrate.Migration {
 | 
			
		||||
	var migrations []*gormigrate.Migration
 | 
			
		||||
	migrations = append(migrations, V1_10_0()...)
 | 
			
		||||
	migrations = append(migrations, V1_10_1()...)
 | 
			
		||||
	return migrations
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func V1_10_0() []*gormigrate.Migration {
 | 
			
		||||
	return []*gormigrate.Migration{
 | 
			
		||||
		{
 | 
			
		||||
			ID: "20250213-v1.10.0-flow-recode",
 | 
			
		||||
			ID: "20250520-v1.10.0-flow-recode",
 | 
			
		||||
			Migrate: func(tx *gorm.DB) error {
 | 
			
		||||
				err := tx.AutoMigrate(&flowentity.Procdef{},
 | 
			
		||||
					&flowentity.Procinst{},
 | 
			
		||||
@@ -32,19 +32,6 @@ func V1_10_0() []*gormigrate.Migration {
 | 
			
		||||
					return err
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
			Rollback: func(tx *gorm.DB) error {
 | 
			
		||||
				return nil
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
func V1_10_1() []*gormigrate.Migration {
 | 
			
		||||
	return []*gormigrate.Migration{
 | 
			
		||||
		{
 | 
			
		||||
			ID: "20250422-v1.10.1-es",
 | 
			
		||||
			Migrate: func(tx *gorm.DB) error {
 | 
			
		||||
				// 添加实例表
 | 
			
		||||
				entities := [...]any{
 | 
			
		||||
					new(esentity.EsInstance),
 | 
			
		||||
@@ -55,7 +42,7 @@ func V1_10_1() []*gormigrate.Migration {
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 添加菜单资源
 | 
			
		||||
				// 添加ES相关菜单资源
 | 
			
		||||
				resources := []*sysentity.Resource{
 | 
			
		||||
					{
 | 
			
		||||
						Model:  model.Model{CreateModel: model.CreateModel{DeletedModel: model.DeletedModel{IdModel: model.IdModel{Id: 1745292787}}}},
 | 
			
		||||
@@ -65,7 +52,7 @@ func V1_10_1() []*gormigrate.Migration {
 | 
			
		||||
						Code:   "/es",
 | 
			
		||||
						Type:   1,
 | 
			
		||||
						Meta:   `{"icon":"icon es/es-color","isKeepAlive":true,"routeName":"ES"}`,
 | 
			
		||||
						Weight: 7,
 | 
			
		||||
						Weight: 50000001,
 | 
			
		||||
					},
 | 
			
		||||
					{
 | 
			
		||||
						Model:  model.Model{CreateModel: model.CreateModel{DeletedModel: model.DeletedModel{IdModel: model.IdModel{Id: 1745319348}}}},
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										209
									
								
								server/pkg/pool/cache_pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										209
									
								
								server/pkg/pool/cache_pool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,209 @@
 | 
			
		||||
package pool
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
	"mayfly-go/pkg/utils/stringx"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var CachePoolDefaultConfig = PoolConfig{
 | 
			
		||||
	MaxConns:            1,
 | 
			
		||||
	IdleTimeout:         60 * time.Minute,
 | 
			
		||||
	WaitTimeout:         10 * time.Second,
 | 
			
		||||
	HealthCheckInterval: 10 * time.Minute,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type cacheEntry[T Conn] struct {
 | 
			
		||||
	conn       T
 | 
			
		||||
	lastActive time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CachePool[T Conn] struct {
 | 
			
		||||
	factory func() (T, error)
 | 
			
		||||
	mu      sync.RWMutex
 | 
			
		||||
	cache   map[string]*cacheEntry[T] // 使用字符串键的缓存
 | 
			
		||||
	config  PoolConfig
 | 
			
		||||
	closeCh chan struct{}
 | 
			
		||||
	closed  bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewCachePool[T Conn](factory func() (T, error), opts ...Option) *CachePool[T] {
 | 
			
		||||
	config := CachePoolDefaultConfig
 | 
			
		||||
	for _, opt := range opts {
 | 
			
		||||
		opt(&config)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p := &CachePool[T]{
 | 
			
		||||
		factory: factory,
 | 
			
		||||
		cache:   make(map[string]*cacheEntry[T]),
 | 
			
		||||
		config:  config,
 | 
			
		||||
		closeCh: make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	go p.backgroundMaintenance()
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get 获取连接(自动创建或复用缓存连接)
 | 
			
		||||
func (p *CachePool[T]) Get(ctx context.Context) (T, error) {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
	var zero T
 | 
			
		||||
 | 
			
		||||
	if p.closed {
 | 
			
		||||
		return zero, ErrPoolClosed
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 1. 尝试从缓存中获取可用连接
 | 
			
		||||
	for key, entry := range p.cache {
 | 
			
		||||
		if time.Since(entry.lastActive) <= p.config.IdleTimeout {
 | 
			
		||||
			entry.lastActive = time.Now() // 更新活跃时间
 | 
			
		||||
			return entry.conn, nil
 | 
			
		||||
		}
 | 
			
		||||
		// 自动清理闲置超时的连接
 | 
			
		||||
		entry.conn.Close()
 | 
			
		||||
		delete(p.cache, key)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 2. 创建新连接并缓存
 | 
			
		||||
	conn, err := p.factory()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return zero, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.cache[generateCacheKey()] = &cacheEntry[T]{
 | 
			
		||||
		conn:       conn,
 | 
			
		||||
		lastActive: time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return conn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Put 将连接放回缓存
 | 
			
		||||
func (p *CachePool[T]) Put(conn T) error {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if p.closed {
 | 
			
		||||
		return conn.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.cache[generateCacheKey()] = &cacheEntry[T]{
 | 
			
		||||
		conn:       conn,
 | 
			
		||||
		lastActive: time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果超出最大连接数,清理最久未使用的
 | 
			
		||||
	if len(p.cache) > p.config.MaxConns {
 | 
			
		||||
		p.removeOldest()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 移除最久未使用的连接
 | 
			
		||||
func (p *CachePool[T]) removeOldest() {
 | 
			
		||||
	var oldestKey string
 | 
			
		||||
	var oldestTime time.Time
 | 
			
		||||
 | 
			
		||||
	for key, entry := range p.cache {
 | 
			
		||||
		if oldestKey == "" || entry.lastActive.Before(oldestTime) {
 | 
			
		||||
			oldestKey = key
 | 
			
		||||
			oldestTime = entry.lastActive
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if oldestKey != "" {
 | 
			
		||||
		p.cache[oldestKey].conn.Close()
 | 
			
		||||
		delete(p.cache, oldestKey)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close 关闭连接池
 | 
			
		||||
func (p *CachePool[T]) Close() {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if p.closed {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.closed = true
 | 
			
		||||
	close(p.closeCh)
 | 
			
		||||
 | 
			
		||||
	for _, entry := range p.cache {
 | 
			
		||||
		if err := entry.conn.Close(); err != nil {
 | 
			
		||||
			logx.Errorf("cache pool - error closing connection: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 触发关闭回调
 | 
			
		||||
	if p.config.OnPoolClose != nil {
 | 
			
		||||
		p.config.OnPoolClose()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.cache = make(map[string]*cacheEntry[T])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Resize 动态调整大小
 | 
			
		||||
func (p *CachePool[T]) Resize(newSize int) {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if p.closed || newSize == p.config.MaxConns {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	p.config.MaxConns = newSize
 | 
			
		||||
 | 
			
		||||
	// 如果新大小小于当前缓存数量,清理多余的连接
 | 
			
		||||
	for len(p.cache) > newSize {
 | 
			
		||||
		p.removeOldest()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Stats 获取统计信息
 | 
			
		||||
func (p *CachePool[T]) Stats() PoolStats {
 | 
			
		||||
	p.mu.RLock()
 | 
			
		||||
	defer p.mu.RUnlock()
 | 
			
		||||
 | 
			
		||||
	return PoolStats{
 | 
			
		||||
		TotalConns: int32(len(p.cache)),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 后台维护协程
 | 
			
		||||
func (p *CachePool[T]) backgroundMaintenance() {
 | 
			
		||||
	ticker := time.NewTicker(p.config.HealthCheckInterval)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			p.cleanupIdle()
 | 
			
		||||
		case <-p.closeCh:
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 清理闲置超时的连接
 | 
			
		||||
func (p *CachePool[T]) cleanupIdle() {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	cutoff := time.Now().Add(-p.config.IdleTimeout)
 | 
			
		||||
	for key, entry := range p.cache {
 | 
			
		||||
		if entry.lastActive.Before(cutoff) {
 | 
			
		||||
			entry.conn.Close()
 | 
			
		||||
			delete(p.cache, key)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 生成缓存键
 | 
			
		||||
func generateCacheKey() string {
 | 
			
		||||
	return stringx.RandUUID()
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										366
									
								
								server/pkg/pool/chan_pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										366
									
								
								server/pkg/pool/chan_pool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,366 @@
 | 
			
		||||
package pool
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
	"mayfly-go/pkg/utils/anyx"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ChanPoolDefaultConfig = PoolConfig{
 | 
			
		||||
	MaxConns:            5,
 | 
			
		||||
	IdleTimeout:         60 * time.Minute,
 | 
			
		||||
	WaitTimeout:         10 * time.Second,
 | 
			
		||||
	HealthCheckInterval: 10 * time.Minute,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ConnWrapper 封装连接及其元数据
 | 
			
		||||
type ConnWrapper[T Conn] struct {
 | 
			
		||||
	conn       T
 | 
			
		||||
	lastActive time.Time // 最后活跃时间
 | 
			
		||||
	isValid    bool      // 连接是否有效
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *ConnWrapper[T]) Ping() error {
 | 
			
		||||
	if !w.isValid {
 | 
			
		||||
		return errors.New("connection marked invalid")
 | 
			
		||||
	}
 | 
			
		||||
	return w.conn.Ping()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (w *ConnWrapper[T]) Close() error {
 | 
			
		||||
	w.isValid = false
 | 
			
		||||
	return w.conn.Close()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChanPool 连接池结构
 | 
			
		||||
type ChanPool[T Conn] struct {
 | 
			
		||||
	mu           sync.RWMutex
 | 
			
		||||
	factory      func() (T, error)
 | 
			
		||||
	idleConns    chan *ConnWrapper[T]
 | 
			
		||||
	config       PoolConfig
 | 
			
		||||
	currentConns int32
 | 
			
		||||
	stats        PoolStats
 | 
			
		||||
	closeChan    chan struct{} // 用于关闭健康检查 goroutine
 | 
			
		||||
	closed       bool          // 关闭状态标识
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PoolStats 统计信息
 | 
			
		||||
type PoolStats struct {
 | 
			
		||||
	TotalConns  int32 // 总连接数
 | 
			
		||||
	IdleConns   int32 // 空闲连接数
 | 
			
		||||
	ActiveConns int32 // 活跃连接数
 | 
			
		||||
	WaitCount   int64 // 等待连接次数
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChannelPool[T Conn](factory func() (T, error), opts ...Option) *ChanPool[T] {
 | 
			
		||||
	// 1. 初始化配置(使用默认值 + Option 覆盖)
 | 
			
		||||
	config := ChanPoolDefaultConfig
 | 
			
		||||
	for _, opt := range opts {
 | 
			
		||||
		opt(&config)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 2. 创建连接池
 | 
			
		||||
	p := &ChanPool[T]{
 | 
			
		||||
		factory:   factory,
 | 
			
		||||
		idleConns: make(chan *ConnWrapper[T], config.MaxConns),
 | 
			
		||||
		config:    config,
 | 
			
		||||
		closeChan: make(chan struct{}),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 3. 启动健康检查
 | 
			
		||||
	go p.healthCheck()
 | 
			
		||||
	return p
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) Get(ctx context.Context) (T, error) {
 | 
			
		||||
	connChan := make(chan T, 1)
 | 
			
		||||
	errChan := make(chan error, 1)
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		conn, err := p.get()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			errChan <- err
 | 
			
		||||
		} else {
 | 
			
		||||
			connChan <- conn
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var zero T
 | 
			
		||||
	select {
 | 
			
		||||
	case <-ctx.Done():
 | 
			
		||||
		return zero, ctx.Err()
 | 
			
		||||
	case err := <-errChan:
 | 
			
		||||
		return zero, err
 | 
			
		||||
	case conn := <-connChan:
 | 
			
		||||
		// 启动监控协程
 | 
			
		||||
		go func() {
 | 
			
		||||
			<-ctx.Done()
 | 
			
		||||
			// 上下文被取消后,将连接放回连接池
 | 
			
		||||
			if err := p.Put(conn); err != nil {
 | 
			
		||||
				logx.Errorf("Failed to return leaked connection: %v", err)
 | 
			
		||||
				conn.Close()
 | 
			
		||||
				atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
		return conn, nil
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) get() (T, error) {
 | 
			
		||||
	// 优先从 channel 获取空闲连接(无锁)
 | 
			
		||||
	select {
 | 
			
		||||
	case wrapper := <-p.idleConns:
 | 
			
		||||
		atomic.AddInt32(&p.stats.IdleConns, -1)
 | 
			
		||||
		atomic.AddInt32(&p.stats.ActiveConns, 1)
 | 
			
		||||
		wrapper.lastActive = time.Now()
 | 
			
		||||
		return wrapper.conn, nil
 | 
			
		||||
	default:
 | 
			
		||||
		return p.createConn()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) createConn() (T, error) {
 | 
			
		||||
	var zero T
 | 
			
		||||
 | 
			
		||||
	// 使用CAS保证原子性
 | 
			
		||||
	for {
 | 
			
		||||
		current := atomic.LoadInt32(&p.currentConns)
 | 
			
		||||
		if current >= int32(p.config.MaxConns) {
 | 
			
		||||
			if p.config.WaitTimeout > 0 {
 | 
			
		||||
				return p.waitForConn()
 | 
			
		||||
			}
 | 
			
		||||
			return zero, errors.New("connection pool exhausted")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if atomic.CompareAndSwapInt32(&p.currentConns, current, current+1) {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 直接创建新连接
 | 
			
		||||
	conn, err := p.factory()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
		return zero, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新状态
 | 
			
		||||
	atomic.AddInt32(&p.stats.ActiveConns, 1)
 | 
			
		||||
	return conn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 新增等待连接方法
 | 
			
		||||
func (p *ChanPool[T]) waitForConn() (T, error) {
 | 
			
		||||
	var zero T
 | 
			
		||||
	timeout := time.NewTimer(p.config.WaitTimeout)
 | 
			
		||||
	defer timeout.Stop()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case wrapper := <-p.idleConns:
 | 
			
		||||
			if wrapper.isValid && wrapper.Ping() == nil {
 | 
			
		||||
				atomic.AddInt32(&p.stats.IdleConns, -1)
 | 
			
		||||
				atomic.AddInt32(&p.stats.ActiveConns, 1)
 | 
			
		||||
				wrapper.lastActive = time.Now()
 | 
			
		||||
				return wrapper.conn, nil
 | 
			
		||||
			}
 | 
			
		||||
			wrapper.Close()
 | 
			
		||||
			atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
		case <-timeout.C:
 | 
			
		||||
			atomic.AddInt64(&p.stats.WaitCount, 1)
 | 
			
		||||
			return zero, errors.New("connection pool wait timeout")
 | 
			
		||||
		default:
 | 
			
		||||
			// 非阻塞检查后短暂休眠避免CPU空转
 | 
			
		||||
			time.Sleep(10 * time.Millisecond)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) Put(conn T) error {
 | 
			
		||||
	if anyx.IsBlank(conn) {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 快速路径
 | 
			
		||||
	select {
 | 
			
		||||
	case p.idleConns <- &ConnWrapper[T]{conn: conn, lastActive: time.Now(), isValid: true}:
 | 
			
		||||
		atomic.AddInt32(&p.stats.IdleConns, 1)
 | 
			
		||||
		atomic.AddInt32(&p.stats.ActiveConns, -1)
 | 
			
		||||
		return nil
 | 
			
		||||
	default:
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 慢速路径
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	// 检查是否超过最大连接数
 | 
			
		||||
	if atomic.LoadInt32(&p.currentConns) > int32(p.config.MaxConns) {
 | 
			
		||||
		conn.Close()
 | 
			
		||||
		atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
	} else {
 | 
			
		||||
		// 直接放入空闲队列
 | 
			
		||||
		select {
 | 
			
		||||
		case p.idleConns <- &ConnWrapper[T]{conn: conn, lastActive: time.Now(), isValid: true}:
 | 
			
		||||
		default:
 | 
			
		||||
			conn.Close()
 | 
			
		||||
			atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	atomic.AddInt32(&p.stats.ActiveConns, -1)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) Close() {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	if p.closed {
 | 
			
		||||
		p.mu.Unlock()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	p.closed = true
 | 
			
		||||
 | 
			
		||||
	// 1. 停止健康检查
 | 
			
		||||
	close(p.closeChan)
 | 
			
		||||
 | 
			
		||||
	// 2. 临时转移空闲连接
 | 
			
		||||
	idle := make([]*ConnWrapper[T], 0, len(p.idleConns))
 | 
			
		||||
	for len(p.idleConns) > 0 {
 | 
			
		||||
		idle = append(idle, <-p.idleConns)
 | 
			
		||||
	}
 | 
			
		||||
	close(p.idleConns) // 安全关闭通道
 | 
			
		||||
 | 
			
		||||
	p.mu.Unlock() // 提前释放锁,避免阻塞其他操作
 | 
			
		||||
 | 
			
		||||
	// 3. 关闭所有连接(无需持有锁)
 | 
			
		||||
	for _, wrapper := range idle {
 | 
			
		||||
		wrapper.Close()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 4. 触发关闭回调
 | 
			
		||||
	if p.config.OnPoolClose != nil {
 | 
			
		||||
		p.config.OnPoolClose()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) healthCheck() {
 | 
			
		||||
	ticker := time.NewTicker(p.config.HealthCheckInterval)
 | 
			
		||||
	defer ticker.Stop()
 | 
			
		||||
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case <-ticker.C:
 | 
			
		||||
			p.checkIdleConns()
 | 
			
		||||
		case <-p.closeChan:
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) checkIdleConns() {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if p.closed {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idle := make([]*ConnWrapper[T], 0, len(p.idleConns))
 | 
			
		||||
	for len(p.idleConns) > 0 {
 | 
			
		||||
		idle = append(idle, <-p.idleConns)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	for _, wrapper := range idle {
 | 
			
		||||
		if now.Sub(wrapper.lastActive) > p.config.IdleTimeout || wrapper.Ping() != nil {
 | 
			
		||||
			wrapper.Close()
 | 
			
		||||
			atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
		} else {
 | 
			
		||||
			select {
 | 
			
		||||
			case p.idleConns <- wrapper:
 | 
			
		||||
			default:
 | 
			
		||||
				wrapper.Close()
 | 
			
		||||
				atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) Resize(newMaxConns int) {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	oldMax := p.config.MaxConns
 | 
			
		||||
	p.config.MaxConns = newMaxConns
 | 
			
		||||
 | 
			
		||||
	// 缩小连接池:关闭多余的空闲连接
 | 
			
		||||
	if newMaxConns < oldMax {
 | 
			
		||||
		toClose := oldMax - newMaxConns
 | 
			
		||||
		closed := 0
 | 
			
		||||
 | 
			
		||||
		// 非阻塞取出待关闭的连接
 | 
			
		||||
		var wrappers []*ConnWrapper[T]
 | 
			
		||||
		for len(p.idleConns) > 0 && closed < toClose {
 | 
			
		||||
			wrappers = append(wrappers, <-p.idleConns)
 | 
			
		||||
			closed++
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 关闭连接并更新计数
 | 
			
		||||
		for _, wrapper := range wrappers {
 | 
			
		||||
			wrapper.Close()
 | 
			
		||||
			atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
			atomic.AddInt32(&p.stats.IdleConns, -1)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 重建空闲连接通道(无需迁移连接,因 channel 本身无状态)
 | 
			
		||||
	p.idleConns = make(chan *ConnWrapper[T], newMaxConns)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) CheckLeaks() []T {
 | 
			
		||||
	p.mu.Lock()
 | 
			
		||||
	defer p.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	var leaks []T
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
 | 
			
		||||
	// 检查所有空闲连接
 | 
			
		||||
	idle := make([]*ConnWrapper[T], 0, len(p.idleConns))
 | 
			
		||||
	for len(p.idleConns) > 0 {
 | 
			
		||||
		idle = append(idle, <-p.idleConns)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, wrapper := range idle {
 | 
			
		||||
		// 判定泄漏条件:长期未使用且未被标记为活跃
 | 
			
		||||
		if now.Sub(wrapper.lastActive) > 10*p.config.IdleTimeout {
 | 
			
		||||
			leaks = append(leaks, wrapper.conn)
 | 
			
		||||
			wrapper.Close()
 | 
			
		||||
			atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
			atomic.AddInt32(&p.stats.IdleConns, -1)
 | 
			
		||||
		} else {
 | 
			
		||||
			// 放回空闲池
 | 
			
		||||
			select {
 | 
			
		||||
			case p.idleConns <- wrapper:
 | 
			
		||||
			default:
 | 
			
		||||
				wrapper.Close()
 | 
			
		||||
				atomic.AddInt32(&p.currentConns, -1)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return leaks
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ChanPool[T]) Stats() PoolStats {
 | 
			
		||||
	p.mu.RLock()
 | 
			
		||||
	defer p.mu.RUnlock()
 | 
			
		||||
	return PoolStats{
 | 
			
		||||
		TotalConns:  atomic.LoadInt32(&p.currentConns),
 | 
			
		||||
		IdleConns:   int32(len(p.idleConns)), // 直接读取通道长度
 | 
			
		||||
		ActiveConns: atomic.LoadInt32(&p.stats.ActiveConns),
 | 
			
		||||
		WaitCount:   atomic.LoadInt64(&p.stats.WaitCount),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,216 +0,0 @@
 | 
			
		||||
package pool
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
	//"reflect"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	//ErrMaxActiveConnReached 连接池超限
 | 
			
		||||
	ErrMaxActiveConnReached = errors.New("MaxActiveConnReached")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Config 连接池相关配置
 | 
			
		||||
type Config struct {
 | 
			
		||||
	//连接池中拥有的最小连接数
 | 
			
		||||
	InitialCap int
 | 
			
		||||
	//最大并发存活连接数
 | 
			
		||||
	MaxCap int
 | 
			
		||||
	//最大空闲连接
 | 
			
		||||
	MaxIdle int
 | 
			
		||||
	//生成连接的方法
 | 
			
		||||
	Factory func() (interface{}, error)
 | 
			
		||||
	//关闭连接的方法
 | 
			
		||||
	Close func(interface{}) error
 | 
			
		||||
	//检查连接是否有效的方法
 | 
			
		||||
	Ping func(interface{}) error
 | 
			
		||||
	//连接最大空闲时间,超过该事件则将失效
 | 
			
		||||
	IdleTimeout time.Duration
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// channelPool 存放连接信息
 | 
			
		||||
type channelPool struct {
 | 
			
		||||
	mu                       sync.RWMutex
 | 
			
		||||
	conns                    chan *idleConn
 | 
			
		||||
	factory                  func() (interface{}, error)
 | 
			
		||||
	close                    func(interface{}) error
 | 
			
		||||
	ping                     func(interface{}) error
 | 
			
		||||
	idleTimeout, waitTimeOut time.Duration
 | 
			
		||||
	maxActive                int
 | 
			
		||||
	openingConns             int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type idleConn struct {
 | 
			
		||||
	conn interface{}
 | 
			
		||||
	t    time.Time
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewChannelPool 初始化连接
 | 
			
		||||
func NewChannelPool(poolConfig *Config) (Pool, error) {
 | 
			
		||||
	if !(poolConfig.InitialCap <= poolConfig.MaxIdle && poolConfig.MaxCap >= poolConfig.MaxIdle && poolConfig.InitialCap >= 0) {
 | 
			
		||||
		return nil, errors.New("invalid capacity settings")
 | 
			
		||||
	}
 | 
			
		||||
	if poolConfig.Factory == nil {
 | 
			
		||||
		return nil, errors.New("invalid factory func settings")
 | 
			
		||||
	}
 | 
			
		||||
	if poolConfig.Close == nil {
 | 
			
		||||
		return nil, errors.New("invalid close func settings")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c := &channelPool{
 | 
			
		||||
		conns:        make(chan *idleConn, poolConfig.MaxIdle),
 | 
			
		||||
		factory:      poolConfig.Factory,
 | 
			
		||||
		close:        poolConfig.Close,
 | 
			
		||||
		idleTimeout:  poolConfig.IdleTimeout,
 | 
			
		||||
		maxActive:    poolConfig.MaxCap,
 | 
			
		||||
		openingConns: poolConfig.InitialCap,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if poolConfig.Ping != nil {
 | 
			
		||||
		c.ping = poolConfig.Ping
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < poolConfig.InitialCap; i++ {
 | 
			
		||||
		conn, err := c.factory()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			c.Release()
 | 
			
		||||
			return nil, fmt.Errorf("factory is not able to fill the pool: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
		c.conns <- &idleConn{conn: conn, t: time.Now()}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return c, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// getConns 获取所有连接
 | 
			
		||||
func (c *channelPool) getConns() chan *idleConn {
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	conns := c.conns
 | 
			
		||||
	c.mu.Unlock()
 | 
			
		||||
	return conns
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get 从pool中取一个连接
 | 
			
		||||
func (c *channelPool) Get() (interface{}, error) {
 | 
			
		||||
	conns := c.getConns()
 | 
			
		||||
	if conns == nil {
 | 
			
		||||
		return nil, ErrClosed
 | 
			
		||||
	}
 | 
			
		||||
	for {
 | 
			
		||||
		select {
 | 
			
		||||
		case wrapConn := <-conns:
 | 
			
		||||
			if wrapConn == nil {
 | 
			
		||||
				return nil, ErrClosed
 | 
			
		||||
			}
 | 
			
		||||
			//判断是否超时,超时则丢弃
 | 
			
		||||
			if timeout := c.idleTimeout; timeout > 0 {
 | 
			
		||||
				if wrapConn.t.Add(timeout).Before(time.Now()) {
 | 
			
		||||
					//丢弃并关闭该连接
 | 
			
		||||
					c.Close(wrapConn.conn)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			//判断是否失效,失效则丢弃,如果用户没有设定 ping 方法,就不检查
 | 
			
		||||
			if c.ping != nil {
 | 
			
		||||
				if err := c.Ping(wrapConn.conn); err != nil {
 | 
			
		||||
					c.Close(wrapConn.conn)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			return wrapConn.conn, nil
 | 
			
		||||
		default:
 | 
			
		||||
			c.mu.Lock()
 | 
			
		||||
			logx.Debugf("openConn %v %v", c.openingConns, c.maxActive)
 | 
			
		||||
			defer c.mu.Unlock()
 | 
			
		||||
			if c.openingConns >= c.maxActive {
 | 
			
		||||
				return nil, ErrMaxActiveConnReached
 | 
			
		||||
			}
 | 
			
		||||
			if c.factory == nil {
 | 
			
		||||
				return nil, ErrClosed
 | 
			
		||||
			}
 | 
			
		||||
			conn, err := c.factory()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return nil, err
 | 
			
		||||
			}
 | 
			
		||||
			c.openingConns++
 | 
			
		||||
			return conn, nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Put 将连接放回pool中
 | 
			
		||||
func (c *channelPool) Put(conn interface{}) error {
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
		return errors.New("connection is nil. rejecting")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
 | 
			
		||||
	if c.conns == nil {
 | 
			
		||||
		c.mu.Unlock()
 | 
			
		||||
		return c.Close(conn)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	select {
 | 
			
		||||
	case c.conns <- &idleConn{conn: conn, t: time.Now()}:
 | 
			
		||||
		c.mu.Unlock()
 | 
			
		||||
		return nil
 | 
			
		||||
	default:
 | 
			
		||||
		c.mu.Unlock()
 | 
			
		||||
		//连接池已满,直接关闭该连接
 | 
			
		||||
		return c.Close(conn)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Close 关闭单条连接
 | 
			
		||||
func (c *channelPool) Close(conn interface{}) error {
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
		return errors.New("connection is nil. rejecting")
 | 
			
		||||
	}
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	defer c.mu.Unlock()
 | 
			
		||||
	if c.close == nil {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	c.openingConns--
 | 
			
		||||
	return c.close(conn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Ping 检查单条连接是否有效
 | 
			
		||||
func (c *channelPool) Ping(conn interface{}) error {
 | 
			
		||||
	if conn == nil {
 | 
			
		||||
		return errors.New("connection is nil. rejecting")
 | 
			
		||||
	}
 | 
			
		||||
	return c.ping(conn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Release 释放连接池中所有连接
 | 
			
		||||
func (c *channelPool) Release() {
 | 
			
		||||
	c.mu.Lock()
 | 
			
		||||
	conns := c.conns
 | 
			
		||||
	c.conns = nil
 | 
			
		||||
	c.factory = nil
 | 
			
		||||
	c.ping = nil
 | 
			
		||||
	closeFun := c.close
 | 
			
		||||
	c.close = nil
 | 
			
		||||
	c.mu.Unlock()
 | 
			
		||||
 | 
			
		||||
	if conns == nil {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	close(conns)
 | 
			
		||||
	for wrapConn := range conns {
 | 
			
		||||
		//log.Printf("Type %v\n",reflect.TypeOf(wrapConn.conn))
 | 
			
		||||
		_ = closeFun(wrapConn.conn)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Len 连接池中已有的连接
 | 
			
		||||
func (c *channelPool) Len() int {
 | 
			
		||||
	return len(c.getConns())
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										55
									
								
								server/pkg/pool/config.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								server/pkg/pool/config.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,55 @@
 | 
			
		||||
package pool
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var ErrPoolClosed = errors.New("pool is closed")
 | 
			
		||||
 | 
			
		||||
// PoolConfig 连接池配置
 | 
			
		||||
type PoolConfig struct {
 | 
			
		||||
	MaxConns            int           // 最大连接数
 | 
			
		||||
	IdleTimeout         time.Duration // 空闲连接超时时间
 | 
			
		||||
	WaitTimeout         time.Duration // 获取连接超时时间
 | 
			
		||||
	HealthCheckInterval time.Duration // 健康检查间隔
 | 
			
		||||
	OnPoolClose         func() error  // 连接池关闭时的回调
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Option 函数类型,用于配置 Pool
 | 
			
		||||
type Option func(*PoolConfig)
 | 
			
		||||
 | 
			
		||||
// WithMaxConns 设置最大连接数
 | 
			
		||||
func WithMaxConns(maxConns int) Option {
 | 
			
		||||
	return func(c *PoolConfig) {
 | 
			
		||||
		c.MaxConns = maxConns
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithIdleTimeout 设置空闲超时
 | 
			
		||||
func WithIdleTimeout(timeout time.Duration) Option {
 | 
			
		||||
	return func(c *PoolConfig) {
 | 
			
		||||
		c.IdleTimeout = timeout
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithWaitTimeout 设置等待超时
 | 
			
		||||
func WithWaitTimeout(timeout time.Duration) Option {
 | 
			
		||||
	return func(c *PoolConfig) {
 | 
			
		||||
		c.WaitTimeout = timeout
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithHealthCheckInterval 设置健康检查间隔
 | 
			
		||||
func WithHealthCheckInterval(interval time.Duration) Option {
 | 
			
		||||
	return func(c *PoolConfig) {
 | 
			
		||||
		c.HealthCheckInterval = interval
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WithOnPoolClose 设置连接池关闭回调
 | 
			
		||||
func WithOnPoolClose(fn func() error) Option {
 | 
			
		||||
	return func(c *PoolConfig) {
 | 
			
		||||
		c.OnPoolClose = fn
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										76
									
								
								server/pkg/pool/group.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										76
									
								
								server/pkg/pool/group.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,76 @@
 | 
			
		||||
package pool
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
 | 
			
		||||
	"golang.org/x/sync/singleflight"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PoolGroup[T Conn] struct {
 | 
			
		||||
	poolGroup   map[string]Pool[T]
 | 
			
		||||
	createGroup singleflight.Group
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPoolGroup[T Conn]() *PoolGroup[T] {
 | 
			
		||||
	return &PoolGroup[T]{
 | 
			
		||||
		poolGroup:   make(map[string]Pool[T]),
 | 
			
		||||
		createGroup: singleflight.Group{},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pg *PoolGroup[T]) GetOrCreate(
 | 
			
		||||
	key string,
 | 
			
		||||
	poolFactory func() Pool[T],
 | 
			
		||||
	opts ...Option,
 | 
			
		||||
) (Pool[T], error) {
 | 
			
		||||
	if p, ok := pg.poolGroup[key]; ok {
 | 
			
		||||
		return p, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	v, err, _ := pg.createGroup.Do(key, func() (any, error) {
 | 
			
		||||
		logx.Infof("pool group - create pool, key: %s", key)
 | 
			
		||||
		p := poolFactory()
 | 
			
		||||
		pg.poolGroup[key] = p
 | 
			
		||||
		return p, nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return v.(Pool[T]), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetChanPool 获取或创建 ChannelPool 类型连接池
 | 
			
		||||
func (pg *PoolGroup[T]) GetChanPool(key string, factory func() (T, error), opts ...Option) (Pool[T], error) {
 | 
			
		||||
	return pg.GetOrCreate(key, func() Pool[T] {
 | 
			
		||||
		return NewChannelPool(factory, opts...)
 | 
			
		||||
	}, opts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetCachePool 获取或创建 CachePool 类型连接池
 | 
			
		||||
func (pg *PoolGroup[T]) GetCachePool(key string, factory func() (T, error), opts ...Option) (Pool[T], error) {
 | 
			
		||||
	return pg.GetOrCreate(key, func() Pool[T] {
 | 
			
		||||
		return NewCachePool(factory, opts...)
 | 
			
		||||
	}, opts...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pg *PoolGroup[T]) Close(key string) error {
 | 
			
		||||
	if p, ok := pg.poolGroup[key]; ok {
 | 
			
		||||
		logx.Infof("pool group - close pool, key: %s", key)
 | 
			
		||||
		p.Close()
 | 
			
		||||
		pg.createGroup.Forget(key)
 | 
			
		||||
		delete(pg.poolGroup, key)
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pg *PoolGroup[T]) CloseAll() {
 | 
			
		||||
	for key := range pg.poolGroup {
 | 
			
		||||
		pg.Close(key)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pg *PoolGroup[T]) AllPool() map[string]Pool[T] {
 | 
			
		||||
	return pg.poolGroup
 | 
			
		||||
}
 | 
			
		||||
@@ -1,21 +1,27 @@
 | 
			
		||||
package pool
 | 
			
		||||
 | 
			
		||||
import "errors"
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	//ErrClosed 连接池已经关闭Error
 | 
			
		||||
	ErrClosed = errors.New("pool is closed")
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// Pool 基本方法
 | 
			
		||||
type Pool interface {
 | 
			
		||||
	Get() (interface{}, error)
 | 
			
		||||
// Conn 连接接口
 | 
			
		||||
// 连接池的连接必须实现 Conn 接口
 | 
			
		||||
type Conn interface {
 | 
			
		||||
	// Close 关闭连接
 | 
			
		||||
	Close() error
 | 
			
		||||
 | 
			
		||||
	Put(interface{}) error
 | 
			
		||||
 | 
			
		||||
	Close(interface{}) error
 | 
			
		||||
 | 
			
		||||
	Release()
 | 
			
		||||
 | 
			
		||||
	Len() int
 | 
			
		||||
	// Ping 检查连接是否有效
 | 
			
		||||
	Ping() error
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Pool 连接池接口
 | 
			
		||||
type Pool[T Conn] interface {
 | 
			
		||||
	// 核心方法
 | 
			
		||||
	Get(ctx context.Context) (T, error)
 | 
			
		||||
	Put(T) error
 | 
			
		||||
	Close()
 | 
			
		||||
 | 
			
		||||
	// 管理方法
 | 
			
		||||
	Resize(int)       // 动态调整大小
 | 
			
		||||
	Stats() PoolStats // 获取统计信息
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user