Files
mayfly-go/server/internal/machine/mcm/sshtunnel.go
2025-05-23 17:26:12 +08:00

219 lines
6.1 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mcm
import (
"context"
"errors"
"fmt"
"io"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/pool"
"mayfly-go/pkg/utils/netx"
"net"
"sync"
"golang.org/x/crypto/ssh"
)
var (
// 所有检测ssh隧道机器是否被使用的函数
checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
tunnelPoolGroup = pool.NewPoolGroup[*SshTunnelMachine]()
)
// 检查ssh隧道机器是否有被使用
type CheckSshTunnelMachineHasUseFunc func(int) bool
// 添加ssh隧道机器检测是否使用函数
func AddCheckSshTunnelMachineUseFunc(checkFunc CheckSshTunnelMachineHasUseFunc) {
if checkSshTunnelMachineHasUseFuncs == nil {
checkSshTunnelMachineHasUseFuncs = make([]CheckSshTunnelMachineHasUseFunc, 0)
}
checkSshTunnelMachineHasUseFuncs = append(checkSshTunnelMachineHasUseFuncs, checkFunc)
}
// ssh隧道机器
type SshTunnelMachine struct {
mi *MachineInfo
machineId int // 隧道机器id
SshClient *ssh.Client
mutex sync.Mutex
tunnels map[string]*Tunnel // 隧道id -> 隧道
}
/******************* pool.Conn impl *******************/
func (stm *SshTunnelMachine) Ping() error {
_, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil)
return err
}
func (stm *SshTunnelMachine) Close() error {
stm.mutex.Lock()
defer stm.mutex.Unlock()
for id, tunnel := range stm.tunnels {
if tunnel != nil {
tunnel.Close()
delete(stm.tunnels, id)
}
}
if stm.SshClient != nil {
logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
err := stm.SshClient.Close()
if err != nil {
logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
}
}
return nil
}
func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) {
stm.mutex.Lock()
defer stm.mutex.Unlock()
tunnel := stm.tunnels[id]
// 已存在该id隧道则直接返回
if tunnel != nil {
// FIXME 后期改成池化连接定时60秒检查连接可用性
return tunnel.localHost, tunnel.localPort, nil
}
localPort, err := netx.GetAvailablePort()
if err != nil {
return "", 0, err
}
localHost := "127.0.0.1"
localAddr := fmt.Sprintf("%s:%d", localHost, localPort)
listener, err := net.Listen("tcp", localAddr)
if err != nil {
return "", 0, err
}
tunnel = &Tunnel{
id: id,
machineId: stm.machineId,
localHost: localHost,
localPort: localPort,
remoteHost: ip,
remotePort: port,
listener: listener,
}
go tunnel.Open(stm.SshClient)
stm.tunnels[tunnel.id] = tunnel
return localHost, localPort, nil
}
func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, error) {
stm.mutex.Lock()
defer stm.mutex.Unlock()
return stm.SshClient.Dial(network, addr)
}
// 获取ssh隧道机器方便统一管理充当ssh隧道的机器避免创建多个ssh client
func GetSshTunnelMachine(ctx context.Context, machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
pool, err := tunnelPoolGroup.GetCachePool(fmt.Sprintf("machine-tunnel-%d", machineId), func() (*SshTunnelMachine, error) {
mi, err := getMachine(uint64(machineId))
if err != nil {
return nil, err
}
if mi == nil {
return nil, errors.New("error get machine info")
}
sshClient, err := GetSshClient(mi, nil)
if err != nil {
return nil, err
}
stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
return stm, err
})
if err != nil {
return nil, err
}
// 从连接池中获取一个可用的连接
return pool.Get(ctx)
}
// 关闭ssh隧道机器的指定隧道
func CloseSshTunnelMachine(machineId uint64, tunnelId string) {
//sshTunnelMachine := mcIdPool[machineId]
//if sshTunnelMachine == nil {
// return
//}
//
//sshTunnelMachine.mutex.Lock()
//defer sshTunnelMachine.mutex.Unlock()
//t := sshTunnelMachine.tunnels[tunnelId]
//if t != nil {
// t.Close()
// delete(sshTunnelMachine.tunnels, tunnelId)
//}
}
type Tunnel struct {
id string // 唯一标识
machineId int // 隧道机器id
localHost string // 本地监听地址
localPort int // 本地端口
remoteHost string // 远程连接地址
remotePort int // 远程端口
listener net.Listener
localConnections []net.Conn
remoteConnections []net.Conn
}
func (r *Tunnel) Open(sshClient *ssh.Client) {
localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort)
for {
logx.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr)
localConn, err := r.listener.Accept()
if err != nil {
logx.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error())
logx.Debug("-------------------------------------------------")
return
}
r.localConnections = append(r.localConnections, localConn)
logx.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String())
remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort)
logx.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr)
remoteConn, err := sshClient.Dial("tcp", remoteAddr)
if err != nil {
logx.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error())
logx.Debug("-------------------------------------------------")
return
}
r.remoteConnections = append(r.remoteConnections, remoteConn)
logx.Debugf("隧道 %v 连接远程主机成功", r.id)
go r.copyConn(localConn, remoteConn)
go r.copyConn(remoteConn, localConn)
logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
}
}
func (r *Tunnel) Close() {
for i := range r.localConnections {
_ = r.localConnections[i].Close()
}
r.localConnections = nil
for i := range r.remoteConnections {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
_ = r.listener.Close()
logx.Debugf("隧道 %s 监听器关闭", r.id)
}
func (r *Tunnel) copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader)
}