Files
mayfly-go/server/internal/db/dbm/postgres/meta.go

112 lines
2.9 KiB
Go
Raw Normal View History

package postgres
import (
"database/sql"
"database/sql/driver"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/netx"
"net"
"strings"
"sync"
"time"
pq "gitee.com/liuzongyang/libpq"
)
var (
meta dbi.Meta
once sync.Once
)
func GetMeta() dbi.Meta {
once.Do(func() {
meta = new(PostgresMeta)
})
return meta
}
type PostgresMeta struct {
}
func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
driverName := string(d.Type)
// SSH Conect
if d.SshTunnelMachineId > 0 {
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
if !collx.ArrayContains(sql.Drivers(), driverName) {
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
}
sql.Drivers()
}
db := d.Database
var dbParam string
exsitSchema := false
if db != "" {
// postgres database可以使用db/schema表示方便连接指定schema, 若不存在schema则使用默认schema
ss := strings.Split(db, "/")
if len(ss) > 1 {
exsitSchema = true
dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
} else {
dbParam = "dbname=" + db
}
}
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
// 存在额外指定参数,则拼接该连接参数
if d.Params != "" {
// 存在指定的db则需要将dbInstance配置中的parmas排除掉dbname和search_path
if db != "" {
paramArr := strings.Split(d.Params, "&")
paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
if strings.HasPrefix(param, "dbname=") {
return true
}
if exsitSchema && strings.HasPrefix(param, "search_path") {
return true
}
return false
})
d.Params = strings.Join(paramArr, " ")
}
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
}
return sql.Open(driverName, dsn)
}
func (md *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
return &PgsqlDialect{conn}
}
// pgsql dialer
type PqSqlDialer struct {
sshTunnelMachineId int
}
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
return pq.DialOpen(d, name)
}
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
if err != nil {
return nil, err
}
if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
// 将ssh conn包装否则会返回错误: ssh: tcpChan: deadline not supported
return &netx.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return pd.Dial(network, address)
}