mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-02 23:40:24 +08:00
fix: 机器文件下载问题修复&dbm重构
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
"countup.js": "^2.7.0",
|
||||
"cropperjs": "^1.5.11",
|
||||
"echarts": "^5.4.3",
|
||||
"element-plus": "^2.5.0",
|
||||
"element-plus": "^2.5.1",
|
||||
"js-base64": "^3.7.5",
|
||||
"jsencrypt": "^3.3.2",
|
||||
"lodash": "^4.17.21",
|
||||
@@ -33,7 +33,7 @@
|
||||
"splitpanes": "^3.1.5",
|
||||
"sql-formatter": "^14.0.0",
|
||||
"uuid": "^9.0.1",
|
||||
"vue": "^3.4.8",
|
||||
"vue": "^3.4.10",
|
||||
"vue-router": "^4.2.5",
|
||||
"xterm": "^5.3.0",
|
||||
"xterm-addon-fit": "^0.8.0",
|
||||
@@ -42,7 +42,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/lodash": "^4.14.178",
|
||||
"@types/node": "^15.6.0",
|
||||
"@types/node": "^18.14.0",
|
||||
"@types/nprogress": "^0.2.0",
|
||||
"@types/sortablejs": "^1.15.3",
|
||||
"@typescript-eslint/eslint-plugin": "^6.7.4",
|
||||
|
||||
@@ -35,6 +35,7 @@ export const machineApi = {
|
||||
mvFile: Api.newPost('/machines/{machineId}/files/{fileId}/mv'),
|
||||
uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?' + joinClientParams()),
|
||||
fileContent: Api.newGet('/machines/{machineId}/files/{fileId}/read'),
|
||||
downloadFile: Api.newGet('/machines/{machineId}/files/{fileId}/download'),
|
||||
createFile: Api.newPost('/machines/{machineId}/files/{id}/create-file'),
|
||||
// 修改文件内容
|
||||
updateFileContent: Api.newPost('/machines/{machineId}/files/{id}/write'),
|
||||
|
||||
@@ -611,7 +611,7 @@ const deleteFile = async (files: any) => {
|
||||
|
||||
const downloadFile = (data: any) => {
|
||||
const a = document.createElement('a');
|
||||
a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&${joinClientParams()}`);
|
||||
a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/download?path=${data.path}&${joinClientParams()}`);
|
||||
a.click();
|
||||
};
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
"mayfly-go/internal/db/api/form"
|
||||
"mayfly-go/internal/db/api/vo"
|
||||
"mayfly-go/internal/db/application"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
msgapp "mayfly-go/internal/msg/application"
|
||||
msgdto "mayfly-go/internal/msg/application/dto"
|
||||
@@ -351,7 +351,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
|
||||
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
|
||||
writer.WriteString("BEGIN;\n")
|
||||
insertSql := "INSERT INTO %s VALUES (%s);\n"
|
||||
dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbm.QueryColumn) error {
|
||||
dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbi.QueryColumn) error {
|
||||
var values []string
|
||||
writer.TryFlush()
|
||||
for _, column := range columns {
|
||||
@@ -470,7 +470,7 @@ func getDbName(g *gin.Context) string {
|
||||
return db
|
||||
}
|
||||
|
||||
func (d *Db) getDbConn(g *gin.Context) *dbm.DbConn {
|
||||
func (d *Db) getDbConn(g *gin.Context) *dbi.DbConn {
|
||||
dc, err := d.DbApp.GetDbConn(getDbId(g), getDbName(g))
|
||||
biz.ErrIsNil(err)
|
||||
return dc
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"mayfly-go/internal/common/consts"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
tagapp "mayfly-go/internal/tag/application"
|
||||
@@ -33,10 +34,10 @@ type Db interface {
|
||||
// @param id 数据库id
|
||||
//
|
||||
// @param dbName 数据库名
|
||||
GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error)
|
||||
GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error)
|
||||
|
||||
// 根据数据库实例id获取连接,随机返回该instanceId下已连接的conn,若不存在则是使用该instanceId关联的db进行连接并返回。
|
||||
GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error)
|
||||
GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error)
|
||||
}
|
||||
|
||||
func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance, tagApp tagapp.TagTree) Db {
|
||||
@@ -142,8 +143,8 @@ func (d *dbAppImpl) Delete(ctx context.Context, id uint64) error {
|
||||
})
|
||||
}
|
||||
|
||||
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
|
||||
return dbm.GetDbConn(dbId, dbName, func() (*dbm.DbInfo, error) {
|
||||
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) {
|
||||
return dbm.GetDbConn(dbId, dbName, func() (*dbi.DbInfo, error) {
|
||||
db, err := d.GetById(new(entity.Db), dbId)
|
||||
if err != nil {
|
||||
return nil, errorx.NewBiz("数据库信息不存在")
|
||||
@@ -156,7 +157,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
|
||||
|
||||
checkDb := dbName
|
||||
// 兼容pgsql/dm db/schema模式
|
||||
if dbm.DbTypePostgres.Equal(instance.Type) || dbm.DbTypeDM.Equal(instance.Type) {
|
||||
if dbi.DbTypePostgres.Equal(instance.Type) || dbi.DbTypeDM.Equal(instance.Type) {
|
||||
ss := strings.Split(dbName, "/")
|
||||
if len(ss) > 1 {
|
||||
checkDb = ss[0]
|
||||
@@ -174,7 +175,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
|
||||
})
|
||||
}
|
||||
|
||||
func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error) {
|
||||
func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) {
|
||||
conn := dbm.GetDbConnByInstanceId(instanceId)
|
||||
if conn != nil {
|
||||
return conn, nil
|
||||
@@ -193,8 +194,8 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error
|
||||
return d.GetDbConn(firstDb.Id, strings.Split(firstDb.Database, " ")[0])
|
||||
}
|
||||
|
||||
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbm.DbInfo {
|
||||
di := new(dbm.DbInfo)
|
||||
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbi.DbInfo {
|
||||
di := new(dbi.DbInfo)
|
||||
di.InstanceId = instance.Id
|
||||
di.Id = dbId
|
||||
di.Database = database
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/base"
|
||||
@@ -182,20 +182,20 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
|
||||
if err != nil {
|
||||
return syncLog, errorx.NewBiz("解析字段映射json出错: %s", err.Error())
|
||||
}
|
||||
var updFieldType dbm.DataType
|
||||
var updFieldType dbi.DataType
|
||||
|
||||
// 记录本次同步数据总数
|
||||
total := 0
|
||||
batchSize := task.PageSize
|
||||
result := make([]map[string]any, 0)
|
||||
var queryColumns []*dbm.QueryColumn
|
||||
var queryColumns []*dbi.QueryColumn
|
||||
|
||||
err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbm.QueryColumn) error {
|
||||
err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbi.QueryColumn) error {
|
||||
if len(queryColumns) == 0 {
|
||||
queryColumns = columns
|
||||
|
||||
// 遍历columns 取task.UpdField的字段类型
|
||||
updFieldType = dbm.DataTypeString
|
||||
updFieldType = dbi.DataTypeString
|
||||
for _, column := range columns {
|
||||
if column.Name == task.UpdField {
|
||||
updFieldType = srcDialect.GetDataType(column.Type)
|
||||
@@ -249,7 +249,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
|
||||
return syncLog, nil
|
||||
}
|
||||
|
||||
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbm.DataType, task *entity.DataSyncTask, srcDialect dbm.DbDialect, targetDbConn *dbm.DbConn, targetDbTx *sql.Tx) error {
|
||||
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
|
||||
var data = make([]map[string]any, 0)
|
||||
|
||||
// 遍历res,组装插入sql
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/logx"
|
||||
@@ -314,7 +314,7 @@ func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
|
||||
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error {
|
||||
binlogHistory, err := s.binlogHistoryRepo.GetHistoryByTime(job.DbInstanceId, job.PointInTime.Time)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -341,7 +341,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
restoreInfo := &dbm.RestoreInfo{
|
||||
restoreInfo := &dbi.RestoreInfo{
|
||||
BackupHistory: backupHistory,
|
||||
BinlogHistories: binlogHistories,
|
||||
StartPosition: backupHistory.BinlogPosition,
|
||||
@@ -354,7 +354,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg
|
||||
return program.ReplayBinlog(ctx, job.DbName, job.DbName, restoreInfo)
|
||||
}
|
||||
|
||||
func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
|
||||
func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error {
|
||||
backupHistory := &entity.DbBackupHistory{}
|
||||
if err := s.backupHistoryRepo.GetById(backupHistory, job.DbBackupHistoryId); err != nil {
|
||||
return err
|
||||
|
||||
@@ -4,7 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/config"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/contextx"
|
||||
@@ -22,11 +22,11 @@ type DbSqlExecReq struct {
|
||||
Db string
|
||||
Sql string
|
||||
Remark string
|
||||
DbConn *dbm.DbConn
|
||||
DbConn *dbi.DbConn
|
||||
}
|
||||
|
||||
type DbSqlExecRes struct {
|
||||
Columns []*dbm.QueryColumn
|
||||
Columns []*dbi.QueryColumn
|
||||
Res []map[string]any
|
||||
}
|
||||
|
||||
@@ -269,7 +269,7 @@ func doInsert(ctx context.Context, insert *sqlparser.Insert, execSqlReq *DbSqlEx
|
||||
return doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
|
||||
}
|
||||
|
||||
func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
|
||||
func doExec(ctx context.Context, sql string, dbConn *dbi.DbConn) (*DbSqlExecRes, error) {
|
||||
rowsAffected, err := dbConn.ExecContext(ctx, sql)
|
||||
execRes := "success"
|
||||
if err != nil {
|
||||
@@ -283,7 +283,7 @@ func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes,
|
||||
res = append(res, resData)
|
||||
|
||||
return &DbSqlExecRes{
|
||||
Columns: []*dbm.QueryColumn{
|
||||
Columns: []*dbi.QueryColumn{
|
||||
{Name: "sql", Type: "string"},
|
||||
{Name: "rowsAffected", Type: "number"},
|
||||
{Name: "result", Type: "string"},
|
||||
|
||||
@@ -3,6 +3,7 @@ package application
|
||||
import (
|
||||
"context"
|
||||
"mayfly-go/internal/db/dbm"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/internal/db/domain/repository"
|
||||
"mayfly-go/pkg/base"
|
||||
@@ -50,7 +51,7 @@ func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 {
|
||||
|
||||
func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance) error {
|
||||
instanceEntity.Network = instanceEntity.GetNetwork()
|
||||
dbConn, err := toDbInfo(instanceEntity, 0, "", "").Conn()
|
||||
dbConn, err := dbm.Conn(toDbInfo(instanceEntity, 0, "", ""))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -100,9 +101,9 @@ func (app *instanceAppImpl) Delete(ctx context.Context, id uint64) error {
|
||||
|
||||
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) {
|
||||
ed.Network = ed.GetNetwork()
|
||||
metaDb := dbm.ToDbType(ed.Type).MetaDbName()
|
||||
metaDb := dbi.ToDbType(ed.Type).MetaDbName()
|
||||
|
||||
dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn()
|
||||
dbConn, err := dbm.Conn(toDbInfo(ed, 0, metaDb, ""))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package dbm
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -117,18 +117,9 @@ func (d *DbConn) Begin() (*sql.Tx, error) {
|
||||
return d.db.Begin()
|
||||
}
|
||||
|
||||
// 获取数据库元信息实现接口
|
||||
func (d *DbConn) GetDialect() DbDialect {
|
||||
switch d.Info.Type {
|
||||
case DbTypeMysql, DbTypeMariadb:
|
||||
return &MysqlDialect{dc: d}
|
||||
case DbTypePostgres:
|
||||
return &PgsqlDialect{dc: d}
|
||||
case DbTypeDM:
|
||||
return &DMDialect{dc: d}
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
|
||||
}
|
||||
// 获取数据库dialect实现接口
|
||||
func (d *DbConn) GetDialect() Dialect {
|
||||
return d.Info.Meta.GetDialect(d)
|
||||
}
|
||||
|
||||
// 关闭连接
|
||||
@@ -1,4 +1,4 @@
|
||||
package dbm
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,4 +1,4 @@
|
||||
package dbm
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -1,4 +1,4 @@
|
||||
package dbm
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"testing"
|
||||
@@ -1,12 +1,13 @@
|
||||
package dbm
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"strings"
|
||||
|
||||
"mayfly-go/pkg/biz"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type DataType string
|
||||
@@ -60,7 +61,7 @@ type Index struct {
|
||||
|
||||
// -----------------------------------元数据接口定义------------------------------------------
|
||||
// 数据库方言、元信息接口(表、列、获取表数据等元信息)
|
||||
type DbDialect interface {
|
||||
type Dialect interface {
|
||||
// 获取数据库服务实例信息
|
||||
GetDbServer() (*DbServer, error)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package dbm
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@@ -8,6 +8,9 @@ import (
|
||||
"mayfly-go/pkg/logx"
|
||||
)
|
||||
|
||||
// 获取sql.DB函数
|
||||
type GetSqlDbFunc func(*DbInfo) (*sql.DB, error)
|
||||
|
||||
type DbInfo struct {
|
||||
InstanceId uint64 // 实例id
|
||||
Id uint64 // dbId
|
||||
@@ -24,6 +27,8 @@ type DbInfo struct {
|
||||
|
||||
TagPath []string
|
||||
SshTunnelMachineId int
|
||||
|
||||
Meta Meta
|
||||
}
|
||||
|
||||
// 获取记录日志的描述
|
||||
@@ -32,22 +37,16 @@ func (d *DbInfo) GetLogDesc() string {
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
func (dbInfo *DbInfo) Conn() (*DbConn, error) {
|
||||
var conn *sql.DB
|
||||
var err error
|
||||
database := dbInfo.Database
|
||||
|
||||
switch dbInfo.Type {
|
||||
case DbTypeMysql, DbTypeMariadb:
|
||||
conn, err = getMysqlDB(dbInfo)
|
||||
case DbTypePostgres:
|
||||
conn, err = getPgsqlDB(dbInfo)
|
||||
case DbTypeDM:
|
||||
conn, err = getDmDB(dbInfo)
|
||||
default:
|
||||
return nil, errorx.NewBiz("invalid database type: %s", dbInfo.Type)
|
||||
func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) {
|
||||
if meta == nil {
|
||||
return nil, errorx.NewBiz("数据库元信息接口不能为空")
|
||||
}
|
||||
|
||||
// 赋值Meta,方便后续获取dialect等
|
||||
dbInfo.Meta = meta
|
||||
database := dbInfo.Database
|
||||
|
||||
conn, err := meta.GetSqlDb(dbInfo)
|
||||
if err != nil {
|
||||
logx.Errorf("连接db失败: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error())
|
||||
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
|
||||
12
server/internal/db/dbm/dbi/meta.go
Normal file
12
server/internal/db/dbm/dbi/meta.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package dbi
|
||||
|
||||
import "database/sql"
|
||||
|
||||
// 数据库元信息获取,如获取sql.DB、Dialect等
|
||||
type Meta interface {
|
||||
// 获取数据库服务实例信息
|
||||
GetSqlDb(*DbInfo) (*sql.DB, error)
|
||||
|
||||
// 获取数据库方言,用于获取表结构等信息
|
||||
GetDialect(*DbConn) Dialect
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package dbm
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
@@ -3,6 +3,10 @@ package dbm
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/common/consts"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/dm"
|
||||
"mayfly-go/internal/db/dbm/mysql"
|
||||
"mayfly-go/internal/db/dbm/postgres"
|
||||
"mayfly-go/internal/machine/mcm"
|
||||
"mayfly-go/pkg/cache"
|
||||
"mayfly-go/pkg/logx"
|
||||
@@ -15,7 +19,7 @@ var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
|
||||
WithUpdateAccessTime(true).
|
||||
OnEvicted(func(key any, value any) {
|
||||
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
|
||||
value.(*DbConn).Close()
|
||||
value.(*dbi.DbConn).Close()
|
||||
})
|
||||
|
||||
func init() {
|
||||
@@ -23,7 +27,7 @@ func init() {
|
||||
// 遍历所有db连接实例,若存在db实例使用该ssh隧道机器,则返回true,表示还在使用中...
|
||||
items := connCache.Items()
|
||||
for _, v := range items {
|
||||
if v.Value.(*DbConn).Info.SshTunnelMachineId == machineId {
|
||||
if v.Value.(*dbi.DbConn).Info.SshTunnelMachineId == machineId {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@@ -33,8 +37,21 @@ func init() {
|
||||
|
||||
var mutex sync.Mutex
|
||||
|
||||
func getDbMetaByType(dt dbi.DbType) dbi.Meta {
|
||||
switch dt {
|
||||
case dbi.DbTypeMysql, dbi.DbTypeMariadb:
|
||||
return mysql.GetMeta()
|
||||
case dbi.DbTypePostgres:
|
||||
return postgres.GetMeta()
|
||||
case dbi.DbTypeDM:
|
||||
return dm.GetMeta()
|
||||
default:
|
||||
panic(fmt.Sprintf("invalid database type: %s", dt))
|
||||
}
|
||||
}
|
||||
|
||||
// 从缓存中获取数据库连接信息,若缓存中不存在则会使用回调函数获取dbInfo进行连接并缓存
|
||||
func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) (*DbConn, error) {
|
||||
func GetDbConn(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) {
|
||||
connId := GetDbConnId(dbId, database)
|
||||
|
||||
// connId不为空,则为需要缓存
|
||||
@@ -42,7 +59,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
|
||||
if needCache {
|
||||
load, ok := connCache.Get(connId)
|
||||
if ok {
|
||||
return load.(*DbConn), nil
|
||||
return load.(*dbi.DbConn), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,7 +73,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
|
||||
}
|
||||
|
||||
// 连接数据库
|
||||
dbConn, err := dbInfo.Conn()
|
||||
dbConn, err := Conn(dbInfo)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -67,10 +84,15 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
|
||||
return dbConn, nil
|
||||
}
|
||||
|
||||
// 使用指定dbInfo信息进行连接
|
||||
func Conn(di *dbi.DbInfo) (*dbi.DbConn, error) {
|
||||
return di.Conn(getDbMetaByType(di.Type))
|
||||
}
|
||||
|
||||
// 根据实例id获取连接
|
||||
func GetDbConnByInstanceId(instanceId uint64) *DbConn {
|
||||
func GetDbConnByInstanceId(instanceId uint64) *dbi.DbConn {
|
||||
for _, connItem := range connCache.Items() {
|
||||
conn := connItem.Value.(*DbConn)
|
||||
conn := connItem.Value.(*dbi.DbConn)
|
||||
if conn.Info.InstanceId == instanceId {
|
||||
return conn
|
||||
}
|
||||
@@ -82,3 +104,12 @@ func GetDbConnByInstanceId(instanceId uint64) *DbConn {
|
||||
func CloseDb(dbId uint64, db string) {
|
||||
connCache.Delete(GetDbConnId(dbId, db))
|
||||
}
|
||||
|
||||
// 获取连接id
|
||||
func GetDbConnId(dbId uint64, db string) string {
|
||||
if dbId == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%d:%s", dbId, db)
|
||||
}
|
||||
@@ -1,9 +1,10 @@
|
||||
package dbm
|
||||
package dm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
@@ -14,30 +15,6 @@ import (
|
||||
_ "gitee.com/chunanyong/dm"
|
||||
)
|
||||
|
||||
func getDmDB(d *DbInfo) (*sql.DB, error) {
|
||||
driverName := "dm"
|
||||
db := d.Database
|
||||
var dbParam string
|
||||
if db != "" {
|
||||
// dm database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
|
||||
ss := strings.Split(db, "/")
|
||||
if len(ss) > 1 {
|
||||
dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
|
||||
} else {
|
||||
dbParam = db
|
||||
}
|
||||
}
|
||||
|
||||
err := d.IfUseSshTunnelChangeIpPort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)
|
||||
return sql.Open(driverName, dsn)
|
||||
}
|
||||
|
||||
// ---------------------------------- DM元数据 -----------------------------------
|
||||
const (
|
||||
DM_META_FILE = "metasql/dm_meta.sql"
|
||||
DM_DB_SCHEMAS = "DM_DB_SCHEMAS"
|
||||
@@ -47,15 +24,15 @@ const (
|
||||
)
|
||||
|
||||
type DMDialect struct {
|
||||
dc *DbConn
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (dd *DMDialect) GetDbServer() (*DbServer, error) {
|
||||
func (dd *DMDialect) GetDbServer() (*dbi.DbServer, error) {
|
||||
_, res, err := dd.dc.Query("select * from v$instance")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ds := &DbServer{
|
||||
ds := &dbi.DbServer{
|
||||
Version: anyx.ConvString(res[0]["SVR_VERSION"]),
|
||||
}
|
||||
return ds, nil
|
||||
@@ -76,20 +53,20 @@ func (dd *DMDialect) GetDbNames() ([]string, error) {
|
||||
}
|
||||
|
||||
// 获取表基础元信息, 如表名等
|
||||
func (dd *DMDialect) GetTables() ([]Table, error) {
|
||||
func (dd *DMDialect) GetTables() ([]dbi.Table, error) {
|
||||
|
||||
// 首先执行更新统计信息sql 这个统计信息在数据量比较大的时候就比较耗时,所以最好定时执行
|
||||
// _, _, err := pd.dc.Query("dbms_stats.GATHER_SCHEMA_stats(SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))")
|
||||
|
||||
// 查询表信息
|
||||
_, res, err := dd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
|
||||
_, res, err := dd.dc.Query(dbi.GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables := make([]Table, 0)
|
||||
tables := make([]dbi.Table, 0)
|
||||
for _, re := range res {
|
||||
tables = append(tables, Table{
|
||||
tables = append(tables, dbi.Table{
|
||||
TableName: re["TABLE_NAME"].(string),
|
||||
TableComment: anyx.ConvString(re["TABLE_COMMENT"]),
|
||||
CreateTime: anyx.ConvString(re["CREATE_TIME"]),
|
||||
@@ -102,7 +79,7 @@ func (dd *DMDialect) GetTables() ([]Table, error) {
|
||||
}
|
||||
|
||||
// 获取列元信息, 如列名等
|
||||
func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
func (dd *DMDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
tableName := ""
|
||||
for i := 0; i < len(tableNames); i++ {
|
||||
if i != 0 {
|
||||
@@ -111,14 +88,14 @@ func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
tableName = tableName + "'" + tableNames[i] + "'"
|
||||
}
|
||||
|
||||
_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
|
||||
_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]Column, 0)
|
||||
columns := make([]dbi.Column, 0)
|
||||
for _, re := range res {
|
||||
columns = append(columns, Column{
|
||||
columns = append(columns, dbi.Column{
|
||||
TableName: re["TABLE_NAME"].(string),
|
||||
ColumnName: re["COLUMN_NAME"].(string),
|
||||
ColumnType: anyx.ConvString(re["COLUMN_TYPE"]),
|
||||
@@ -150,15 +127,15 @@ func (dd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
|
||||
}
|
||||
|
||||
// 获取表索引信息
|
||||
func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
|
||||
_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
|
||||
func (dd *DMDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexs := make([]Index, 0)
|
||||
indexs := make([]dbi.Index, 0)
|
||||
for _, re := range res {
|
||||
indexs = append(indexs, Index{
|
||||
indexs = append(indexs, dbi.Index{
|
||||
IndexName: re["INDEX_NAME"].(string),
|
||||
ColumnName: anyx.ConvString(re["COLUMN_NAME"]),
|
||||
IndexType: anyx.ConvString(re["INDEX_TYPE"]),
|
||||
@@ -168,7 +145,7 @@ func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
|
||||
})
|
||||
}
|
||||
// 把查询结果以索引名分组,索引字段以逗号连接
|
||||
result := make([]Index, 0)
|
||||
result := make([]dbi.Index, 0)
|
||||
key := ""
|
||||
for _, v := range indexs {
|
||||
// 当前的索引名
|
||||
@@ -255,13 +232,13 @@ func (dd *DMDialect) GetTableDDL(tableName string) (string, error) {
|
||||
return builder.String(), nil
|
||||
}
|
||||
|
||||
func (dd *DMDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
|
||||
func (dd *DMDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
|
||||
return dd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
|
||||
}
|
||||
|
||||
// 获取DM当前连接的库可访问的schemaNames
|
||||
func (dd *DMDialect) GetSchemas() ([]string, error) {
|
||||
sql := GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
|
||||
sql := dbi.GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
|
||||
_, res, err := dd.dc.Query(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -274,27 +251,27 @@ func (dd *DMDialect) GetSchemas() ([]string, error) {
|
||||
}
|
||||
|
||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
|
||||
func (dd *DMDialect) GetDbProgram() DbProgram {
|
||||
func (dd *DMDialect) GetDbProgram() dbi.DbProgram {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (dd *DMDialect) GetDataType(dbColumnType string) DataType {
|
||||
func (dd *DMDialect) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
|
||||
return DataTypeNumber
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
|
||||
return DataTypeDateTime
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
// 日期类型
|
||||
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
|
||||
return DataTypeDate
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
// 时间类型
|
||||
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
|
||||
return DataTypeTime
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
return DataTypeString
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
|
||||
@@ -322,15 +299,15 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string,
|
||||
return int64(effRows), nil
|
||||
}
|
||||
|
||||
func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
|
||||
func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
|
||||
switch dataType {
|
||||
case DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
|
||||
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
|
||||
res, _ := time.Parse(time.RFC3339, dbColumnValue)
|
||||
return res.Format(time.DateTime)
|
||||
case DataTypeDate: // "2024-01-02T00:00:00+08:00"
|
||||
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
|
||||
res, _ := time.Parse(time.RFC3339, dbColumnValue)
|
||||
return res.Format(time.DateOnly)
|
||||
case DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
|
||||
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
|
||||
res, _ := time.Parse(time.RFC3339, dbColumnValue)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
51
server/internal/db/dbm/dm/meta.go
Normal file
51
server/internal/db/dbm/dm/meta.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package dm
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
meta dbi.Meta
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func GetMeta() dbi.Meta {
|
||||
once.Do(func() {
|
||||
meta = new(DmMeta)
|
||||
})
|
||||
return meta
|
||||
}
|
||||
|
||||
type DmMeta struct {
|
||||
}
|
||||
|
||||
func (md *DmMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
|
||||
driverName := "dm"
|
||||
db := d.Database
|
||||
var dbParam string
|
||||
if db != "" {
|
||||
// dm database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
|
||||
ss := strings.Split(db, "/")
|
||||
if len(ss) > 1 {
|
||||
dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
|
||||
} else {
|
||||
dbParam = db
|
||||
}
|
||||
}
|
||||
|
||||
err := d.IfUseSshTunnelChangeIpPort()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)
|
||||
return sql.Open(driverName, dsn)
|
||||
}
|
||||
|
||||
func (md *DmMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
||||
return &DMDialect{conn}
|
||||
}
|
||||
@@ -1,40 +1,16 @@
|
||||
package dbm
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
func getMysqlDB(d *DbInfo) (*sql.DB, error) {
|
||||
// SSH Conect
|
||||
if d.SshTunnelMachineId > 0 {
|
||||
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return sshTunnelMachine.GetDialConn("tcp", addr)
|
||||
})
|
||||
}
|
||||
// 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
const driverName = "mysql"
|
||||
return sql.Open(driverName, dsn)
|
||||
}
|
||||
|
||||
// ---------------------------------- mysql元数据 -----------------------------------
|
||||
const (
|
||||
MYSQL_META_FILE = "metasql/mysql_meta.sql"
|
||||
MYSQL_DBS = "MYSQL_DBS"
|
||||
@@ -44,22 +20,22 @@ const (
|
||||
)
|
||||
|
||||
type MysqlDialect struct {
|
||||
dc *DbConn
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetDbServer() (*DbServer, error) {
|
||||
func (md *MysqlDialect) GetDbServer() (*dbi.DbServer, error) {
|
||||
_, res, err := md.dc.Query("SELECT VERSION() version")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ds := &DbServer{
|
||||
ds := &dbi.DbServer{
|
||||
Version: anyx.ConvString(res[0]["version"]),
|
||||
}
|
||||
return ds, nil
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetDbNames() ([]string, error) {
|
||||
_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_DBS))
|
||||
_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_DBS))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -73,15 +49,15 @@ func (md *MysqlDialect) GetDbNames() ([]string, error) {
|
||||
}
|
||||
|
||||
// 获取表基础元信息, 如表名等
|
||||
func (md *MysqlDialect) GetTables() ([]Table, error) {
|
||||
_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
|
||||
func (md *MysqlDialect) GetTables() ([]dbi.Table, error) {
|
||||
_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables := make([]Table, 0)
|
||||
tables := make([]dbi.Table, 0)
|
||||
for _, re := range res {
|
||||
tables = append(tables, Table{
|
||||
tables = append(tables, dbi.Table{
|
||||
TableName: re["tableName"].(string),
|
||||
TableComment: anyx.ConvString(re["tableComment"]),
|
||||
CreateTime: anyx.ConvString(re["createTime"]),
|
||||
@@ -94,7 +70,7 @@ func (md *MysqlDialect) GetTables() ([]Table, error) {
|
||||
}
|
||||
|
||||
// 获取列元信息, 如列名等
|
||||
func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
func (md *MysqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
tableName := ""
|
||||
for i := 0; i < len(tableNames); i++ {
|
||||
if i != 0 {
|
||||
@@ -103,14 +79,14 @@ func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
tableName = tableName + "'" + tableNames[i] + "'"
|
||||
}
|
||||
|
||||
_, res, err := md.dc.Query(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
|
||||
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]Column, 0)
|
||||
columns := make([]dbi.Column, 0)
|
||||
for _, re := range res {
|
||||
columns = append(columns, Column{
|
||||
columns = append(columns, dbi.Column{
|
||||
TableName: re["tableName"].(string),
|
||||
ColumnName: re["columnName"].(string),
|
||||
ColumnType: anyx.ConvString(re["columnType"]),
|
||||
@@ -144,15 +120,15 @@ func (md *MysqlDialect) GetPrimaryKey(tablename string) (string, error) {
|
||||
}
|
||||
|
||||
// 获取表索引信息
|
||||
func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) {
|
||||
_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)
|
||||
func (md *MysqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexs := make([]Index, 0)
|
||||
indexs := make([]dbi.Index, 0)
|
||||
for _, re := range res {
|
||||
indexs = append(indexs, Index{
|
||||
indexs = append(indexs, dbi.Index{
|
||||
IndexName: re["indexName"].(string),
|
||||
ColumnName: anyx.ConvString(re["columnName"]),
|
||||
IndexType: anyx.ConvString(re["indexType"]),
|
||||
@@ -162,7 +138,7 @@ func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) {
|
||||
})
|
||||
}
|
||||
// 把查询结果以索引名分组,索引字段以逗号连接
|
||||
result := make([]Index, 0)
|
||||
result := make([]dbi.Index, 0)
|
||||
key := ""
|
||||
for _, v := range indexs {
|
||||
// 当前的索引名
|
||||
@@ -189,7 +165,7 @@ func (md *MysqlDialect) GetTableDDL(tableName string) (string, error) {
|
||||
return res[0]["Create Table"].(string) + ";", nil
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
|
||||
func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
|
||||
return md.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
|
||||
}
|
||||
|
||||
@@ -198,27 +174,27 @@ func (md *MysqlDialect) GetSchemas() ([]string, error) {
|
||||
}
|
||||
|
||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
|
||||
func (md *MysqlDialect) GetDbProgram() DbProgram {
|
||||
func (md *MysqlDialect) GetDbProgram() dbi.DbProgram {
|
||||
return NewDbProgramMysql(md.dc)
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetDataType(dbColumnType string) DataType {
|
||||
func (md *MysqlDialect) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
|
||||
return DataTypeNumber
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
|
||||
return DataTypeDateTime
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
// 日期类型
|
||||
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
|
||||
return DataTypeDate
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
// 时间类型
|
||||
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
|
||||
return DataTypeTime
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
return DataTypeString
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
|
||||
@@ -246,7 +222,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
|
||||
return md.dc.TxExec(tx, sqlStr, args...)
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
|
||||
func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
|
||||
// mysql不需要格式化时间日期等
|
||||
return dbColumnValue
|
||||
}
|
||||
52
server/internal/db/dbm/mysql/meta.go
Normal file
52
server/internal/db/dbm/mysql/meta.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
var (
|
||||
meta dbi.Meta
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func GetMeta() dbi.Meta {
|
||||
once.Do(func() {
|
||||
meta = new(MysqlMeta)
|
||||
})
|
||||
return meta
|
||||
}
|
||||
|
||||
type MysqlMeta struct {
|
||||
}
|
||||
|
||||
func (md *MysqlMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
|
||||
// SSH Conect
|
||||
if d.SshTunnelMachineId > 0 {
|
||||
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
return sshTunnelMachine.GetDialConn("tcp", addr)
|
||||
})
|
||||
}
|
||||
// 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
const driverName = "mysql"
|
||||
return sql.Open(driverName, dsn)
|
||||
}
|
||||
|
||||
func (md *MysqlMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
||||
return &MysqlDialect{conn}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package dbm
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
@@ -19,27 +19,28 @@ import (
|
||||
"golang.org/x/sync/singleflight"
|
||||
|
||||
"mayfly-go/internal/db/config"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"mayfly-go/pkg/logx"
|
||||
)
|
||||
|
||||
var _ DbProgram = (*DbProgramMysql)(nil)
|
||||
var _ dbi.DbProgram = (*DbProgramMysql)(nil)
|
||||
|
||||
type DbProgramMysql struct {
|
||||
dbConn *DbConn
|
||||
dbConn *dbi.DbConn
|
||||
// mysqlBin 用于集成测试
|
||||
mysqlBin *config.MysqlBin
|
||||
// backupPath 用于集成测试
|
||||
backupPath string
|
||||
}
|
||||
|
||||
func NewDbProgramMysql(dbConn *DbConn) *DbProgramMysql {
|
||||
func NewDbProgramMysql(dbConn *dbi.DbConn) *DbProgramMysql {
|
||||
return &DbProgramMysql{
|
||||
dbConn: dbConn,
|
||||
}
|
||||
}
|
||||
|
||||
func (svc *DbProgramMysql) dbInfo() *DbInfo {
|
||||
func (svc *DbProgramMysql) dbInfo() *dbi.DbInfo {
|
||||
dbInfo := svc.dbConn.Info
|
||||
err := dbInfo.IfUseSshTunnelChangeIpPort()
|
||||
if err != nil {
|
||||
@@ -55,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
|
||||
dbInfo := svc.dbInfo()
|
||||
var mysqlBin *config.MysqlBin
|
||||
switch dbInfo.Type {
|
||||
case DbTypeMariadb:
|
||||
case dbi.DbTypeMariadb:
|
||||
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin)
|
||||
case DbTypeMysql:
|
||||
case dbi.DbTypeMysql:
|
||||
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
|
||||
default:
|
||||
panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type))
|
||||
@@ -488,7 +489,7 @@ func (svc *DbProgramMysql) GetBinlogEventPositionAtOrAfterTime(ctx context.Conte
|
||||
}
|
||||
|
||||
// ReplayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`.
|
||||
func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) {
|
||||
func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *dbi.RestoreInfo) (replayErr error) {
|
||||
const (
|
||||
// Variable lower_case_table_names related.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
//go:build e2e
|
||||
|
||||
package dbm
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -1,10 +1,11 @@
|
||||
package dbm
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"mayfly-go/internal/db/domain/entity"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_readBinlogInfoFromBackup(t *testing.T) {
|
||||
@@ -1,99 +1,17 @@
|
||||
package dbm
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/netx"
|
||||
"net"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
pq "gitee.com/liuzongyang/libpq"
|
||||
)
|
||||
|
||||
func getPgsqlDB(d *DbInfo) (*sql.DB, error) {
|
||||
driverName := string(d.Type)
|
||||
// SSH Conect
|
||||
if d.SshTunnelMachineId > 0 {
|
||||
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
|
||||
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
|
||||
if !collx.ArrayContains(sql.Drivers(), driverName) {
|
||||
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
|
||||
}
|
||||
sql.Drivers()
|
||||
}
|
||||
|
||||
db := d.Database
|
||||
var dbParam string
|
||||
exsitSchema := false
|
||||
if db != "" {
|
||||
// postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
|
||||
ss := strings.Split(db, "/")
|
||||
if len(ss) > 1 {
|
||||
exsitSchema = true
|
||||
dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
|
||||
} else {
|
||||
dbParam = "dbname=" + db
|
||||
}
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
|
||||
// 存在额外指定参数,则拼接该连接参数
|
||||
if d.Params != "" {
|
||||
// 存在指定的db,则需要将dbInstance配置中的parmas排除掉dbname和search_path
|
||||
if db != "" {
|
||||
paramArr := strings.Split(d.Params, "&")
|
||||
paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
|
||||
if strings.HasPrefix(param, "dbname=") {
|
||||
return true
|
||||
}
|
||||
if exsitSchema && strings.HasPrefix(param, "search_path") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
d.Params = strings.Join(paramArr, " ")
|
||||
}
|
||||
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
|
||||
}
|
||||
|
||||
return sql.Open(driverName, dsn)
|
||||
}
|
||||
|
||||
// pgsql dialer
|
||||
type PqSqlDialer struct {
|
||||
sshTunnelMachineId int
|
||||
}
|
||||
|
||||
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
|
||||
return pq.DialOpen(d, name)
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
|
||||
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
|
||||
// 将ssh conn包装,否则会返回错误: ssh: tcpChan: deadline not supported
|
||||
return &netx.WrapSshConn{Conn: sshConn}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
return pd.Dial(network, address)
|
||||
}
|
||||
|
||||
// ---------------------------------- pgsql元数据 -----------------------------------
|
||||
const (
|
||||
PGSQL_META_FILE = "metasql/pgsql_meta.sql"
|
||||
PGSQL_DB_SCHEMAS = "PGSQL_DB_SCHEMAS"
|
||||
@@ -104,15 +22,15 @@ const (
|
||||
)
|
||||
|
||||
type PgsqlDialect struct {
|
||||
dc *DbConn
|
||||
dc *dbi.DbConn
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) GetDbServer() (*DbServer, error) {
|
||||
func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
|
||||
_, res, err := pd.dc.Query("SHOW server_version")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ds := &DbServer{
|
||||
ds := &dbi.DbServer{
|
||||
Version: anyx.ConvString(res[0]["server_version"]),
|
||||
}
|
||||
return ds, nil
|
||||
@@ -133,15 +51,15 @@ func (pd *PgsqlDialect) GetDbNames() ([]string, error) {
|
||||
}
|
||||
|
||||
// 获取表基础元信息, 如表名等
|
||||
func (pd *PgsqlDialect) GetTables() ([]Table, error) {
|
||||
_, res, err := pd.dc.Query(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
|
||||
func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) {
|
||||
_, res, err := pd.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tables := make([]Table, 0)
|
||||
tables := make([]dbi.Table, 0)
|
||||
for _, re := range res {
|
||||
tables = append(tables, Table{
|
||||
tables = append(tables, dbi.Table{
|
||||
TableName: re["tableName"].(string),
|
||||
TableComment: anyx.ConvString(re["tableComment"]),
|
||||
CreateTime: anyx.ConvString(re["createTime"]),
|
||||
@@ -154,7 +72,7 @@ func (pd *PgsqlDialect) GetTables() ([]Table, error) {
|
||||
}
|
||||
|
||||
// 获取列元信息, 如列名等
|
||||
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
tableName := ""
|
||||
for i := 0; i < len(tableNames); i++ {
|
||||
if i != 0 {
|
||||
@@ -163,14 +81,14 @@ func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
|
||||
tableName = tableName + "'" + tableNames[i] + "'"
|
||||
}
|
||||
|
||||
_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
|
||||
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := make([]Column, 0)
|
||||
columns := make([]dbi.Column, 0)
|
||||
for _, re := range res {
|
||||
columns = append(columns, Column{
|
||||
columns = append(columns, dbi.Column{
|
||||
TableName: re["tableName"].(string),
|
||||
ColumnName: re["columnName"].(string),
|
||||
ColumnType: anyx.ConvString(re["columnType"]),
|
||||
@@ -202,15 +120,15 @@ func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
|
||||
}
|
||||
|
||||
// 获取表索引信息
|
||||
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
|
||||
_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
|
||||
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
indexs := make([]Index, 0)
|
||||
indexs := make([]dbi.Index, 0)
|
||||
for _, re := range res {
|
||||
indexs = append(indexs, Index{
|
||||
indexs = append(indexs, dbi.Index{
|
||||
IndexName: re["indexName"].(string),
|
||||
ColumnName: anyx.ConvString(re["columnName"]),
|
||||
IndexType: anyx.ConvString(re["IndexType"]),
|
||||
@@ -220,7 +138,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
|
||||
})
|
||||
}
|
||||
// 把查询结果以索引名分组,索引字段以逗号连接
|
||||
result := make([]Index, 0)
|
||||
result := make([]dbi.Index, 0)
|
||||
key := ""
|
||||
for _, v := range indexs {
|
||||
// 当前的索引名
|
||||
@@ -240,7 +158,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
|
||||
_, err := pd.dc.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
|
||||
_, err := pd.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -257,13 +175,13 @@ func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
|
||||
return res[0]["sql"].(string), nil
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
|
||||
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
|
||||
return pd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
|
||||
}
|
||||
|
||||
// 获取pgsql当前连接的库可访问的schemaNames
|
||||
func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
|
||||
sql := GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
|
||||
sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
|
||||
_, res, err := pd.dc.Query(sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -276,27 +194,27 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
|
||||
}
|
||||
|
||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
|
||||
func (pd *PgsqlDialect) GetDbProgram() DbProgram {
|
||||
func (pd *PgsqlDialect) GetDbProgram() dbi.DbProgram {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) GetDataType(dbColumnType string) DataType {
|
||||
func (pd *PgsqlDialect) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
|
||||
return DataTypeNumber
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
|
||||
return DataTypeDateTime
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
// 日期类型
|
||||
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
|
||||
return DataTypeDate
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
// 时间类型
|
||||
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
|
||||
return DataTypeTime
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
return DataTypeString
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
|
||||
@@ -325,15 +243,15 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
|
||||
return pd.dc.TxExec(tx, sqlStr, args...)
|
||||
}
|
||||
|
||||
func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
|
||||
func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
|
||||
switch dataType {
|
||||
case DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
|
||||
case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
|
||||
res, _ := time.Parse(time.RFC3339, dbColumnValue)
|
||||
return res.Format(time.DateTime)
|
||||
case DataTypeDate: // "2024-01-02T00:00:00Z"
|
||||
case dbi.DataTypeDate: // "2024-01-02T00:00:00Z"
|
||||
res, _ := time.Parse(time.RFC3339, dbColumnValue)
|
||||
return res.Format(time.DateOnly)
|
||||
case DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
|
||||
case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
|
||||
res, _ := time.Parse(time.RFC3339, dbColumnValue)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
111
server/internal/db/dbm/postgres/meta.go
Normal file
111
server/internal/db/dbm/postgres/meta.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
machineapp "mayfly-go/internal/machine/application"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/netx"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
pq "gitee.com/liuzongyang/libpq"
|
||||
)
|
||||
|
||||
var (
|
||||
meta dbi.Meta
|
||||
once sync.Once
|
||||
)
|
||||
|
||||
func GetMeta() dbi.Meta {
|
||||
once.Do(func() {
|
||||
meta = new(PostgresMeta)
|
||||
})
|
||||
return meta
|
||||
}
|
||||
|
||||
type PostgresMeta struct {
|
||||
}
|
||||
|
||||
func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
|
||||
driverName := string(d.Type)
|
||||
// SSH Conect
|
||||
if d.SshTunnelMachineId > 0 {
|
||||
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
|
||||
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
|
||||
if !collx.ArrayContains(sql.Drivers(), driverName) {
|
||||
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
|
||||
}
|
||||
sql.Drivers()
|
||||
}
|
||||
|
||||
db := d.Database
|
||||
var dbParam string
|
||||
exsitSchema := false
|
||||
if db != "" {
|
||||
// postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
|
||||
ss := strings.Split(db, "/")
|
||||
if len(ss) > 1 {
|
||||
exsitSchema = true
|
||||
dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
|
||||
} else {
|
||||
dbParam = "dbname=" + db
|
||||
}
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
|
||||
// 存在额外指定参数,则拼接该连接参数
|
||||
if d.Params != "" {
|
||||
// 存在指定的db,则需要将dbInstance配置中的parmas排除掉dbname和search_path
|
||||
if db != "" {
|
||||
paramArr := strings.Split(d.Params, "&")
|
||||
paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
|
||||
if strings.HasPrefix(param, "dbname=") {
|
||||
return true
|
||||
}
|
||||
if exsitSchema && strings.HasPrefix(param, "search_path") {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
})
|
||||
d.Params = strings.Join(paramArr, " ")
|
||||
}
|
||||
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
|
||||
}
|
||||
|
||||
return sql.Open(driverName, dsn)
|
||||
}
|
||||
|
||||
func (md *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
||||
return &PgsqlDialect{conn}
|
||||
}
|
||||
|
||||
// pgsql dialer
|
||||
type PqSqlDialer struct {
|
||||
sshTunnelMachineId int
|
||||
}
|
||||
|
||||
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
|
||||
return pq.DialOpen(d, name)
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
|
||||
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
|
||||
// 将ssh conn包装,否则会返回错误: ssh: tcpChan: deadline not supported
|
||||
return &netx.WrapSshConn{Conn: sshConn}, nil
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
|
||||
return pd.Dial(network, address)
|
||||
}
|
||||
@@ -90,7 +90,6 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
fid := GetMachineFileId(g)
|
||||
readPath := g.Query("path")
|
||||
readType := g.Query("type")
|
||||
|
||||
sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath)
|
||||
rc.ReqParam = collx.Kvs("machine", mi, "path", readPath)
|
||||
@@ -99,20 +98,27 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) {
|
||||
|
||||
fileInfo, _ := sftpFile.Stat()
|
||||
filesize := fileInfo.Size()
|
||||
// 如果是读取文件内容,则校验文件大小
|
||||
if readType != "1" {
|
||||
biz.IsTrue(filesize < max_read_size, "文件超过1m,请使用下载查看")
|
||||
}
|
||||
// 如果读取类型为下载,则下载文件,否则获取文件内容
|
||||
if readType == "1" {
|
||||
// 截取文件名,如/usr/local/test.java -》 test.java
|
||||
path := strings.Split(readPath, "/")
|
||||
rc.Download(sftpFile, path[len(path)-1])
|
||||
} else {
|
||||
datas, err := io.ReadAll(sftpFile)
|
||||
biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s")
|
||||
rc.ResData = string(datas)
|
||||
}
|
||||
|
||||
biz.IsTrue(filesize < max_read_size, "文件超过1m,请使用下载查看")
|
||||
datas, err := io.ReadAll(sftpFile)
|
||||
biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s")
|
||||
|
||||
rc.ResData = string(datas)
|
||||
}
|
||||
|
||||
func (m *MachineFile) DownloadFile(rc *req.Ctx) {
|
||||
g := rc.GinCtx
|
||||
fid := GetMachineFileId(g)
|
||||
readPath := g.Query("path")
|
||||
|
||||
sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath)
|
||||
rc.ReqParam = collx.Kvs("machine", mi, "path", readPath)
|
||||
biz.ErrIsNilAppendErr(err, "打开文件失败: %s")
|
||||
defer sftpFile.Close()
|
||||
|
||||
// 截取文件名,如/usr/local/test.java -》 test.java
|
||||
path := strings.Split(readPath, "/")
|
||||
rc.Download(sftpFile, path[len(path)-1])
|
||||
}
|
||||
|
||||
func (m *MachineFile) GetDirEntry(rc *req.Ctx) {
|
||||
|
||||
@@ -25,7 +25,9 @@ func InitMachineFileRouter(router *gin.RouterGroup) {
|
||||
|
||||
req.NewDelete(":machineId/files/:fileId", mf.DeleteFile).Log(req.NewLogSave("机器-删除文件配置")).RequiredPermissionCode("machine:file:del"),
|
||||
|
||||
req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-获取文件内容")),
|
||||
req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-读取文件内容")),
|
||||
|
||||
req.NewGet(":machineId/files/:fileId/download", mf.DownloadFile).NoRes().Log(req.NewLogSave("机器-文件下载")),
|
||||
|
||||
req.NewGet(":machineId/files/:fileId/read-dir", mf.GetDirEntry),
|
||||
|
||||
|
||||
Reference in New Issue
Block a user