mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 15:30:25 +08:00
fix: pgsql隧道连接问题修复
This commit is contained in:
@@ -1,28 +1,22 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/constant"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/internal/machine/infrastructure/machine"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/global"
|
||||
"mayfly-go/pkg/model"
|
||||
"mayfly-go/pkg/utils"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type Db interface {
|
||||
@@ -190,10 +184,10 @@ func (da *dbAppImpl) GetDbInstance(id uint64, db string) *DbInstance {
|
||||
defer mutex.Unlock()
|
||||
|
||||
d := da.GetById(id)
|
||||
// 密码解密
|
||||
d.PwdDecrypt()
|
||||
biz.NotNil(d, "数据库信息不存在")
|
||||
biz.IsTrue(strings.Contains(d.Database, db), "未配置该库的操作权限")
|
||||
// 密码解密
|
||||
d.PwdDecrypt()
|
||||
|
||||
dbInfo := new(DbInfo)
|
||||
utils.Copy(dbInfo, d)
|
||||
@@ -347,22 +341,14 @@ func TestConnection(d *entity.Db) {
|
||||
|
||||
// 获取数据库连接
|
||||
func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
|
||||
// SSH Conect
|
||||
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
|
||||
sshTunnelMachine := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
|
||||
if d.Type == entity.DbTypeMysql {
|
||||
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return sshTunnelMachine.GetDialConn("tcp", addr)
|
||||
})
|
||||
} else if d.Type == entity.DbTypePostgres {
|
||||
_, err := pq.DialOpen(&PqSqlDialer{sshTunnelMachine: sshTunnelMachine}, getDsn(d, db))
|
||||
if err != nil {
|
||||
panic(biz.NewBizErr(fmt.Sprintf("postgres隧道连接失败: %s", err.Error())))
|
||||
}
|
||||
}
|
||||
var DB *sql.DB
|
||||
var err error
|
||||
if d.Type == entity.DbTypeMysql {
|
||||
DB, err = getMysqlDB(d, db)
|
||||
} else if d.Type == entity.DbTypePostgres {
|
||||
DB, err = getPgsqlDB(d, db)
|
||||
}
|
||||
|
||||
DB, err := sql.Open(d.Type, getDsn(d, db))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -375,28 +361,6 @@ func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
|
||||
return DB, nil
|
||||
}
|
||||
|
||||
// 获取dataSourceName
|
||||
func getDsn(d *entity.Db, db string) string {
|
||||
var dsn string
|
||||
if d.Type == entity.DbTypeMysql {
|
||||
// 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
return dsn
|
||||
}
|
||||
|
||||
if d.Type == entity.DbTypePostgres {
|
||||
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
|
||||
}
|
||||
return dsn
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func SelectDataByDb(db *sql.DB, selectSql string, isInner bool) ([]string, []map[string]interface{}, error) {
|
||||
rows, err := db.Query(selectSql)
|
||||
if err != nil {
|
||||
@@ -525,20 +489,3 @@ func Select2StructByDb(db *sql.DB, selectSql string, dest interface{}) error {
|
||||
func CloseDb(dbId uint64, db string) {
|
||||
dbCache.Delete(GetDbCacheKey(dbId, db))
|
||||
}
|
||||
|
||||
type PqSqlDialer struct {
|
||||
sshTunnelMachine *machine.SshTunnelMachine
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
|
||||
if sshConn, err := pd.sshTunnelMachine.GetDialConn("tcp", address); err == nil {
|
||||
// 将ssh conn包装,否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
|
||||
return &utils.WrapSshConn{Conn: sshConn}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
return pd.Dial(network, address)
|
||||
}
|
||||
|
||||
@@ -1,10 +1,33 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/pkg/biz"
|
||||
"net"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func getMysqlDB(d *entity.Db, db string) (*sql.DB, error) {
|
||||
// SSH Conect
|
||||
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
|
||||
sshTunnelMachine := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
|
||||
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return sshTunnelMachine.GetDialConn("tcp", addr)
|
||||
})
|
||||
}
|
||||
// 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
return sql.Open(d.Type, dsn)
|
||||
}
|
||||
|
||||
// ---------------------------------- mysql元数据 -----------------------------------
|
||||
const (
|
||||
// mysql 表信息元数据
|
||||
|
||||
@@ -1,10 +1,60 @@
|
||||
package application
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/utils"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func getPgsqlDB(d *entity.Db, db string) (*sql.DB, error) {
|
||||
driverName := d.Type
|
||||
// SSH Conect
|
||||
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
|
||||
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
|
||||
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
|
||||
if !utils.ArrContains(sql.Drivers(), driverName) {
|
||||
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
|
||||
}
|
||||
sql.Drivers()
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
|
||||
}
|
||||
return sql.Open(driverName, dsn)
|
||||
}
|
||||
|
||||
// pgsql dialer
|
||||
type PqSqlDialer struct {
|
||||
sshTunnelMachineId uint64
|
||||
}
|
||||
|
||||
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
|
||||
return pq.DialOpen(d, name)
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
|
||||
if sshConn, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId).GetDialConn("tcp", address); err == nil {
|
||||
return sshConn, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
return pd.Dial(network, address)
|
||||
}
|
||||
|
||||
// ---------------------------------- pgsql元数据 -----------------------------------
|
||||
const (
|
||||
// postgres 表信息元数据
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package utils
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// 数组比较
|
||||
// 依次返回,新增值,删除值,以及不变值
|
||||
@@ -49,3 +51,13 @@ func NumberArr2StrArr[T NumT](numberArr []T) []string {
|
||||
}
|
||||
return strArr
|
||||
}
|
||||
|
||||
// 判断数组中是否含有指定元素
|
||||
func ArrContains[T comparable](arr []T, el T) bool {
|
||||
for _, v := range arr {
|
||||
if v == el {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user