package machine import ( "fmt" "io" "mayfly-go/internal/devops/domain/entity" "mayfly-go/pkg/global" "mayfly-go/pkg/utils" "net" "os" "sync" "time" "golang.org/x/crypto/ssh" ) var ( sshTunnelMachines map[uint64]*SshTunnelMachine = make(map[uint64]*SshTunnelMachine) mutex sync.Mutex // 所有检测ssh隧道机器是否被使用的函数 checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc // 是否开启检查ssh隧道机器是否被使用,只有使用到了隧道机器才启用 startCheckSshTunnelHasUse bool = false ) // 检查ssh隧道机器是否有被使用 type CheckSshTunnelMachineHasUseFunc func(uint64) bool func startCheckUse() { global.Log.Info("开启定时检测ssh隧道机器是否还有被使用") heartbeat := time.Duration(10) * time.Minute tick := time.NewTicker(heartbeat) go func() { for range tick.C { func() { if !mutex.TryLock() { return } defer mutex.Unlock() // 遍历隧道机器,都未被使用将会被关闭 for mid, sshTunnelMachine := range sshTunnelMachines { global.Log.Debugf("开始定时检查ssh隧道机器[%d]是否还有被使用...", mid) for _, checkUseFunc := range checkSshTunnelMachineHasUseFuncs { // 如果一个在使用则返回不关闭,不继续后续检查 if checkUseFunc(mid) { return } } // 都未被使用,则关闭 sshTunnelMachine.Close() } }() } }() } // 添加ssh隧道机器检测是否使用函数 func AddCheckSshTunnelMachineUseFunc(checkFunc CheckSshTunnelMachineHasUseFunc) { if checkSshTunnelMachineHasUseFuncs == nil { checkSshTunnelMachineHasUseFuncs = make([]CheckSshTunnelMachineHasUseFunc, 0) } checkSshTunnelMachineHasUseFuncs = append(checkSshTunnelMachineHasUseFuncs, checkFunc) } // ssh隧道机器 type SshTunnelMachine struct { machineId uint64 // 隧道机器id SshClient *ssh.Client mutex sync.Mutex tunnels map[uint64]*Tunnel // 机器id -> 隧道 } func (stm *SshTunnelMachine) OpenSshTunnel(id uint64, ip string, port int) (exposedIp string, exposedPort int, err error) { stm.mutex.Lock() defer stm.mutex.Unlock() localPort, err := utils.GetAvailablePort() if err != nil { return "", 0, err } hostname, err := os.Hostname() if err != nil { return "", 0, err } // debug //hostname = "0.0.0.0" localAddr := fmt.Sprintf("%s:%d", hostname, localPort) listener, err := net.Listen("tcp", localAddr) if err != nil { return "", 0, err } tunnel := &Tunnel{ id: id, machineId: stm.machineId, localHost: hostname, localPort: localPort, remoteHost: ip, remotePort: port, listener: listener, } go tunnel.Open(stm.SshClient) stm.tunnels[tunnel.id] = tunnel return tunnel.localHost, tunnel.localPort, nil } func (st *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, error) { st.mutex.Lock() defer st.mutex.Unlock() return st.SshClient.Dial(network, addr) } func (stm *SshTunnelMachine) Close() { stm.mutex.Lock() defer stm.mutex.Unlock() for id, tunnel := range stm.tunnels { if tunnel != nil { tunnel.Close() delete(stm.tunnels, id) } } if stm.SshClient != nil { global.Log.Infof("ssh隧道机器[%d]未被使用, 关闭隧道...", stm.machineId) stm.SshClient.Close() } delete(sshTunnelMachines, stm.machineId) } // 获取ssh隧道机器,方便统一管理充当ssh隧道的机器,避免创建多个ssh client func GetSshTunnelMachine(machineId uint64, getMachine func(uint64) *entity.Machine) (*SshTunnelMachine, error) { sshTunnelMachine := sshTunnelMachines[machineId] if sshTunnelMachine != nil { return sshTunnelMachine, nil } mutex.Lock() defer mutex.Unlock() me := getMachine(machineId) sshClient, err := GetSshClient(me) if err != nil { return nil, err } sshTunnelMachine = &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[uint64]*Tunnel{}} global.Log.Infof("初次连接ssh隧道机器[%d][%s:%d]", machineId, me.Ip, me.Port) sshTunnelMachines[machineId] = sshTunnelMachine // 如果实用了隧道机器且还没开始定时检查是否还被实用,则执行定时任务检测隧道是否还被使用 if !startCheckSshTunnelHasUse { startCheckUse() startCheckSshTunnelHasUse = true } return sshTunnelMachine, nil } // 关闭ssh隧道机器的指定隧道 func CloseSshTunnelMachine(machineId uint64, tunnelId uint64) { sshTunnelMachine := sshTunnelMachines[machineId] if sshTunnelMachine == nil { return } sshTunnelMachine.mutex.Lock() defer sshTunnelMachine.mutex.Unlock() t := sshTunnelMachine.tunnels[tunnelId] if t != nil { t.Close() delete(sshTunnelMachine.tunnels, tunnelId) } } type Tunnel struct { id uint64 // 唯一标识 machineId uint64 // 隧道机器id localHost string // 本地监听地址 localPort int // 本地端口 remoteHost string // 远程连接地址 remotePort int // 远程端口 listener net.Listener localConnections []net.Conn remoteConnections []net.Conn } func (r *Tunnel) Open(sshClient *ssh.Client) { localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort) for { global.Log.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr) localConn, err := r.listener.Accept() if err != nil { global.Log.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error()) global.Log.Debug("-------------------------------------------------") return } r.localConnections = append(r.localConnections, localConn) global.Log.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String()) remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort) global.Log.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr) remoteConn, err := sshClient.Dial("tcp", remoteAddr) if err != nil { global.Log.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error()) global.Log.Debug("-------------------------------------------------") return } r.remoteConnections = append(r.remoteConnections, remoteConn) global.Log.Debugf("隧道 %v 连接远程主机成功", r.id) go copyConn(localConn, remoteConn) go copyConn(remoteConn, localConn) global.Log.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr) global.Log.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~") } } func (r *Tunnel) Close() { for i := range r.localConnections { _ = r.localConnections[i].Close() } r.localConnections = nil for i := range r.remoteConnections { _ = r.remoteConnections[i].Close() } r.remoteConnections = nil _ = r.listener.Close() global.Log.Debugf("隧道 %d 监听器关闭", r.id) } func copyConn(writer, reader net.Conn) { _, _ = io.Copy(writer, reader) }