mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	refactor: 达梦ssh连接调整
This commit is contained in:
		@@ -4,93 +4,20 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"io"
 | 
			
		||||
	machineapp "mayfly-go/internal/machine/application"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
	"mayfly-go/pkg/utils/anyx"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	_ "gitee.com/chunanyong/dm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ConnectionInfo struct {
 | 
			
		||||
	Port       int
 | 
			
		||||
	Listener   net.Listener
 | 
			
		||||
	remoteConn net.Conn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var connectionMap = make(map[string]ConnectionInfo)
 | 
			
		||||
 | 
			
		||||
func getLocalListener() (net.Listener, int, error) {
 | 
			
		||||
	// Setup localListener (type net.Listener)
 | 
			
		||||
	localListener, err := net.Listen("tcp", "localhost:0")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取本地端口
 | 
			
		||||
	localPort := localListener.Addr().(*net.TCPAddr).Port
 | 
			
		||||
	return localListener, localPort, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func acceptConn(listener net.Listener, sshConn net.Conn) {
 | 
			
		||||
	for {
 | 
			
		||||
		localConn, err := listener.Accept()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logx.Warn("端口转发出错", err)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		go forward(localConn, sshConn)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func forward(localConn net.Conn, remoteConn net.Conn) {
 | 
			
		||||
	copyConn := func(writer, reader net.Conn) {
 | 
			
		||||
		_, err := io.Copy(writer, reader)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logx.Warnf("io.Copy error: %s", err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	go copyConn(localConn, remoteConn)
 | 
			
		||||
	go copyConn(remoteConn, localConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func openSsh(d *DbInfo) error {
 | 
			
		||||
 | 
			
		||||
	sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	remoteConn, err := sshTunnelMachine.GetDialConn("tcp", fmt.Sprintf("%s:%d", d.Host, d.Port))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// 获取sshConn的本地端口
 | 
			
		||||
	localLister, localPort, err := getLocalListener()
 | 
			
		||||
	// defer localLister.Close()
 | 
			
		||||
	go acceptConn(localLister, remoteConn)
 | 
			
		||||
	connectionMap[d.Network] = ConnectionInfo{
 | 
			
		||||
		Port:       localPort,
 | 
			
		||||
		Listener:   localLister,
 | 
			
		||||
		remoteConn: remoteConn,
 | 
			
		||||
	}
 | 
			
		||||
	d.Host = "127.0.0.1"
 | 
			
		||||
	d.Port = localPort
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 创建一个成员变量存放ssh隧道转发对应的本地连接
 | 
			
		||||
 | 
			
		||||
func getDmDB(d *DbInfo) (*sql.DB, error) {
 | 
			
		||||
	driverName := "dm"
 | 
			
		||||
	// SSH Conect 暂时不支持隧道连接
 | 
			
		||||
	db := d.Database
 | 
			
		||||
	var dbParam string
 | 
			
		||||
	if db != "" {
 | 
			
		||||
		// postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
 | 
			
		||||
		// dm database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
 | 
			
		||||
		ss := strings.Split(db, "/")
 | 
			
		||||
		if len(ss) > 1 {
 | 
			
		||||
			dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
 | 
			
		||||
@@ -101,10 +28,16 @@ func getDmDB(d *DbInfo) (*sql.DB, error) {
 | 
			
		||||
 | 
			
		||||
	// 开启ssh隧道
 | 
			
		||||
	if d.SshTunnelMachineId > 0 {
 | 
			
		||||
		err := openSsh(d)
 | 
			
		||||
		sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		exposedIp, exposedPort, err := sshTunnelMachine.OpenSshTunnel(fmt.Sprintf("db:%d", d.Id), d.Host, d.Port)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		d.Host = exposedIp
 | 
			
		||||
		d.Port = exposedPort
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user