reafctor: pool

This commit is contained in:
meilin.huang
2025-05-22 23:29:50 +08:00
parent 142bbd265d
commit 778cb7f4de
50 changed files with 1146 additions and 874 deletions

View File

@@ -1,6 +1,7 @@
package mcm
import (
"context"
"errors"
"fmt"
"io"
@@ -9,7 +10,6 @@ import (
"mayfly-go/pkg/utils/netx"
"net"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
@@ -20,7 +20,7 @@ var (
// 所有检测ssh隧道机器是否被使用的函数
checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
tunnelPool = make(map[int]pool.Pool)
tunnelPool = make(map[int]pool.Pool[*SshTunnelMachine])
)
// 检查ssh隧道机器是否有被使用
@@ -43,11 +43,36 @@ type SshTunnelMachine struct {
tunnels map[string]*Tunnel // 隧道id -> 隧道
}
/******************* pool.Conn impl *******************/
func (stm *SshTunnelMachine) Ping() error {
_, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil)
return err
}
func (stm *SshTunnelMachine) Close() error {
stm.mutex.Lock()
defer stm.mutex.Unlock()
for id, tunnel := range stm.tunnels {
if tunnel != nil {
tunnel.Close()
delete(stm.tunnels, id)
}
}
if stm.SshClient != nil {
logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
err := stm.SshClient.Close()
if err != nil {
logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
}
}
delete(tunnelPool, stm.machineId)
return nil
}
func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) {
stm.mutex.Lock()
defer stm.mutex.Unlock()
@@ -92,84 +117,44 @@ func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn,
return stm.SshClient.Dial(network, addr)
}
func (stm *SshTunnelMachine) Close() {
stm.mutex.Lock()
defer stm.mutex.Unlock()
for id, tunnel := range stm.tunnels {
if tunnel != nil {
tunnel.Close()
delete(stm.tunnels, id)
}
}
if stm.SshClient != nil {
logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
err := stm.SshClient.Close()
if err != nil {
logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
}
}
delete(tunnelPool, stm.machineId)
}
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool, error) {
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool[*SshTunnelMachine], error) {
// 获取连接池,如果没有,则创建一个
if p, ok := tunnelPool[machineId]; !ok {
var err error
p, err = pool.NewChannelPool(&pool.Config{
InitialCap: 1, //资源池初始连接数
MaxCap: 10, //最大空闲连接数
MaxIdle: 10, //最大并发连接数
IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
Factory: func() (interface{}, error) {
mi, err := getMachine(uint64(machineId))
if err != nil {
return nil, err
}
if mi == nil {
return nil, errors.New("error get machine info")
}
sshClient, err := GetSshClient(mi, nil)
if err != nil {
return nil, err
}
stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
if p, ok := tunnelPool[machineId]; ok {
return p, nil
}
return stm, err
},
Close: func(v interface{}) error {
v.(*SshTunnelMachine).Close()
return nil
},
Ping: func(v interface{}) error {
return v.(*SshTunnelMachine).Ping()
},
})
p := pool.NewChannelPool(func() (*SshTunnelMachine, error) {
mi, err := getMachine(uint64(machineId))
if err != nil {
return nil, err
}
tunnelPool[machineId] = p
return p, nil
} else {
return p, nil
}
if mi == nil {
return nil, errors.New("error get machine info")
}
sshClient, err := GetSshClient(mi, nil)
if err != nil {
return nil, err
}
stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
return stm, err
}, pool.WithOnPoolClose(func() error {
delete(tunnelPool, machineId)
return nil
}))
tunnelPool[machineId] = p
return p, nil
}
// 获取ssh隧道机器方便统一管理充当ssh隧道的机器避免创建多个ssh client
func GetSshTunnelMachine(machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
func GetSshTunnelMachine(ctx context.Context, machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
p, err := getTunnelPool(machineId, getMachine)
if err != nil {
return nil, err
}
// 从连接池中获取一个可用的连接
c, err := p.Get()
if err != nil {
return nil, err
}
return c.(*SshTunnelMachine), nil
return p.Get(ctx)
}
// 关闭ssh隧道机器的指定隧道