feat: linux支持ssh隧道访问&其他优化

This commit is contained in:
meilin.huang
2022-07-23 16:41:04 +08:00
parent f0540559bb
commit 76d6fc3ba5
26 changed files with 2003 additions and 1556 deletions

View File

@@ -3,6 +3,7 @@ package machine
import (
"errors"
"fmt"
"mayfly-go/internal/constant"
"mayfly-go/internal/devops/domain/entity"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/cache"
@@ -18,10 +19,12 @@ import (
// 客户端信息
type Cli struct {
machine *entity.Machine
// ssh客户端
client *ssh.Client
sftpClient *sftp.Client
client *ssh.Client // ssh客户端
sftpClient *sftp.Client // sftp客户端
enableSshTunnel int8
sshTunnelMachineId uint64
}
//连接
@@ -39,7 +42,7 @@ func (c *Cli) connect() error {
return nil
}
// 关闭client并从缓存中移除
// 关闭client并从缓存中移除,如果使用隧道则也关闭
func (c *Cli) Close() {
m := c.machine
global.Log.Info(fmt.Sprintf("关闭机器客户端连接-> id: %d, name: %s, ip: %s", m.Id, m.Name, m.Ip))
@@ -51,6 +54,9 @@ func (c *Cli) Close() {
c.sftpClient.Close()
c.sftpClient = nil
}
if c.enableSshTunnel == 1 {
CloseSshTunnelMachine(c.sshTunnelMachineId, c.machine.Id)
}
}
// 获取sftp client
@@ -105,13 +111,26 @@ func (c *Cli) GetMachine() *entity.Machine {
return c.machine
}
// 机器客户端连接缓存,45分钟内没有访问则会被关闭
var cliCache = cache.NewTimedCache(45*time.Minute, 5*time.Second).
// 机器客户端连接缓存,指定时间内没有访问则会被关闭
var cliCache = cache.NewTimedCache(constant.MachineConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(_, value interface{}) {
value.(*Cli).Close()
})
func init() {
AddCheckSshTunnelMachineUseFunc(func(machineId uint64) bool {
// 遍历所有机器连接实例若存在机器连接实例使用该ssh隧道机器则返回true表示还在使用中...
items := cliCache.Items()
for _, v := range items {
if v.Value.(*Cli).sshTunnelMachineId == machineId {
return true
}
}
return false
})
}
// 是否存在指定id的客户端连接
func HasCli(machineId uint64) bool {
if _, ok := cliCache.Get(machineId); ok {
@@ -128,10 +147,18 @@ func DeleteCli(id uint64) {
// 从缓存中获取客户端信息,不存在则回调获取机器信息函数,并新建
func GetCli(machineId uint64, getMachine func(uint64) *entity.Machine) (*Cli, error) {
cli, err := cliCache.ComputeIfAbsent(machineId, func(_ interface{}) (interface{}, error) {
c, err := newClient(getMachine(machineId))
me := getMachine(machineId)
err := IfUseSshTunnelChangeIpPort(me, getMachine)
if err != nil {
return nil, fmt.Errorf("ssh隧道连接失败: %s", err.Error())
}
c, err := newClient(me)
if err != nil {
CloseSshTunnelMachine(me.SshTunnelMachineId, me.Id)
return nil, err
}
c.enableSshTunnel = me.EnableSshTunnel
c.sshTunnelMachineId = me.SshTunnelMachineId
return c, nil
})
@@ -141,9 +168,20 @@ func GetCli(machineId uint64, getMachine func(uint64) *entity.Machine) (*Cli, er
return nil, err
}
// 测试连接
func TestConn(m *entity.Machine) error {
sshClient, err := GetSshClient(m)
// 测试连接使用传值的方式而非引用。因为如果使用了ssh隧道则ip和端口会变为本地映射地址与端口
func TestConn(me entity.Machine, getSshTunnelMachine func(uint64) *entity.Machine) error {
originId := me.Id
if originId == 0 {
// 随机设置一个ip如果使用了隧道则用于临时保存隧道
me.Id = uint64(time.Now().Nanosecond())
}
err := IfUseSshTunnelChangeIpPort(&me, getSshTunnelMachine)
biz.ErrIsNilAppendErr(err, "ssh隧道连接失败: %s")
if me.EnableSshTunnel == 1 {
defer CloseSshTunnelMachine(me.SshTunnelMachineId, me.Id)
}
sshClient, err := GetSshClient(&me)
if err != nil {
return err
}
@@ -151,6 +189,27 @@ func TestConn(m *entity.Machine) error {
return nil
}
// 如果使用了ssh隧道则修改机器ip port为暴露的ip port
func IfUseSshTunnelChangeIpPort(me *entity.Machine, getMachine func(uint64) *entity.Machine) error {
if me.EnableSshTunnel != 1 {
return nil
}
sshTunnelMachine, err := GetSshTunnelMachine(me.SshTunnelMachineId, func(u uint64) *entity.Machine {
return getMachine(u)
})
if err != nil {
return err
}
exposeIp, exposePort, err := sshTunnelMachine.OpenSshTunnel(me.Id, me.Ip, me.Port)
if err != nil {
return err
}
// 修改机器ip地址
me.Ip = exposeIp
me.Port = exposePort
return nil
}
func GetSshClient(m *entity.Machine) (*ssh.Client, error) {
config := ssh.ClientConfig{
User: m.Username,

View File

@@ -0,0 +1,240 @@
package machine
import (
"fmt"
"io"
"mayfly-go/internal/devops/domain/entity"
"mayfly-go/pkg/global"
"mayfly-go/pkg/utils"
"net"
"os"
"sync"
"time"
"golang.org/x/crypto/ssh"
)
var (
sshTunnelMachines map[uint64]*SshTunnelMachine = make(map[uint64]*SshTunnelMachine)
mutex sync.Mutex
// 所有检测ssh隧道机器是否被使用的函数
checkSshTunnelMachineHasUseFuncs []CheckSshTunnelMachineHasUseFunc
// 是否开启检查ssh隧道机器是否被使用只有使用到了隧道机器才启用
startCheckSshTunnelHasUse bool = false
)
// 检查ssh隧道机器是否有被使用
type CheckSshTunnelMachineHasUseFunc func(uint64) bool
func startCheckUse() {
global.Log.Info("开启定时检测ssh隧道机器是否还有被使用")
heartbeat := time.Duration(10) * time.Minute
tick := time.NewTicker(heartbeat)
go func() {
for range tick.C {
func() {
if !mutex.TryLock() {
return
}
defer mutex.Unlock()
// 遍历隧道机器,都未被使用将会被关闭
for mid, sshTunnelMachine := range sshTunnelMachines {
global.Log.Debugf("开始定时检查ssh隧道机器[%d]是否还有被使用...", mid)
for _, checkUseFunc := range checkSshTunnelMachineHasUseFuncs {
// 如果一个在使用则返回不关闭,不继续后续检查
if checkUseFunc(mid) {
return
}
}
// 都未被使用,则关闭
sshTunnelMachine.Close()
}
}()
}
}()
}
// 添加ssh隧道机器检测是否使用函数
func AddCheckSshTunnelMachineUseFunc(checkFunc CheckSshTunnelMachineHasUseFunc) {
if checkSshTunnelMachineHasUseFuncs == nil {
checkSshTunnelMachineHasUseFuncs = make([]CheckSshTunnelMachineHasUseFunc, 0)
}
checkSshTunnelMachineHasUseFuncs = append(checkSshTunnelMachineHasUseFuncs, checkFunc)
}
// ssh隧道机器
type SshTunnelMachine struct {
machineId uint64 // 隧道机器id
SshClient *ssh.Client
mutex sync.Mutex
tunnels map[uint64]*Tunnel // 机器id -> 隧道
}
func (stm *SshTunnelMachine) OpenSshTunnel(id uint64, ip string, port int) (exposedIp string, exposedPort int, err error) {
stm.mutex.Lock()
defer stm.mutex.Unlock()
localPort, err := utils.GetAvailablePort()
if err != nil {
return "", 0, err
}
hostname, err := os.Hostname()
if err != nil {
return "", 0, err
}
// debug
//hostname = "0.0.0.0"
localAddr := fmt.Sprintf("%s:%d", hostname, localPort)
listener, err := net.Listen("tcp", localAddr)
if err != nil {
return "", 0, err
}
tunnel := &Tunnel{
id: id,
machineId: stm.machineId,
localHost: hostname,
localPort: localPort,
remoteHost: ip,
remotePort: port,
listener: listener,
}
go tunnel.Open(stm.SshClient)
stm.tunnels[tunnel.id] = tunnel
return tunnel.localHost, tunnel.localPort, nil
}
func (st *SshTunnelMachine) GetDialConn(network string, addr string) (net.Conn, error) {
st.mutex.Lock()
defer st.mutex.Unlock()
return st.SshClient.Dial(network, addr)
}
func (stm *SshTunnelMachine) Close() {
stm.mutex.Lock()
defer stm.mutex.Unlock()
for id, tunnel := range stm.tunnels {
if tunnel != nil {
tunnel.Close()
delete(stm.tunnels, id)
}
}
if stm.SshClient != nil {
global.Log.Infof("ssh隧道机器[%d]未被使用, 关闭隧道...", stm.machineId)
stm.SshClient.Close()
}
delete(sshTunnelMachines, stm.machineId)
}
// 获取ssh隧道机器方便统一管理充当ssh隧道的机器避免创建多个ssh client
func GetSshTunnelMachine(machineId uint64, getMachine func(uint64) *entity.Machine) (*SshTunnelMachine, error) {
sshTunnelMachine := sshTunnelMachines[machineId]
if sshTunnelMachine != nil {
return sshTunnelMachine, nil
}
mutex.Lock()
defer mutex.Unlock()
me := getMachine(machineId)
sshClient, err := GetSshClient(me)
if err != nil {
return nil, err
}
sshTunnelMachine = &SshTunnelMachine{SshClient: sshClient, machineId: machineId, tunnels: map[uint64]*Tunnel{}}
global.Log.Infof("初次连接ssh隧道机器[%d][%s:%d]", machineId, me.Ip, me.Port)
sshTunnelMachines[machineId] = sshTunnelMachine
// 如果实用了隧道机器且还没开始定时检查是否还被实用,则执行定时任务检测隧道是否还被使用
if !startCheckSshTunnelHasUse {
startCheckUse()
startCheckSshTunnelHasUse = true
}
return sshTunnelMachine, nil
}
// 关闭ssh隧道机器的指定隧道
func CloseSshTunnelMachine(machineId uint64, tunnelId uint64) {
sshTunnelMachine := sshTunnelMachines[machineId]
if sshTunnelMachine == nil {
return
}
sshTunnelMachine.mutex.Lock()
defer sshTunnelMachine.mutex.Unlock()
t := sshTunnelMachine.tunnels[tunnelId]
if t != nil {
t.Close()
delete(sshTunnelMachine.tunnels, tunnelId)
}
}
type Tunnel struct {
id uint64 // 唯一标识
machineId uint64 // 隧道机器id
localHost string // 本地监听地址
localPort int // 本地端口
remoteHost string // 远程连接地址
remotePort int // 远程端口
listener net.Listener
localConnections []net.Conn
remoteConnections []net.Conn
}
func (r *Tunnel) Open(sshClient *ssh.Client) {
localAddr := fmt.Sprintf("%s:%d", r.localHost, r.localPort)
for {
global.Log.Debugf("隧道 %v 等待客户端访问 %v", r.id, localAddr)
localConn, err := r.listener.Accept()
if err != nil {
global.Log.Debugf("隧道 %v 接受连接失败 %v, 退出循环", r.id, err.Error())
global.Log.Debug("-------------------------------------------------")
return
}
r.localConnections = append(r.localConnections, localConn)
global.Log.Debugf("隧道 %v 新增本地连接 %v", r.id, localConn.RemoteAddr().String())
remoteAddr := fmt.Sprintf("%s:%d", r.remoteHost, r.remotePort)
global.Log.Debugf("隧道 %v 连接远程地址 %v ...", r.id, remoteAddr)
remoteConn, err := sshClient.Dial("tcp", remoteAddr)
if err != nil {
global.Log.Debugf("隧道 %v 连接远程地址 %v, 退出循环", r.id, err.Error())
global.Log.Debug("-------------------------------------------------")
return
}
r.remoteConnections = append(r.remoteConnections, remoteConn)
global.Log.Debugf("隧道 %v 连接远程主机成功", r.id)
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
global.Log.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
global.Log.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
}
}
func (r *Tunnel) Close() {
for i := range r.localConnections {
_ = r.localConnections[i].Close()
}
r.localConnections = nil
for i := range r.remoteConnections {
_ = r.remoteConnections[i].Close()
}
r.remoteConnections = nil
_ = r.listener.Close()
global.Log.Debugf("隧道 %d 监听器关闭", r.id)
}
func copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader)
}