2023-10-30 17:34:56 +08:00
|
|
|
|
package mcm
|
2022-07-23 16:41:04 +08:00
|
|
|
|
|
|
|
|
|
|
import (
|
2025-05-21 04:42:30 +00:00
|
|
|
|
"errors"
|
2022-07-23 16:41:04 +08:00
|
|
|
|
"fmt"
|
|
|
|
|
|
"io"
|
2023-09-02 17:24:18 +08:00
|
|
|
|
"mayfly-go/pkg/logx"
|
2025-05-21 04:42:30 +00:00
|
|
|
|
"mayfly-go/pkg/pool"
|
2023-07-21 17:07:04 +08:00
|
|
|
|
"mayfly-go/pkg/utils/netx"
|
2022-07-23 16:41:04 +08:00
|
|
|
|
"net"
|
|
|
|
|
|
"sync"
|
2025-05-21 04:42:30 +00:00
|
|
|
|
"time"
|
2022-07-23 16:41:04 +08:00
|
|
|
|
|
|
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
var (
|
|
|
|
|
|
mutex sync.Mutex
|
|
|
|
|
|
|
|
|
|
|
|
// 所有检测ssh隧道机器是否被使用的函数
|
|
|
|
|
|
checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
|
|
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
tunnelPool = make(map[int]pool.Pool)
|
2022-07-23 16:41:04 +08:00
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// 检查ssh隧道机器是否有被使用
|
2023-03-06 16:59:57 +08:00
|
|
|
|
type CheckSshTunnelMachineHasUseFunc func(int) bool
|
2022-07-23 16:41:04 +08:00
|
|
|
|
|
|
|
|
|
|
// 添加ssh隧道机器检测是否使用函数
|
|
|
|
|
|
func AddCheckSshTunnelMachineUseFunc(checkFunc CheckSshTunnelMachineHasUseFunc) {
|
|
|
|
|
|
if checkSshTunnelMachineHasUseFuncs == nil {
|
|
|
|
|
|
checkSshTunnelMachineHasUseFuncs = make([]CheckSshTunnelMachineHasUseFunc, 0)
|
|
|
|
|
|
}
|
|
|
|
|
|
checkSshTunnelMachineHasUseFuncs = append(checkSshTunnelMachineHasUseFuncs, checkFunc)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ssh隧道机器
|
|
|
|
|
|
type SshTunnelMachine struct {
|
2025-05-21 04:42:30 +00:00
|
|
|
|
mi *MachineInfo
|
2023-03-06 16:59:57 +08:00
|
|
|
|
machineId int // 隧道机器id
|
2022-07-23 16:41:04 +08:00
|
|
|
|
SshClient *ssh.Client
|
|
|
|
|
|
mutex sync.Mutex
|
2023-12-20 23:01:51 +08:00
|
|
|
|
tunnels map[string]*Tunnel // 隧道id -> 隧道
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
func (stm *SshTunnelMachine) Ping() error {
|
|
|
|
|
|
_, _, err := stm.SshClient.Conn.SendRequest("ping", true, nil)
|
|
|
|
|
|
return err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-12-20 23:01:51 +08:00
|
|
|
|
func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (exposedIp string, exposedPort int, err error) {
|
2022-07-23 16:41:04 +08:00
|
|
|
|
stm.mutex.Lock()
|
|
|
|
|
|
defer stm.mutex.Unlock()
|
|
|
|
|
|
|
2024-01-05 22:16:38 +08:00
|
|
|
|
tunnel := stm.tunnels[id]
|
|
|
|
|
|
// 已存在该id隧道,则直接返回
|
|
|
|
|
|
if tunnel != nil {
|
2025-05-21 04:42:30 +00:00
|
|
|
|
// FIXME 后期改成池化连接,定时60秒检查连接可用性
|
2024-01-05 22:16:38 +08:00
|
|
|
|
return tunnel.localHost, tunnel.localPort, nil
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-07-21 17:07:04 +08:00
|
|
|
|
localPort, err := netx.GetAvailablePort()
|
2022-07-23 16:41:04 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return "", 0, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
localHost := "127.0.0.1"
|
2024-03-01 04:03:03 +00:00
|
|
|
|
localAddr := fmt.Sprintf("%s:%d", localHost, localPort)
|
2022-07-23 16:41:04 +08:00
|
|
|
|
listener, err := net.Listen("tcp", localAddr)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return "", 0, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2024-01-05 22:16:38 +08:00
|
|
|
|
tunnel = &Tunnel{
|
2022-07-23 16:41:04 +08:00
|
|
|
|
id: id,
|
|
|
|
|
|
machineId: stm.machineId,
|
2024-03-01 04:03:03 +00:00
|
|
|
|
localHost: localHost,
|
2022-07-23 16:41:04 +08:00
|
|
|
|
localPort: localPort,
|
|
|
|
|
|
remoteHost: ip,
|
|
|
|
|
|
remotePort: port,
|
|
|
|
|
|
listener: listener,
|
|
|
|
|
|
}
|
|
|
|
|
|
go tunnel.Open(stm.SshClient)
|
|
|
|
|
|
stm.tunnels[tunnel.id] = tunnel
|
|
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
return localHost, localPort, nil
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
func (stm *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, error) {
|
|
|
|
|
|
stm.mutex.Lock()
|
|
|
|
|
|
defer stm.mutex.Unlock()
|
|
|
|
|
|
return stm.SshClient.Dial(network, addr)
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (stm *SshTunnelMachine) Close() {
|
|
|
|
|
|
stm.mutex.Lock()
|
|
|
|
|
|
defer stm.mutex.Unlock()
|
|
|
|
|
|
|
|
|
|
|
|
for id, tunnel := range stm.tunnels {
|
|
|
|
|
|
if tunnel != nil {
|
|
|
|
|
|
tunnel.Close()
|
|
|
|
|
|
delete(stm.tunnels, id)
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if stm.SshClient != nil {
|
2024-11-20 22:43:53 +08:00
|
|
|
|
logx.Infof("ssh tunnel machine [%d] is not in use, close tunnel...", stm.machineId)
|
2022-07-24 15:37:13 +08:00
|
|
|
|
err := stm.SshClient.Close()
|
|
|
|
|
|
if err != nil {
|
2024-11-20 22:43:53 +08:00
|
|
|
|
logx.Errorf("error in closing ssh tunnel machine [%d]: %s", stm.machineId, err.Error())
|
2022-07-24 15:37:13 +08:00
|
|
|
|
}
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
2025-05-21 04:42:30 +00:00
|
|
|
|
delete(tunnelPool, stm.machineId)
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
func getTunnelPool(machineId int, getMachine func(uint64) (*MachineInfo, error)) (pool.Pool, error) {
|
|
|
|
|
|
// 获取连接池,如果没有,则创建一个
|
|
|
|
|
|
if p, ok := tunnelPool[machineId]; !ok {
|
|
|
|
|
|
var err error
|
|
|
|
|
|
p, err = pool.NewChannelPool(&pool.Config{
|
|
|
|
|
|
InitialCap: 1, //资源池初始连接数
|
|
|
|
|
|
MaxCap: 10, //最大空闲连接数
|
|
|
|
|
|
MaxIdle: 10, //最大并发连接数
|
|
|
|
|
|
IdleTimeout: 10 * time.Minute, // 连接最大空闲时间,过期则失效
|
|
|
|
|
|
Factory: func() (interface{}, error) {
|
|
|
|
|
|
mi, err := getMachine(uint64(machineId))
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if mi == nil {
|
|
|
|
|
|
return nil, errors.New("error get machine info")
|
|
|
|
|
|
}
|
|
|
|
|
|
sshClient, err := GetSshClient(mi, nil)
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
stm := &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[string]*Tunnel{}, mi: mi}
|
|
|
|
|
|
logx.Infof("connect to the ssh tunnel machine for the first time[%d][%s:%d]", machineId, mi.Ip, mi.Port)
|
|
|
|
|
|
|
|
|
|
|
|
return stm, err
|
|
|
|
|
|
},
|
|
|
|
|
|
Close: func(v interface{}) error {
|
|
|
|
|
|
v.(*SshTunnelMachine).Close()
|
|
|
|
|
|
return nil
|
|
|
|
|
|
},
|
|
|
|
|
|
Ping: func(v interface{}) error {
|
|
|
|
|
|
return v.(*SshTunnelMachine).Ping()
|
|
|
|
|
|
},
|
|
|
|
|
|
})
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
tunnelPool[machineId] = p
|
|
|
|
|
|
return p, nil
|
|
|
|
|
|
} else {
|
|
|
|
|
|
return p, nil
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
2025-05-21 04:42:30 +00:00
|
|
|
|
}
|
2022-07-23 16:41:04 +08:00
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
// 获取ssh隧道机器,方便统一管理充当ssh隧道的机器,避免创建多个ssh client
|
|
|
|
|
|
func GetSshTunnelMachine(machineId int, getMachine func(uint64) (*MachineInfo, error)) (*SshTunnelMachine, error) {
|
|
|
|
|
|
p, err := getTunnelPool(machineId, getMachine)
|
2023-10-26 17:15:49 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
2025-05-21 04:42:30 +00:00
|
|
|
|
// 从连接池中获取一个可用的连接
|
|
|
|
|
|
c, err := p.Get()
|
2022-07-23 16:41:04 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2025-05-21 04:42:30 +00:00
|
|
|
|
return c.(*SshTunnelMachine), nil
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 关闭ssh隧道机器的指定隧道
|
2025-05-21 04:42:30 +00:00
|
|
|
|
func CloseSshTunnelMachine(machineId uint64, tunnelId string) {
|
|
|
|
|
|
//sshTunnelMachine := mcIdPool[machineId]
|
|
|
|
|
|
//if sshTunnelMachine == nil {
|
|
|
|
|
|
// return
|
|
|
|
|
|
//}
|
|
|
|
|
|
//
|
|
|
|
|
|
//sshTunnelMachine.mutex.Lock()
|
|
|
|
|
|
//defer sshTunnelMachine.mutex.Unlock()
|
|
|
|
|
|
//t := sshTunnelMachine.tunnels[tunnelId]
|
|
|
|
|
|
//if t != nil {
|
|
|
|
|
|
// t.Close()
|
|
|
|
|
|
// delete(sshTunnelMachine.tunnels, tunnelId)
|
|
|
|
|
|
//}
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type Tunnel struct {
|
2023-12-20 23:01:51 +08:00
|
|
|
|
id string // 唯一标识
|
2023-03-06 16:59:57 +08:00
|
|
|
|
machineId int // 隧道机器id
|
2022-07-23 16:41:04 +08:00
|
|
|
|
localHost string // 本地监听地址
|
|
|
|
|
|
localPort int // 本地端口
|
|
|
|
|
|
remoteHost string // 远程连接地址
|
|
|
|
|
|
remotePort int // 远程端口
|
|
|
|
|
|
listener net.Listener
|
|
|
|
|
|
localConnections []net.Conn
|
|
|
|
|
|
remoteConnections []net.Conn
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (r *Tunnel) Open(sshClient *ssh.Client) {
|
|
|
|
|
|
localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort)
|
|
|
|
|
|
|
|
|
|
|
|
for {
|
2023-09-02 17:24:18 +08:00
|
|
|
|
logx.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr)
|
2022-07-23 16:41:04 +08:00
|
|
|
|
localConn, err := r.listener.Accept()
|
|
|
|
|
|
if err != nil {
|
2023-09-02 17:24:18 +08:00
|
|
|
|
logx.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error())
|
|
|
|
|
|
logx.Debug("-------------------------------------------------")
|
2022-07-23 16:41:04 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
r.localConnections = append(r.localConnections, localConn)
|
|
|
|
|
|
|
2023-09-02 17:24:18 +08:00
|
|
|
|
logx.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String())
|
2022-07-23 16:41:04 +08:00
|
|
|
|
remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort)
|
2023-09-02 17:24:18 +08:00
|
|
|
|
logx.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr)
|
2022-07-23 16:41:04 +08:00
|
|
|
|
remoteConn, err := sshClient.Dial("tcp", remoteAddr)
|
|
|
|
|
|
if err != nil {
|
2023-09-02 17:24:18 +08:00
|
|
|
|
logx.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error())
|
|
|
|
|
|
logx.Debug("-------------------------------------------------")
|
2022-07-23 16:41:04 +08:00
|
|
|
|
return
|
|
|
|
|
|
}
|
|
|
|
|
|
r.remoteConnections = append(r.remoteConnections, remoteConn)
|
|
|
|
|
|
|
2023-09-02 17:24:18 +08:00
|
|
|
|
logx.Debugf("隧道 %v 连接远程主机成功", r.id)
|
2025-01-17 03:53:15 +00:00
|
|
|
|
go r.copyConn(localConn, remoteConn)
|
|
|
|
|
|
go r.copyConn(remoteConn, localConn)
|
2023-09-02 17:24:18 +08:00
|
|
|
|
logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
|
|
|
|
|
|
logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (r *Tunnel) Close() {
|
|
|
|
|
|
for i := range r.localConnections {
|
|
|
|
|
|
_ = r.localConnections[i].Close()
|
|
|
|
|
|
}
|
|
|
|
|
|
r.localConnections = nil
|
|
|
|
|
|
for i := range r.remoteConnections {
|
|
|
|
|
|
_ = r.remoteConnections[i].Close()
|
|
|
|
|
|
}
|
|
|
|
|
|
r.remoteConnections = nil
|
|
|
|
|
|
_ = r.listener.Close()
|
2023-12-20 23:01:51 +08:00
|
|
|
|
logx.Debugf("隧道 %s 监听器关闭", r.id)
|
2022-07-23 16:41:04 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
2025-01-17 03:53:15 +00:00
|
|
|
|
func (r *Tunnel) copyConn(writer, reader net.Conn) {
|
2022-07-23 16:41:04 +08:00
|
|
|
|
_, _ = io.Copy(writer, reader)
|
|
|
|
|
|
}
|