mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 00:10:25 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			238 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			238 lines
		
	
	
		
			6.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package mcm
 | 
						||
 | 
						||
import (
 | 
						||
	"context"
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"io"
 | 
						||
	"mayfly-go/pkg/logx"
 | 
						||
	"mayfly-go/pkg/pool"
 | 
						||
	"mayfly-go/pkg/utils/netx"
 | 
						||
	"net"
 | 
						||
	"sync"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"golang.org/x/crypto/ssh"
 | 
						||
)
 | 
						||
 | 
						||
// type SshTunnelAble interface {
 | 
						||
// 	GetSshTunnelMachineId() int
 | 
						||
// }
 | 
						||
 | 
						||
var (
 | 
						||
	// 所有检测ssh隧道机器是否被使用的函数
 | 
						||
	checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
 | 
						||
 | 
						||
	tunnelPoolGroup = pool.NewPoolGroup[*SshTunnelMachine]()
 | 
						||
)
 | 
						||
 | 
						||
// 检查ssh隧道机器是否有被使用
 | 
						||
type CheckSshTunnelMachineHasUseFunc func(int) bool
 | 
						||
 | 
						||
// 添加ssh隧道机器检测是否使用函数
 | 
						||
func AddCheckSshTunnelMachineUseFunc(checkFunc CheckSshTunnelMachineHasUseFunc) {
 | 
						||
	if checkSshTunnelMachineHasUseFuncs == nil {
 | 
						||
		checkSshTunnelMachineHasUseFuncs = make([]CheckSshTunnelMachineHasUseFunc, 0)
 | 
						||
	}
 | 
						||
	checkSshTunnelMachineHasUseFuncs = append(checkSshTunnelMachineHasUseFuncs, checkFunc)
 | 
						||
}
 | 
						||
 | 
						||
// ssh隧道机器
 | 
						||
type SshTunnelMachine struct {
 | 
						||
	mi        *MachineInfo
 | 
						||
	machineId int // 隧道机器id
 | 
						||
	SshClient *ssh.Client
 | 
						||
	mutex     sync.Mutex
 | 
						||
	tunnels   map[string]*Tunnel // 隧道id -> 隧道
 | 
						||
}
 | 
						||
 | 
						||
/******************* pool.Conn impl *******************/
 | 
						||
 | 
						||
func (stm *SshTunnelMachine) Ping() error {
 | 
						||
	_, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil)
 | 
						||
	return err
 | 
						||
}
 | 
						||
 | 
						||
func (stm *SshTunnelMachine) Close() error {
 | 
						||
	stm.mutex.Lock()
 | 
						||
	defer stm.mutex.Unlock()
 | 
						||
 | 
						||
	for id, tunnel := range stm.tunnels {
 | 
						||
		if tunnel != nil {
 | 
						||
			tunnel.Close()
 | 
						||
			delete(stm.tunnels, id)
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	if stm.SshClient != nil {
 | 
						||
		logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
 | 
						||
		err := stm.SshClient.Close()
 | 
						||
		if err != nil {
 | 
						||
			logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) {
 | 
						||
	stm.mutex.Lock()
 | 
						||
	defer stm.mutex.Unlock()
 | 
						||
 | 
						||
	tunnel := stm.tunnels[id]
 | 
						||
	// 已存在该id隧道,则直接返回
 | 
						||
	if tunnel != nil {
 | 
						||
		// FIXME 后期改成池化连接,定时60秒检查连接可用性
 | 
						||
		return tunnel.localHost, tunnel.localPort, nil
 | 
						||
	}
 | 
						||
 | 
						||
	localPort, err := netx.GetAvailablePort()
 | 
						||
	if err != nil {
 | 
						||
		return "", 0, err
 | 
						||
	}
 | 
						||
 | 
						||
	localHost := "127.0.0.1"
 | 
						||
	localAddr := fmt.Sprintf("%s:%d", localHost, localPort)
 | 
						||
	listener, err := net.Listen("tcp", localAddr)
 | 
						||
	if err != nil {
 | 
						||
		return "", 0, err
 | 
						||
	}
 | 
						||
 | 
						||
	tunnel = &Tunnel{
 | 
						||
		id:         id,
 | 
						||
		machineId:  stm.machineId,
 | 
						||
		localHost:  localHost,
 | 
						||
		localPort:  localPort,
 | 
						||
		remoteHost: ip,
 | 
						||
		remotePort: port,
 | 
						||
		listener:   listener,
 | 
						||
	}
 | 
						||
	go tunnel.Open(stm.SshClient)
 | 
						||
	stm.tunnels[tunnel.id] = tunnel
 | 
						||
 | 
						||
	return localHost, localPort, nil
 | 
						||
}
 | 
						||
 | 
						||
func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, error) {
 | 
						||
	stm.mutex.Lock()
 | 
						||
	defer stm.mutex.Unlock()
 | 
						||
	return stm.SshClient.Dial(network, addr)
 | 
						||
}
 | 
						||
 | 
						||
// 获取ssh隧道机器,方便统一管理充当ssh隧道的机器,避免创建多个ssh client
 | 
						||
func GetSshTunnelMachine(ctx context.Context, machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
 | 
						||
	pool, err := tunnelPoolGroup.GetCachePool(fmt.Sprintf("machine-tunnel-%d", machineId), func() (*SshTunnelMachine, error) {
 | 
						||
		mi, err := getMachine(uint64(machineId))
 | 
						||
		if err != nil {
 | 
						||
			return nil, err
 | 
						||
		}
 | 
						||
		if mi == nil {
 | 
						||
			return nil, errors.New("error get machine info")
 | 
						||
		}
 | 
						||
		sshClient, err := GetSshClient(mi, nil)
 | 
						||
		if err != nil {
 | 
						||
			return nil, err
 | 
						||
		}
 | 
						||
		stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
 | 
						||
		logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
 | 
						||
 | 
						||
		return stm, err
 | 
						||
	}, pool.WithIdleTimeout[*SshTunnelMachine](50*time.Minute), pool.WithOnConnClose(func(conn *SshTunnelMachine) error {
 | 
						||
		mid := int(conn.mi.Id)
 | 
						||
		logx.Debugf("periodically check if the ssh tunnel machine [%d] is still in use...", mid)
 | 
						||
 | 
						||
		for _, checkUseFunc := range checkSshTunnelMachineHasUseFuncs {
 | 
						||
			// 如果一个在使用则返回不关闭,不继续后续检查
 | 
						||
			if checkUseFunc(mid) {
 | 
						||
				return fmt.Errorf("ssh tunnel machine [%s] is still in use", conn.mi.Name)
 | 
						||
			}
 | 
						||
		}
 | 
						||
 | 
						||
		return nil
 | 
						||
	}))
 | 
						||
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
	// 从连接池中获取一个可用的连接
 | 
						||
	return pool.Get(ctx)
 | 
						||
}
 | 
						||
 | 
						||
// 关闭ssh隧道机器的指定隧道
 | 
						||
func CloseSshTunnelMachine(machineId uint64, tunnelId string) {
 | 
						||
	sshTunnelMachinePool, ok := tunnelPoolGroup.Get(fmt.Sprintf("machine-tunnel-%d", machineId))
 | 
						||
	if !ok {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	sshTunnelMachine, err := sshTunnelMachinePool.Get(context.Background())
 | 
						||
	if err != nil {
 | 
						||
		return
 | 
						||
	}
 | 
						||
	t := sshTunnelMachine.tunnels[tunnelId]
 | 
						||
	if t != nil {
 | 
						||
		t.Close()
 | 
						||
		delete(sshTunnelMachine.tunnels, tunnelId)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
type Tunnel struct {
 | 
						||
	id                string // 唯一标识
 | 
						||
	machineId         int    // 隧道机器id
 | 
						||
	localHost         string // 本地监听地址
 | 
						||
	localPort         int    // 本地端口
 | 
						||
	remoteHost        string // 远程连接地址
 | 
						||
	remotePort        int    // 远程端口
 | 
						||
	listener          net.Listener
 | 
						||
	localConnections  []net.Conn
 | 
						||
	remoteConnections []net.Conn
 | 
						||
}
 | 
						||
 | 
						||
func (r *Tunnel) Open(sshClient *ssh.Client) {
 | 
						||
	localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort)
 | 
						||
 | 
						||
	for {
 | 
						||
		logx.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr)
 | 
						||
		localConn, err := r.listener.Accept()
 | 
						||
		if err != nil {
 | 
						||
			logx.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error())
 | 
						||
			logx.Debug("-------------------------------------------------")
 | 
						||
			return
 | 
						||
		}
 | 
						||
		r.localConnections = append(r.localConnections, localConn)
 | 
						||
 | 
						||
		logx.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String())
 | 
						||
		remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort)
 | 
						||
		logx.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr)
 | 
						||
		remoteConn, err := sshClient.Dial("tcp", remoteAddr)
 | 
						||
		if err != nil {
 | 
						||
			logx.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error())
 | 
						||
			logx.Debug("-------------------------------------------------")
 | 
						||
			return
 | 
						||
		}
 | 
						||
		r.remoteConnections = append(r.remoteConnections, remoteConn)
 | 
						||
 | 
						||
		logx.Debugf("隧道 %v 连接远程主机成功", r.id)
 | 
						||
		go r.copyConn(localConn, remoteConn)
 | 
						||
		go r.copyConn(remoteConn, localConn)
 | 
						||
		logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
 | 
						||
		logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
func (r *Tunnel) Close() {
 | 
						||
	for i := range r.localConnections {
 | 
						||
		_ = r.localConnections[i].Close()
 | 
						||
	}
 | 
						||
	r.localConnections = nil
 | 
						||
	for i := range r.remoteConnections {
 | 
						||
		_ = r.remoteConnections[i].Close()
 | 
						||
	}
 | 
						||
	r.remoteConnections = nil
 | 
						||
	_ = r.listener.Close()
 | 
						||
	logx.Debugf("隧道 %s 监听器关闭", r.id)
 | 
						||
}
 | 
						||
 | 
						||
func (r *Tunnel) copyConn(writer, reader net.Conn) {
 | 
						||
	_, _ = io.Copy(writer, reader)
 | 
						||
}
 |