mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 08:20:25 +08:00 
			
		
		
		
	fix: 机器文件下载问题修复&dbm重构
This commit is contained in:
		@@ -17,7 +17,7 @@
 | 
			
		||||
    "countup.js": "^2.7.0",
 | 
			
		||||
    "cropperjs": "^1.5.11",
 | 
			
		||||
    "echarts": "^5.4.3",
 | 
			
		||||
    "element-plus": "^2.5.0",
 | 
			
		||||
    "element-plus": "^2.5.1",
 | 
			
		||||
    "js-base64": "^3.7.5",
 | 
			
		||||
    "jsencrypt": "^3.3.2",
 | 
			
		||||
    "lodash": "^4.17.21",
 | 
			
		||||
@@ -33,7 +33,7 @@
 | 
			
		||||
    "splitpanes": "^3.1.5",
 | 
			
		||||
    "sql-formatter": "^14.0.0",
 | 
			
		||||
    "uuid": "^9.0.1",
 | 
			
		||||
    "vue": "^3.4.8",
 | 
			
		||||
    "vue": "^3.4.10",
 | 
			
		||||
    "vue-router": "^4.2.5",
 | 
			
		||||
    "xterm": "^5.3.0",
 | 
			
		||||
    "xterm-addon-fit": "^0.8.0",
 | 
			
		||||
@@ -42,7 +42,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  "devDependencies": {
 | 
			
		||||
    "@types/lodash": "^4.14.178",
 | 
			
		||||
    "@types/node": "^15.6.0",
 | 
			
		||||
    "@types/node": "^18.14.0",
 | 
			
		||||
    "@types/nprogress": "^0.2.0",
 | 
			
		||||
    "@types/sortablejs": "^1.15.3",
 | 
			
		||||
    "@typescript-eslint/eslint-plugin": "^6.7.4",
 | 
			
		||||
 
 | 
			
		||||
@@ -35,6 +35,7 @@ export const machineApi = {
 | 
			
		||||
    mvFile: Api.newPost('/machines/{machineId}/files/{fileId}/mv'),
 | 
			
		||||
    uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?' + joinClientParams()),
 | 
			
		||||
    fileContent: Api.newGet('/machines/{machineId}/files/{fileId}/read'),
 | 
			
		||||
    downloadFile: Api.newGet('/machines/{machineId}/files/{fileId}/download'),
 | 
			
		||||
    createFile: Api.newPost('/machines/{machineId}/files/{id}/create-file'),
 | 
			
		||||
    // 修改文件内容
 | 
			
		||||
    updateFileContent: Api.newPost('/machines/{machineId}/files/{id}/write'),
 | 
			
		||||
 
 | 
			
		||||
@@ -611,7 +611,7 @@ const deleteFile = async (files: any) => {
 | 
			
		||||
 | 
			
		||||
const downloadFile = (data: any) => {
 | 
			
		||||
    const a = document.createElement('a');
 | 
			
		||||
    a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&${joinClientParams()}`);
 | 
			
		||||
    a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/download?path=${data.path}&${joinClientParams()}`);
 | 
			
		||||
    a.click();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -9,7 +9,7 @@ import (
 | 
			
		||||
	"mayfly-go/internal/db/api/form"
 | 
			
		||||
	"mayfly-go/internal/db/api/vo"
 | 
			
		||||
	"mayfly-go/internal/db/application"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	msgapp "mayfly-go/internal/msg/application"
 | 
			
		||||
	msgdto "mayfly-go/internal/msg/application/dto"
 | 
			
		||||
@@ -351,7 +351,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
 | 
			
		||||
		writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
 | 
			
		||||
		writer.WriteString("BEGIN;\n")
 | 
			
		||||
		insertSql := "INSERT INTO %s VALUES (%s);\n"
 | 
			
		||||
		dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbm.QueryColumn) error {
 | 
			
		||||
		dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbi.QueryColumn) error {
 | 
			
		||||
			var values []string
 | 
			
		||||
			writer.TryFlush()
 | 
			
		||||
			for _, column := range columns {
 | 
			
		||||
@@ -470,7 +470,7 @@ func getDbName(g *gin.Context) string {
 | 
			
		||||
	return db
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *Db) getDbConn(g *gin.Context) *dbm.DbConn {
 | 
			
		||||
func (d *Db) getDbConn(g *gin.Context) *dbi.DbConn {
 | 
			
		||||
	dc, err := d.DbApp.GetDbConn(getDbId(g), getDbName(g))
 | 
			
		||||
	biz.ErrIsNil(err)
 | 
			
		||||
	return dc
 | 
			
		||||
 
 | 
			
		||||
@@ -4,6 +4,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"mayfly-go/internal/common/consts"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
	tagapp "mayfly-go/internal/tag/application"
 | 
			
		||||
@@ -33,10 +34,10 @@ type Db interface {
 | 
			
		||||
	// @param id 数据库id
 | 
			
		||||
	//
 | 
			
		||||
	// @param dbName 数据库名
 | 
			
		||||
	GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error)
 | 
			
		||||
	GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error)
 | 
			
		||||
 | 
			
		||||
	// 根据数据库实例id获取连接,随机返回该instanceId下已连接的conn,若不存在则是使用该instanceId关联的db进行连接并返回。
 | 
			
		||||
	GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error)
 | 
			
		||||
	GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance, tagApp tagapp.TagTree) Db {
 | 
			
		||||
@@ -142,8 +143,8 @@ func (d *dbAppImpl) Delete(ctx context.Context, id uint64) error {
 | 
			
		||||
		})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
 | 
			
		||||
	return dbm.GetDbConn(dbId, dbName, func() (*dbm.DbInfo, error) {
 | 
			
		||||
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) {
 | 
			
		||||
	return dbm.GetDbConn(dbId, dbName, func() (*dbi.DbInfo, error) {
 | 
			
		||||
		db, err := d.GetById(new(entity.Db), dbId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, errorx.NewBiz("数据库信息不存在")
 | 
			
		||||
@@ -156,7 +157,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
 | 
			
		||||
 | 
			
		||||
		checkDb := dbName
 | 
			
		||||
		// 兼容pgsql/dm db/schema模式
 | 
			
		||||
		if dbm.DbTypePostgres.Equal(instance.Type) || dbm.DbTypeDM.Equal(instance.Type) {
 | 
			
		||||
		if dbi.DbTypePostgres.Equal(instance.Type) || dbi.DbTypeDM.Equal(instance.Type) {
 | 
			
		||||
			ss := strings.Split(dbName, "/")
 | 
			
		||||
			if len(ss) > 1 {
 | 
			
		||||
				checkDb = ss[0]
 | 
			
		||||
@@ -174,7 +175,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
 | 
			
		||||
	})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error) {
 | 
			
		||||
func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) {
 | 
			
		||||
	conn := dbm.GetDbConnByInstanceId(instanceId)
 | 
			
		||||
	if conn != nil {
 | 
			
		||||
		return conn, nil
 | 
			
		||||
@@ -193,8 +194,8 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error
 | 
			
		||||
	return d.GetDbConn(firstDb.Id, strings.Split(firstDb.Database, " ")[0])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbm.DbInfo {
 | 
			
		||||
	di := new(dbm.DbInfo)
 | 
			
		||||
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbi.DbInfo {
 | 
			
		||||
	di := new(dbi.DbInfo)
 | 
			
		||||
	di.InstanceId = instance.Id
 | 
			
		||||
	di.Id = dbId
 | 
			
		||||
	di.Database = database
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
	"mayfly-go/pkg/base"
 | 
			
		||||
@@ -182,20 +182,20 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return syncLog, errorx.NewBiz("解析字段映射json出错: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	var updFieldType dbm.DataType
 | 
			
		||||
	var updFieldType dbi.DataType
 | 
			
		||||
 | 
			
		||||
	// 记录本次同步数据总数
 | 
			
		||||
	total := 0
 | 
			
		||||
	batchSize := task.PageSize
 | 
			
		||||
	result := make([]map[string]any, 0)
 | 
			
		||||
	var queryColumns []*dbm.QueryColumn
 | 
			
		||||
	var queryColumns []*dbi.QueryColumn
 | 
			
		||||
 | 
			
		||||
	err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbm.QueryColumn) error {
 | 
			
		||||
	err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbi.QueryColumn) error {
 | 
			
		||||
		if len(queryColumns) == 0 {
 | 
			
		||||
			queryColumns = columns
 | 
			
		||||
 | 
			
		||||
			// 遍历columns 取task.UpdField的字段类型
 | 
			
		||||
			updFieldType = dbm.DataTypeString
 | 
			
		||||
			updFieldType = dbi.DataTypeString
 | 
			
		||||
			for _, column := range columns {
 | 
			
		||||
				if column.Name == task.UpdField {
 | 
			
		||||
					updFieldType = srcDialect.GetDataType(column.Type)
 | 
			
		||||
@@ -249,7 +249,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
 | 
			
		||||
	return syncLog, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbm.DataType, task *entity.DataSyncTask, srcDialect dbm.DbDialect, targetDbConn *dbm.DbConn, targetDbTx *sql.Tx) error {
 | 
			
		||||
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
 | 
			
		||||
	var data = make([]map[string]any, 0)
 | 
			
		||||
 | 
			
		||||
	// 遍历res,组装插入sql
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
@@ -314,7 +314,7 @@ func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool {
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
 | 
			
		||||
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error {
 | 
			
		||||
	binlogHistory, err := s.binlogHistoryRepo.GetHistoryByTime(job.DbInstanceId, job.PointInTime.Time)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
@@ -341,7 +341,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	restoreInfo := &dbm.RestoreInfo{
 | 
			
		||||
	restoreInfo := &dbi.RestoreInfo{
 | 
			
		||||
		BackupHistory:   backupHistory,
 | 
			
		||||
		BinlogHistories: binlogHistories,
 | 
			
		||||
		StartPosition:   backupHistory.BinlogPosition,
 | 
			
		||||
@@ -354,7 +354,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg
 | 
			
		||||
	return program.ReplayBinlog(ctx, job.DbName, job.DbName, restoreInfo)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
 | 
			
		||||
func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error {
 | 
			
		||||
	backupHistory := &entity.DbBackupHistory{}
 | 
			
		||||
	if err := s.backupHistoryRepo.GetById(backupHistory, job.DbBackupHistoryId); err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
 
 | 
			
		||||
@@ -4,7 +4,7 @@ import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/config"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
	"mayfly-go/pkg/contextx"
 | 
			
		||||
@@ -22,11 +22,11 @@ type DbSqlExecReq struct {
 | 
			
		||||
	Db     string
 | 
			
		||||
	Sql    string
 | 
			
		||||
	Remark string
 | 
			
		||||
	DbConn *dbm.DbConn
 | 
			
		||||
	DbConn *dbi.DbConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DbSqlExecRes struct {
 | 
			
		||||
	Columns []*dbm.QueryColumn
 | 
			
		||||
	Columns []*dbi.QueryColumn
 | 
			
		||||
	Res     []map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -269,7 +269,7 @@ func doInsert(ctx context.Context, insert *sqlparser.Insert, execSqlReq *DbSqlEx
 | 
			
		||||
	return doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
 | 
			
		||||
func doExec(ctx context.Context, sql string, dbConn *dbi.DbConn) (*DbSqlExecRes, error) {
 | 
			
		||||
	rowsAffected, err := dbConn.ExecContext(ctx, sql)
 | 
			
		||||
	execRes := "success"
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -283,7 +283,7 @@ func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes,
 | 
			
		||||
	res = append(res, resData)
 | 
			
		||||
 | 
			
		||||
	return &DbSqlExecRes{
 | 
			
		||||
		Columns: []*dbm.QueryColumn{
 | 
			
		||||
		Columns: []*dbi.QueryColumn{
 | 
			
		||||
			{Name: "sql", Type: "string"},
 | 
			
		||||
			{Name: "rowsAffected", Type: "number"},
 | 
			
		||||
			{Name: "result", Type: "string"},
 | 
			
		||||
 
 | 
			
		||||
@@ -3,6 +3,7 @@ package application
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"mayfly-go/internal/db/dbm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/internal/db/domain/repository"
 | 
			
		||||
	"mayfly-go/pkg/base"
 | 
			
		||||
@@ -50,7 +51,7 @@ func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 {
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance) error {
 | 
			
		||||
	instanceEntity.Network = instanceEntity.GetNetwork()
 | 
			
		||||
	dbConn, err := toDbInfo(instanceEntity, 0, "", "").Conn()
 | 
			
		||||
	dbConn, err := dbm.Conn(toDbInfo(instanceEntity, 0, "", ""))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
@@ -100,9 +101,9 @@ func (app *instanceAppImpl) Delete(ctx context.Context, id uint64) error {
 | 
			
		||||
 | 
			
		||||
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) {
 | 
			
		||||
	ed.Network = ed.GetNetwork()
 | 
			
		||||
	metaDb := dbm.ToDbType(ed.Type).MetaDbName()
 | 
			
		||||
	metaDb := dbi.ToDbType(ed.Type).MetaDbName()
 | 
			
		||||
 | 
			
		||||
	dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn()
 | 
			
		||||
	dbConn, err := dbm.Conn(toDbInfo(ed, 0, metaDb, ""))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
@@ -117,18 +117,9 @@ func (d *DbConn) Begin() (*sql.Tx, error) {
 | 
			
		||||
	return d.db.Begin()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取数据库元信息实现接口
 | 
			
		||||
func (d *DbConn) GetDialect() DbDialect {
 | 
			
		||||
	switch d.Info.Type {
 | 
			
		||||
	case DbTypeMysql, DbTypeMariadb:
 | 
			
		||||
		return &MysqlDialect{dc: d}
 | 
			
		||||
	case DbTypePostgres:
 | 
			
		||||
		return &PgsqlDialect{dc: d}
 | 
			
		||||
	case DbTypeDM:
 | 
			
		||||
		return &DMDialect{dc: d}
 | 
			
		||||
	default:
 | 
			
		||||
		panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
 | 
			
		||||
	}
 | 
			
		||||
// 获取数据库dialect实现接口
 | 
			
		||||
func (d *DbConn) GetDialect() Dialect {
 | 
			
		||||
	return d.Info.Meta.GetDialect(d)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 关闭连接
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"testing"
 | 
			
		||||
@@ -1,12 +1,13 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"embed"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"mayfly-go/pkg/biz"
 | 
			
		||||
	"mayfly-go/pkg/utils/collx"
 | 
			
		||||
	"mayfly-go/pkg/utils/stringx"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DataType string
 | 
			
		||||
@@ -60,7 +61,7 @@ type Index struct {
 | 
			
		||||
 | 
			
		||||
// -----------------------------------元数据接口定义------------------------------------------
 | 
			
		||||
// 数据库方言、元信息接口(表、列、获取表数据等元信息)
 | 
			
		||||
type DbDialect interface {
 | 
			
		||||
type Dialect interface {
 | 
			
		||||
	// 获取数据库服务实例信息
 | 
			
		||||
	GetDbServer() (*DbServer, error)
 | 
			
		||||
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
@@ -8,6 +8,9 @@ import (
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 获取sql.DB函数
 | 
			
		||||
type GetSqlDbFunc func(*DbInfo) (*sql.DB, error)
 | 
			
		||||
 | 
			
		||||
type DbInfo struct {
 | 
			
		||||
	InstanceId uint64 // 实例id
 | 
			
		||||
	Id         uint64 // dbId
 | 
			
		||||
@@ -24,6 +27,8 @@ type DbInfo struct {
 | 
			
		||||
 | 
			
		||||
	TagPath            []string
 | 
			
		||||
	SshTunnelMachineId int
 | 
			
		||||
 | 
			
		||||
	Meta Meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取记录日志的描述
 | 
			
		||||
@@ -32,22 +37,16 @@ func (d *DbInfo) GetLogDesc() string {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 连接数据库
 | 
			
		||||
func (dbInfo *DbInfo) Conn() (*DbConn, error) {
 | 
			
		||||
	var conn *sql.DB
 | 
			
		||||
	var err error
 | 
			
		||||
	database := dbInfo.Database
 | 
			
		||||
 | 
			
		||||
	switch dbInfo.Type {
 | 
			
		||||
	case DbTypeMysql, DbTypeMariadb:
 | 
			
		||||
		conn, err = getMysqlDB(dbInfo)
 | 
			
		||||
	case DbTypePostgres:
 | 
			
		||||
		conn, err = getPgsqlDB(dbInfo)
 | 
			
		||||
	case DbTypeDM:
 | 
			
		||||
		conn, err = getDmDB(dbInfo)
 | 
			
		||||
	default:
 | 
			
		||||
		return nil, errorx.NewBiz("invalid database type: %s", dbInfo.Type)
 | 
			
		||||
func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) {
 | 
			
		||||
	if meta == nil {
 | 
			
		||||
		return nil, errorx.NewBiz("数据库元信息接口不能为空")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 赋值Meta,方便后续获取dialect等
 | 
			
		||||
	dbInfo.Meta = meta
 | 
			
		||||
	database := dbInfo.Database
 | 
			
		||||
 | 
			
		||||
	conn, err := meta.GetSqlDb(dbInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logx.Errorf("连接db失败: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error())
 | 
			
		||||
		return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))
 | 
			
		||||
							
								
								
									
										12
									
								
								server/internal/db/dbm/dbi/meta.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								server/internal/db/dbm/dbi/meta.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,12 @@
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import "database/sql"
 | 
			
		||||
 | 
			
		||||
// 数据库元信息获取,如获取sql.DB、Dialect等
 | 
			
		||||
type Meta interface {
 | 
			
		||||
	// 获取数据库服务实例信息
 | 
			
		||||
	GetSqlDb(*DbInfo) (*sql.DB, error)
 | 
			
		||||
 | 
			
		||||
	// 获取数据库方言,用于获取表结构等信息
 | 
			
		||||
	GetDialect(*DbConn) Dialect
 | 
			
		||||
}
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dbi
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
@@ -3,6 +3,10 @@ package dbm
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/common/consts"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dm"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/mysql"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/postgres"
 | 
			
		||||
	"mayfly-go/internal/machine/mcm"
 | 
			
		||||
	"mayfly-go/pkg/cache"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
@@ -15,7 +19,7 @@ var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
 | 
			
		||||
	WithUpdateAccessTime(true).
 | 
			
		||||
	OnEvicted(func(key any, value any) {
 | 
			
		||||
		logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
 | 
			
		||||
		value.(*DbConn).Close()
 | 
			
		||||
		value.(*dbi.DbConn).Close()
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
func init() {
 | 
			
		||||
@@ -23,7 +27,7 @@ func init() {
 | 
			
		||||
		// 遍历所有db连接实例,若存在db实例使用该ssh隧道机器,则返回true,表示还在使用中...
 | 
			
		||||
		items := connCache.Items()
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			if v.Value.(*DbConn).Info.SshTunnelMachineId == machineId {
 | 
			
		||||
			if v.Value.(*dbi.DbConn).Info.SshTunnelMachineId == machineId {
 | 
			
		||||
				return true
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
@@ -33,8 +37,21 @@ func init() {
 | 
			
		||||
 | 
			
		||||
var mutex sync.Mutex
 | 
			
		||||
 | 
			
		||||
func getDbMetaByType(dt dbi.DbType) dbi.Meta {
 | 
			
		||||
	switch dt {
 | 
			
		||||
	case dbi.DbTypeMysql, dbi.DbTypeMariadb:
 | 
			
		||||
		return mysql.GetMeta()
 | 
			
		||||
	case dbi.DbTypePostgres:
 | 
			
		||||
		return postgres.GetMeta()
 | 
			
		||||
	case dbi.DbTypeDM:
 | 
			
		||||
		return dm.GetMeta()
 | 
			
		||||
	default:
 | 
			
		||||
		panic(fmt.Sprintf("invalid database type: %s", dt))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 从缓存中获取数据库连接信息,若缓存中不存在则会使用回调函数获取dbInfo进行连接并缓存
 | 
			
		||||
func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) (*DbConn, error) {
 | 
			
		||||
func GetDbConn(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) {
 | 
			
		||||
	connId := GetDbConnId(dbId, database)
 | 
			
		||||
 | 
			
		||||
	// connId不为空,则为需要缓存
 | 
			
		||||
@@ -42,7 +59,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
 | 
			
		||||
	if needCache {
 | 
			
		||||
		load, ok := connCache.Get(connId)
 | 
			
		||||
		if ok {
 | 
			
		||||
			return load.(*DbConn), nil
 | 
			
		||||
			return load.(*dbi.DbConn), nil
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -56,7 +73,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 连接数据库
 | 
			
		||||
	dbConn, err := dbInfo.Conn()
 | 
			
		||||
	dbConn, err := Conn(dbInfo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -67,10 +84,15 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
 | 
			
		||||
	return dbConn, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 使用指定dbInfo信息进行连接
 | 
			
		||||
func Conn(di *dbi.DbInfo) (*dbi.DbConn, error) {
 | 
			
		||||
	return di.Conn(getDbMetaByType(di.Type))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 根据实例id获取连接
 | 
			
		||||
func GetDbConnByInstanceId(instanceId uint64) *DbConn {
 | 
			
		||||
func GetDbConnByInstanceId(instanceId uint64) *dbi.DbConn {
 | 
			
		||||
	for _, connItem := range connCache.Items() {
 | 
			
		||||
		conn := connItem.Value.(*DbConn)
 | 
			
		||||
		conn := connItem.Value.(*dbi.DbConn)
 | 
			
		||||
		if conn.Info.InstanceId == instanceId {
 | 
			
		||||
			return conn
 | 
			
		||||
		}
 | 
			
		||||
@@ -82,3 +104,12 @@ func GetDbConnByInstanceId(instanceId uint64) *DbConn {
 | 
			
		||||
func CloseDb(dbId uint64, db string) {
 | 
			
		||||
	connCache.Delete(GetDbConnId(dbId, db))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取连接id
 | 
			
		||||
func GetDbConnId(dbId uint64, db string) string {
 | 
			
		||||
	if dbId == 0 {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("%d:%s", dbId, db)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,9 +1,10 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package dm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
	"mayfly-go/pkg/utils/anyx"
 | 
			
		||||
@@ -14,30 +15,6 @@ import (
 | 
			
		||||
	_ "gitee.com/chunanyong/dm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getDmDB(d *DbInfo) (*sql.DB, error) {
 | 
			
		||||
	driverName := "dm"
 | 
			
		||||
	db := d.Database
 | 
			
		||||
	var dbParam string
 | 
			
		||||
	if db != "" {
 | 
			
		||||
		// dm database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
 | 
			
		||||
		ss := strings.Split(db, "/")
 | 
			
		||||
		if len(ss) > 1 {
 | 
			
		||||
			dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
 | 
			
		||||
		} else {
 | 
			
		||||
			dbParam = db
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)
 | 
			
		||||
	return sql.Open(driverName, dsn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ---------------------------------- DM元数据 -----------------------------------
 | 
			
		||||
const (
 | 
			
		||||
	DM_META_FILE      = "metasql/dm_meta.sql"
 | 
			
		||||
	DM_DB_SCHEMAS     = "DM_DB_SCHEMAS"
 | 
			
		||||
@@ -47,15 +24,15 @@ const (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DMDialect struct {
 | 
			
		||||
	dc *DbConn
 | 
			
		||||
	dc *dbi.DbConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dd *DMDialect) GetDbServer() (*DbServer, error) {
 | 
			
		||||
func (dd *DMDialect) GetDbServer() (*dbi.DbServer, error) {
 | 
			
		||||
	_, res, err := dd.dc.Query("select * from v$instance")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ds := &DbServer{
 | 
			
		||||
	ds := &dbi.DbServer{
 | 
			
		||||
		Version: anyx.ConvString(res[0]["SVR_VERSION"]),
 | 
			
		||||
	}
 | 
			
		||||
	return ds, nil
 | 
			
		||||
@@ -76,20 +53,20 @@ func (dd *DMDialect) GetDbNames() ([]string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取表基础元信息, 如表名等
 | 
			
		||||
func (dd *DMDialect) GetTables() ([]Table, error) {
 | 
			
		||||
func (dd *DMDialect) GetTables() ([]dbi.Table, error) {
 | 
			
		||||
 | 
			
		||||
	// 首先执行更新统计信息sql 这个统计信息在数据量比较大的时候就比较耗时,所以最好定时执行
 | 
			
		||||
	// _, _, err := pd.dc.Query("dbms_stats.GATHER_SCHEMA_stats(SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))")
 | 
			
		||||
 | 
			
		||||
	// 查询表信息
 | 
			
		||||
	_, res, err := dd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
 | 
			
		||||
	_, res, err := dd.dc.Query(dbi.GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tables := make([]Table, 0)
 | 
			
		||||
	tables := make([]dbi.Table, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		tables = append(tables, Table{
 | 
			
		||||
		tables = append(tables, dbi.Table{
 | 
			
		||||
			TableName:    re["TABLE_NAME"].(string),
 | 
			
		||||
			TableComment: anyx.ConvString(re["TABLE_COMMENT"]),
 | 
			
		||||
			CreateTime:   anyx.ConvString(re["CREATE_TIME"]),
 | 
			
		||||
@@ -102,7 +79,7 @@ func (dd *DMDialect) GetTables() ([]Table, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取列元信息, 如列名等
 | 
			
		||||
func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
 | 
			
		||||
func (dd *DMDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
 | 
			
		||||
	tableName := ""
 | 
			
		||||
	for i := 0; i < len(tableNames); i++ {
 | 
			
		||||
		if i != 0 {
 | 
			
		||||
@@ -111,14 +88,14 @@ func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
 | 
			
		||||
		tableName = tableName + "'" + tableNames[i] + "'"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
 | 
			
		||||
	_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	columns := make([]Column, 0)
 | 
			
		||||
	columns := make([]dbi.Column, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		columns = append(columns, Column{
 | 
			
		||||
		columns = append(columns, dbi.Column{
 | 
			
		||||
			TableName:     re["TABLE_NAME"].(string),
 | 
			
		||||
			ColumnName:    re["COLUMN_NAME"].(string),
 | 
			
		||||
			ColumnType:    anyx.ConvString(re["COLUMN_TYPE"]),
 | 
			
		||||
@@ -150,15 +127,15 @@ func (dd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取表索引信息
 | 
			
		||||
func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
 | 
			
		||||
	_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
 | 
			
		||||
func (dd *DMDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
 | 
			
		||||
	_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	indexs := make([]Index, 0)
 | 
			
		||||
	indexs := make([]dbi.Index, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		indexs = append(indexs, Index{
 | 
			
		||||
		indexs = append(indexs, dbi.Index{
 | 
			
		||||
			IndexName:    re["INDEX_NAME"].(string),
 | 
			
		||||
			ColumnName:   anyx.ConvString(re["COLUMN_NAME"]),
 | 
			
		||||
			IndexType:    anyx.ConvString(re["INDEX_TYPE"]),
 | 
			
		||||
@@ -168,7 +145,7 @@ func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	// 把查询结果以索引名分组,索引字段以逗号连接
 | 
			
		||||
	result := make([]Index, 0)
 | 
			
		||||
	result := make([]dbi.Index, 0)
 | 
			
		||||
	key := ""
 | 
			
		||||
	for _, v := range indexs {
 | 
			
		||||
		// 当前的索引名
 | 
			
		||||
@@ -255,13 +232,13 @@ func (dd *DMDialect) GetTableDDL(tableName string) (string, error) {
 | 
			
		||||
	return builder.String(), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dd *DMDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
 | 
			
		||||
func (dd *DMDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
 | 
			
		||||
	return dd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取DM当前连接的库可访问的schemaNames
 | 
			
		||||
func (dd *DMDialect) GetSchemas() ([]string, error) {
 | 
			
		||||
	sql := GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
 | 
			
		||||
	sql := dbi.GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
 | 
			
		||||
	_, res, err := dd.dc.Query(sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -274,27 +251,27 @@ func (dd *DMDialect) GetSchemas() ([]string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
 | 
			
		||||
func (dd *DMDialect) GetDbProgram() DbProgram {
 | 
			
		||||
func (dd *DMDialect) GetDbProgram() dbi.DbProgram {
 | 
			
		||||
	panic("implement me")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dd *DMDialect) GetDataType(dbColumnType string) DataType {
 | 
			
		||||
func (dd *DMDialect) GetDataType(dbColumnType string) dbi.DataType {
 | 
			
		||||
	if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeNumber
 | 
			
		||||
		return dbi.DataTypeNumber
 | 
			
		||||
	}
 | 
			
		||||
	// 日期时间类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeDateTime
 | 
			
		||||
		return dbi.DataTypeDateTime
 | 
			
		||||
	}
 | 
			
		||||
	// 日期类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeDate
 | 
			
		||||
		return dbi.DataTypeDate
 | 
			
		||||
	}
 | 
			
		||||
	// 时间类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeTime
 | 
			
		||||
		return dbi.DataTypeTime
 | 
			
		||||
	}
 | 
			
		||||
	return DataTypeString
 | 
			
		||||
	return dbi.DataTypeString
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
 | 
			
		||||
@@ -322,15 +299,15 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string,
 | 
			
		||||
	return int64(effRows), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
 | 
			
		||||
func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
 | 
			
		||||
	switch dataType {
 | 
			
		||||
	case DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
 | 
			
		||||
	case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
 | 
			
		||||
		res, _ := time.Parse(time.RFC3339, dbColumnValue)
 | 
			
		||||
		return res.Format(time.DateTime)
 | 
			
		||||
	case DataTypeDate: // "2024-01-02T00:00:00+08:00"
 | 
			
		||||
	case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
 | 
			
		||||
		res, _ := time.Parse(time.RFC3339, dbColumnValue)
 | 
			
		||||
		return res.Format(time.DateOnly)
 | 
			
		||||
	case DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
 | 
			
		||||
	case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
 | 
			
		||||
		res, _ := time.Parse(time.RFC3339, dbColumnValue)
 | 
			
		||||
		return res.Format(time.TimeOnly)
 | 
			
		||||
	}
 | 
			
		||||
							
								
								
									
										51
									
								
								server/internal/db/dbm/dm/meta.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								server/internal/db/dbm/dm/meta.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,51 @@
 | 
			
		||||
package dm
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	meta dbi.Meta
 | 
			
		||||
	once sync.Once
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetMeta() dbi.Meta {
 | 
			
		||||
	once.Do(func() {
 | 
			
		||||
		meta = new(DmMeta)
 | 
			
		||||
	})
 | 
			
		||||
	return meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type DmMeta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *DmMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	driverName := "dm"
 | 
			
		||||
	db := d.Database
 | 
			
		||||
	var dbParam string
 | 
			
		||||
	if db != "" {
 | 
			
		||||
		// dm database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
 | 
			
		||||
		ss := strings.Split(db, "/")
 | 
			
		||||
		if len(ss) > 1 {
 | 
			
		||||
			dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
 | 
			
		||||
		} else {
 | 
			
		||||
			dbParam = db
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err := d.IfUseSshTunnelChangeIpPort()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)
 | 
			
		||||
	return sql.Open(driverName, dsn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *DmMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
 | 
			
		||||
	return &DMDialect{conn}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,40 +1,16 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package mysql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	machineapp "mayfly-go/internal/machine/application"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/utils/anyx"
 | 
			
		||||
	"net"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-sql-driver/mysql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getMysqlDB(d *DbInfo) (*sql.DB, error) {
 | 
			
		||||
	// SSH Conect
 | 
			
		||||
	if d.SshTunnelMachineId > 0 {
 | 
			
		||||
		sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
 | 
			
		||||
			return sshTunnelMachine.GetDialConn("tcp", addr)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	// 设置dataSourceName  -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
 | 
			
		||||
	dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
 | 
			
		||||
	if d.Params != "" {
 | 
			
		||||
		dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
 | 
			
		||||
	}
 | 
			
		||||
	const driverName = "mysql"
 | 
			
		||||
	return sql.Open(driverName, dsn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ---------------------------------- mysql元数据 -----------------------------------
 | 
			
		||||
const (
 | 
			
		||||
	MYSQL_META_FILE      = "metasql/mysql_meta.sql"
 | 
			
		||||
	MYSQL_DBS            = "MYSQL_DBS"
 | 
			
		||||
@@ -44,22 +20,22 @@ const (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MysqlDialect struct {
 | 
			
		||||
	dc *DbConn
 | 
			
		||||
	dc *dbi.DbConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlDialect) GetDbServer() (*DbServer, error) {
 | 
			
		||||
func (md *MysqlDialect) GetDbServer() (*dbi.DbServer, error) {
 | 
			
		||||
	_, res, err := md.dc.Query("SELECT VERSION() version")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ds := &DbServer{
 | 
			
		||||
	ds := &dbi.DbServer{
 | 
			
		||||
		Version: anyx.ConvString(res[0]["version"]),
 | 
			
		||||
	}
 | 
			
		||||
	return ds, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlDialect) GetDbNames() ([]string, error) {
 | 
			
		||||
	_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_DBS))
 | 
			
		||||
	_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_DBS))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
@@ -73,15 +49,15 @@ func (md *MysqlDialect) GetDbNames() ([]string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取表基础元信息, 如表名等
 | 
			
		||||
func (md *MysqlDialect) GetTables() ([]Table, error) {
 | 
			
		||||
	_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
 | 
			
		||||
func (md *MysqlDialect) GetTables() ([]dbi.Table, error) {
 | 
			
		||||
	_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tables := make([]Table, 0)
 | 
			
		||||
	tables := make([]dbi.Table, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		tables = append(tables, Table{
 | 
			
		||||
		tables = append(tables, dbi.Table{
 | 
			
		||||
			TableName:    re["tableName"].(string),
 | 
			
		||||
			TableComment: anyx.ConvString(re["tableComment"]),
 | 
			
		||||
			CreateTime:   anyx.ConvString(re["createTime"]),
 | 
			
		||||
@@ -94,7 +70,7 @@ func (md *MysqlDialect) GetTables() ([]Table, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取列元信息, 如列名等
 | 
			
		||||
func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
 | 
			
		||||
func (md *MysqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
 | 
			
		||||
	tableName := ""
 | 
			
		||||
	for i := 0; i < len(tableNames); i++ {
 | 
			
		||||
		if i != 0 {
 | 
			
		||||
@@ -103,14 +79,14 @@ func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
 | 
			
		||||
		tableName = tableName + "'" + tableNames[i] + "'"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, res, err := md.dc.Query(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
 | 
			
		||||
	_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	columns := make([]Column, 0)
 | 
			
		||||
	columns := make([]dbi.Column, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		columns = append(columns, Column{
 | 
			
		||||
		columns = append(columns, dbi.Column{
 | 
			
		||||
			TableName:     re["tableName"].(string),
 | 
			
		||||
			ColumnName:    re["columnName"].(string),
 | 
			
		||||
			ColumnType:    anyx.ConvString(re["columnType"]),
 | 
			
		||||
@@ -144,15 +120,15 @@ func (md *MysqlDialect) GetPrimaryKey(tablename string) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取表索引信息
 | 
			
		||||
func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) {
 | 
			
		||||
	_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)
 | 
			
		||||
func (md *MysqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
 | 
			
		||||
	_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	indexs := make([]Index, 0)
 | 
			
		||||
	indexs := make([]dbi.Index, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		indexs = append(indexs, Index{
 | 
			
		||||
		indexs = append(indexs, dbi.Index{
 | 
			
		||||
			IndexName:    re["indexName"].(string),
 | 
			
		||||
			ColumnName:   anyx.ConvString(re["columnName"]),
 | 
			
		||||
			IndexType:    anyx.ConvString(re["indexType"]),
 | 
			
		||||
@@ -162,7 +138,7 @@ func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) {
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	// 把查询结果以索引名分组,索引字段以逗号连接
 | 
			
		||||
	result := make([]Index, 0)
 | 
			
		||||
	result := make([]dbi.Index, 0)
 | 
			
		||||
	key := ""
 | 
			
		||||
	for _, v := range indexs {
 | 
			
		||||
		// 当前的索引名
 | 
			
		||||
@@ -189,7 +165,7 @@ func (md *MysqlDialect) GetTableDDL(tableName string) (string, error) {
 | 
			
		||||
	return res[0]["Create Table"].(string) + ";", nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
 | 
			
		||||
func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
 | 
			
		||||
	return md.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -198,27 +174,27 @@ func (md *MysqlDialect) GetSchemas() ([]string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
 | 
			
		||||
func (md *MysqlDialect) GetDbProgram() DbProgram {
 | 
			
		||||
func (md *MysqlDialect) GetDbProgram() dbi.DbProgram {
 | 
			
		||||
	return NewDbProgramMysql(md.dc)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlDialect) GetDataType(dbColumnType string) DataType {
 | 
			
		||||
func (md *MysqlDialect) GetDataType(dbColumnType string) dbi.DataType {
 | 
			
		||||
	if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeNumber
 | 
			
		||||
		return dbi.DataTypeNumber
 | 
			
		||||
	}
 | 
			
		||||
	// 日期时间类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeDateTime
 | 
			
		||||
		return dbi.DataTypeDateTime
 | 
			
		||||
	}
 | 
			
		||||
	// 日期类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeDate
 | 
			
		||||
		return dbi.DataTypeDate
 | 
			
		||||
	}
 | 
			
		||||
	// 时间类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeTime
 | 
			
		||||
		return dbi.DataTypeTime
 | 
			
		||||
	}
 | 
			
		||||
	return DataTypeString
 | 
			
		||||
	return dbi.DataTypeString
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
 | 
			
		||||
@@ -246,7 +222,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
 | 
			
		||||
	return md.dc.TxExec(tx, sqlStr, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
 | 
			
		||||
func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
 | 
			
		||||
	// mysql不需要格式化时间日期等
 | 
			
		||||
	return dbColumnValue
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										52
									
								
								server/internal/db/dbm/mysql/meta.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								server/internal/db/dbm/mysql/meta.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
package mysql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	machineapp "mayfly-go/internal/machine/application"
 | 
			
		||||
	"net"
 | 
			
		||||
	"sync"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-sql-driver/mysql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	meta dbi.Meta
 | 
			
		||||
	once sync.Once
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetMeta() dbi.Meta {
 | 
			
		||||
	once.Do(func() {
 | 
			
		||||
		meta = new(MysqlMeta)
 | 
			
		||||
	})
 | 
			
		||||
	return meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MysqlMeta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	// SSH Conect
 | 
			
		||||
	if d.SshTunnelMachineId > 0 {
 | 
			
		||||
		sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
 | 
			
		||||
			return sshTunnelMachine.GetDialConn("tcp", addr)
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	// 设置dataSourceName  -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
 | 
			
		||||
	dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
 | 
			
		||||
	if d.Params != "" {
 | 
			
		||||
		dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
 | 
			
		||||
	}
 | 
			
		||||
	const driverName = "mysql"
 | 
			
		||||
	return sql.Open(driverName, dsn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *MysqlMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
 | 
			
		||||
	return &MysqlDialect{conn}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package mysql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
@@ -19,27 +19,28 @@ import (
 | 
			
		||||
	"golang.org/x/sync/singleflight"
 | 
			
		||||
 | 
			
		||||
	"mayfly-go/internal/db/config"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"mayfly-go/pkg/logx"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var _ DbProgram = (*DbProgramMysql)(nil)
 | 
			
		||||
var _ dbi.DbProgram = (*DbProgramMysql)(nil)
 | 
			
		||||
 | 
			
		||||
type DbProgramMysql struct {
 | 
			
		||||
	dbConn *DbConn
 | 
			
		||||
	dbConn *dbi.DbConn
 | 
			
		||||
	// mysqlBin 用于集成测试
 | 
			
		||||
	mysqlBin *config.MysqlBin
 | 
			
		||||
	// backupPath 用于集成测试
 | 
			
		||||
	backupPath string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDbProgramMysql(dbConn *DbConn) *DbProgramMysql {
 | 
			
		||||
func NewDbProgramMysql(dbConn *dbi.DbConn) *DbProgramMysql {
 | 
			
		||||
	return &DbProgramMysql{
 | 
			
		||||
		dbConn: dbConn,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (svc *DbProgramMysql) dbInfo() *DbInfo {
 | 
			
		||||
func (svc *DbProgramMysql) dbInfo() *dbi.DbInfo {
 | 
			
		||||
	dbInfo := svc.dbConn.Info
 | 
			
		||||
	err := dbInfo.IfUseSshTunnelChangeIpPort()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -55,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
 | 
			
		||||
	dbInfo := svc.dbInfo()
 | 
			
		||||
	var mysqlBin *config.MysqlBin
 | 
			
		||||
	switch dbInfo.Type {
 | 
			
		||||
	case DbTypeMariadb:
 | 
			
		||||
	case dbi.DbTypeMariadb:
 | 
			
		||||
		mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin)
 | 
			
		||||
	case DbTypeMysql:
 | 
			
		||||
	case dbi.DbTypeMysql:
 | 
			
		||||
		mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
 | 
			
		||||
	default:
 | 
			
		||||
		panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type))
 | 
			
		||||
@@ -488,7 +489,7 @@ func (svc *DbProgramMysql) GetBinlogEventPositionAtOrAfterTime(ctx context.Conte
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ReplayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`.
 | 
			
		||||
func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) {
 | 
			
		||||
func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *dbi.RestoreInfo) (replayErr error) {
 | 
			
		||||
	const (
 | 
			
		||||
		// Variable lower_case_table_names related.
 | 
			
		||||
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
//go:build e2e
 | 
			
		||||
 | 
			
		||||
package dbm
 | 
			
		||||
package mysql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
@@ -1,10 +1,11 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package mysql
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
	"mayfly-go/internal/db/domain/entity"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"testing"
 | 
			
		||||
 | 
			
		||||
	"github.com/stretchr/testify/require"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func Test_readBinlogInfoFromBackup(t *testing.T) {
 | 
			
		||||
@@ -1,99 +1,17 @@
 | 
			
		||||
package dbm
 | 
			
		||||
package postgres
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	machineapp "mayfly-go/internal/machine/application"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	"mayfly-go/pkg/errorx"
 | 
			
		||||
	"mayfly-go/pkg/utils/anyx"
 | 
			
		||||
	"mayfly-go/pkg/utils/collx"
 | 
			
		||||
	"mayfly-go/pkg/utils/netx"
 | 
			
		||||
	"net"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	pq "gitee.com/liuzongyang/libpq"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func getPgsqlDB(d *DbInfo) (*sql.DB, error) {
 | 
			
		||||
	driverName := string(d.Type)
 | 
			
		||||
	// SSH Conect
 | 
			
		||||
	if d.SshTunnelMachineId > 0 {
 | 
			
		||||
		// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
 | 
			
		||||
		driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
 | 
			
		||||
		if !collx.ArrayContains(sql.Drivers(), driverName) {
 | 
			
		||||
			sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
 | 
			
		||||
		}
 | 
			
		||||
		sql.Drivers()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db := d.Database
 | 
			
		||||
	var dbParam string
 | 
			
		||||
	exsitSchema := false
 | 
			
		||||
	if db != "" {
 | 
			
		||||
		// postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
 | 
			
		||||
		ss := strings.Split(db, "/")
 | 
			
		||||
		if len(ss) > 1 {
 | 
			
		||||
			exsitSchema = true
 | 
			
		||||
			dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
 | 
			
		||||
		} else {
 | 
			
		||||
			dbParam = "dbname=" + db
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
 | 
			
		||||
	// 存在额外指定参数,则拼接该连接参数
 | 
			
		||||
	if d.Params != "" {
 | 
			
		||||
		// 存在指定的db,则需要将dbInstance配置中的parmas排除掉dbname和search_path
 | 
			
		||||
		if db != "" {
 | 
			
		||||
			paramArr := strings.Split(d.Params, "&")
 | 
			
		||||
			paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
 | 
			
		||||
				if strings.HasPrefix(param, "dbname=") {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				if exsitSchema && strings.HasPrefix(param, "search_path") {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				return false
 | 
			
		||||
			})
 | 
			
		||||
			d.Params = strings.Join(paramArr, " ")
 | 
			
		||||
		}
 | 
			
		||||
		dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sql.Open(driverName, dsn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// pgsql dialer
 | 
			
		||||
type PqSqlDialer struct {
 | 
			
		||||
	sshTunnelMachineId int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
 | 
			
		||||
	return pq.DialOpen(d, name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
 | 
			
		||||
	sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
 | 
			
		||||
		// 将ssh conn包装,否则会返回错误: ssh: tcpChan: deadline not supported
 | 
			
		||||
		return &netx.WrapSshConn{Conn: sshConn}, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
 | 
			
		||||
	return pd.Dial(network, address)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ---------------------------------- pgsql元数据 -----------------------------------
 | 
			
		||||
const (
 | 
			
		||||
	PGSQL_META_FILE      = "metasql/pgsql_meta.sql"
 | 
			
		||||
	PGSQL_DB_SCHEMAS     = "PGSQL_DB_SCHEMAS"
 | 
			
		||||
@@ -104,15 +22,15 @@ const (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PgsqlDialect struct {
 | 
			
		||||
	dc *DbConn
 | 
			
		||||
	dc *dbi.DbConn
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) GetDbServer() (*DbServer, error) {
 | 
			
		||||
func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query("SHOW server_version")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	ds := &DbServer{
 | 
			
		||||
	ds := &dbi.DbServer{
 | 
			
		||||
		Version: anyx.ConvString(res[0]["server_version"]),
 | 
			
		||||
	}
 | 
			
		||||
	return ds, nil
 | 
			
		||||
@@ -133,15 +51,15 @@ func (pd *PgsqlDialect) GetDbNames() ([]string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取表基础元信息, 如表名等
 | 
			
		||||
func (pd *PgsqlDialect) GetTables() ([]Table, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
 | 
			
		||||
func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tables := make([]Table, 0)
 | 
			
		||||
	tables := make([]dbi.Table, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		tables = append(tables, Table{
 | 
			
		||||
		tables = append(tables, dbi.Table{
 | 
			
		||||
			TableName:    re["tableName"].(string),
 | 
			
		||||
			TableComment: anyx.ConvString(re["tableComment"]),
 | 
			
		||||
			CreateTime:   anyx.ConvString(re["createTime"]),
 | 
			
		||||
@@ -154,7 +72,7 @@ func (pd *PgsqlDialect) GetTables() ([]Table, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取列元信息, 如列名等
 | 
			
		||||
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
 | 
			
		||||
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
 | 
			
		||||
	tableName := ""
 | 
			
		||||
	for i := 0; i < len(tableNames); i++ {
 | 
			
		||||
		if i != 0 {
 | 
			
		||||
@@ -163,14 +81,14 @@ func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
 | 
			
		||||
		tableName = tableName + "'" + tableNames[i] + "'"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
 | 
			
		||||
	_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	columns := make([]Column, 0)
 | 
			
		||||
	columns := make([]dbi.Column, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		columns = append(columns, Column{
 | 
			
		||||
		columns = append(columns, dbi.Column{
 | 
			
		||||
			TableName:     re["tableName"].(string),
 | 
			
		||||
			ColumnName:    re["columnName"].(string),
 | 
			
		||||
			ColumnType:    anyx.ConvString(re["columnType"]),
 | 
			
		||||
@@ -202,15 +120,15 @@ func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取表索引信息
 | 
			
		||||
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
 | 
			
		||||
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
 | 
			
		||||
	_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	indexs := make([]Index, 0)
 | 
			
		||||
	indexs := make([]dbi.Index, 0)
 | 
			
		||||
	for _, re := range res {
 | 
			
		||||
		indexs = append(indexs, Index{
 | 
			
		||||
		indexs = append(indexs, dbi.Index{
 | 
			
		||||
			IndexName:    re["indexName"].(string),
 | 
			
		||||
			ColumnName:   anyx.ConvString(re["columnName"]),
 | 
			
		||||
			IndexType:    anyx.ConvString(re["IndexType"]),
 | 
			
		||||
@@ -220,7 +138,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	// 把查询结果以索引名分组,索引字段以逗号连接
 | 
			
		||||
	result := make([]Index, 0)
 | 
			
		||||
	result := make([]dbi.Index, 0)
 | 
			
		||||
	key := ""
 | 
			
		||||
	for _, v := range indexs {
 | 
			
		||||
		// 当前的索引名
 | 
			
		||||
@@ -240,7 +158,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
 | 
			
		||||
 | 
			
		||||
// 获取建表ddl
 | 
			
		||||
func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
 | 
			
		||||
	_, err := pd.dc.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
 | 
			
		||||
	_, err := pd.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
@@ -257,13 +175,13 @@ func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
 | 
			
		||||
	return res[0]["sql"].(string), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
 | 
			
		||||
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
 | 
			
		||||
	return pd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 获取pgsql当前连接的库可访问的schemaNames
 | 
			
		||||
func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
 | 
			
		||||
	sql := GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
 | 
			
		||||
	sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
 | 
			
		||||
	_, res, err := pd.dc.Query(sql)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
@@ -276,27 +194,27 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
 | 
			
		||||
func (pd *PgsqlDialect) GetDbProgram() DbProgram {
 | 
			
		||||
func (pd *PgsqlDialect) GetDbProgram() dbi.DbProgram {
 | 
			
		||||
	panic("implement me")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) GetDataType(dbColumnType string) DataType {
 | 
			
		||||
func (pd *PgsqlDialect) GetDataType(dbColumnType string) dbi.DataType {
 | 
			
		||||
	if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeNumber
 | 
			
		||||
		return dbi.DataTypeNumber
 | 
			
		||||
	}
 | 
			
		||||
	// 日期时间类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeDateTime
 | 
			
		||||
		return dbi.DataTypeDateTime
 | 
			
		||||
	}
 | 
			
		||||
	// 日期类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeDate
 | 
			
		||||
		return dbi.DataTypeDate
 | 
			
		||||
	}
 | 
			
		||||
	// 时间类型
 | 
			
		||||
	if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
 | 
			
		||||
		return DataTypeTime
 | 
			
		||||
		return dbi.DataTypeTime
 | 
			
		||||
	}
 | 
			
		||||
	return DataTypeString
 | 
			
		||||
	return dbi.DataTypeString
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
 | 
			
		||||
@@ -325,15 +243,15 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
 | 
			
		||||
	return pd.dc.TxExec(tx, sqlStr, args...)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
 | 
			
		||||
func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
 | 
			
		||||
	switch dataType {
 | 
			
		||||
	case DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
 | 
			
		||||
	case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
 | 
			
		||||
		res, _ := time.Parse(time.RFC3339, dbColumnValue)
 | 
			
		||||
		return res.Format(time.DateTime)
 | 
			
		||||
	case DataTypeDate: //  "2024-01-02T00:00:00Z"
 | 
			
		||||
	case dbi.DataTypeDate: //  "2024-01-02T00:00:00Z"
 | 
			
		||||
		res, _ := time.Parse(time.RFC3339, dbColumnValue)
 | 
			
		||||
		return res.Format(time.DateOnly)
 | 
			
		||||
	case DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
 | 
			
		||||
	case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
 | 
			
		||||
		res, _ := time.Parse(time.RFC3339, dbColumnValue)
 | 
			
		||||
		return res.Format(time.TimeOnly)
 | 
			
		||||
	}
 | 
			
		||||
							
								
								
									
										111
									
								
								server/internal/db/dbm/postgres/meta.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								server/internal/db/dbm/postgres/meta.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,111 @@
 | 
			
		||||
package postgres
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"database/sql/driver"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"mayfly-go/internal/db/dbm/dbi"
 | 
			
		||||
	machineapp "mayfly-go/internal/machine/application"
 | 
			
		||||
	"mayfly-go/pkg/utils/collx"
 | 
			
		||||
	"mayfly-go/pkg/utils/netx"
 | 
			
		||||
	"net"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	pq "gitee.com/liuzongyang/libpq"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	meta dbi.Meta
 | 
			
		||||
	once sync.Once
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func GetMeta() dbi.Meta {
 | 
			
		||||
	once.Do(func() {
 | 
			
		||||
		meta = new(PostgresMeta)
 | 
			
		||||
	})
 | 
			
		||||
	return meta
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PostgresMeta struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
 | 
			
		||||
	driverName := string(d.Type)
 | 
			
		||||
	// SSH Conect
 | 
			
		||||
	if d.SshTunnelMachineId > 0 {
 | 
			
		||||
		// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
 | 
			
		||||
		driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
 | 
			
		||||
		if !collx.ArrayContains(sql.Drivers(), driverName) {
 | 
			
		||||
			sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
 | 
			
		||||
		}
 | 
			
		||||
		sql.Drivers()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	db := d.Database
 | 
			
		||||
	var dbParam string
 | 
			
		||||
	exsitSchema := false
 | 
			
		||||
	if db != "" {
 | 
			
		||||
		// postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema
 | 
			
		||||
		ss := strings.Split(db, "/")
 | 
			
		||||
		if len(ss) > 1 {
 | 
			
		||||
			exsitSchema = true
 | 
			
		||||
			dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
 | 
			
		||||
		} else {
 | 
			
		||||
			dbParam = "dbname=" + db
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
 | 
			
		||||
	// 存在额外指定参数,则拼接该连接参数
 | 
			
		||||
	if d.Params != "" {
 | 
			
		||||
		// 存在指定的db,则需要将dbInstance配置中的parmas排除掉dbname和search_path
 | 
			
		||||
		if db != "" {
 | 
			
		||||
			paramArr := strings.Split(d.Params, "&")
 | 
			
		||||
			paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
 | 
			
		||||
				if strings.HasPrefix(param, "dbname=") {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				if exsitSchema && strings.HasPrefix(param, "search_path") {
 | 
			
		||||
					return true
 | 
			
		||||
				}
 | 
			
		||||
				return false
 | 
			
		||||
			})
 | 
			
		||||
			d.Params = strings.Join(paramArr, " ")
 | 
			
		||||
		}
 | 
			
		||||
		dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sql.Open(driverName, dsn)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (md *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
 | 
			
		||||
	return &PgsqlDialect{conn}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// pgsql dialer
 | 
			
		||||
type PqSqlDialer struct {
 | 
			
		||||
	sshTunnelMachineId int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
 | 
			
		||||
	return pq.DialOpen(d, name)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
 | 
			
		||||
	sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
 | 
			
		||||
		// 将ssh conn包装,否则会返回错误: ssh: tcpChan: deadline not supported
 | 
			
		||||
		return &netx.WrapSshConn{Conn: sshConn}, nil
 | 
			
		||||
	} else {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
 | 
			
		||||
	return pd.Dial(network, address)
 | 
			
		||||
}
 | 
			
		||||
@@ -90,7 +90,6 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) {
 | 
			
		||||
	g := rc.GinCtx
 | 
			
		||||
	fid := GetMachineFileId(g)
 | 
			
		||||
	readPath := g.Query("path")
 | 
			
		||||
	readType := g.Query("type")
 | 
			
		||||
 | 
			
		||||
	sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath)
 | 
			
		||||
	rc.ReqParam = collx.Kvs("machine", mi, "path", readPath)
 | 
			
		||||
@@ -99,20 +98,27 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) {
 | 
			
		||||
 | 
			
		||||
	fileInfo, _ := sftpFile.Stat()
 | 
			
		||||
	filesize := fileInfo.Size()
 | 
			
		||||
	// 如果是读取文件内容,则校验文件大小
 | 
			
		||||
	if readType != "1" {
 | 
			
		||||
 | 
			
		||||
	biz.IsTrue(filesize < max_read_size, "文件超过1m,请使用下载查看")
 | 
			
		||||
	}
 | 
			
		||||
	// 如果读取类型为下载,则下载文件,否则获取文件内容
 | 
			
		||||
	if readType == "1" {
 | 
			
		||||
	datas, err := io.ReadAll(sftpFile)
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s")
 | 
			
		||||
 | 
			
		||||
	rc.ResData = string(datas)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *MachineFile) DownloadFile(rc *req.Ctx) {
 | 
			
		||||
	g := rc.GinCtx
 | 
			
		||||
	fid := GetMachineFileId(g)
 | 
			
		||||
	readPath := g.Query("path")
 | 
			
		||||
 | 
			
		||||
	sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath)
 | 
			
		||||
	rc.ReqParam = collx.Kvs("machine", mi, "path", readPath)
 | 
			
		||||
	biz.ErrIsNilAppendErr(err, "打开文件失败: %s")
 | 
			
		||||
	defer sftpFile.Close()
 | 
			
		||||
 | 
			
		||||
	// 截取文件名,如/usr/local/test.java -》 test.java
 | 
			
		||||
	path := strings.Split(readPath, "/")
 | 
			
		||||
	rc.Download(sftpFile, path[len(path)-1])
 | 
			
		||||
	} else {
 | 
			
		||||
		datas, err := io.ReadAll(sftpFile)
 | 
			
		||||
		biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s")
 | 
			
		||||
		rc.ResData = string(datas)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (m *MachineFile) GetDirEntry(rc *req.Ctx) {
 | 
			
		||||
 
 | 
			
		||||
@@ -25,7 +25,9 @@ func InitMachineFileRouter(router *gin.RouterGroup) {
 | 
			
		||||
 | 
			
		||||
		req.NewDelete(":machineId/files/:fileId", mf.DeleteFile).Log(req.NewLogSave("机器-删除文件配置")).RequiredPermissionCode("machine:file:del"),
 | 
			
		||||
 | 
			
		||||
		req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-获取文件内容")),
 | 
			
		||||
		req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-读取文件内容")),
 | 
			
		||||
 | 
			
		||||
		req.NewGet(":machineId/files/:fileId/download", mf.DownloadFile).NoRes().Log(req.NewLogSave("机器-文件下载")),
 | 
			
		||||
 | 
			
		||||
		req.NewGet(":machineId/files/:fileId/read-dir", mf.GetDirEntry),
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user