Files
mayfly-go/server/internal/machine/mcm/sshtunnel.go

236 lines
6.5 KiB
Go
Raw Normal View History

2023-10-30 17:34:56 +08:00
package mcm
import (
2025-05-22 23:29:50 +08:00
"context"
"errors"
"fmt"
"io"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/pool"
"mayfly-go/pkg/utils/netx"
"net"
"sync"
"golang.org/x/crypto/ssh"
)
var (
mutex sync.Mutex
// 所有检测ssh隧道机器是否被使用的函数
checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
2025-05-22 23:29:50 +08:00
tunnelPool = make(map[int]pool.Pool[*SshTunnelMachine])
)
// 检查ssh隧道机器是否有被使用
type CheckSshTunnelMachineHasUseFunc func(int) bool
// 添加ssh隧道机器检测是否使用函数
func AddCheckSshTunnelMachineUseFunc(checkFunc CheckSshTunnelMachineHasUseFunc) {
if checkSshTunnelMachineHasUseFuncs == nil {
checkSshTunnelMachineHasUseFuncs = make([]CheckSshTunnelMachineHasUseFunc, 0)
}
checkSshTunnelMachineHasUseFuncs = append(checkSshTunnelMachineHasUseFuncs, checkFunc)
}
// ssh隧道机器
type SshTunnelMachine struct {
mi *MachineInfo
machineId int // 隧道机器id
SshClient *ssh.Client
mutex sync.Mutex
2023-12-20 23:01:51 +08:00
tunnels map[string]*Tunnel // 隧道id -> 隧道
}
2025-05-22 23:29:50 +08:00
/******************* pool.Conn impl *******************/
func (stm *SshTunnelMachine) Ping() error {
_, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil)
return err
}
2025-05-22 23:29:50 +08:00
func (stm *SshTunnelMachine) Close() error {
stm.mutex.Lock()
defer stm.mutex.Unlock()
for id, tunnel := range stm.tunnels {
if tunnel != nil {
tunnel.Close()
delete(stm.tunnels, id)
}
}
if stm.SshClient != nil {
logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
err := stm.SshClient.Close()
if err != nil {
logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
}
}
delete(tunnelPool, stm.machineId)
return nil
}
2023-12-20 23:01:51 +08:00
func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) {
stm.mutex.Lock()
defer stm.mutex.Unlock()
tunnel := stm.tunnels[id]
// 已存在该id隧道则直接返回
if tunnel != nil {
// FIXME 后期改成池化连接定时60秒检查连接可用性
return tunnel.localHost, tunnel.localPort, nil
}
localPort, err := netx.GetAvailablePort()
if err != nil {
return "", 0, err
}
localHost := "127.0.0.1"
localAddr := fmt.Sprintf("%s:%d", localHost, localPort)
listener, err := net.Listen("tcp", localAddr)
if err != nil {
return "", 0, err
}
tunnel = &Tunnel{
id: id,
machineId: stm.machineId,
localHost: localHost,
localPort: localPort,
remoteHost: ip,
remotePort: port,
listener: listener,
}
go tunnel.Open(stm.SshClient)
stm.tunnels[tunnel.id] = tunnel
return localHost, localPort, nil
}
func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, error) {
stm.mutex.Lock()
defer stm.mutex.Unlock()
return stm.SshClient.Dial(network, addr)
}
2025-05-22 23:29:50 +08:00
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool[*SshTunnelMachine], error) {
// 获取连接池,如果没有,则创建一个
if p, ok := tunnelPool[machineId]; ok {
return p, nil
}
2025-05-22 23:29:50 +08:00
p := pool.NewChannelPool(func() (*SshTunnelMachine, error) {
mi, err := getMachine(uint64(machineId))
2022-07-24 15:37:13 +08:00
if err != nil {
2025-05-22 23:29:50 +08:00
return nil, err
2022-07-24 15:37:13 +08:00
}
2025-05-22 23:29:50 +08:00
if mi == nil {
return nil, errors.New("error get machine info")
}
sshClient, err := GetSshClient(mi, nil)
if err != nil {
return nil, err
}
2025-05-22 23:29:50 +08:00
stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
return stm, err
}, pool.WithOnPoolClose(func() error {
delete(tunnelPool, machineId)
return nil
}))
tunnelPool[machineId] = p
return p, nil
}
// 获取ssh隧道机器方便统一管理充当ssh隧道的机器避免创建多个ssh client
2025-05-22 23:29:50 +08:00
func GetSshTunnelMachine(ctx context.Context, machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
p, err := getTunnelPool(machineId, getMachine)
if err != nil {
return nil, err
}
// 从连接池中获取一个可用的连接
2025-05-22 23:29:50 +08:00
return p.Get(ctx)
}
// 关闭ssh隧道机器的指定隧道
func CloseSshTunnelMachine(machineId uint64, tunnelId string) {
//sshTunnelMachine := mcIdPool[machineId]
//if sshTunnelMachine == nil {
// return
//}
//
//sshTunnelMachine.mutex.Lock()
//defer sshTunnelMachine.mutex.Unlock()
//t := sshTunnelMachine.tunnels[tunnelId]
//if t != nil {
// t.Close()
// delete(sshTunnelMachine.tunnels, tunnelId)
//}
}
type Tunnel struct {
2023-12-20 23:01:51 +08:00
id string // 唯一标识
machineId int // 隧道机器id
localHost string // 本地监听地址
localPort int // 本地端口
remoteHost string // 远程连接地址
remotePort int // 远程端口
listener net.Listener
localConnections []net.Conn
remoteConnections []net.Conn
}
func (r *Tunnel) Open(sshClient *ssh.Client) {
localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort)
for {
logx.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr)
localConn, err := r.listener.Accept()
if err != nil {
logx.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error())
logx.Debug("-------------------------------------------------")
return
}
r.localConnections = append(r.localConnections, localConn)
logx.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String())
remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort)
logx.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr)
remoteConn, err := sshClient.Dial("tcp", remoteAddr)
if err != nil {
logx.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error())
logx.Debug("-------------------------------------------------")
return
}
r.remoteConnections = append(r.remoteConnections, remoteConn)
logx.Debugf("隧道 %v 连接远程主机成功", r.id)
go r.copyConn(localConn, remoteConn)
go r.copyConn(remoteConn, localConn)
logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
}
}
func (r *Tunnel) Close() {
for i := range r.localConnections {
_ = r.localConnections[i].Close()
}
r.localConnections = nil
for i := range r.remoteConnections {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
_ = r.listener.Close()
2023-12-20 23:01:51 +08:00
logx.Debugf("隧道 %s 监听器关闭", r.id)
}
func (r *Tunnel) copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader)
}