mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-03 16:00:25 +08:00
fix: pgsql隧道连接问题修复
This commit is contained in:
@@ -1,28 +1,22 @@
|
|||||||
package application
|
package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"mayfly-go/internal/constant"
|
"mayfly-go/internal/constant"
|
||||||
"mayfly-go/internal/db/domain/entity"
|
"mayfly-go/internal/db/domain/entity"
|
||||||
"mayfly-go/internal/db/domain/repository"
|
"mayfly-go/internal/db/domain/repository"
|
||||||
machineapp "mayfly-go/internal/machine/application"
|
|
||||||
"mayfly-go/internal/machine/infrastructure/machine"
|
"mayfly-go/internal/machine/infrastructure/machine"
|
||||||
"mayfly-go/pkg/biz"
|
"mayfly-go/pkg/biz"
|
||||||
"mayfly-go/pkg/cache"
|
"mayfly-go/pkg/cache"
|
||||||
"mayfly-go/pkg/global"
|
"mayfly-go/pkg/global"
|
||||||
"mayfly-go/pkg/model"
|
"mayfly-go/pkg/model"
|
||||||
"mayfly-go/pkg/utils"
|
"mayfly-go/pkg/utils"
|
||||||
"net"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-sql-driver/mysql"
|
|
||||||
"github.com/lib/pq"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Db interface {
|
type Db interface {
|
||||||
@@ -190,10 +184,10 @@ func (da *dbAppImpl) GetDbInstance(id uint64, db string) *DbInstance {
|
|||||||
defer mutex.Unlock()
|
defer mutex.Unlock()
|
||||||
|
|
||||||
d := da.GetById(id)
|
d := da.GetById(id)
|
||||||
// 密码解密
|
|
||||||
d.PwdDecrypt()
|
|
||||||
biz.NotNil(d, "数据库信息不存在")
|
biz.NotNil(d, "数据库信息不存在")
|
||||||
biz.IsTrue(strings.Contains(d.Database, db), "未配置该库的操作权限")
|
biz.IsTrue(strings.Contains(d.Database, db), "未配置该库的操作权限")
|
||||||
|
// 密码解密
|
||||||
|
d.PwdDecrypt()
|
||||||
|
|
||||||
dbInfo := new(DbInfo)
|
dbInfo := new(DbInfo)
|
||||||
utils.Copy(dbInfo, d)
|
utils.Copy(dbInfo, d)
|
||||||
@@ -347,22 +341,14 @@ func TestConnection(d *entity.Db) {
|
|||||||
|
|
||||||
// 获取数据库连接
|
// 获取数据库连接
|
||||||
func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
|
func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
|
||||||
// SSH Conect
|
var DB *sql.DB
|
||||||
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
|
var err error
|
||||||
sshTunnelMachine := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
|
if d.Type == entity.DbTypeMysql {
|
||||||
if d.Type == entity.DbTypeMysql {
|
DB, err = getMysqlDB(d, db)
|
||||||
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
|
} else if d.Type == entity.DbTypePostgres {
|
||||||
return sshTunnelMachine.GetDialConn("tcp", addr)
|
DB, err = getPgsqlDB(d, db)
|
||||||
})
|
|
||||||
} else if d.Type == entity.DbTypePostgres {
|
|
||||||
_, err := pq.DialOpen(&PqSqlDialer{sshTunnelMachine: sshTunnelMachine}, getDsn(d, db))
|
|
||||||
if err != nil {
|
|
||||||
panic(biz.NewBizErr(fmt.Sprintf("postgres隧道连接失败: %s", err.Error())))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DB, err := sql.Open(d.Type, getDsn(d, db))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -375,28 +361,6 @@ func GetDbConn(d *entity.Db, db string) (*sql.DB, error) {
|
|||||||
return DB, nil
|
return DB, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取dataSourceName
|
|
||||||
func getDsn(d *entity.Db, db string) string {
|
|
||||||
var dsn string
|
|
||||||
if d.Type == entity.DbTypeMysql {
|
|
||||||
// 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
|
||||||
dsn = fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
|
|
||||||
if d.Params != "" {
|
|
||||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
|
||||||
}
|
|
||||||
return dsn
|
|
||||||
}
|
|
||||||
|
|
||||||
if d.Type == entity.DbTypePostgres {
|
|
||||||
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
|
|
||||||
if d.Params != "" {
|
|
||||||
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
|
|
||||||
}
|
|
||||||
return dsn
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func SelectDataByDb(db *sql.DB, selectSql string, isInner bool) ([]string, []map[string]interface{}, error) {
|
func SelectDataByDb(db *sql.DB, selectSql string, isInner bool) ([]string, []map[string]interface{}, error) {
|
||||||
rows, err := db.Query(selectSql)
|
rows, err := db.Query(selectSql)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -525,20 +489,3 @@ func Select2StructByDb(db *sql.DB, selectSql string, dest interface{}) error {
|
|||||||
func CloseDb(dbId uint64, db string) {
|
func CloseDb(dbId uint64, db string) {
|
||||||
dbCache.Delete(GetDbCacheKey(dbId, db))
|
dbCache.Delete(GetDbCacheKey(dbId, db))
|
||||||
}
|
}
|
||||||
|
|
||||||
type PqSqlDialer struct {
|
|
||||||
sshTunnelMachine *machine.SshTunnelMachine
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
|
|
||||||
if sshConn, err := pd.sshTunnelMachine.GetDialConn("tcp", address); err == nil {
|
|
||||||
// 将ssh conn包装,否则redis内部设置超时会报错,ssh conn不支持设置超时会返回错误: ssh: tcpChan: deadline not supported
|
|
||||||
return &utils.WrapSshConn{Conn: sshConn}, nil
|
|
||||||
} else {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
|
||||||
return pd.Dial(network, address)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,10 +1,33 @@
|
|||||||
package application
|
package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"mayfly-go/internal/db/domain/entity"
|
||||||
|
machineapp "mayfly-go/internal/machine/application"
|
||||||
"mayfly-go/pkg/biz"
|
"mayfly-go/pkg/biz"
|
||||||
|
"net"
|
||||||
|
|
||||||
|
"github.com/go-sql-driver/mysql"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getMysqlDB(d *entity.Db, db string) (*sql.DB, error) {
|
||||||
|
// SSH Conect
|
||||||
|
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
|
||||||
|
sshTunnelMachine := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
|
||||||
|
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
|
||||||
|
return sshTunnelMachine.GetDialConn("tcp", addr)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
// 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||||
|
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, db)
|
||||||
|
if d.Params != "" {
|
||||||
|
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||||
|
}
|
||||||
|
return sql.Open(d.Type, dsn)
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------- mysql元数据 -----------------------------------
|
// ---------------------------------- mysql元数据 -----------------------------------
|
||||||
const (
|
const (
|
||||||
// mysql 表信息元数据
|
// mysql 表信息元数据
|
||||||
|
|||||||
@@ -1,10 +1,60 @@
|
|||||||
package application
|
package application
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"mayfly-go/internal/db/domain/entity"
|
||||||
|
machineapp "mayfly-go/internal/machine/application"
|
||||||
"mayfly-go/pkg/biz"
|
"mayfly-go/pkg/biz"
|
||||||
|
"mayfly-go/pkg/utils"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getPgsqlDB(d *entity.Db, db string) (*sql.DB, error) {
|
||||||
|
driverName := d.Type
|
||||||
|
// SSH Conect
|
||||||
|
if d.EnableSshTunnel == 1 && d.SshTunnelMachineId != 0 {
|
||||||
|
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
|
||||||
|
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
|
||||||
|
if !utils.ArrContains(sql.Drivers(), driverName) {
|
||||||
|
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
|
||||||
|
}
|
||||||
|
sql.Drivers()
|
||||||
|
}
|
||||||
|
|
||||||
|
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", d.Host, d.Port, d.Username, d.Password, db)
|
||||||
|
if d.Params != "" {
|
||||||
|
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
|
||||||
|
}
|
||||||
|
return sql.Open(driverName, dsn)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pgsql dialer
|
||||||
|
type PqSqlDialer struct {
|
||||||
|
sshTunnelMachineId uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
|
||||||
|
return pq.DialOpen(d, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
|
||||||
|
if sshConn, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId).GetDialConn("tcp", address); err == nil {
|
||||||
|
return sshConn, nil
|
||||||
|
} else {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
return pd.Dial(network, address)
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------- pgsql元数据 -----------------------------------
|
// ---------------------------------- pgsql元数据 -----------------------------------
|
||||||
const (
|
const (
|
||||||
// postgres 表信息元数据
|
// postgres 表信息元数据
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
// 数组比较
|
// 数组比较
|
||||||
// 依次返回,新增值,删除值,以及不变值
|
// 依次返回,新增值,删除值,以及不变值
|
||||||
@@ -49,3 +51,13 @@ func NumberArr2StrArr[T NumT](numberArr []T) []string {
|
|||||||
}
|
}
|
||||||
return strArr
|
return strArr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 判断数组中是否含有指定元素
|
||||||
|
func ArrContains[T comparable](arr []T, el T) bool {
|
||||||
|
for _, v := range arr {
|
||||||
|
if v == el {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user