mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-04 08:20:25 +08:00
feat: linux支持ssh隧道访问&其他优化
This commit is contained in:
240
server/internal/devops/infrastructure/machine/sshtunnel.go
Normal file
240
server/internal/devops/infrastructure/machine/sshtunnel.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package machine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"mayfly-go/internal/devops/domain/entity"
|
||||
"mayfly-go/pkg/global"
|
||||
"mayfly-go/pkg/utils"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
var (
|
||||
sshTunnelMachines map[uint64]*SshTunnelMachine = make(map[uint64]*SshTunnelMachine)
|
||||
|
||||
mutex sync.Mutex
|
||||
|
||||
// 所有检测ssh隧道机器是否被使用的函数
|
||||
checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
|
||||
|
||||
// 是否开启检查ssh隧道机器是否被使用,只有使用到了隧道机器才启用
|
||||
startCheckSshTunnelHasUse bool = false
|
||||
)
|
||||
|
||||
// 检查ssh隧道机器是否有被使用
|
||||
type CheckSshTunnelMachineHasUseFunc func(uint64) bool
|
||||
|
||||
func startCheckUse() {
|
||||
global.Log.Info("开启定时检测ssh隧道机器是否还有被使用")
|
||||
heartbeat := time.Duration(10) * time.Minute
|
||||
tick := time.NewTicker(heartbeat)
|
||||
go func() {
|
||||
for range tick.C {
|
||||
func() {
|
||||
if !mutex.TryLock() {
|
||||
return
|
||||
}
|
||||
defer mutex.Unlock()
|
||||
// 遍历隧道机器,都未被使用将会被关闭
|
||||
for mid, sshTunnelMachine := range sshTunnelMachines {
|
||||
global.Log.Debugf("开始定时检查ssh隧道机器[%d]是否还有被使用...", mid)
|
||||
for _, checkUseFunc := range checkSshTunnelMachineHasUseFuncs {
|
||||
// 如果一个在使用则返回不关闭,不继续后续检查
|
||||
if checkUseFunc(mid) {
|
||||
return
|
||||
}
|
||||
}
|
||||
// 都未被使用,则关闭
|
||||
sshTunnelMachine.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// 添加ssh隧道机器检测是否使用函数
|
||||
func AddCheckSshTunnelMachineUseFunc(checkFunc CheckSshTunnelMachineHasUseFunc) {
|
||||
if checkSshTunnelMachineHasUseFuncs == nil {
|
||||
checkSshTunnelMachineHasUseFuncs = make([]CheckSshTunnelMachineHasUseFunc, 0)
|
||||
}
|
||||
checkSshTunnelMachineHasUseFuncs = append(checkSshTunnelMachineHasUseFuncs, checkFunc)
|
||||
}
|
||||
|
||||
// ssh隧道机器
|
||||
type SshTunnelMachine struct {
|
||||
machineId uint64 // 隧道机器id
|
||||
SshClient *ssh.Client
|
||||
mutex sync.Mutex
|
||||
tunnels map[uint64]*Tunnel // 机器id -> 隧道
|
||||
}
|
||||
|
||||
func (stm *SshTunnelMachine) OpenSshTunnel(id uint64, ip string, port int) (exposedIp string, exposedPort int, err error) {
|
||||
stm.mutex.Lock()
|
||||
defer stm.mutex.Unlock()
|
||||
|
||||
localPort, err := utils.GetAvailablePort()
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
// debug
|
||||
//hostname = "0.0.0.0"
|
||||
|
||||
localAddr := fmt.Sprintf("%s:%d", hostname, localPort)
|
||||
listener, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
return "", 0, err
|
||||
}
|
||||
|
||||
tunnel := &Tunnel{
|
||||
id: id,
|
||||
machineId: stm.machineId,
|
||||
localHost: hostname,
|
||||
localPort: localPort,
|
||||
remoteHost: ip,
|
||||
remotePort: port,
|
||||
listener: listener,
|
||||
}
|
||||
go tunnel.Open(stm.SshClient)
|
||||
stm.tunnels[tunnel.id] = tunnel
|
||||
|
||||
return tunnel.localHost, tunnel.localPort, nil
|
||||
}
|
||||
|
||||
func (st *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, error) {
|
||||
st.mutex.Lock()
|
||||
defer st.mutex.Unlock()
|
||||
return st.SshClient.Dial(network, addr)
|
||||
}
|
||||
|
||||
func (stm *SshTunnelMachine) Close() {
|
||||
stm.mutex.Lock()
|
||||
defer stm.mutex.Unlock()
|
||||
|
||||
for id, tunnel := range stm.tunnels {
|
||||
if tunnel != nil {
|
||||
tunnel.Close()
|
||||
delete(stm.tunnels, id)
|
||||
}
|
||||
}
|
||||
|
||||
if stm.SshClient != nil {
|
||||
global.Log.Infof("ssh隧道机器[%d]未被使用, 关闭隧道...", stm.machineId)
|
||||
stm.SshClient.Close()
|
||||
}
|
||||
delete(sshTunnelMachines, stm.machineId)
|
||||
}
|
||||
|
||||
// 获取ssh隧道机器,方便统一管理充当ssh隧道的机器,避免创建多个ssh client
|
||||
func GetSshTunnelMachine(machineId uint64, getMachine func(uint64) *entity.Machine) (*SshTunnelMachine, error) {
|
||||
sshTunnelMachine := sshTunnelMachines[machineId]
|
||||
if sshTunnelMachine != nil {
|
||||
return sshTunnelMachine, nil
|
||||
}
|
||||
|
||||
mutex.Lock()
|
||||
defer mutex.Unlock()
|
||||
|
||||
me := getMachine(machineId)
|
||||
sshClient, err := GetSshClient(me)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sshTunnelMachine = &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[uint64]*Tunnel{}}
|
||||
|
||||
global.Log.Infof("初次连接ssh隧道机器[%d][%s:%d]", machineId, me.Ip, me.Port)
|
||||
sshTunnelMachines[machineId] = sshTunnelMachine
|
||||
|
||||
// 如果实用了隧道机器且还没开始定时检查是否还被实用,则执行定时任务检测隧道是否还被使用
|
||||
if !startCheckSshTunnelHasUse {
|
||||
startCheckUse()
|
||||
startCheckSshTunnelHasUse = true
|
||||
}
|
||||
return sshTunnelMachine, nil
|
||||
}
|
||||
|
||||
// 关闭ssh隧道机器的指定隧道
|
||||
func CloseSshTunnelMachine(machineId uint64, tunnelId uint64) {
|
||||
sshTunnelMachine := sshTunnelMachines[machineId]
|
||||
if sshTunnelMachine == nil {
|
||||
return
|
||||
}
|
||||
|
||||
sshTunnelMachine.mutex.Lock()
|
||||
defer sshTunnelMachine.mutex.Unlock()
|
||||
t := sshTunnelMachine.tunnels[tunnelId]
|
||||
if t != nil {
|
||||
t.Close()
|
||||
delete(sshTunnelMachine.tunnels, tunnelId)
|
||||
}
|
||||
}
|
||||
|
||||
type Tunnel struct {
|
||||
id uint64 // 唯一标识
|
||||
machineId uint64 // 隧道机器id
|
||||
localHost string // 本地监听地址
|
||||
localPort int // 本地端口
|
||||
remoteHost string // 远程连接地址
|
||||
remotePort int // 远程端口
|
||||
listener net.Listener
|
||||
localConnections []net.Conn
|
||||
remoteConnections []net.Conn
|
||||
}
|
||||
|
||||
func (r *Tunnel) Open(sshClient *ssh.Client) {
|
||||
localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort)
|
||||
|
||||
for {
|
||||
global.Log.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr)
|
||||
localConn, err := r.listener.Accept()
|
||||
if err != nil {
|
||||
global.Log.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error())
|
||||
global.Log.Debug("-------------------------------------------------")
|
||||
return
|
||||
}
|
||||
r.localConnections = append(r.localConnections, localConn)
|
||||
|
||||
global.Log.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String())
|
||||
remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort)
|
||||
global.Log.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr)
|
||||
remoteConn, err := sshClient.Dial("tcp", remoteAddr)
|
||||
if err != nil {
|
||||
global.Log.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error())
|
||||
global.Log.Debug("-------------------------------------------------")
|
||||
return
|
||||
}
|
||||
r.remoteConnections = append(r.remoteConnections, remoteConn)
|
||||
|
||||
global.Log.Debugf("隧道 %v 连接远程主机成功", r.id)
|
||||
go copyConn(localConn, remoteConn)
|
||||
go copyConn(remoteConn, localConn)
|
||||
global.Log.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
|
||||
global.Log.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Tunnel) Close() {
|
||||
for i := range r.localConnections {
|
||||
_ = r.localConnections[i].Close()
|
||||
}
|
||||
r.localConnections = nil
|
||||
for i := range r.remoteConnections {
|
||||
_ = r.remoteConnections[i].Close()
|
||||
}
|
||||
r.remoteConnections = nil
|
||||
_ = r.listener.Close()
|
||||
global.Log.Debugf("隧道 %d 监听器关闭", r.id)
|
||||
}
|
||||
|
||||
func copyConn(writer, reader net.Conn) {
|
||||
_, _ = io.Copy(writer, reader)
|
||||
}
|
||||
Reference in New Issue
Block a user