refactor: 达梦ssh连接调整

This commit is contained in:
meilin.huang
2023-12-20 23:01:51 +08:00
parent f29a1560aa
commit 550631c03b
10 changed files with 53 additions and 89 deletions

View File

@@ -4,93 +4,20 @@ import (
"context"
"database/sql"
"fmt"
"io"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/anyx"
"net"
"strings"
_ "gitee.com/chunanyong/dm"
)
type ConnectionInfo struct {
Port int
Listener net.Listener
remoteConn net.Conn
}
var connectionMap = make(map[string]ConnectionInfo)
func getLocalListener() (net.Listener, int, error) {
// Setup localListener (type net.Listener)
localListener, err := net.Listen("tcp", "localhost:0")
if err != nil {
return nil, 0, err
}
// 获取本地端口
localPort := localListener.Addr().(*net.TCPAddr).Port
return localListener, localPort, nil
}
func acceptConn(listener net.Listener, sshConn net.Conn) {
for {
localConn, err := listener.Accept()
if err != nil {
logx.Warn("端口转发出错", err)
return
}
go forward(localConn, sshConn)
}
}
func forward(localConn net.Conn, remoteConn net.Conn) {
copyConn := func(writer, reader net.Conn) {
_, err := io.Copy(writer, reader)
if err != nil {
logx.Warnf("io.Copy error: %s", err)
}
}
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
}
func openSsh(d *DbInfo) error {
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
if err != nil {
return err
}
remoteConn, err := sshTunnelMachine.GetDialConn("tcp", fmt.Sprintf("%s:%d", d.Host, d.Port))
if err != nil {
return err
}
// 获取sshConn的本地端口
localLister, localPort, err := getLocalListener()
// defer localLister.Close()
go acceptConn(localLister, remoteConn)
connectionMap[d.Network] = ConnectionInfo{
Port: localPort,
Listener: localLister,
remoteConn: remoteConn,
}
d.Host = "127.0.0.1"
d.Port = localPort
return nil
}
// 创建一个成员变量存放ssh隧道转发对应的本地连接
func getDmDB(d *DbInfo) (*sql.DB, error) {
driverName := "dm"
// SSH Conect 暂时不支持隧道连接
db := d.Database
var dbParam string
if db != "" {
// postgres database可以使用db/schema表示方便连接指定schema, 若不存在schema则使用默认schema
// dm database可以使用db/schema表示方便连接指定schema, 若不存在schema则使用默认schema
ss := strings.Split(db, "/")
if len(ss) > 1 {
dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
@@ -101,10 +28,16 @@ func getDmDB(d *DbInfo) (*sql.DB, error) {
// 开启ssh隧道
if d.SshTunnelMachineId > 0 {
err := openSsh(d)
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
if err != nil {
return nil, err
}
exposedIp, exposedPort, err := sshTunnelMachine.OpenSshTunnel(fmt.Sprintf("db:%d", d.Id), d.Host, d.Port)
if err != nil {
return nil, err
}
d.Host = exposedIp
d.Port = exposedPort
}
dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)