mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 16:30:25 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			254 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			254 lines
		
	
	
		
			6.1 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package machine
 | 
						||
 | 
						||
import (
 | 
						||
	"errors"
 | 
						||
	"fmt"
 | 
						||
	"mayfly-go/internal/constant"
 | 
						||
	"mayfly-go/internal/devops/domain/entity"
 | 
						||
	"mayfly-go/pkg/biz"
 | 
						||
	"mayfly-go/pkg/cache"
 | 
						||
	"mayfly-go/pkg/global"
 | 
						||
	"net"
 | 
						||
	"time"
 | 
						||
 | 
						||
	"github.com/pkg/sftp"
 | 
						||
 | 
						||
	"golang.org/x/crypto/ssh"
 | 
						||
)
 | 
						||
 | 
						||
// 客户端信息
 | 
						||
type Cli struct {
 | 
						||
	machine *entity.Machine
 | 
						||
 | 
						||
	client     *ssh.Client  // ssh客户端
 | 
						||
	sftpClient *sftp.Client // sftp客户端
 | 
						||
 | 
						||
	enableSshTunnel    int8
 | 
						||
	sshTunnelMachineId uint64
 | 
						||
}
 | 
						||
 | 
						||
//连接
 | 
						||
func (c *Cli) connect() error {
 | 
						||
	// 如果已经有client则直接返回
 | 
						||
	if c.client != nil {
 | 
						||
		return nil
 | 
						||
	}
 | 
						||
	m := c.machine
 | 
						||
	sshClient, err := GetSshClient(m)
 | 
						||
	if err != nil {
 | 
						||
		return err
 | 
						||
	}
 | 
						||
	c.client = sshClient
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
// 关闭client并从缓存中移除,如果使用隧道则也关闭
 | 
						||
func (c *Cli) Close() {
 | 
						||
	m := c.machine
 | 
						||
	global.Log.Info(fmt.Sprintf("关闭机器客户端连接-> id: %d, name: %s, ip: %s", m.Id, m.Name, m.Ip))
 | 
						||
	if c.client != nil {
 | 
						||
		c.client.Close()
 | 
						||
		c.client = nil
 | 
						||
	}
 | 
						||
	if c.sftpClient != nil {
 | 
						||
		c.sftpClient.Close()
 | 
						||
		c.sftpClient = nil
 | 
						||
	}
 | 
						||
	if c.enableSshTunnel == 1 {
 | 
						||
		CloseSshTunnelMachine(c.sshTunnelMachineId, c.machine.Id)
 | 
						||
	}
 | 
						||
}
 | 
						||
 | 
						||
// 获取sftp client
 | 
						||
func (c *Cli) GetSftpCli() *sftp.Client {
 | 
						||
	if c.client == nil {
 | 
						||
		if err := c.connect(); err != nil {
 | 
						||
			panic(biz.NewBizErr("连接ssh失败:" + err.Error()))
 | 
						||
		}
 | 
						||
	}
 | 
						||
	sftpclient := c.sftpClient
 | 
						||
	// 如果sftpClient为nil,则连接
 | 
						||
	if sftpclient == nil {
 | 
						||
		sc, serr := sftp.NewClient(c.client)
 | 
						||
		if serr != nil {
 | 
						||
			panic(biz.NewBizErr("获取sftp client失败:" + serr.Error()))
 | 
						||
		}
 | 
						||
		sftpclient = sc
 | 
						||
		c.sftpClient = sftpclient
 | 
						||
	}
 | 
						||
 | 
						||
	return sftpclient
 | 
						||
}
 | 
						||
 | 
						||
// 获取session
 | 
						||
func (c *Cli) GetSession() (*ssh.Session, error) {
 | 
						||
	if c.client == nil {
 | 
						||
		if err := c.connect(); err != nil {
 | 
						||
			return nil, err
 | 
						||
		}
 | 
						||
	}
 | 
						||
	return c.client.NewSession()
 | 
						||
}
 | 
						||
 | 
						||
//执行shell
 | 
						||
//@param shell shell脚本命令
 | 
						||
func (c *Cli) Run(shell string) (*string, error) {
 | 
						||
	session, err := c.GetSession()
 | 
						||
	if err != nil {
 | 
						||
		c.Close()
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
	defer session.Close()
 | 
						||
	buf, rerr := session.CombinedOutput(shell)
 | 
						||
	if rerr != nil {
 | 
						||
		return nil, rerr
 | 
						||
	}
 | 
						||
	res := string(buf)
 | 
						||
	return &res, nil
 | 
						||
}
 | 
						||
 | 
						||
func (c *Cli) GetMachine() *entity.Machine {
 | 
						||
	return c.machine
 | 
						||
}
 | 
						||
 | 
						||
// 机器客户端连接缓存,指定时间内没有访问则会被关闭
 | 
						||
var cliCache = cache.NewTimedCache(constant.MachineConnExpireTime, 5*time.Second).
 | 
						||
	WithUpdateAccessTime(true).
 | 
						||
	OnEvicted(func(_, value interface{}) {
 | 
						||
		value.(*Cli).Close()
 | 
						||
	})
 | 
						||
 | 
						||
func init() {
 | 
						||
	AddCheckSshTunnelMachineUseFunc(func(machineId uint64) bool {
 | 
						||
		// 遍历所有机器连接实例,若存在机器连接实例使用该ssh隧道机器,则返回true,表示还在使用中...
 | 
						||
		items := cliCache.Items()
 | 
						||
		for _, v := range items {
 | 
						||
			if v.Value.(*Cli).sshTunnelMachineId == machineId {
 | 
						||
				return true
 | 
						||
			}
 | 
						||
		}
 | 
						||
		return false
 | 
						||
	})
 | 
						||
}
 | 
						||
 | 
						||
// 是否存在指定id的客户端连接
 | 
						||
func HasCli(machineId uint64) bool {
 | 
						||
	if _, ok := cliCache.Get(machineId); ok {
 | 
						||
		return true
 | 
						||
	}
 | 
						||
	return false
 | 
						||
}
 | 
						||
 | 
						||
// 删除指定机器客户端,并关闭客户端连接
 | 
						||
func DeleteCli(id uint64) {
 | 
						||
	cliCache.Delete(id)
 | 
						||
}
 | 
						||
 | 
						||
// 从缓存中获取客户端信息,不存在则回调获取机器信息函数,并新建
 | 
						||
func GetCli(machineId uint64, getMachine func(uint64) *entity.Machine) (*Cli, error) {
 | 
						||
	cli, err := cliCache.ComputeIfAbsent(machineId, func(_ interface{}) (interface{}, error) {
 | 
						||
		me := getMachine(machineId)
 | 
						||
		err := IfUseSshTunnelChangeIpPort(me, getMachine)
 | 
						||
		if err != nil {
 | 
						||
			return nil, fmt.Errorf("ssh隧道连接失败: %s", err.Error())
 | 
						||
		}
 | 
						||
		c, err := newClient(me)
 | 
						||
		if err != nil {
 | 
						||
			CloseSshTunnelMachine(me.SshTunnelMachineId, me.Id)
 | 
						||
			return nil, err
 | 
						||
		}
 | 
						||
		c.enableSshTunnel = me.EnableSshTunnel
 | 
						||
		c.sshTunnelMachineId = me.SshTunnelMachineId
 | 
						||
		return c, nil
 | 
						||
	})
 | 
						||
 | 
						||
	if cli != nil {
 | 
						||
		return cli.(*Cli), err
 | 
						||
	}
 | 
						||
	return nil, err
 | 
						||
}
 | 
						||
 | 
						||
// 测试连接,使用传值的方式,而非引用。因为如果使用了ssh隧道,则ip和端口会变为本地映射地址与端口
 | 
						||
func TestConn(me entity.Machine, getSshTunnelMachine func(uint64) *entity.Machine) error {
 | 
						||
	originId := me.Id
 | 
						||
	if originId == 0 {
 | 
						||
		// 随机设置一个ip,如果使用了隧道则用于临时保存隧道
 | 
						||
		me.Id = uint64(time.Now().Nanosecond())
 | 
						||
	}
 | 
						||
 | 
						||
	err := IfUseSshTunnelChangeIpPort(&me, getSshTunnelMachine)
 | 
						||
	biz.ErrIsNilAppendErr(err, "ssh隧道连接失败: %s")
 | 
						||
	if me.EnableSshTunnel == 1 {
 | 
						||
		defer CloseSshTunnelMachine(me.SshTunnelMachineId, me.Id)
 | 
						||
	}
 | 
						||
	sshClient, err := GetSshClient(&me)
 | 
						||
	if err != nil {
 | 
						||
		return err
 | 
						||
	}
 | 
						||
	defer sshClient.Close()
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
// 如果使用了ssh隧道,则修改机器ip port为暴露的ip port
 | 
						||
func IfUseSshTunnelChangeIpPort(me *entity.Machine, getMachine func(uint64) *entity.Machine) error {
 | 
						||
	if me.EnableSshTunnel != 1 {
 | 
						||
		return nil
 | 
						||
	}
 | 
						||
	sshTunnelMachine, err := GetSshTunnelMachine(me.SshTunnelMachineId, func(u uint64) *entity.Machine {
 | 
						||
		return getMachine(u)
 | 
						||
	})
 | 
						||
	if err != nil {
 | 
						||
		return err
 | 
						||
	}
 | 
						||
	exposeIp, exposePort, err := sshTunnelMachine.OpenSshTunnel(me.Id, me.Ip, me.Port)
 | 
						||
	if err != nil {
 | 
						||
		return err
 | 
						||
	}
 | 
						||
	// 修改机器ip地址
 | 
						||
	me.Ip = exposeIp
 | 
						||
	me.Port = exposePort
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
func GetSshClient(m *entity.Machine) (*ssh.Client, error) {
 | 
						||
	config := ssh.ClientConfig{
 | 
						||
		User: m.Username,
 | 
						||
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
 | 
						||
			return nil
 | 
						||
		},
 | 
						||
		Timeout: 5 * time.Second,
 | 
						||
	}
 | 
						||
	if m.AuthMethod == entity.MachineAuthMethodPassword {
 | 
						||
		config.Auth = []ssh.AuthMethod{ssh.Password(m.Password)}
 | 
						||
	} else if m.AuthMethod == entity.MachineAuthMethodPublicKey {
 | 
						||
		if signer, err := ssh.ParsePrivateKey([]byte(m.Password)); err != nil {
 | 
						||
			return nil, err
 | 
						||
		} else {
 | 
						||
			config.Auth = []ssh.AuthMethod{ssh.PublicKeys(signer)}
 | 
						||
		}
 | 
						||
	}
 | 
						||
 | 
						||
	addr := fmt.Sprintf("%s:%d", m.Ip, m.Port)
 | 
						||
	sshClient, err := ssh.Dial("tcp", addr, &config)
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
	return sshClient, nil
 | 
						||
}
 | 
						||
 | 
						||
//根据机器信息创建客户端对象
 | 
						||
func newClient(machine *entity.Machine) (*Cli, error) {
 | 
						||
	if machine == nil {
 | 
						||
		return nil, errors.New("机器不存在")
 | 
						||
	}
 | 
						||
 | 
						||
	global.Log.Infof("[%s]机器连接:%s:%d", machine.Name, machine.Ip, machine.Port)
 | 
						||
	cli := new(Cli)
 | 
						||
	cli.machine = machine
 | 
						||
	err := cli.connect()
 | 
						||
	if err != nil {
 | 
						||
		return nil, err
 | 
						||
	}
 | 
						||
	return cli, nil
 | 
						||
}
 |