From bfd346e65a1b1c50090c6443d71ef64fabfc4d4a Mon Sep 17 00:00:00 2001 From: "meilin.huang" <954537473@qq.com> Date: Fri, 12 Jan 2024 13:15:30 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=9C=BA=E5=99=A8=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E4=B8=8B=E8=BD=BD=E9=97=AE=E9=A2=98=E4=BF=AE=E5=A4=8D&dbm?= =?UTF-8?q?=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mayfly_go_web/package.json | 6 +- mayfly_go_web/src/views/ops/machine/api.ts | 1 + .../views/ops/machine/file/MachineFile.vue | 2 +- server/internal/db/api/db.go | 6 +- server/internal/db/application/db.go | 17 +- .../internal/db/application/db_data_sync.go | 12 +- .../internal/db/application/db_scheduler.go | 8 +- server/internal/db/application/db_sql_exec.go | 10 +- server/internal/db/application/instance.go | 7 +- server/internal/db/dbm/{ => dbi}/conn.go | 17 +- .../internal/db/dbm/{ => dbi}/db_program.go | 2 +- server/internal/db/dbm/{ => dbi}/db_type.go | 2 +- .../internal/db/dbm/{ => dbi}/db_type_test.go | 2 +- server/internal/db/dbm/{ => dbi}/dialect.go | 7 +- server/internal/db/dbm/{ => dbi}/info.go | 29 ++-- server/internal/db/dbm/dbi/meta.go | 12 ++ .../db/dbm/{ => dbi}/metasql/dm_meta.sql | 0 .../db/dbm/{ => dbi}/metasql/mysql_meta.sql | 0 .../db/dbm/{ => dbi}/metasql/pgsql_meta.sql | 0 server/internal/db/dbm/{ => dbi}/sqlx.go | 2 +- .../internal/db/dbm/{conn_cache.go => dbm.go} | 45 +++++- .../db/dbm/{dialect_dm.go => dm/dialect.go} | 85 ++++------ server/internal/db/dbm/dm/meta.go | 51 ++++++ .../{dialect_mysql.go => mysql/dialect.go} | 80 ++++------ server/internal/db/dbm/mysql/meta.go | 52 +++++++ .../{db_program_mysql.go => mysql/program.go} | 17 +- .../program_e2e_test.go} | 2 +- .../program_test.go} | 5 +- .../{dialect_pgsql.go => postgres/dialect.go} | 146 ++++-------------- server/internal/db/dbm/postgres/meta.go | 111 +++++++++++++ server/internal/machine/api/machine_file.go | 36 +++-- .../internal/machine/router/machine_file.go | 4 +- 32 files changed, 454 insertions(+), 322 deletions(-) rename server/internal/db/dbm/{ => dbi}/conn.go (95%) rename server/internal/db/dbm/{ => dbi}/db_program.go (98%) rename server/internal/db/dbm/{ => dbi}/db_type.go (99%) rename server/internal/db/dbm/{ => dbi}/db_type_test.go (99%) rename server/internal/db/dbm/{ => dbi}/dialect.go (99%) rename server/internal/db/dbm/{ => dbi}/info.go (85%) create mode 100644 server/internal/db/dbm/dbi/meta.go rename server/internal/db/dbm/{ => dbi}/metasql/dm_meta.sql (100%) rename server/internal/db/dbm/{ => dbi}/metasql/mysql_meta.sql (100%) rename server/internal/db/dbm/{ => dbi}/metasql/pgsql_meta.sql (100%) rename server/internal/db/dbm/{ => dbi}/sqlx.go (99%) rename server/internal/db/dbm/{conn_cache.go => dbm.go} (63%) rename server/internal/db/dbm/{dialect_dm.go => dm/dialect.go} (79%) create mode 100644 server/internal/db/dbm/dm/meta.go rename server/internal/db/dbm/{dialect_mysql.go => mysql/dialect.go} (72%) create mode 100644 server/internal/db/dbm/mysql/meta.go rename server/internal/db/dbm/{db_program_mysql.go => mysql/program.go} (98%) rename server/internal/db/dbm/{db_program_mysql_e2e_test.go => mysql/program_e2e_test.go} (99%) rename server/internal/db/dbm/{db_program_mysql_test.go => mysql/program_test.go} (97%) rename server/internal/db/dbm/{dialect_pgsql.go => postgres/dialect.go} (60%) create mode 100644 server/internal/db/dbm/postgres/meta.go diff --git a/mayfly_go_web/package.json b/mayfly_go_web/package.json index 0006db11..174d2b1f 100644 --- a/mayfly_go_web/package.json +++ b/mayfly_go_web/package.json @@ -17,7 +17,7 @@ "countup.js": "^2.7.0", "cropperjs": "^1.5.11", "echarts": "^5.4.3", - "element-plus": "^2.5.0", + "element-plus": "^2.5.1", "js-base64": "^3.7.5", "jsencrypt": "^3.3.2", "lodash": "^4.17.21", @@ -33,7 +33,7 @@ "splitpanes": "^3.1.5", "sql-formatter": "^14.0.0", "uuid": "^9.0.1", - "vue": "^3.4.8", + "vue": "^3.4.10", "vue-router": "^4.2.5", "xterm": "^5.3.0", "xterm-addon-fit": "^0.8.0", @@ -42,7 +42,7 @@ }, "devDependencies": { "@types/lodash": "^4.14.178", - "@types/node": "^15.6.0", + "@types/node": "^18.14.0", "@types/nprogress": "^0.2.0", "@types/sortablejs": "^1.15.3", "@typescript-eslint/eslint-plugin": "^6.7.4", diff --git a/mayfly_go_web/src/views/ops/machine/api.ts b/mayfly_go_web/src/views/ops/machine/api.ts index 067e136e..10d143d4 100644 --- a/mayfly_go_web/src/views/ops/machine/api.ts +++ b/mayfly_go_web/src/views/ops/machine/api.ts @@ -35,6 +35,7 @@ export const machineApi = { mvFile: Api.newPost('/machines/{machineId}/files/{fileId}/mv'), uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?' + joinClientParams()), fileContent: Api.newGet('/machines/{machineId}/files/{fileId}/read'), + downloadFile: Api.newGet('/machines/{machineId}/files/{fileId}/download'), createFile: Api.newPost('/machines/{machineId}/files/{id}/create-file'), // 修改文件内容 updateFileContent: Api.newPost('/machines/{machineId}/files/{id}/write'), diff --git a/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue b/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue index 2b3eceae..6867fe17 100755 --- a/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue +++ b/mayfly_go_web/src/views/ops/machine/file/MachineFile.vue @@ -611,7 +611,7 @@ const deleteFile = async (files: any) => { const downloadFile = (data: any) => { const a = document.createElement('a'); - a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&${joinClientParams()}`); + a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/download?path=${data.path}&${joinClientParams()}`); a.click(); }; diff --git a/server/internal/db/api/db.go b/server/internal/db/api/db.go index 9b2509b9..b783fa97 100644 --- a/server/internal/db/api/db.go +++ b/server/internal/db/api/db.go @@ -9,7 +9,7 @@ import ( "mayfly-go/internal/db/api/form" "mayfly-go/internal/db/api/vo" "mayfly-go/internal/db/application" - "mayfly-go/internal/db/dbm" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" msgapp "mayfly-go/internal/msg/application" msgdto "mayfly-go/internal/msg/application/dto" @@ -351,7 +351,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table)) writer.WriteString("BEGIN;\n") insertSql := "INSERT INTO %s VALUES (%s);\n" - dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbm.QueryColumn) error { + dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbi.QueryColumn) error { var values []string writer.TryFlush() for _, column := range columns { @@ -470,7 +470,7 @@ func getDbName(g *gin.Context) string { return db } -func (d *Db) getDbConn(g *gin.Context) *dbm.DbConn { +func (d *Db) getDbConn(g *gin.Context) *dbi.DbConn { dc, err := d.DbApp.GetDbConn(getDbId(g), getDbName(g)) biz.ErrIsNil(err) return dc diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index 53995309..773e959f 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -4,6 +4,7 @@ import ( "context" "mayfly-go/internal/common/consts" "mayfly-go/internal/db/dbm" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" tagapp "mayfly-go/internal/tag/application" @@ -33,10 +34,10 @@ type Db interface { // @param id 数据库id // // @param dbName 数据库名 - GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) + GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) // 根据数据库实例id获取连接,随机返回该instanceId下已连接的conn,若不存在则是使用该instanceId关联的db进行连接并返回。 - GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error) + GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) } func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance, tagApp tagapp.TagTree) Db { @@ -142,8 +143,8 @@ func (d *dbAppImpl) Delete(ctx context.Context, id uint64) error { }) } -func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) { - return dbm.GetDbConn(dbId, dbName, func() (*dbm.DbInfo, error) { +func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) { + return dbm.GetDbConn(dbId, dbName, func() (*dbi.DbInfo, error) { db, err := d.GetById(new(entity.Db), dbId) if err != nil { return nil, errorx.NewBiz("数据库信息不存在") @@ -156,7 +157,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) { checkDb := dbName // 兼容pgsql/dm db/schema模式 - if dbm.DbTypePostgres.Equal(instance.Type) || dbm.DbTypeDM.Equal(instance.Type) { + if dbi.DbTypePostgres.Equal(instance.Type) || dbi.DbTypeDM.Equal(instance.Type) { ss := strings.Split(dbName, "/") if len(ss) > 1 { checkDb = ss[0] @@ -174,7 +175,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) { }) } -func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error) { +func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) { conn := dbm.GetDbConnByInstanceId(instanceId) if conn != nil { return conn, nil @@ -193,8 +194,8 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error return d.GetDbConn(firstDb.Id, strings.Split(firstDb.Database, " ")[0]) } -func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbm.DbInfo { - di := new(dbm.DbInfo) +func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbi.DbInfo { + di := new(dbi.DbInfo) di.InstanceId = instance.Id di.Id = dbId di.Database = database diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index 7fda75ec..13361d3d 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -5,7 +5,7 @@ import ( "database/sql" "encoding/json" "fmt" - "mayfly-go/internal/db/dbm" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/base" @@ -182,20 +182,20 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* if err != nil { return syncLog, errorx.NewBiz("解析字段映射json出错: %s", err.Error()) } - var updFieldType dbm.DataType + var updFieldType dbi.DataType // 记录本次同步数据总数 total := 0 batchSize := task.PageSize result := make([]map[string]any, 0) - var queryColumns []*dbm.QueryColumn + var queryColumns []*dbi.QueryColumn - err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbm.QueryColumn) error { + err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbi.QueryColumn) error { if len(queryColumns) == 0 { queryColumns = columns // 遍历columns 取task.UpdField的字段类型 - updFieldType = dbm.DataTypeString + updFieldType = dbi.DataTypeString for _, column := range columns { if column.Name == task.UpdField { updFieldType = srcDialect.GetDataType(column.Type) @@ -249,7 +249,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* return syncLog, nil } -func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbm.DataType, task *entity.DataSyncTask, srcDialect dbm.DbDialect, targetDbConn *dbm.DbConn, targetDbTx *sql.Tx) error { +func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error { var data = make([]map[string]any, 0) // 遍历res,组装插入sql diff --git a/server/internal/db/application/db_scheduler.go b/server/internal/db/application/db_scheduler.go index 22039075..52ca044d 100644 --- a/server/internal/db/application/db_scheduler.go +++ b/server/internal/db/application/db_scheduler.go @@ -4,7 +4,7 @@ import ( "context" "errors" "fmt" - "mayfly-go/internal/db/dbm" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/logx" @@ -314,7 +314,7 @@ func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool { return true } -func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error { +func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error { binlogHistory, err := s.binlogHistoryRepo.GetHistoryByTime(job.DbInstanceId, job.PointInTime.Time) if err != nil { return err @@ -341,7 +341,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg if err != nil { return err } - restoreInfo := &dbm.RestoreInfo{ + restoreInfo := &dbi.RestoreInfo{ BackupHistory: backupHistory, BinlogHistories: binlogHistories, StartPosition: backupHistory.BinlogPosition, @@ -354,7 +354,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg return program.ReplayBinlog(ctx, job.DbName, job.DbName, restoreInfo) } -func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error { +func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error { backupHistory := &entity.DbBackupHistory{} if err := s.backupHistoryRepo.GetById(backupHistory, job.DbBackupHistoryId); err != nil { return err diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index ced35eb6..114dc9ac 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "mayfly-go/internal/db/config" - "mayfly-go/internal/db/dbm" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/contextx" @@ -22,11 +22,11 @@ type DbSqlExecReq struct { Db string Sql string Remark string - DbConn *dbm.DbConn + DbConn *dbi.DbConn } type DbSqlExecRes struct { - Columns []*dbm.QueryColumn + Columns []*dbi.QueryColumn Res []map[string]any } @@ -269,7 +269,7 @@ func doInsert(ctx context.Context, insert *sqlparser.Insert, execSqlReq *DbSqlEx return doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn) } -func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) { +func doExec(ctx context.Context, sql string, dbConn *dbi.DbConn) (*DbSqlExecRes, error) { rowsAffected, err := dbConn.ExecContext(ctx, sql) execRes := "success" if err != nil { @@ -283,7 +283,7 @@ func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, res = append(res, resData) return &DbSqlExecRes{ - Columns: []*dbm.QueryColumn{ + Columns: []*dbi.QueryColumn{ {Name: "sql", Type: "string"}, {Name: "rowsAffected", Type: "number"}, {Name: "result", Type: "string"}, diff --git a/server/internal/db/application/instance.go b/server/internal/db/application/instance.go index 40dde07a..d8162c00 100644 --- a/server/internal/db/application/instance.go +++ b/server/internal/db/application/instance.go @@ -3,6 +3,7 @@ package application import ( "context" "mayfly-go/internal/db/dbm" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/internal/db/domain/repository" "mayfly-go/pkg/base" @@ -50,7 +51,7 @@ func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 { func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance) error { instanceEntity.Network = instanceEntity.GetNetwork() - dbConn, err := toDbInfo(instanceEntity, 0, "", "").Conn() + dbConn, err := dbm.Conn(toDbInfo(instanceEntity, 0, "", "")) if err != nil { return err } @@ -100,9 +101,9 @@ func (app *instanceAppImpl) Delete(ctx context.Context, id uint64) error { func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) { ed.Network = ed.GetNetwork() - metaDb := dbm.ToDbType(ed.Type).MetaDbName() + metaDb := dbi.ToDbType(ed.Type).MetaDbName() - dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn() + dbConn, err := dbm.Conn(toDbInfo(ed, 0, metaDb, "")) if err != nil { return nil, err } diff --git a/server/internal/db/dbm/conn.go b/server/internal/db/dbm/dbi/conn.go similarity index 95% rename from server/internal/db/dbm/conn.go rename to server/internal/db/dbm/dbi/conn.go index ee3b7333..63be1dfa 100644 --- a/server/internal/db/dbm/conn.go +++ b/server/internal/db/dbm/dbi/conn.go @@ -1,4 +1,4 @@ -package dbm +package dbi import ( "context" @@ -117,18 +117,9 @@ func (d *DbConn) Begin() (*sql.Tx, error) { return d.db.Begin() } -// 获取数据库元信息实现接口 -func (d *DbConn) GetDialect() DbDialect { - switch d.Info.Type { - case DbTypeMysql, DbTypeMariadb: - return &MysqlDialect{dc: d} - case DbTypePostgres: - return &PgsqlDialect{dc: d} - case DbTypeDM: - return &DMDialect{dc: d} - default: - panic(fmt.Sprintf("invalid database type: %s", d.Info.Type)) - } +// 获取数据库dialect实现接口 +func (d *DbConn) GetDialect() Dialect { + return d.Info.Meta.GetDialect(d) } // 关闭连接 diff --git a/server/internal/db/dbm/db_program.go b/server/internal/db/dbm/dbi/db_program.go similarity index 98% rename from server/internal/db/dbm/db_program.go rename to server/internal/db/dbm/dbi/db_program.go index f8c7d213..938e4dd8 100644 --- a/server/internal/db/dbm/db_program.go +++ b/server/internal/db/dbm/dbi/db_program.go @@ -1,4 +1,4 @@ -package dbm +package dbi import ( "context" diff --git a/server/internal/db/dbm/db_type.go b/server/internal/db/dbm/dbi/db_type.go similarity index 99% rename from server/internal/db/dbm/db_type.go rename to server/internal/db/dbm/dbi/db_type.go index 81d7a253..0258f68e 100644 --- a/server/internal/db/dbm/db_type.go +++ b/server/internal/db/dbm/dbi/db_type.go @@ -1,4 +1,4 @@ -package dbm +package dbi import ( "fmt" diff --git a/server/internal/db/dbm/db_type_test.go b/server/internal/db/dbm/dbi/db_type_test.go similarity index 99% rename from server/internal/db/dbm/db_type_test.go rename to server/internal/db/dbm/dbi/db_type_test.go index dca23f42..197c092c 100644 --- a/server/internal/db/dbm/db_type_test.go +++ b/server/internal/db/dbm/dbi/db_type_test.go @@ -1,4 +1,4 @@ -package dbm +package dbi import ( "testing" diff --git a/server/internal/db/dbm/dialect.go b/server/internal/db/dbm/dbi/dialect.go similarity index 99% rename from server/internal/db/dbm/dialect.go rename to server/internal/db/dbm/dbi/dialect.go index 020c16f9..cff9e7d4 100644 --- a/server/internal/db/dbm/dialect.go +++ b/server/internal/db/dbm/dbi/dialect.go @@ -1,12 +1,13 @@ -package dbm +package dbi import ( "database/sql" "embed" + "strings" + "mayfly-go/pkg/biz" "mayfly-go/pkg/utils/collx" "mayfly-go/pkg/utils/stringx" - "strings" ) type DataType string @@ -60,7 +61,7 @@ type Index struct { // -----------------------------------元数据接口定义------------------------------------------ // 数据库方言、元信息接口(表、列、获取表数据等元信息) -type DbDialect interface { +type Dialect interface { // 获取数据库服务实例信息 GetDbServer() (*DbServer, error) diff --git a/server/internal/db/dbm/info.go b/server/internal/db/dbm/dbi/info.go similarity index 85% rename from server/internal/db/dbm/info.go rename to server/internal/db/dbm/dbi/info.go index 99c393dd..5811b43b 100644 --- a/server/internal/db/dbm/info.go +++ b/server/internal/db/dbm/dbi/info.go @@ -1,4 +1,4 @@ -package dbm +package dbi import ( "database/sql" @@ -8,6 +8,9 @@ import ( "mayfly-go/pkg/logx" ) +// 获取sql.DB函数 +type GetSqlDbFunc func(*DbInfo) (*sql.DB, error) + type DbInfo struct { InstanceId uint64 // 实例id Id uint64 // dbId @@ -24,6 +27,8 @@ type DbInfo struct { TagPath []string SshTunnelMachineId int + + Meta Meta } // 获取记录日志的描述 @@ -32,22 +37,16 @@ func (d *DbInfo) GetLogDesc() string { } // 连接数据库 -func (dbInfo *DbInfo) Conn() (*DbConn, error) { - var conn *sql.DB - var err error - database := dbInfo.Database - - switch dbInfo.Type { - case DbTypeMysql, DbTypeMariadb: - conn, err = getMysqlDB(dbInfo) - case DbTypePostgres: - conn, err = getPgsqlDB(dbInfo) - case DbTypeDM: - conn, err = getDmDB(dbInfo) - default: - return nil, errorx.NewBiz("invalid database type: %s", dbInfo.Type) +func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) { + if meta == nil { + return nil, errorx.NewBiz("数据库元信息接口不能为空") } + // 赋值Meta,方便后续获取dialect等 + dbInfo.Meta = meta + database := dbInfo.Database + + conn, err := meta.GetSqlDb(dbInfo) if err != nil { logx.Errorf("连接db失败: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error()) return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error())) diff --git a/server/internal/db/dbm/dbi/meta.go b/server/internal/db/dbm/dbi/meta.go new file mode 100644 index 00000000..5f942957 --- /dev/null +++ b/server/internal/db/dbm/dbi/meta.go @@ -0,0 +1,12 @@ +package dbi + +import "database/sql" + +// 数据库元信息获取,如获取sql.DB、Dialect等 +type Meta interface { + // 获取数据库服务实例信息 + GetSqlDb(*DbInfo) (*sql.DB, error) + + // 获取数据库方言,用于获取表结构等信息 + GetDialect(*DbConn) Dialect +} diff --git a/server/internal/db/dbm/metasql/dm_meta.sql b/server/internal/db/dbm/dbi/metasql/dm_meta.sql similarity index 100% rename from server/internal/db/dbm/metasql/dm_meta.sql rename to server/internal/db/dbm/dbi/metasql/dm_meta.sql diff --git a/server/internal/db/dbm/metasql/mysql_meta.sql b/server/internal/db/dbm/dbi/metasql/mysql_meta.sql similarity index 100% rename from server/internal/db/dbm/metasql/mysql_meta.sql rename to server/internal/db/dbm/dbi/metasql/mysql_meta.sql diff --git a/server/internal/db/dbm/metasql/pgsql_meta.sql b/server/internal/db/dbm/dbi/metasql/pgsql_meta.sql similarity index 100% rename from server/internal/db/dbm/metasql/pgsql_meta.sql rename to server/internal/db/dbm/dbi/metasql/pgsql_meta.sql diff --git a/server/internal/db/dbm/sqlx.go b/server/internal/db/dbm/dbi/sqlx.go similarity index 99% rename from server/internal/db/dbm/sqlx.go rename to server/internal/db/dbm/dbi/sqlx.go index 66e2154f..33c1c92a 100644 --- a/server/internal/db/dbm/sqlx.go +++ b/server/internal/db/dbm/dbi/sqlx.go @@ -1,4 +1,4 @@ -package dbm +package dbi import ( "database/sql" diff --git a/server/internal/db/dbm/conn_cache.go b/server/internal/db/dbm/dbm.go similarity index 63% rename from server/internal/db/dbm/conn_cache.go rename to server/internal/db/dbm/dbm.go index 99677215..25a4e53d 100644 --- a/server/internal/db/dbm/conn_cache.go +++ b/server/internal/db/dbm/dbm.go @@ -3,6 +3,10 @@ package dbm import ( "fmt" "mayfly-go/internal/common/consts" + "mayfly-go/internal/db/dbm/dbi" + "mayfly-go/internal/db/dbm/dm" + "mayfly-go/internal/db/dbm/mysql" + "mayfly-go/internal/db/dbm/postgres" "mayfly-go/internal/machine/mcm" "mayfly-go/pkg/cache" "mayfly-go/pkg/logx" @@ -15,7 +19,7 @@ var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second). WithUpdateAccessTime(true). OnEvicted(func(key any, value any) { logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key)) - value.(*DbConn).Close() + value.(*dbi.DbConn).Close() }) func init() { @@ -23,7 +27,7 @@ func init() { // 遍历所有db连接实例,若存在db实例使用该ssh隧道机器,则返回true,表示还在使用中... items := connCache.Items() for _, v := range items { - if v.Value.(*DbConn).Info.SshTunnelMachineId == machineId { + if v.Value.(*dbi.DbConn).Info.SshTunnelMachineId == machineId { return true } } @@ -33,8 +37,21 @@ func init() { var mutex sync.Mutex +func getDbMetaByType(dt dbi.DbType) dbi.Meta { + switch dt { + case dbi.DbTypeMysql, dbi.DbTypeMariadb: + return mysql.GetMeta() + case dbi.DbTypePostgres: + return postgres.GetMeta() + case dbi.DbTypeDM: + return dm.GetMeta() + default: + panic(fmt.Sprintf("invalid database type: %s", dt)) + } +} + // 从缓存中获取数据库连接信息,若缓存中不存在则会使用回调函数获取dbInfo进行连接并缓存 -func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) (*DbConn, error) { +func GetDbConn(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) { connId := GetDbConnId(dbId, database) // connId不为空,则为需要缓存 @@ -42,7 +59,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) if needCache { load, ok := connCache.Get(connId) if ok { - return load.(*DbConn), nil + return load.(*dbi.DbConn), nil } } @@ -56,7 +73,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) } // 连接数据库 - dbConn, err := dbInfo.Conn() + dbConn, err := Conn(dbInfo) if err != nil { return nil, err } @@ -67,10 +84,15 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) return dbConn, nil } +// 使用指定dbInfo信息进行连接 +func Conn(di *dbi.DbInfo) (*dbi.DbConn, error) { + return di.Conn(getDbMetaByType(di.Type)) +} + // 根据实例id获取连接 -func GetDbConnByInstanceId(instanceId uint64) *DbConn { +func GetDbConnByInstanceId(instanceId uint64) *dbi.DbConn { for _, connItem := range connCache.Items() { - conn := connItem.Value.(*DbConn) + conn := connItem.Value.(*dbi.DbConn) if conn.Info.InstanceId == instanceId { return conn } @@ -82,3 +104,12 @@ func GetDbConnByInstanceId(instanceId uint64) *DbConn { func CloseDb(dbId uint64, db string) { connCache.Delete(GetDbConnId(dbId, db)) } + +// 获取连接id +func GetDbConnId(dbId uint64, db string) string { + if dbId == 0 { + return "" + } + + return fmt.Sprintf("%d:%s", dbId, db) +} diff --git a/server/internal/db/dbm/dialect_dm.go b/server/internal/db/dbm/dm/dialect.go similarity index 79% rename from server/internal/db/dbm/dialect_dm.go rename to server/internal/db/dbm/dm/dialect.go index c90fd580..0f174c24 100644 --- a/server/internal/db/dbm/dialect_dm.go +++ b/server/internal/db/dbm/dm/dialect.go @@ -1,9 +1,10 @@ -package dbm +package dm import ( "context" "database/sql" "fmt" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/errorx" "mayfly-go/pkg/logx" "mayfly-go/pkg/utils/anyx" @@ -14,30 +15,6 @@ import ( _ "gitee.com/chunanyong/dm" ) -func getDmDB(d *DbInfo) (*sql.DB, error) { - driverName := "dm" - db := d.Database - var dbParam string - if db != "" { - // dm database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema - ss := strings.Split(db, "/") - if len(ss) > 1 { - dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1]) - } else { - dbParam = db - } - } - - err := d.IfUseSshTunnelChangeIpPort() - if err != nil { - return nil, err - } - - dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam) - return sql.Open(driverName, dsn) -} - -// ---------------------------------- DM元数据 ----------------------------------- const ( DM_META_FILE = "metasql/dm_meta.sql" DM_DB_SCHEMAS = "DM_DB_SCHEMAS" @@ -47,15 +24,15 @@ const ( ) type DMDialect struct { - dc *DbConn + dc *dbi.DbConn } -func (dd *DMDialect) GetDbServer() (*DbServer, error) { +func (dd *DMDialect) GetDbServer() (*dbi.DbServer, error) { _, res, err := dd.dc.Query("select * from v$instance") if err != nil { return nil, err } - ds := &DbServer{ + ds := &dbi.DbServer{ Version: anyx.ConvString(res[0]["SVR_VERSION"]), } return ds, nil @@ -76,20 +53,20 @@ func (dd *DMDialect) GetDbNames() ([]string, error) { } // 获取表基础元信息, 如表名等 -func (dd *DMDialect) GetTables() ([]Table, error) { +func (dd *DMDialect) GetTables() ([]dbi.Table, error) { // 首先执行更新统计信息sql 这个统计信息在数据量比较大的时候就比较耗时,所以最好定时执行 // _, _, err := pd.dc.Query("dbms_stats.GATHER_SCHEMA_stats(SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))") // 查询表信息 - _, res, err := dd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY)) + _, res, err := dd.dc.Query(dbi.GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY)) if err != nil { return nil, err } - tables := make([]Table, 0) + tables := make([]dbi.Table, 0) for _, re := range res { - tables = append(tables, Table{ + tables = append(tables, dbi.Table{ TableName: re["TABLE_NAME"].(string), TableComment: anyx.ConvString(re["TABLE_COMMENT"]), CreateTime: anyx.ConvString(re["CREATE_TIME"]), @@ -102,7 +79,7 @@ func (dd *DMDialect) GetTables() ([]Table, error) { } // 获取列元信息, 如列名等 -func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) { +func (dd *DMDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) { tableName := "" for i := 0; i < len(tableNames); i++ { if i != 0 { @@ -111,14 +88,14 @@ func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) { tableName = tableName + "'" + tableNames[i] + "'" } - _, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName)) + _, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName)) if err != nil { return nil, err } - columns := make([]Column, 0) + columns := make([]dbi.Column, 0) for _, re := range res { - columns = append(columns, Column{ + columns = append(columns, dbi.Column{ TableName: re["TABLE_NAME"].(string), ColumnName: re["COLUMN_NAME"].(string), ColumnType: anyx.ConvString(re["COLUMN_TYPE"]), @@ -150,15 +127,15 @@ func (dd *DMDialect) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) { - _, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName)) +func (dd *DMDialect) GetTableIndex(tableName string) ([]dbi.Index, error) { + _, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err } - indexs := make([]Index, 0) + indexs := make([]dbi.Index, 0) for _, re := range res { - indexs = append(indexs, Index{ + indexs = append(indexs, dbi.Index{ IndexName: re["INDEX_NAME"].(string), ColumnName: anyx.ConvString(re["COLUMN_NAME"]), IndexType: anyx.ConvString(re["INDEX_TYPE"]), @@ -168,7 +145,7 @@ func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) { }) } // 把查询结果以索引名分组,索引字段以逗号连接 - result := make([]Index, 0) + result := make([]dbi.Index, 0) key := "" for _, v := range indexs { // 当前的索引名 @@ -255,13 +232,13 @@ func (dd *DMDialect) GetTableDDL(tableName string) (string, error) { return builder.String(), nil } -func (dd *DMDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error { +func (dd *DMDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error { return dd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn) } // 获取DM当前连接的库可访问的schemaNames func (dd *DMDialect) GetSchemas() ([]string, error) { - sql := GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS) + sql := dbi.GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS) _, res, err := dd.dc.Query(sql) if err != nil { return nil, err @@ -274,27 +251,27 @@ func (dd *DMDialect) GetSchemas() ([]string, error) { } // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 -func (dd *DMDialect) GetDbProgram() DbProgram { +func (dd *DMDialect) GetDbProgram() dbi.DbProgram { panic("implement me") } -func (dd *DMDialect) GetDataType(dbColumnType string) DataType { +func (dd *DMDialect) GetDataType(dbColumnType string) dbi.DataType { if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { - return DataTypeNumber + return dbi.DataTypeNumber } // 日期时间类型 if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { - return DataTypeDateTime + return dbi.DataTypeDateTime } // 日期类型 if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { - return DataTypeDate + return dbi.DataTypeDate } // 时间类型 if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { - return DataTypeTime + return dbi.DataTypeTime } - return DataTypeString + return dbi.DataTypeString } func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { @@ -322,15 +299,15 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, return int64(effRows), nil } -func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType DataType) string { +func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { switch dataType { - case DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" + case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00" res, _ := time.Parse(time.RFC3339, dbColumnValue) return res.Format(time.DateTime) - case DataTypeDate: // "2024-01-02T00:00:00+08:00" + case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00" res, _ := time.Parse(time.RFC3339, dbColumnValue) return res.Format(time.DateOnly) - case DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" + case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00" res, _ := time.Parse(time.RFC3339, dbColumnValue) return res.Format(time.TimeOnly) } diff --git a/server/internal/db/dbm/dm/meta.go b/server/internal/db/dbm/dm/meta.go new file mode 100644 index 00000000..148aeb50 --- /dev/null +++ b/server/internal/db/dbm/dm/meta.go @@ -0,0 +1,51 @@ +package dm + +import ( + "database/sql" + "fmt" + "mayfly-go/internal/db/dbm/dbi" + "strings" + "sync" +) + +var ( + meta dbi.Meta + once sync.Once +) + +func GetMeta() dbi.Meta { + once.Do(func() { + meta = new(DmMeta) + }) + return meta +} + +type DmMeta struct { +} + +func (md *DmMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { + driverName := "dm" + db := d.Database + var dbParam string + if db != "" { + // dm database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema + ss := strings.Split(db, "/") + if len(ss) > 1 { + dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1]) + } else { + dbParam = db + } + } + + err := d.IfUseSshTunnelChangeIpPort() + if err != nil { + return nil, err + } + + dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam) + return sql.Open(driverName, dsn) +} + +func (md *DmMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { + return &DMDialect{conn} +} diff --git a/server/internal/db/dbm/dialect_mysql.go b/server/internal/db/dbm/mysql/dialect.go similarity index 72% rename from server/internal/db/dbm/dialect_mysql.go rename to server/internal/db/dbm/mysql/dialect.go index 7382e591..65217215 100644 --- a/server/internal/db/dbm/dialect_mysql.go +++ b/server/internal/db/dbm/mysql/dialect.go @@ -1,40 +1,16 @@ -package dbm +package mysql import ( "context" "database/sql" "fmt" - machineapp "mayfly-go/internal/machine/application" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" - "net" "regexp" "strings" - - "github.com/go-sql-driver/mysql" ) -func getMysqlDB(d *DbInfo) (*sql.DB, error) { - // SSH Conect - if d.SshTunnelMachineId > 0 { - sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId) - if err != nil { - return nil, err - } - mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) { - return sshTunnelMachine.GetDialConn("tcp", addr) - }) - } - // 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name - dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database) - if d.Params != "" { - dsn = fmt.Sprintf("%s&%s", dsn, d.Params) - } - const driverName = "mysql" - return sql.Open(driverName, dsn) -} - -// ---------------------------------- mysql元数据 ----------------------------------- const ( MYSQL_META_FILE = "metasql/mysql_meta.sql" MYSQL_DBS = "MYSQL_DBS" @@ -44,22 +20,22 @@ const ( ) type MysqlDialect struct { - dc *DbConn + dc *dbi.DbConn } -func (md *MysqlDialect) GetDbServer() (*DbServer, error) { +func (md *MysqlDialect) GetDbServer() (*dbi.DbServer, error) { _, res, err := md.dc.Query("SELECT VERSION() version") if err != nil { return nil, err } - ds := &DbServer{ + ds := &dbi.DbServer{ Version: anyx.ConvString(res[0]["version"]), } return ds, nil } func (md *MysqlDialect) GetDbNames() ([]string, error) { - _, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_DBS)) + _, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_DBS)) if err != nil { return nil, err } @@ -73,15 +49,15 @@ func (md *MysqlDialect) GetDbNames() ([]string, error) { } // 获取表基础元信息, 如表名等 -func (md *MysqlDialect) GetTables() ([]Table, error) { - _, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY)) +func (md *MysqlDialect) GetTables() ([]dbi.Table, error) { + _, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY)) if err != nil { return nil, err } - tables := make([]Table, 0) + tables := make([]dbi.Table, 0) for _, re := range res { - tables = append(tables, Table{ + tables = append(tables, dbi.Table{ TableName: re["tableName"].(string), TableComment: anyx.ConvString(re["tableComment"]), CreateTime: anyx.ConvString(re["createTime"]), @@ -94,7 +70,7 @@ func (md *MysqlDialect) GetTables() ([]Table, error) { } // 获取列元信息, 如列名等 -func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) { +func (md *MysqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) { tableName := "" for i := 0; i < len(tableNames); i++ { if i != 0 { @@ -103,14 +79,14 @@ func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) { tableName = tableName + "'" + tableNames[i] + "'" } - _, res, err := md.dc.Query(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName)) + _, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName)) if err != nil { return nil, err } - columns := make([]Column, 0) + columns := make([]dbi.Column, 0) for _, re := range res { - columns = append(columns, Column{ + columns = append(columns, dbi.Column{ TableName: re["tableName"].(string), ColumnName: re["columnName"].(string), ColumnType: anyx.ConvString(re["columnType"]), @@ -144,15 +120,15 @@ func (md *MysqlDialect) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) { - _, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName) +func (md *MysqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) { + _, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName) if err != nil { return nil, err } - indexs := make([]Index, 0) + indexs := make([]dbi.Index, 0) for _, re := range res { - indexs = append(indexs, Index{ + indexs = append(indexs, dbi.Index{ IndexName: re["indexName"].(string), ColumnName: anyx.ConvString(re["columnName"]), IndexType: anyx.ConvString(re["indexType"]), @@ -162,7 +138,7 @@ func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) { }) } // 把查询结果以索引名分组,索引字段以逗号连接 - result := make([]Index, 0) + result := make([]dbi.Index, 0) key := "" for _, v := range indexs { // 当前的索引名 @@ -189,7 +165,7 @@ func (md *MysqlDialect) GetTableDDL(tableName string) (string, error) { return res[0]["Create Table"].(string) + ";", nil } -func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error { +func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error { return md.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn) } @@ -198,27 +174,27 @@ func (md *MysqlDialect) GetSchemas() ([]string, error) { } // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 -func (md *MysqlDialect) GetDbProgram() DbProgram { +func (md *MysqlDialect) GetDbProgram() dbi.DbProgram { return NewDbProgramMysql(md.dc) } -func (md *MysqlDialect) GetDataType(dbColumnType string) DataType { +func (md *MysqlDialect) GetDataType(dbColumnType string) dbi.DataType { if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { - return DataTypeNumber + return dbi.DataTypeNumber } // 日期时间类型 if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { - return DataTypeDateTime + return dbi.DataTypeDateTime } // 日期类型 if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { - return DataTypeDate + return dbi.DataTypeDate } // 时间类型 if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { - return DataTypeTime + return dbi.DataTypeTime } - return DataTypeString + return dbi.DataTypeString } func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { @@ -246,7 +222,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri return md.dc.TxExec(tx, sqlStr, args...) } -func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string { +func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { // mysql不需要格式化时间日期等 return dbColumnValue } diff --git a/server/internal/db/dbm/mysql/meta.go b/server/internal/db/dbm/mysql/meta.go new file mode 100644 index 00000000..21ae7568 --- /dev/null +++ b/server/internal/db/dbm/mysql/meta.go @@ -0,0 +1,52 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + "mayfly-go/internal/db/dbm/dbi" + machineapp "mayfly-go/internal/machine/application" + "net" + "sync" + + "github.com/go-sql-driver/mysql" +) + +var ( + meta dbi.Meta + once sync.Once +) + +func GetMeta() dbi.Meta { + once.Do(func() { + meta = new(MysqlMeta) + }) + return meta +} + +type MysqlMeta struct { +} + +func (md *MysqlMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { + // SSH Conect + if d.SshTunnelMachineId > 0 { + sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId) + if err != nil { + return nil, err + } + mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) { + return sshTunnelMachine.GetDialConn("tcp", addr) + }) + } + // 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name + dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database) + if d.Params != "" { + dsn = fmt.Sprintf("%s&%s", dsn, d.Params) + } + const driverName = "mysql" + return sql.Open(driverName, dsn) +} + +func (md *MysqlMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { + return &MysqlDialect{conn} +} diff --git a/server/internal/db/dbm/db_program_mysql.go b/server/internal/db/dbm/mysql/program.go similarity index 98% rename from server/internal/db/dbm/db_program_mysql.go rename to server/internal/db/dbm/mysql/program.go index c95b64ea..3fdf7523 100644 --- a/server/internal/db/dbm/db_program_mysql.go +++ b/server/internal/db/dbm/mysql/program.go @@ -1,4 +1,4 @@ -package dbm +package mysql import ( "bufio" @@ -19,27 +19,28 @@ import ( "golang.org/x/sync/singleflight" "mayfly-go/internal/db/config" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/domain/entity" "mayfly-go/pkg/logx" ) -var _ DbProgram = (*DbProgramMysql)(nil) +var _ dbi.DbProgram = (*DbProgramMysql)(nil) type DbProgramMysql struct { - dbConn *DbConn + dbConn *dbi.DbConn // mysqlBin 用于集成测试 mysqlBin *config.MysqlBin // backupPath 用于集成测试 backupPath string } -func NewDbProgramMysql(dbConn *DbConn) *DbProgramMysql { +func NewDbProgramMysql(dbConn *dbi.DbConn) *DbProgramMysql { return &DbProgramMysql{ dbConn: dbConn, } } -func (svc *DbProgramMysql) dbInfo() *DbInfo { +func (svc *DbProgramMysql) dbInfo() *dbi.DbInfo { dbInfo := svc.dbConn.Info err := dbInfo.IfUseSshTunnelChangeIpPort() if err != nil { @@ -55,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin { dbInfo := svc.dbInfo() var mysqlBin *config.MysqlBin switch dbInfo.Type { - case DbTypeMariadb: + case dbi.DbTypeMariadb: mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin) - case DbTypeMysql: + case dbi.DbTypeMysql: mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin) default: panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type)) @@ -488,7 +489,7 @@ func (svc *DbProgramMysql) GetBinlogEventPositionAtOrAfterTime(ctx context.Conte } // ReplayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`. -func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) { +func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *dbi.RestoreInfo) (replayErr error) { const ( // Variable lower_case_table_names related. diff --git a/server/internal/db/dbm/db_program_mysql_e2e_test.go b/server/internal/db/dbm/mysql/program_e2e_test.go similarity index 99% rename from server/internal/db/dbm/db_program_mysql_e2e_test.go rename to server/internal/db/dbm/mysql/program_e2e_test.go index 25c79e3d..63d5c62d 100644 --- a/server/internal/db/dbm/db_program_mysql_e2e_test.go +++ b/server/internal/db/dbm/mysql/program_e2e_test.go @@ -1,6 +1,6 @@ //go:build e2e -package dbm +package mysql import ( "context" diff --git a/server/internal/db/dbm/db_program_mysql_test.go b/server/internal/db/dbm/mysql/program_test.go similarity index 97% rename from server/internal/db/dbm/db_program_mysql_test.go rename to server/internal/db/dbm/mysql/program_test.go index c03a8757..979429a7 100644 --- a/server/internal/db/dbm/db_program_mysql_test.go +++ b/server/internal/db/dbm/mysql/program_test.go @@ -1,10 +1,11 @@ -package dbm +package mysql import ( - "github.com/stretchr/testify/require" "mayfly-go/internal/db/domain/entity" "strings" "testing" + + "github.com/stretchr/testify/require" ) func Test_readBinlogInfoFromBackup(t *testing.T) { diff --git a/server/internal/db/dbm/dialect_pgsql.go b/server/internal/db/dbm/postgres/dialect.go similarity index 60% rename from server/internal/db/dbm/dialect_pgsql.go rename to server/internal/db/dbm/postgres/dialect.go index cc21e2f3..1c5c77ce 100644 --- a/server/internal/db/dbm/dialect_pgsql.go +++ b/server/internal/db/dbm/postgres/dialect.go @@ -1,99 +1,17 @@ -package dbm +package postgres import ( "context" "database/sql" - "database/sql/driver" "fmt" - machineapp "mayfly-go/internal/machine/application" + "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/errorx" "mayfly-go/pkg/utils/anyx" - "mayfly-go/pkg/utils/collx" - "mayfly-go/pkg/utils/netx" - "net" "regexp" "strings" "time" - - pq "gitee.com/liuzongyang/libpq" ) -func getPgsqlDB(d *DbInfo) (*sql.DB, error) { - driverName := string(d.Type) - // SSH Conect - if d.SshTunnelMachineId > 0 { - // 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名 - driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId) - if !collx.ArrayContains(sql.Drivers(), driverName) { - sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId}) - } - sql.Drivers() - } - - db := d.Database - var dbParam string - exsitSchema := false - if db != "" { - // postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema - ss := strings.Split(db, "/") - if len(ss) > 1 { - exsitSchema = true - dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1]) - } else { - dbParam = "dbname=" + db - } - } - - dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam) - // 存在额外指定参数,则拼接该连接参数 - if d.Params != "" { - // 存在指定的db,则需要将dbInstance配置中的parmas排除掉dbname和search_path - if db != "" { - paramArr := strings.Split(d.Params, "&") - paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool { - if strings.HasPrefix(param, "dbname=") { - return true - } - if exsitSchema && strings.HasPrefix(param, "search_path") { - return true - } - return false - }) - d.Params = strings.Join(paramArr, " ") - } - dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " ")) - } - - return sql.Open(driverName, dsn) -} - -// pgsql dialer -type PqSqlDialer struct { - sshTunnelMachineId int -} - -func (d *PqSqlDialer) Open(name string) (driver.Conn, error) { - return pq.DialOpen(d, name) -} - -func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) { - sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId) - if err != nil { - return nil, err - } - if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil { - // 将ssh conn包装,否则会返回错误: ssh: tcpChan: deadline not supported - return &netx.WrapSshConn{Conn: sshConn}, nil - } else { - return nil, err - } -} - -func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { - return pd.Dial(network, address) -} - -// ---------------------------------- pgsql元数据 ----------------------------------- const ( PGSQL_META_FILE = "metasql/pgsql_meta.sql" PGSQL_DB_SCHEMAS = "PGSQL_DB_SCHEMAS" @@ -104,15 +22,15 @@ const ( ) type PgsqlDialect struct { - dc *DbConn + dc *dbi.DbConn } -func (pd *PgsqlDialect) GetDbServer() (*DbServer, error) { +func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) { _, res, err := pd.dc.Query("SHOW server_version") if err != nil { return nil, err } - ds := &DbServer{ + ds := &dbi.DbServer{ Version: anyx.ConvString(res[0]["server_version"]), } return ds, nil @@ -133,15 +51,15 @@ func (pd *PgsqlDialect) GetDbNames() ([]string, error) { } // 获取表基础元信息, 如表名等 -func (pd *PgsqlDialect) GetTables() ([]Table, error) { - _, res, err := pd.dc.Query(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY)) +func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) { + _, res, err := pd.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY)) if err != nil { return nil, err } - tables := make([]Table, 0) + tables := make([]dbi.Table, 0) for _, re := range res { - tables = append(tables, Table{ + tables = append(tables, dbi.Table{ TableName: re["tableName"].(string), TableComment: anyx.ConvString(re["tableComment"]), CreateTime: anyx.ConvString(re["createTime"]), @@ -154,7 +72,7 @@ func (pd *PgsqlDialect) GetTables() ([]Table, error) { } // 获取列元信息, 如列名等 -func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) { +func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) { tableName := "" for i := 0; i < len(tableNames); i++ { if i != 0 { @@ -163,14 +81,14 @@ func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) { tableName = tableName + "'" + tableNames[i] + "'" } - _, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) + _, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName)) if err != nil { return nil, err } - columns := make([]Column, 0) + columns := make([]dbi.Column, 0) for _, re := range res { - columns = append(columns, Column{ + columns = append(columns, dbi.Column{ TableName: re["tableName"].(string), ColumnName: re["columnName"].(string), ColumnType: anyx.ConvString(re["columnType"]), @@ -202,15 +120,15 @@ func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) { } // 获取表索引信息 -func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) { - _, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) +func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) { + _, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName)) if err != nil { return nil, err } - indexs := make([]Index, 0) + indexs := make([]dbi.Index, 0) for _, re := range res { - indexs = append(indexs, Index{ + indexs = append(indexs, dbi.Index{ IndexName: re["indexName"].(string), ColumnName: anyx.ConvString(re["columnName"]), IndexType: anyx.ConvString(re["IndexType"]), @@ -220,7 +138,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) { }) } // 把查询结果以索引名分组,索引字段以逗号连接 - result := make([]Index, 0) + result := make([]dbi.Index, 0) key := "" for _, v := range indexs { // 当前的索引名 @@ -240,7 +158,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) { // 获取建表ddl func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) { - _, err := pd.dc.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY)) + _, err := pd.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY)) if err != nil { return "", err } @@ -257,13 +175,13 @@ func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) { return res[0]["sql"].(string), nil } -func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error { +func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error { return pd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn) } // 获取pgsql当前连接的库可访问的schemaNames func (pd *PgsqlDialect) GetSchemas() ([]string, error) { - sql := GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS) + sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS) _, res, err := pd.dc.Query(sql) if err != nil { return nil, err @@ -276,27 +194,27 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) { } // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 -func (pd *PgsqlDialect) GetDbProgram() DbProgram { +func (pd *PgsqlDialect) GetDbProgram() dbi.DbProgram { panic("implement me") } -func (pd *PgsqlDialect) GetDataType(dbColumnType string) DataType { +func (pd *PgsqlDialect) GetDataType(dbColumnType string) dbi.DataType { if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { - return DataTypeNumber + return dbi.DataTypeNumber } // 日期时间类型 if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) { - return DataTypeDateTime + return dbi.DataTypeDateTime } // 日期类型 if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) { - return DataTypeDate + return dbi.DataTypeDate } // 时间类型 if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) { - return DataTypeTime + return dbi.DataTypeTime } - return DataTypeString + return dbi.DataTypeString } func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) { @@ -325,15 +243,15 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri return pd.dc.TxExec(tx, sqlStr, args...) } -func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string { +func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string { switch dataType { - case DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00" + case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00" res, _ := time.Parse(time.RFC3339, dbColumnValue) return res.Format(time.DateTime) - case DataTypeDate: // "2024-01-02T00:00:00Z" + case dbi.DataTypeDate: // "2024-01-02T00:00:00Z" res, _ := time.Parse(time.RFC3339, dbColumnValue) return res.Format(time.DateOnly) - case DataTypeTime: // "0000-01-01T22:16:28.545075+08:00" + case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00" res, _ := time.Parse(time.RFC3339, dbColumnValue) return res.Format(time.TimeOnly) } diff --git a/server/internal/db/dbm/postgres/meta.go b/server/internal/db/dbm/postgres/meta.go new file mode 100644 index 00000000..f4617757 --- /dev/null +++ b/server/internal/db/dbm/postgres/meta.go @@ -0,0 +1,111 @@ +package postgres + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "mayfly-go/internal/db/dbm/dbi" + machineapp "mayfly-go/internal/machine/application" + "mayfly-go/pkg/utils/collx" + "mayfly-go/pkg/utils/netx" + "net" + "strings" + "sync" + "time" + + pq "gitee.com/liuzongyang/libpq" +) + +var ( + meta dbi.Meta + once sync.Once +) + +func GetMeta() dbi.Meta { + once.Do(func() { + meta = new(PostgresMeta) + }) + return meta +} + +type PostgresMeta struct { +} + +func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) { + driverName := string(d.Type) + // SSH Conect + if d.SshTunnelMachineId > 0 { + // 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名 + driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId) + if !collx.ArrayContains(sql.Drivers(), driverName) { + sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId}) + } + sql.Drivers() + } + + db := d.Database + var dbParam string + exsitSchema := false + if db != "" { + // postgres database可以使用db/schema表示,方便连接指定schema, 若不存在schema则使用默认schema + ss := strings.Split(db, "/") + if len(ss) > 1 { + exsitSchema = true + dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1]) + } else { + dbParam = "dbname=" + db + } + } + + dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam) + // 存在额外指定参数,则拼接该连接参数 + if d.Params != "" { + // 存在指定的db,则需要将dbInstance配置中的parmas排除掉dbname和search_path + if db != "" { + paramArr := strings.Split(d.Params, "&") + paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool { + if strings.HasPrefix(param, "dbname=") { + return true + } + if exsitSchema && strings.HasPrefix(param, "search_path") { + return true + } + return false + }) + d.Params = strings.Join(paramArr, " ") + } + dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " ")) + } + + return sql.Open(driverName, dsn) +} + +func (md *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect { + return &PgsqlDialect{conn} +} + +// pgsql dialer +type PqSqlDialer struct { + sshTunnelMachineId int +} + +func (d *PqSqlDialer) Open(name string) (driver.Conn, error) { + return pq.DialOpen(d, name) +} + +func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) { + sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId) + if err != nil { + return nil, err + } + if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil { + // 将ssh conn包装,否则会返回错误: ssh: tcpChan: deadline not supported + return &netx.WrapSshConn{Conn: sshConn}, nil + } else { + return nil, err + } +} + +func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + return pd.Dial(network, address) +} diff --git a/server/internal/machine/api/machine_file.go b/server/internal/machine/api/machine_file.go index dcc9a120..02f5ee11 100644 --- a/server/internal/machine/api/machine_file.go +++ b/server/internal/machine/api/machine_file.go @@ -90,7 +90,6 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) { g := rc.GinCtx fid := GetMachineFileId(g) readPath := g.Query("path") - readType := g.Query("type") sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath) rc.ReqParam = collx.Kvs("machine", mi, "path", readPath) @@ -99,20 +98,27 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) { fileInfo, _ := sftpFile.Stat() filesize := fileInfo.Size() - // 如果是读取文件内容,则校验文件大小 - if readType != "1" { - biz.IsTrue(filesize < max_read_size, "文件超过1m,请使用下载查看") - } - // 如果读取类型为下载,则下载文件,否则获取文件内容 - if readType == "1" { - // 截取文件名,如/usr/local/test.java -》 test.java - path := strings.Split(readPath, "/") - rc.Download(sftpFile, path[len(path)-1]) - } else { - datas, err := io.ReadAll(sftpFile) - biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s") - rc.ResData = string(datas) - } + + biz.IsTrue(filesize < max_read_size, "文件超过1m,请使用下载查看") + datas, err := io.ReadAll(sftpFile) + biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s") + + rc.ResData = string(datas) +} + +func (m *MachineFile) DownloadFile(rc *req.Ctx) { + g := rc.GinCtx + fid := GetMachineFileId(g) + readPath := g.Query("path") + + sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath) + rc.ReqParam = collx.Kvs("machine", mi, "path", readPath) + biz.ErrIsNilAppendErr(err, "打开文件失败: %s") + defer sftpFile.Close() + + // 截取文件名,如/usr/local/test.java -》 test.java + path := strings.Split(readPath, "/") + rc.Download(sftpFile, path[len(path)-1]) } func (m *MachineFile) GetDirEntry(rc *req.Ctx) { diff --git a/server/internal/machine/router/machine_file.go b/server/internal/machine/router/machine_file.go index eaceaf58..59877cb9 100644 --- a/server/internal/machine/router/machine_file.go +++ b/server/internal/machine/router/machine_file.go @@ -25,7 +25,9 @@ func InitMachineFileRouter(router *gin.RouterGroup) { req.NewDelete(":machineId/files/:fileId", mf.DeleteFile).Log(req.NewLogSave("机器-删除文件配置")).RequiredPermissionCode("machine:file:del"), - req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-获取文件内容")), + req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-读取文件内容")), + + req.NewGet(":machineId/files/:fileId/download", mf.DownloadFile).NoRes().Log(req.NewLogSave("机器-文件下载")), req.NewGet(":machineId/files/:fileId/read-dir", mf.GetDirEntry),