mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	
		
			
				
	
	
		
			155 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			155 lines
		
	
	
		
			3.9 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
package dbi
 | 
						||
 | 
						||
import (
 | 
						||
	"fmt"
 | 
						||
	machineapp "mayfly-go/internal/machine/application"
 | 
						||
	"mayfly-go/internal/machine/mcm"
 | 
						||
	"mayfly-go/pkg/errorx"
 | 
						||
	"mayfly-go/pkg/logx"
 | 
						||
	"strings"
 | 
						||
)
 | 
						||
 | 
						||
type DbType string
 | 
						||
 | 
						||
const (
 | 
						||
	DbTypeMysql      DbType = "mysql"
 | 
						||
	DbTypeMariadb    DbType = "mariadb"
 | 
						||
	DbTypePostgres   DbType = "postgres"
 | 
						||
	DbTypeGauss      DbType = "gauss"
 | 
						||
	DbTypeDM         DbType = "dm"
 | 
						||
	DbTypeOracle     DbType = "oracle"
 | 
						||
	DbTypeSqlite     DbType = "sqlite"
 | 
						||
	DbTypeMssql      DbType = "mssql"
 | 
						||
	DbTypeKingbaseEs DbType = "kingbaseEs"
 | 
						||
	DbTypeVastbase   DbType = "vastbase"
 | 
						||
)
 | 
						||
 | 
						||
func ToDbType(dbType string) DbType {
 | 
						||
	return DbType(dbType)
 | 
						||
}
 | 
						||
 | 
						||
func (dbType DbType) Equal(typ string) bool {
 | 
						||
	return ToDbType(typ) == dbType
 | 
						||
}
 | 
						||
 | 
						||
type DbInfo struct {
 | 
						||
	InstanceId uint64 // 实例id
 | 
						||
	Id         uint64 // dbId
 | 
						||
	Name       string
 | 
						||
 | 
						||
	Type     DbType // 类型,mysql postgres等
 | 
						||
	Host     string
 | 
						||
	Port     int
 | 
						||
	Extra    string // 连接需要的其他额外参数(json字符串),如oracle数据库需要指定sid等
 | 
						||
	Network  string
 | 
						||
	Username string
 | 
						||
	Password string
 | 
						||
	Params   string
 | 
						||
	Database string // 若有schema的库则为'database/scheam'格式
 | 
						||
 | 
						||
	FlowProcdefKey     string // 流程定义key
 | 
						||
	TagPath            []string
 | 
						||
	SshTunnelMachineId int
 | 
						||
 | 
						||
	Meta Meta
 | 
						||
}
 | 
						||
 | 
						||
// 获取记录日志的描述
 | 
						||
func (d *DbInfo) GetLogDesc() string {
 | 
						||
	return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.TagPath, d.Name, d.Host, d.Port, d.Database)
 | 
						||
}
 | 
						||
 | 
						||
// 连接数据库
 | 
						||
func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) {
 | 
						||
	if meta == nil {
 | 
						||
		return nil, errorx.NewBiz("数据库元信息接口不能为空")
 | 
						||
	}
 | 
						||
 | 
						||
	// 赋值Meta,方便后续获取dialect等
 | 
						||
	dbInfo.Meta = meta
 | 
						||
	database := dbInfo.Database
 | 
						||
	// 如果数据库为空,则使用默认数据库进行连接
 | 
						||
	if database == "" {
 | 
						||
		database = meta.GetMetaData(&DbConn{Info: dbInfo}).DefaultDb()
 | 
						||
		dbInfo.Database = database
 | 
						||
	}
 | 
						||
 | 
						||
	conn, err := meta.GetSqlDb(dbInfo)
 | 
						||
	if err != nil {
 | 
						||
		logx.Errorf("连接db失败: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error())
 | 
						||
		return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
 | 
						||
	}
 | 
						||
 | 
						||
	err = conn.Ping()
 | 
						||
	if err != nil {
 | 
						||
		logx.Errorf("db ping失败: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error())
 | 
						||
		return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
 | 
						||
	}
 | 
						||
 | 
						||
	dbc := &DbConn{Id: GetDbConnId(dbInfo.Id, database), Info: dbInfo}
 | 
						||
 | 
						||
	// 最大连接周期,超过时间的连接就close
 | 
						||
	// conn.SetConnMaxLifetime(100 * time.Second)
 | 
						||
	// 设置最大连接数
 | 
						||
	conn.SetMaxOpenConns(5)
 | 
						||
	// 设置闲置连接数
 | 
						||
	conn.SetMaxIdleConns(1)
 | 
						||
	dbc.db = conn
 | 
						||
	logx.Infof("连接db: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
 | 
						||
 | 
						||
	return dbc, nil
 | 
						||
}
 | 
						||
 | 
						||
// 如果使用了ssh隧道,将其host port改变其本地映射host port
 | 
						||
func (di *DbInfo) IfUseSshTunnelChangeIpPort() error {
 | 
						||
	// 开启ssh隧道
 | 
						||
	if di.SshTunnelMachineId > 0 {
 | 
						||
		sshTunnelMachine, err := GetSshTunnel(di.SshTunnelMachineId)
 | 
						||
		if err != nil {
 | 
						||
			return err
 | 
						||
		}
 | 
						||
		exposedIp, exposedPort, err := sshTunnelMachine.OpenSshTunnel(fmt.Sprintf("db:%d", di.Id), di.Host, di.Port)
 | 
						||
		if err != nil {
 | 
						||
			return err
 | 
						||
		}
 | 
						||
		di.Host = exposedIp
 | 
						||
		di.Port = exposedPort
 | 
						||
	}
 | 
						||
	return nil
 | 
						||
}
 | 
						||
 | 
						||
// 获取当前库的schema(兼容 database/schema模式)
 | 
						||
func (di *DbInfo) CurrentSchema() string {
 | 
						||
	dbName := di.Database
 | 
						||
	schema := ""
 | 
						||
	arr := strings.Split(dbName, "/")
 | 
						||
	if len(arr) == 2 {
 | 
						||
		schema = arr[1]
 | 
						||
	}
 | 
						||
	return schema
 | 
						||
}
 | 
						||
 | 
						||
// 获取当前数据库(兼容 database/schema模式)
 | 
						||
func (di *DbInfo) GetDatabase() string {
 | 
						||
	dbName := di.Database
 | 
						||
	ss := strings.Split(dbName, "/")
 | 
						||
	if len(ss) > 1 {
 | 
						||
		return ss[0]
 | 
						||
	}
 | 
						||
	return dbName
 | 
						||
}
 | 
						||
 | 
						||
// 根据ssh tunnel机器id返回ssh tunnel
 | 
						||
func GetSshTunnel(sshTunnelMachineId int) (*mcm.SshTunnelMachine, error) {
 | 
						||
	return machineapp.GetMachineApp().GetSshTunnelMachine(sshTunnelMachineId)
 | 
						||
}
 | 
						||
 | 
						||
// 获取连接id
 | 
						||
func GetDbConnId(dbId uint64, db string) string {
 | 
						||
	if dbId == 0 {
 | 
						||
		return ""
 | 
						||
	}
 | 
						||
 | 
						||
	return fmt.Sprintf("%d:%s", dbId, db)
 | 
						||
}
 |