mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-01-06 14:45:48 +08:00
reafctor: pool
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package mcm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -9,7 +10,6 @@ import (
|
||||
"mayfly-go/pkg/utils/netx"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
@@ -20,7 +20,7 @@ var (
|
||||
// 所有检测ssh隧道机器是否被使用的函数
|
||||
checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
|
||||
|
||||
tunnelPool = make(map[int]pool.Pool)
|
||||
tunnelPool = make(map[int]pool.Pool[*SshTunnelMachine])
|
||||
)
|
||||
|
||||
// 检查ssh隧道机器是否有被使用
|
||||
@@ -43,11 +43,36 @@ type SshTunnelMachine struct {
|
||||
tunnels map[string]*Tunnel // 隧道id -> 隧道
|
||||
}
|
||||
|
||||
/******************* pool.Conn impl *******************/
|
||||
|
||||
func (stm *SshTunnelMachine) Ping() error {
|
||||
_, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (stm *SshTunnelMachine) Close() error {
|
||||
stm.mutex.Lock()
|
||||
defer stm.mutex.Unlock()
|
||||
|
||||
for id, tunnel := range stm.tunnels {
|
||||
if tunnel != nil {
|
||||
tunnel.Close()
|
||||
delete(stm.tunnels, id)
|
||||
}
|
||||
}
|
||||
|
||||
if stm.SshClient != nil {
|
||||
logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
|
||||
err := stm.SshClient.Close()
|
||||
if err != nil {
|
||||
logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
|
||||
}
|
||||
}
|
||||
delete(tunnelPool, stm.machineId)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) {
|
||||
stm.mutex.Lock()
|
||||
defer stm.mutex.Unlock()
|
||||
@@ -92,84 +117,44 @@ func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn,
|
||||
return stm.SshClient.Dial(network, addr)
|
||||
}
|
||||
|
||||
func (stm *SshTunnelMachine) Close() {
|
||||
stm.mutex.Lock()
|
||||
defer stm.mutex.Unlock()
|
||||
|
||||
for id, tunnel := range stm.tunnels {
|
||||
if tunnel != nil {
|
||||
tunnel.Close()
|
||||
delete(stm.tunnels, id)
|
||||
}
|
||||
}
|
||||
|
||||
if stm.SshClient != nil {
|
||||
logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
|
||||
err := stm.SshClient.Close()
|
||||
if err != nil {
|
||||
logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
|
||||
}
|
||||
}
|
||||
delete(tunnelPool, stm.machineId)
|
||||
}
|
||||
|
||||
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool, error) {
|
||||
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool[*SshTunnelMachine], error) {
|
||||
// 获取连接池,如果没有,则创建一个
|
||||
if p, ok := tunnelPool[machineId]; !ok {
|
||||
var err error
|
||||
p, err = pool.NewChannelPool(&pool.Config{
|
||||
InitialCap: 1, //资源池初始连接数
|
||||
MaxCap: 10, //最大空闲连接数
|
||||
MaxIdle: 10, //最大并发连接数
|
||||
IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
|
||||
Factory: func() (interface{}, error) {
|
||||
mi, err := getMachine(uint64(machineId))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if mi == nil {
|
||||
return nil, errors.New("error get machine info")
|
||||
}
|
||||
sshClient, err := GetSshClient(mi, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
|
||||
logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
|
||||
if p, ok := tunnelPool[machineId]; ok {
|
||||
return p, nil
|
||||
}
|
||||
|
||||
return stm, err
|
||||
},
|
||||
Close: func(v interface{}) error {
|
||||
v.(*SshTunnelMachine).Close()
|
||||
return nil
|
||||
},
|
||||
Ping: func(v interface{}) error {
|
||||
return v.(*SshTunnelMachine).Ping()
|
||||
},
|
||||
})
|
||||
p := pool.NewChannelPool(func() (*SshTunnelMachine, error) {
|
||||
mi, err := getMachine(uint64(machineId))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tunnelPool[machineId] = p
|
||||
return p, nil
|
||||
} else {
|
||||
return p, nil
|
||||
}
|
||||
if mi == nil {
|
||||
return nil, errors.New("error get machine info")
|
||||
}
|
||||
sshClient, err := GetSshClient(mi, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
|
||||
logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
|
||||
|
||||
return stm, err
|
||||
}, pool.WithOnPoolClose(func() error {
|
||||
delete(tunnelPool, machineId)
|
||||
return nil
|
||||
}))
|
||||
tunnelPool[machineId] = p
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// 获取ssh隧道机器,方便统一管理充当ssh隧道的机器,避免创建多个ssh client
|
||||
func GetSshTunnelMachine(machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
|
||||
func GetSshTunnelMachine(ctx context.Context, machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
|
||||
p, err := getTunnelPool(machineId, getMachine)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 从连接池中获取一个可用的连接
|
||||
c, err := p.Get()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c.(*SshTunnelMachine), nil
|
||||
return p.Get(ctx)
|
||||
}
|
||||
|
||||
// 关闭ssh隧道机器的指定隧道
|
||||
|
||||
Reference in New Issue
Block a user