2024-01-12 13:15:30 +08:00
|
|
|
|
package postgres
|
|
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
|
"database/sql"
|
|
|
|
|
|
"database/sql/driver"
|
|
|
|
|
|
"fmt"
|
|
|
|
|
|
"mayfly-go/internal/db/dbm/dbi"
|
|
|
|
|
|
"mayfly-go/pkg/utils/collx"
|
|
|
|
|
|
"mayfly-go/pkg/utils/netx"
|
|
|
|
|
|
"net"
|
|
|
|
|
|
"strings"
|
|
|
|
|
|
"time"
|
|
|
|
|
|
|
|
|
|
|
|
pq "gitee.com/liuzongyang/libpq"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2024-01-24 17:01:17 +08:00
|
|
|
|
func init() {
|
2024-02-06 07:32:03 +00:00
|
|
|
|
meta := new(PostgresMeta)
|
|
|
|
|
|
dbi.Register(dbi.DbTypePostgres, meta)
|
|
|
|
|
|
dbi.Register(dbi.DbTypeKingbaseEs, meta)
|
|
|
|
|
|
dbi.Register(dbi.DbTypeVastbase, meta)
|
2024-01-30 13:09:26 +00:00
|
|
|
|
|
2024-03-01 04:03:03 +00:00
|
|
|
|
gauss := &PostgresMeta{
|
|
|
|
|
|
Param: "dbtype=gauss",
|
|
|
|
|
|
}
|
2024-01-30 13:09:26 +00:00
|
|
|
|
dbi.Register(dbi.DbTypeGauss, gauss)
|
2024-01-12 13:15:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
type PostgresMeta struct {
|
2024-01-30 13:09:26 +00:00
|
|
|
|
Param string
|
2024-01-12 13:15:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
|
2024-01-30 13:09:26 +00:00
|
|
|
|
driverName := "postgres"
|
2024-01-12 13:15:30 +08:00
|
|
|
|
// SSH Conect
|
|
|
|
|
|
if d.SshTunnelMachineId > 0 {
|
|
|
|
|
|
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
|
|
|
|
|
|
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
|
|
|
|
|
|
if !collx.ArrayContains(sql.Drivers(), driverName) {
|
|
|
|
|
|
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
|
|
|
|
|
|
}
|
|
|
|
|
|
sql.Drivers()
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
db := d.Database
|
|
|
|
|
|
var dbParam string
|
2024-02-06 07:32:03 +00:00
|
|
|
|
existSchema := false
|
|
|
|
|
|
if db == "" {
|
|
|
|
|
|
db = d.Type.MetaDbName()
|
|
|
|
|
|
}
|
|
|
|
|
|
// postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
|
|
|
|
|
|
ss := strings.Split(db, "/")
|
|
|
|
|
|
if len(ss) > 1 {
|
|
|
|
|
|
existSchema = true
|
|
|
|
|
|
dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
|
|
|
|
|
|
} else {
|
|
|
|
|
|
dbParam = "dbname=" + db
|
2024-01-12 13:15:30 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
|
|
|
|
|
|
// 存在额外指定参数,则拼接该连接参数
|
|
|
|
|
|
if d.Params != "" {
|
|
|
|
|
|
// 存在指定的db,则需要将dbInstance配置中的parmas排除掉dbname和search_path
|
|
|
|
|
|
if db != "" {
|
|
|
|
|
|
paramArr := strings.Split(d.Params, "&")
|
|
|
|
|
|
paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
|
|
|
|
|
|
if strings.HasPrefix(param, "dbname=") {
|
|
|
|
|
|
return true
|
|
|
|
|
|
}
|
2024-02-06 07:32:03 +00:00
|
|
|
|
if existSchema && strings.HasPrefix(param, "search_path") {
|
2024-01-12 13:15:30 +08:00
|
|
|
|
return true
|
|
|
|
|
|
}
|
|
|
|
|
|
return false
|
|
|
|
|
|
})
|
|
|
|
|
|
d.Params = strings.Join(paramArr, " ")
|
|
|
|
|
|
}
|
|
|
|
|
|
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2024-01-30 13:09:26 +00:00
|
|
|
|
if md.Param != "" && !strings.Contains(dsn, "dbtype") {
|
|
|
|
|
|
dsn = fmt.Sprintf("%s %s", dsn, md.Param)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2024-01-12 13:15:30 +08:00
|
|
|
|
return sql.Open(driverName, dsn)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (md *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
|
|
|
|
|
return &PgsqlDialect{conn}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// pgsql dialer
|
|
|
|
|
|
type PqSqlDialer struct {
|
|
|
|
|
|
sshTunnelMachineId int
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
|
|
|
|
|
|
return pq.DialOpen(d, name)
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
|
2024-01-13 13:38:53 +08:00
|
|
|
|
sshTunnel, err := dbi.GetSshTunnel(pd.sshTunnelMachineId)
|
2024-01-12 13:15:30 +08:00
|
|
|
|
if err != nil {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
|
|
|
|
|
|
// 将ssh conn包装,否则会返回错误: ssh: tcpChan: deadline not supported
|
|
|
|
|
|
return &netx.WrapSshConn{Conn: sshConn}, nil
|
|
|
|
|
|
} else {
|
|
|
|
|
|
return nil, err
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
|
|
|
|
|
return pd.Dial(network, address)
|
|
|
|
|
|
}
|