fix: pgsql隧道连接问题修复

This commit is contained in:
meilin.huang
2022-12-22 18:41:34 +08:00
parent 85349df8a1
commit 4fec38724d
4 changed files with 94 additions and 62 deletions

View File

@@ -1,28 +1,22 @@
package application
import (
"context"
"database/sql"
"fmt"
"mayfly-go/internal/constant"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/internal/machine/infrastructure/machine"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/global"
"mayfly-go/pkg/model"
"mayfly-go/pkg/utils"
"net"
"reflect"
"strconv"
"strings"
"sync"
"time"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
)
type Db interface {
@@ -190,10 +184,10 @@ func (da *dbAppImpl) GetDbInstance(id uint64, db string) *DbInstance {
defer mutex.Unlock()
d := da.GetById(id)
// 密码解密
d.PwdDecrypt()
biz.NotNil(d, "数据库信息不存在")
biz.IsTrue(strings.Contains(d.Database, db), "未配置该库的操作权限")
// 密码解密
d.PwdDecrypt()
dbInfo := new(DbInfo)
utils.Copy(dbInfo, d)
@@ -347,22 +341,14 @@ func TestConnection(d *entity.Db) {
// 获取数据库连接
func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
// SSH Conect
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
sshTunnelMachine := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
if d.Type == entity.DbTypeMysql {
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
return sshTunnelMachine.GetDialConn("tcp", addr)
})
} else if d.Type == entity.DbTypePostgres {
_, err := pq.DialOpen(&PqSqlDialer{sshTunnelMachine: sshTunnelMachine}, getDsn(d, db))
if err != nil {
panic(biz.NewBizErr(fmt.Sprintf("postgres隧道连接失败: %s", err.Error())))
}
}
var DB *sql.DB
var err error
if d.Type == entity.DbTypeMysql {
DB, err = getMysqlDB(d, db)
} else if d.Type == entity.DbTypePostgres {
DB, err = getPgsqlDB(d, db)
}
DB, err := sql.Open(d.Type, getDsn(d, db))
if err != nil {
return nil, err
}
@@ -375,28 +361,6 @@ func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
return DB, nil
}
// 获取dataSourceName
func getDsn(d *entity.Db, db string) string {
var dsn string
if d.Type == entity.DbTypeMysql {
// 更多参数参考https://github.com/go-sql-driver/mysql#dsn-data-source-name
dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
return dsn
}
if d.Type == entity.DbTypePostgres {
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
if d.Params != "" {
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
}
return dsn
}
return ""
}
func SelectDataByDb(db *sql.DB, selectSql string, isInner bool) ([]string, []map[string]interface{}, error) {
rows, err := db.Query(selectSql)
if err != nil {
@@ -525,20 +489,3 @@ func Select2StructByDb(db *sql.DB, selectSql string, dest interface{}) error {
func CloseDb(dbId uint64, db string) {
dbCache.Delete(GetDbCacheKey(dbId, db))
}
type PqSqlDialer struct {
sshTunnelMachine *machine.SshTunnelMachine
}
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
if sshConn, err := pd.sshTunnelMachine.GetDialConn("tcp", address); err == nil {
// 将ssh conn包装否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
return &utils.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return pd.Dial(network, address)
}

View File

@@ -1,10 +1,33 @@
package application
import (
"context"
"database/sql"
"fmt"
"mayfly-go/internal/db/domain/entity"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/biz"
"net"
"github.com/go-sql-driver/mysql"
)
func getMysqlDB(d *entity.Db, db string) (*sql.DB, error) {
// SSH Conect
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
sshTunnelMachine := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
return sshTunnelMachine.GetDialConn("tcp", addr)
})
}
// 设置dataSourceName -> 更多参数参考https://github.com/go-sql-driver/mysql#dsn-data-source-name
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
return sql.Open(d.Type, dsn)
}
// ---------------------------------- mysql元数据 -----------------------------------
const (
// mysql 表信息元数据

View File

@@ -1,10 +1,60 @@
package application
import (
"database/sql"
"database/sql/driver"
"fmt"
"mayfly-go/internal/db/domain/entity"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/utils"
"net"
"strings"
"time"
"github.com/lib/pq"
)
func getPgsqlDB(d *entity.Db, db string) (*sql.DB, error) {
driverName := d.Type
// SSH Conect
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
if !utils.ArrContains(sql.Drivers(), driverName) {
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
}
sql.Drivers()
}
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
if d.Params != "" {
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
}
return sql.Open(driverName, dsn)
}
// pgsql dialer
type PqSqlDialer struct {
sshTunnelMachineId uint64
}
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
return pq.DialOpen(d, name)
}
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
if sshConn, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId).GetDialConn("tcp", address); err == nil {
return sshConn, nil
} else {
return nil, err
}
}
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return pd.Dial(network, address)
}
// ---------------------------------- pgsql元数据 -----------------------------------
const (
// postgres 表信息元数据

View File

@@ -1,6 +1,8 @@
package utils
import "fmt"
import (
"fmt"
)
// 数组比较
// 依次返回,新增值,删除值,以及不变值
@@ -49,3 +51,13 @@ func NumberArr2StrArr[T NumT](numberArr []T) []string {
}
return strArr
}
// 判断数组中是否含有指定元素
func ArrContains[T comparable](arr []T, el T) bool {
for _, v := range arr {
if v == el {
return true
}
}
return false
}