fix: 机器文件下载问题修复&dbm重构

This commit is contained in:
meilin.huang
2024-01-12 13:15:30 +08:00
parent bc811cbd49
commit bfd346e65a
32 changed files with 454 additions and 322 deletions

View File

@@ -17,7 +17,7 @@
"countup.js": "^2.7.0",
"cropperjs": "^1.5.11",
"echarts": "^5.4.3",
"element-plus": "^2.5.0",
"element-plus": "^2.5.1",
"js-base64": "^3.7.5",
"jsencrypt": "^3.3.2",
"lodash": "^4.17.21",
@@ -33,7 +33,7 @@
"splitpanes": "^3.1.5",
"sql-formatter": "^14.0.0",
"uuid": "^9.0.1",
"vue": "^3.4.8",
"vue": "^3.4.10",
"vue-router": "^4.2.5",
"xterm": "^5.3.0",
"xterm-addon-fit": "^0.8.0",
@@ -42,7 +42,7 @@
},
"devDependencies": {
"@types/lodash": "^4.14.178",
"@types/node": "^15.6.0",
"@types/node": "^18.14.0",
"@types/nprogress": "^0.2.0",
"@types/sortablejs": "^1.15.3",
"@typescript-eslint/eslint-plugin": "^6.7.4",

View File

@@ -35,6 +35,7 @@ export const machineApi = {
mvFile: Api.newPost('/machines/{machineId}/files/{fileId}/mv'),
uploadFile: Api.newPost('/machines/{machineId}/files/{fileId}/upload?' + joinClientParams()),
fileContent: Api.newGet('/machines/{machineId}/files/{fileId}/read'),
downloadFile: Api.newGet('/machines/{machineId}/files/{fileId}/download'),
createFile: Api.newPost('/machines/{machineId}/files/{id}/create-file'),
// 修改文件内容
updateFileContent: Api.newPost('/machines/{machineId}/files/{id}/write'),

View File

@@ -611,7 +611,7 @@ const deleteFile = async (files: any) => {
const downloadFile = (data: any) => {
const a = document.createElement('a');
a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/read?type=1&path=${data.path}&${joinClientParams()}`);
a.setAttribute('href', `${config.baseApiUrl}/machines/${props.machineId}/files/${props.fileId}/download?path=${data.path}&${joinClientParams()}`);
a.click();
};

View File

@@ -9,7 +9,7 @@ import (
"mayfly-go/internal/db/api/form"
"mayfly-go/internal/db/api/vo"
"mayfly-go/internal/db/application"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
msgapp "mayfly-go/internal/msg/application"
msgdto "mayfly-go/internal/msg/application/dto"
@@ -351,7 +351,7 @@ func (d *Db) dumpDb(writer *gzipWriter, dbId uint64, dbName string, tables []str
writer.WriteString(fmt.Sprintf("\n-- ----------------------------\n-- 表记录: %s \n-- ----------------------------\n", table))
writer.WriteString("BEGIN;\n")
insertSql := "INSERT INTO %s VALUES (%s);\n"
dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbm.QueryColumn) error {
dbMeta.WalkTableRecord(table, func(record map[string]any, columns []*dbi.QueryColumn) error {
var values []string
writer.TryFlush()
for _, column := range columns {
@@ -470,7 +470,7 @@ func getDbName(g *gin.Context) string {
return db
}
func (d *Db) getDbConn(g *gin.Context) *dbm.DbConn {
func (d *Db) getDbConn(g *gin.Context) *dbi.DbConn {
dc, err := d.DbApp.GetDbConn(getDbId(g), getDbName(g))
biz.ErrIsNil(err)
return dc

View File

@@ -4,6 +4,7 @@ import (
"context"
"mayfly-go/internal/common/consts"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
tagapp "mayfly-go/internal/tag/application"
@@ -33,10 +34,10 @@ type Db interface {
// @param id 数据库id
//
// @param dbName 数据库名
GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error)
GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error)
// 根据数据库实例id获取连接随机返回该instanceId下已连接的conn若不存在则是使用该instanceId关联的db进行连接并返回。
GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error)
GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error)
}
func newDbApp(dbRepo repository.Db, dbSqlRepo repository.DbSql, dbInstanceApp Instance, tagApp tagapp.TagTree) Db {
@@ -142,8 +143,8 @@ func (d *dbAppImpl) Delete(ctx context.Context, id uint64) error {
})
}
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
return dbm.GetDbConn(dbId, dbName, func() (*dbm.DbInfo, error) {
func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbi.DbConn, error) {
return dbm.GetDbConn(dbId, dbName, func() (*dbi.DbInfo, error) {
db, err := d.GetById(new(entity.Db), dbId)
if err != nil {
return nil, errorx.NewBiz("数据库信息不存在")
@@ -156,7 +157,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
checkDb := dbName
// 兼容pgsql/dm db/schema模式
if dbm.DbTypePostgres.Equal(instance.Type) || dbm.DbTypeDM.Equal(instance.Type) {
if dbi.DbTypePostgres.Equal(instance.Type) || dbi.DbTypeDM.Equal(instance.Type) {
ss := strings.Split(dbName, "/")
if len(ss) > 1 {
checkDb = ss[0]
@@ -174,7 +175,7 @@ func (d *dbAppImpl) GetDbConn(dbId uint64, dbName string) (*dbm.DbConn, error) {
})
}
func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error) {
func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbi.DbConn, error) {
conn := dbm.GetDbConnByInstanceId(instanceId)
if conn != nil {
return conn, nil
@@ -193,8 +194,8 @@ func (d *dbAppImpl) GetDbConnByInstanceId(instanceId uint64) (*dbm.DbConn, error
return d.GetDbConn(firstDb.Id, strings.Split(firstDb.Database, " ")[0])
}
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbm.DbInfo {
di := new(dbm.DbInfo)
func toDbInfo(instance *entity.DbInstance, dbId uint64, database string, tagPath ...string) *dbi.DbInfo {
di := new(dbi.DbInfo)
di.InstanceId = instance.Id
di.Id = dbId
di.Database = database

View File

@@ -5,7 +5,7 @@ import (
"database/sql"
"encoding/json"
"fmt"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
@@ -182,20 +182,20 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
if err != nil {
return syncLog, errorx.NewBiz("解析字段映射json出错: %s", err.Error())
}
var updFieldType dbm.DataType
var updFieldType dbi.DataType
// 记录本次同步数据总数
total := 0
batchSize := task.PageSize
result := make([]map[string]any, 0)
var queryColumns []*dbm.QueryColumn
var queryColumns []*dbi.QueryColumn
err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbm.QueryColumn) error {
err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbi.QueryColumn) error {
if len(queryColumns) == 0 {
queryColumns = columns
// 遍历columns 取task.UpdField的字段类型
updFieldType = dbm.DataTypeString
updFieldType = dbi.DataTypeString
for _, column := range columns {
if column.Name == task.UpdField {
updFieldType = srcDialect.GetDataType(column.Type)
@@ -249,7 +249,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
return syncLog, nil
}
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbm.DataType, task *entity.DataSyncTask, srcDialect dbm.DbDialect, targetDbConn *dbm.DbConn, targetDbTx *sql.Tx) error {
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
var data = make([]map[string]any, 0)
// 遍历res组装插入sql

View File

@@ -4,7 +4,7 @@ import (
"context"
"errors"
"fmt"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/logx"
@@ -314,7 +314,7 @@ func (s *dbScheduler) runnable(job entity.DbJob, next runner.NextFunc) bool {
return true
}
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error {
binlogHistory, err := s.binlogHistoryRepo.GetHistoryByTime(job.DbInstanceId, job.PointInTime.Time)
if err != nil {
return err
@@ -341,7 +341,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg
if err != nil {
return err
}
restoreInfo := &dbm.RestoreInfo{
restoreInfo := &dbi.RestoreInfo{
BackupHistory: backupHistory,
BinlogHistories: binlogHistories,
StartPosition: backupHistory.BinlogPosition,
@@ -354,7 +354,7 @@ func (s *dbScheduler) restorePointInTime(ctx context.Context, program dbm.DbProg
return program.ReplayBinlog(ctx, job.DbName, job.DbName, restoreInfo)
}
func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbm.DbProgram, job *entity.DbRestore) error {
func (s *dbScheduler) restoreBackupHistory(ctx context.Context, program dbi.DbProgram, job *entity.DbRestore) error {
backupHistory := &entity.DbBackupHistory{}
if err := s.backupHistoryRepo.GetById(backupHistory, job.DbBackupHistoryId); err != nil {
return err

View File

@@ -4,7 +4,7 @@ import (
"context"
"fmt"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/contextx"
@@ -22,11 +22,11 @@ type DbSqlExecReq struct {
Db string
Sql string
Remark string
DbConn *dbm.DbConn
DbConn *dbi.DbConn
}
type DbSqlExecRes struct {
Columns []*dbm.QueryColumn
Columns []*dbi.QueryColumn
Res []map[string]any
}
@@ -269,7 +269,7 @@ func doInsert(ctx context.Context, insert *sqlparser.Insert, execSqlReq *DbSqlEx
return doExec(ctx, execSqlReq.Sql, execSqlReq.DbConn)
}
func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes, error) {
func doExec(ctx context.Context, sql string, dbConn *dbi.DbConn) (*DbSqlExecRes, error) {
rowsAffected, err := dbConn.ExecContext(ctx, sql)
execRes := "success"
if err != nil {
@@ -283,7 +283,7 @@ func doExec(ctx context.Context, sql string, dbConn *dbm.DbConn) (*DbSqlExecRes,
res = append(res, resData)
return &DbSqlExecRes{
Columns: []*dbm.QueryColumn{
Columns: []*dbi.QueryColumn{
{Name: "sql", Type: "string"},
{Name: "rowsAffected", Type: "number"},
{Name: "result", Type: "string"},

View File

@@ -3,6 +3,7 @@ package application
import (
"context"
"mayfly-go/internal/db/dbm"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/internal/db/domain/repository"
"mayfly-go/pkg/base"
@@ -50,7 +51,7 @@ func (app *instanceAppImpl) Count(condition *entity.InstanceQuery) int64 {
func (app *instanceAppImpl) TestConn(instanceEntity *entity.DbInstance) error {
instanceEntity.Network = instanceEntity.GetNetwork()
dbConn, err := toDbInfo(instanceEntity, 0, "", "").Conn()
dbConn, err := dbm.Conn(toDbInfo(instanceEntity, 0, "", ""))
if err != nil {
return err
}
@@ -100,9 +101,9 @@ func (app *instanceAppImpl) Delete(ctx context.Context, id uint64) error {
func (app *instanceAppImpl) GetDatabases(ed *entity.DbInstance) ([]string, error) {
ed.Network = ed.GetNetwork()
metaDb := dbm.ToDbType(ed.Type).MetaDbName()
metaDb := dbi.ToDbType(ed.Type).MetaDbName()
dbConn, err := toDbInfo(ed, 0, metaDb, "").Conn()
dbConn, err := dbm.Conn(toDbInfo(ed, 0, metaDb, ""))
if err != nil {
return nil, err
}

View File

@@ -1,4 +1,4 @@
package dbm
package dbi
import (
"context"
@@ -117,18 +117,9 @@ func (d *DbConn) Begin() (*sql.Tx, error) {
return d.db.Begin()
}
// 获取数据库元信息实现接口
func (d *DbConn) GetDialect() DbDialect {
switch d.Info.Type {
case DbTypeMysql, DbTypeMariadb:
return &MysqlDialect{dc: d}
case DbTypePostgres:
return &PgsqlDialect{dc: d}
case DbTypeDM:
return &DMDialect{dc: d}
default:
panic(fmt.Sprintf("invalid database type: %s", d.Info.Type))
}
// 获取数据库dialect实现接口
func (d *DbConn) GetDialect() Dialect {
return d.Info.Meta.GetDialect(d)
}
// 关闭连接

View File

@@ -1,4 +1,4 @@
package dbm
package dbi
import (
"context"

View File

@@ -1,4 +1,4 @@
package dbm
package dbi
import (
"fmt"

View File

@@ -1,4 +1,4 @@
package dbm
package dbi
import (
"testing"

View File

@@ -1,12 +1,13 @@
package dbm
package dbi
import (
"database/sql"
"embed"
"strings"
"mayfly-go/pkg/biz"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"strings"
)
type DataType string
@@ -60,7 +61,7 @@ type Index struct {
// -----------------------------------元数据接口定义------------------------------------------
// 数据库方言、元信息接口(表、列、获取表数据等元信息)
type DbDialect interface {
type Dialect interface {
// 获取数据库服务实例信息
GetDbServer() (*DbServer, error)

View File

@@ -1,4 +1,4 @@
package dbm
package dbi
import (
"database/sql"
@@ -8,6 +8,9 @@ import (
"mayfly-go/pkg/logx"
)
// 获取sql.DB函数
type GetSqlDbFunc func(*DbInfo) (*sql.DB, error)
type DbInfo struct {
InstanceId uint64 // 实例id
Id uint64 // dbId
@@ -24,6 +27,8 @@ type DbInfo struct {
TagPath []string
SshTunnelMachineId int
Meta Meta
}
// 获取记录日志的描述
@@ -32,22 +37,16 @@ func (d *DbInfo) GetLogDesc() string {
}
// 连接数据库
func (dbInfo *DbInfo) Conn() (*DbConn, error) {
var conn *sql.DB
var err error
database := dbInfo.Database
switch dbInfo.Type {
case DbTypeMysql, DbTypeMariadb:
conn, err = getMysqlDB(dbInfo)
case DbTypePostgres:
conn, err = getPgsqlDB(dbInfo)
case DbTypeDM:
conn, err = getDmDB(dbInfo)
default:
return nil, errorx.NewBiz("invalid database type: %s", dbInfo.Type)
func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) {
if meta == nil {
return nil, errorx.NewBiz("数据库元信息接口不能为空")
}
// 赋值Meta方便后续获取dialect等
dbInfo.Meta = meta
database := dbInfo.Database
conn, err := meta.GetSqlDb(dbInfo)
if err != nil {
logx.Errorf("连接db失败: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error())
return nil, errorx.NewBiz(fmt.Sprintf("数据库连接失败: %s", err.Error()))

View File

@@ -0,0 +1,12 @@
package dbi
import "database/sql"
// 数据库元信息获取如获取sql.DB、Dialect等
type Meta interface {
// 获取数据库服务实例信息
GetSqlDb(*DbInfo) (*sql.DB, error)
// 获取数据库方言,用于获取表结构等信息
GetDialect(*DbConn) Dialect
}

View File

@@ -1,4 +1,4 @@
package dbm
package dbi
import (
"database/sql"

View File

@@ -3,6 +3,10 @@ package dbm
import (
"fmt"
"mayfly-go/internal/common/consts"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/dm"
"mayfly-go/internal/db/dbm/mysql"
"mayfly-go/internal/db/dbm/postgres"
"mayfly-go/internal/machine/mcm"
"mayfly-go/pkg/cache"
"mayfly-go/pkg/logx"
@@ -15,7 +19,7 @@ var connCache = cache.NewTimedCache(consts.DbConnExpireTime, 5*time.Second).
WithUpdateAccessTime(true).
OnEvicted(func(key any, value any) {
logx.Info(fmt.Sprintf("删除db连接缓存 id = %s", key))
value.(*DbConn).Close()
value.(*dbi.DbConn).Close()
})
func init() {
@@ -23,7 +27,7 @@ func init() {
// 遍历所有db连接实例若存在db实例使用该ssh隧道机器则返回true表示还在使用中...
items := connCache.Items()
for _, v := range items {
if v.Value.(*DbConn).Info.SshTunnelMachineId == machineId {
if v.Value.(*dbi.DbConn).Info.SshTunnelMachineId == machineId {
return true
}
}
@@ -33,8 +37,21 @@ func init() {
var mutex sync.Mutex
func getDbMetaByType(dt dbi.DbType) dbi.Meta {
switch dt {
case dbi.DbTypeMysql, dbi.DbTypeMariadb:
return mysql.GetMeta()
case dbi.DbTypePostgres:
return postgres.GetMeta()
case dbi.DbTypeDM:
return dm.GetMeta()
default:
panic(fmt.Sprintf("invalid database type: %s", dt))
}
}
// 从缓存中获取数据库连接信息若缓存中不存在则会使用回调函数获取dbInfo进行连接并缓存
func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error)) (*DbConn, error) {
func GetDbConn(dbId uint64, database string, getDbInfo func() (*dbi.DbInfo, error)) (*dbi.DbConn, error) {
connId := GetDbConnId(dbId, database)
// connId不为空则为需要缓存
@@ -42,7 +59,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
if needCache {
load, ok := connCache.Get(connId)
if ok {
return load.(*DbConn), nil
return load.(*dbi.DbConn), nil
}
}
@@ -56,7 +73,7 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
}
// 连接数据库
dbConn, err := dbInfo.Conn()
dbConn, err := Conn(dbInfo)
if err != nil {
return nil, err
}
@@ -67,10 +84,15 @@ func GetDbConn(dbId uint64, database string, getDbInfo func() (*DbInfo, error))
return dbConn, nil
}
// 使用指定dbInfo信息进行连接
func Conn(di *dbi.DbInfo) (*dbi.DbConn, error) {
return di.Conn(getDbMetaByType(di.Type))
}
// 根据实例id获取连接
func GetDbConnByInstanceId(instanceId uint64) *DbConn {
func GetDbConnByInstanceId(instanceId uint64) *dbi.DbConn {
for _, connItem := range connCache.Items() {
conn := connItem.Value.(*DbConn)
conn := connItem.Value.(*dbi.DbConn)
if conn.Info.InstanceId == instanceId {
return conn
}
@@ -82,3 +104,12 @@ func GetDbConnByInstanceId(instanceId uint64) *DbConn {
func CloseDb(dbId uint64, db string) {
connCache.Delete(GetDbConnId(dbId, db))
}
// 获取连接id
func GetDbConnId(dbId uint64, db string) string {
if dbId == 0 {
return ""
}
return fmt.Sprintf("%d:%s", dbId, db)
}

View File

@@ -1,9 +1,10 @@
package dbm
package dm
import (
"context"
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/anyx"
@@ -14,30 +15,6 @@ import (
_ "gitee.com/chunanyong/dm"
)
func getDmDB(d *DbInfo) (*sql.DB, error) {
driverName := "dm"
db := d.Database
var dbParam string
if db != "" {
// dm database可以使用db/schema表示方便连接指定schema, 若不存在schema则使用默认schema
ss := strings.Split(db, "/")
if len(ss) > 1 {
dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
} else {
dbParam = db
}
}
err := d.IfUseSshTunnelChangeIpPort()
if err != nil {
return nil, err
}
dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)
return sql.Open(driverName, dsn)
}
// ---------------------------------- DM元数据 -----------------------------------
const (
DM_META_FILE = "metasql/dm_meta.sql"
DM_DB_SCHEMAS = "DM_DB_SCHEMAS"
@@ -47,15 +24,15 @@ const (
)
type DMDialect struct {
dc *DbConn
dc *dbi.DbConn
}
func (dd *DMDialect) GetDbServer() (*DbServer, error) {
func (dd *DMDialect) GetDbServer() (*dbi.DbServer, error) {
_, res, err := dd.dc.Query("select * from v$instance")
if err != nil {
return nil, err
}
ds := &DbServer{
ds := &dbi.DbServer{
Version: anyx.ConvString(res[0]["SVR_VERSION"]),
}
return ds, nil
@@ -76,20 +53,20 @@ func (dd *DMDialect) GetDbNames() ([]string, error) {
}
// 获取表基础元信息, 如表名等
func (dd *DMDialect) GetTables() ([]Table, error) {
func (dd *DMDialect) GetTables() ([]dbi.Table, error) {
// 首先执行更新统计信息sql 这个统计信息在数据量比较大的时候就比较耗时,所以最好定时执行
// _, _, err := pd.dc.Query("dbms_stats.GATHER_SCHEMA_stats(SELECT SF_GET_SCHEMA_NAME_BY_ID(CURRENT_SCHID))")
// 查询表信息
_, res, err := dd.dc.Query(GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
_, res, err := dd.dc.Query(dbi.GetLocalSql(DM_META_FILE, DM_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
tables := make([]Table, 0)
tables := make([]dbi.Table, 0)
for _, re := range res {
tables = append(tables, Table{
tables = append(tables, dbi.Table{
TableName: re["TABLE_NAME"].(string),
TableComment: anyx.ConvString(re["TABLE_COMMENT"]),
CreateTime: anyx.ConvString(re["CREATE_TIME"]),
@@ -102,7 +79,7 @@ func (dd *DMDialect) GetTables() ([]Table, error) {
}
// 获取列元信息, 如列名等
func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
func (dd *DMDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
tableName := ""
for i := 0; i < len(tableNames); i++ {
if i != 0 {
@@ -111,14 +88,14 @@ func (dd *DMDialect) GetColumns(tableNames ...string) ([]Column, error) {
tableName = tableName + "'" + tableNames[i] + "'"
}
_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
columns := make([]Column, 0)
columns := make([]dbi.Column, 0)
for _, re := range res {
columns = append(columns, Column{
columns = append(columns, dbi.Column{
TableName: re["TABLE_NAME"].(string),
ColumnName: re["COLUMN_NAME"].(string),
ColumnType: anyx.ConvString(re["COLUMN_TYPE"]),
@@ -150,15 +127,15 @@ func (dd *DMDialect) GetPrimaryKey(tablename string) (string, error) {
}
// 获取表索引信息
func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := dd.dc.Query(fmt.Sprintf(GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
func (dd *DMDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
_, res, err := dd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(DM_META_FILE, DM_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
}
indexs := make([]Index, 0)
indexs := make([]dbi.Index, 0)
for _, re := range res {
indexs = append(indexs, Index{
indexs = append(indexs, dbi.Index{
IndexName: re["INDEX_NAME"].(string),
ColumnName: anyx.ConvString(re["COLUMN_NAME"]),
IndexType: anyx.ConvString(re["INDEX_TYPE"]),
@@ -168,7 +145,7 @@ func (dd *DMDialect) GetTableIndex(tableName string) ([]Index, error) {
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
result := make([]Index, 0)
result := make([]dbi.Index, 0)
key := ""
for _, v := range indexs {
// 当前的索引名
@@ -255,13 +232,13 @@ func (dd *DMDialect) GetTableDDL(tableName string) (string, error) {
return builder.String(), nil
}
func (dd *DMDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
func (dd *DMDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
return dd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
}
// 获取DM当前连接的库可访问的schemaNames
func (dd *DMDialect) GetSchemas() ([]string, error) {
sql := GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
sql := dbi.GetLocalSql(DM_META_FILE, DM_DB_SCHEMAS)
_, res, err := dd.dc.Query(sql)
if err != nil {
return nil, err
@@ -274,27 +251,27 @@ func (dd *DMDialect) GetSchemas() ([]string, error) {
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (dd *DMDialect) GetDbProgram() DbProgram {
func (dd *DMDialect) GetDbProgram() dbi.DbProgram {
panic("implement me")
}
func (dd *DMDialect) GetDataType(dbColumnType string) DataType {
func (dd *DMDialect) GetDataType(dbColumnType string) dbi.DataType {
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
return DataTypeNumber
return dbi.DataTypeNumber
}
// 日期时间类型
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
return DataTypeDateTime
return dbi.DataTypeDateTime
}
// 日期类型
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
return DataTypeDate
return dbi.DataTypeDate
}
// 时间类型
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
return DataTypeTime
return dbi.DataTypeTime
}
return DataTypeString
return dbi.DataTypeString
}
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
@@ -322,15 +299,15 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string,
return int64(effRows), nil
}
func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
func (dd *DMDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
switch dataType {
case DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
case dbi.DataTypeDateTime: // "2024-01-02T22:08:22.275697+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue)
return res.Format(time.DateTime)
case DataTypeDate: // "2024-01-02T00:00:00+08:00"
case dbi.DataTypeDate: // "2024-01-02T00:00:00+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue)
return res.Format(time.DateOnly)
case DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
case dbi.DataTypeTime: // "0000-01-01T22:08:22.275688+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue)
return res.Format(time.TimeOnly)
}

View File

@@ -0,0 +1,51 @@
package dm
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"strings"
"sync"
)
var (
meta dbi.Meta
once sync.Once
)
func GetMeta() dbi.Meta {
once.Do(func() {
meta = new(DmMeta)
})
return meta
}
type DmMeta struct {
}
func (md *DmMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
driverName := "dm"
db := d.Database
var dbParam string
if db != "" {
// dm database可以使用db/schema表示方便连接指定schema, 若不存在schema则使用默认schema
ss := strings.Split(db, "/")
if len(ss) > 1 {
dbParam = fmt.Sprintf("%s?schema=%s", ss[0], ss[len(ss)-1])
} else {
dbParam = db
}
}
err := d.IfUseSshTunnelChangeIpPort()
if err != nil {
return nil, err
}
dsn := fmt.Sprintf("dm://%s:%s@%s:%d/%s", d.Username, d.Password, d.Host, d.Port, dbParam)
return sql.Open(driverName, dsn)
}
func (md *DmMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
return &DMDialect{conn}
}

View File

@@ -1,40 +1,16 @@
package dbm
package mysql
import (
"context"
"database/sql"
"fmt"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/utils/anyx"
"net"
"regexp"
"strings"
"github.com/go-sql-driver/mysql"
)
func getMysqlDB(d *DbInfo) (*sql.DB, error) {
// SSH Conect
if d.SshTunnelMachineId > 0 {
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
if err != nil {
return nil, err
}
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
return sshTunnelMachine.GetDialConn("tcp", addr)
})
}
// 设置dataSourceName -> 更多参数参考https://github.com/go-sql-driver/mysql#dsn-data-source-name
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
const driverName = "mysql"
return sql.Open(driverName, dsn)
}
// ---------------------------------- mysql元数据 -----------------------------------
const (
MYSQL_META_FILE = "metasql/mysql_meta.sql"
MYSQL_DBS = "MYSQL_DBS"
@@ -44,22 +20,22 @@ const (
)
type MysqlDialect struct {
dc *DbConn
dc *dbi.DbConn
}
func (md *MysqlDialect) GetDbServer() (*DbServer, error) {
func (md *MysqlDialect) GetDbServer() (*dbi.DbServer, error) {
_, res, err := md.dc.Query("SELECT VERSION() version")
if err != nil {
return nil, err
}
ds := &DbServer{
ds := &dbi.DbServer{
Version: anyx.ConvString(res[0]["version"]),
}
return ds, nil
}
func (md *MysqlDialect) GetDbNames() ([]string, error) {
_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_DBS))
_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_DBS))
if err != nil {
return nil, err
}
@@ -73,15 +49,15 @@ func (md *MysqlDialect) GetDbNames() ([]string, error) {
}
// 获取表基础元信息, 如表名等
func (md *MysqlDialect) GetTables() ([]Table, error) {
_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
func (md *MysqlDialect) GetTables() ([]dbi.Table, error) {
_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
tables := make([]Table, 0)
tables := make([]dbi.Table, 0)
for _, re := range res {
tables = append(tables, Table{
tables = append(tables, dbi.Table{
TableName: re["tableName"].(string),
TableComment: anyx.ConvString(re["tableComment"]),
CreateTime: anyx.ConvString(re["createTime"]),
@@ -94,7 +70,7 @@ func (md *MysqlDialect) GetTables() ([]Table, error) {
}
// 获取列元信息, 如列名等
func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
func (md *MysqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
tableName := ""
for i := 0; i < len(tableNames); i++ {
if i != 0 {
@@ -103,14 +79,14 @@ func (md *MysqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
tableName = tableName + "'" + tableNames[i] + "'"
}
_, res, err := md.dc.Query(fmt.Sprintf(GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
columns := make([]Column, 0)
columns := make([]dbi.Column, 0)
for _, re := range res {
columns = append(columns, Column{
columns = append(columns, dbi.Column{
TableName: re["tableName"].(string),
ColumnName: re["columnName"].(string),
ColumnType: anyx.ConvString(re["columnType"]),
@@ -144,15 +120,15 @@ func (md *MysqlDialect) GetPrimaryKey(tablename string) (string, error) {
}
// 获取表索引信息
func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := md.dc.Query(GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)
func (md *MysqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
_, res, err := md.dc.Query(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_INDEX_INFO_KEY), tableName)
if err != nil {
return nil, err
}
indexs := make([]Index, 0)
indexs := make([]dbi.Index, 0)
for _, re := range res {
indexs = append(indexs, Index{
indexs = append(indexs, dbi.Index{
IndexName: re["indexName"].(string),
ColumnName: anyx.ConvString(re["columnName"]),
IndexType: anyx.ConvString(re["indexType"]),
@@ -162,7 +138,7 @@ func (md *MysqlDialect) GetTableIndex(tableName string) ([]Index, error) {
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
result := make([]Index, 0)
result := make([]dbi.Index, 0)
key := ""
for _, v := range indexs {
// 当前的索引名
@@ -189,7 +165,7 @@ func (md *MysqlDialect) GetTableDDL(tableName string) (string, error) {
return res[0]["Create Table"].(string) + ";", nil
}
func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
return md.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
}
@@ -198,27 +174,27 @@ func (md *MysqlDialect) GetSchemas() ([]string, error) {
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (md *MysqlDialect) GetDbProgram() DbProgram {
func (md *MysqlDialect) GetDbProgram() dbi.DbProgram {
return NewDbProgramMysql(md.dc)
}
func (md *MysqlDialect) GetDataType(dbColumnType string) DataType {
func (md *MysqlDialect) GetDataType(dbColumnType string) dbi.DataType {
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
return DataTypeNumber
return dbi.DataTypeNumber
}
// 日期时间类型
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
return DataTypeDateTime
return dbi.DataTypeDateTime
}
// 日期类型
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
return DataTypeDate
return dbi.DataTypeDate
}
// 时间类型
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
return DataTypeTime
return dbi.DataTypeTime
}
return DataTypeString
return dbi.DataTypeString
}
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
@@ -246,7 +222,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
return md.dc.TxExec(tx, sqlStr, args...)
}
func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
func (md *MysqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
// mysql不需要格式化时间日期等
return dbColumnValue
}

View File

@@ -0,0 +1,52 @@
package mysql
import (
"context"
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
machineapp "mayfly-go/internal/machine/application"
"net"
"sync"
"github.com/go-sql-driver/mysql"
)
var (
meta dbi.Meta
once sync.Once
)
func GetMeta() dbi.Meta {
once.Do(func() {
meta = new(MysqlMeta)
})
return meta
}
type MysqlMeta struct {
}
func (md *MysqlMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
// SSH Conect
if d.SshTunnelMachineId > 0 {
sshTunnelMachine, err := machineapp.GetMachineApp().GetSshTunnelMachine(d.SshTunnelMachineId)
if err != nil {
return nil, err
}
mysql.RegisterDialContext(d.Network, func(ctx context.Context, addr string) (net.Conn, error) {
return sshTunnelMachine.GetDialConn("tcp", addr)
})
}
// 设置dataSourceName -> 更多参数参考https://github.com/go-sql-driver/mysql#dsn-data-source-name
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
const driverName = "mysql"
return sql.Open(driverName, dsn)
}
func (md *MysqlMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
return &MysqlDialect{conn}
}

View File

@@ -1,4 +1,4 @@
package dbm
package mysql
import (
"bufio"
@@ -19,27 +19,28 @@ import (
"golang.org/x/sync/singleflight"
"mayfly-go/internal/db/config"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/domain/entity"
"mayfly-go/pkg/logx"
)
var _ DbProgram = (*DbProgramMysql)(nil)
var _ dbi.DbProgram = (*DbProgramMysql)(nil)
type DbProgramMysql struct {
dbConn *DbConn
dbConn *dbi.DbConn
// mysqlBin 用于集成测试
mysqlBin *config.MysqlBin
// backupPath 用于集成测试
backupPath string
}
func NewDbProgramMysql(dbConn *DbConn) *DbProgramMysql {
func NewDbProgramMysql(dbConn *dbi.DbConn) *DbProgramMysql {
return &DbProgramMysql{
dbConn: dbConn,
}
}
func (svc *DbProgramMysql) dbInfo() *DbInfo {
func (svc *DbProgramMysql) dbInfo() *dbi.DbInfo {
dbInfo := svc.dbConn.Info
err := dbInfo.IfUseSshTunnelChangeIpPort()
if err != nil {
@@ -55,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
dbInfo := svc.dbInfo()
var mysqlBin *config.MysqlBin
switch dbInfo.Type {
case DbTypeMariadb:
case dbi.DbTypeMariadb:
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin)
case DbTypeMysql:
case dbi.DbTypeMysql:
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
default:
panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type))
@@ -488,7 +489,7 @@ func (svc *DbProgramMysql) GetBinlogEventPositionAtOrAfterTime(ctx context.Conte
}
// ReplayBinlog replays the binlog for `originDatabase` from `startBinlogInfo.Position` to `targetTs`, read binlog from `binlogDir`.
func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *RestoreInfo) (replayErr error) {
func (svc *DbProgramMysql) ReplayBinlog(ctx context.Context, originalDatabase, targetDatabase string, restoreInfo *dbi.RestoreInfo) (replayErr error) {
const (
// Variable lower_case_table_names related.

View File

@@ -1,6 +1,6 @@
//go:build e2e
package dbm
package mysql
import (
"context"

View File

@@ -1,10 +1,11 @@
package dbm
package mysql
import (
"github.com/stretchr/testify/require"
"mayfly-go/internal/db/domain/entity"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func Test_readBinlogInfoFromBackup(t *testing.T) {

View File

@@ -1,99 +1,17 @@
package dbm
package postgres
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/netx"
"net"
"regexp"
"strings"
"time"
pq "gitee.com/liuzongyang/libpq"
)
func getPgsqlDB(d *DbInfo) (*sql.DB, error) {
driverName := string(d.Type)
// SSH Conect
if d.SshTunnelMachineId > 0 {
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
if !collx.ArrayContains(sql.Drivers(), driverName) {
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
}
sql.Drivers()
}
db := d.Database
var dbParam string
exsitSchema := false
if db != "" {
// postgres database可以使用db/schema表示方便连接指定schema, 若不存在schema则使用默认schema
ss := strings.Split(db, "/")
if len(ss) > 1 {
exsitSchema = true
dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
} else {
dbParam = "dbname=" + db
}
}
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
// 存在额外指定参数,则拼接该连接参数
if d.Params != "" {
// 存在指定的db则需要将dbInstance配置中的parmas排除掉dbname和search_path
if db != "" {
paramArr := strings.Split(d.Params, "&")
paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
if strings.HasPrefix(param, "dbname=") {
return true
}
if exsitSchema && strings.HasPrefix(param, "search_path") {
return true
}
return false
})
d.Params = strings.Join(paramArr, " ")
}
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
}
return sql.Open(driverName, dsn)
}
// pgsql dialer
type PqSqlDialer struct {
sshTunnelMachineId int
}
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
return pq.DialOpen(d, name)
}
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
if err != nil {
return nil, err
}
if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
// 将ssh conn包装否则会返回错误: ssh: tcpChan: deadline not supported
return &netx.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return pd.Dial(network, address)
}
// ---------------------------------- pgsql元数据 -----------------------------------
const (
PGSQL_META_FILE = "metasql/pgsql_meta.sql"
PGSQL_DB_SCHEMAS = "PGSQL_DB_SCHEMAS"
@@ -104,15 +22,15 @@ const (
)
type PgsqlDialect struct {
dc *DbConn
dc *dbi.DbConn
}
func (pd *PgsqlDialect) GetDbServer() (*DbServer, error) {
func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
_, res, err := pd.dc.Query("SHOW server_version")
if err != nil {
return nil, err
}
ds := &DbServer{
ds := &dbi.DbServer{
Version: anyx.ConvString(res[0]["server_version"]),
}
return ds, nil
@@ -133,15 +51,15 @@ func (pd *PgsqlDialect) GetDbNames() ([]string, error) {
}
// 获取表基础元信息, 如表名等
func (pd *PgsqlDialect) GetTables() ([]Table, error) {
_, res, err := pd.dc.Query(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) {
_, res, err := pd.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
tables := make([]Table, 0)
tables := make([]dbi.Table, 0)
for _, re := range res {
tables = append(tables, Table{
tables = append(tables, dbi.Table{
TableName: re["tableName"].(string),
TableComment: anyx.ConvString(re["tableComment"]),
CreateTime: anyx.ConvString(re["createTime"]),
@@ -154,7 +72,7 @@ func (pd *PgsqlDialect) GetTables() ([]Table, error) {
}
// 获取列元信息, 如列名等
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
tableName := ""
for i := 0; i < len(tableNames); i++ {
if i != 0 {
@@ -163,14 +81,14 @@ func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]Column, error) {
tableName = tableName + "'" + tableNames[i] + "'"
}
_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
columns := make([]Column, 0)
columns := make([]dbi.Column, 0)
for _, re := range res {
columns = append(columns, Column{
columns = append(columns, dbi.Column{
TableName: re["tableName"].(string),
ColumnName: re["columnName"].(string),
ColumnType: anyx.ConvString(re["columnType"]),
@@ -202,15 +120,15 @@ func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
}
// 获取表索引信息
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
_, res, err := pd.dc.Query(fmt.Sprintf(GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
}
indexs := make([]Index, 0)
indexs := make([]dbi.Index, 0)
for _, re := range res {
indexs = append(indexs, Index{
indexs = append(indexs, dbi.Index{
IndexName: re["indexName"].(string),
ColumnName: anyx.ConvString(re["columnName"]),
IndexType: anyx.ConvString(re["IndexType"]),
@@ -220,7 +138,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
result := make([]Index, 0)
result := make([]dbi.Index, 0)
key := ""
for _, v := range indexs {
// 当前的索引名
@@ -240,7 +158,7 @@ func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]Index, error) {
// 获取建表ddl
func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
_, err := pd.dc.Exec(GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
_, err := pd.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
if err != nil {
return "", err
}
@@ -257,13 +175,13 @@ func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
return res[0]["sql"].(string), nil
}
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
func (pd *PgsqlDialect) WalkTableRecord(tableName string, walkFn dbi.WalkQueryRowsFunc) error {
return pd.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
}
// 获取pgsql当前连接的库可访问的schemaNames
func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
sql := GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
_, res, err := pd.dc.Query(sql)
if err != nil {
return nil, err
@@ -276,27 +194,27 @@ func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (pd *PgsqlDialect) GetDbProgram() DbProgram {
func (pd *PgsqlDialect) GetDbProgram() dbi.DbProgram {
panic("implement me")
}
func (pd *PgsqlDialect) GetDataType(dbColumnType string) DataType {
func (pd *PgsqlDialect) GetDataType(dbColumnType string) dbi.DataType {
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
return DataTypeNumber
return dbi.DataTypeNumber
}
// 日期时间类型
if regexp.MustCompile(`(?i)datetime|timestamp`).MatchString(dbColumnType) {
return DataTypeDateTime
return dbi.DataTypeDateTime
}
// 日期类型
if regexp.MustCompile(`(?i)date`).MatchString(dbColumnType) {
return DataTypeDate
return dbi.DataTypeDate
}
// 时间类型
if regexp.MustCompile(`(?i)time`).MatchString(dbColumnType) {
return DataTypeTime
return dbi.DataTypeTime
}
return DataTypeString
return dbi.DataTypeString
}
func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
@@ -325,15 +243,15 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
return pd.dc.TxExec(tx, sqlStr, args...)
}
func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string {
func (pd *PgsqlDialect) FormatStrData(dbColumnValue string, dataType dbi.DataType) string {
switch dataType {
case DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
case dbi.DataTypeDateTime: // "2024-01-02T22:16:28.545377+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue)
return res.Format(time.DateTime)
case DataTypeDate: // "2024-01-02T00:00:00Z"
case dbi.DataTypeDate: // "2024-01-02T00:00:00Z"
res, _ := time.Parse(time.RFC3339, dbColumnValue)
return res.Format(time.DateOnly)
case DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
case dbi.DataTypeTime: // "0000-01-01T22:16:28.545075+08:00"
res, _ := time.Parse(time.RFC3339, dbColumnValue)
return res.Format(time.TimeOnly)
}

View File

@@ -0,0 +1,111 @@
package postgres
import (
"database/sql"
"database/sql/driver"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
machineapp "mayfly-go/internal/machine/application"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/netx"
"net"
"strings"
"sync"
"time"
pq "gitee.com/liuzongyang/libpq"
)
var (
meta dbi.Meta
once sync.Once
)
func GetMeta() dbi.Meta {
once.Do(func() {
meta = new(PostgresMeta)
})
return meta
}
type PostgresMeta struct {
}
func (md *PostgresMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
driverName := string(d.Type)
// SSH Conect
if d.SshTunnelMachineId > 0 {
// 如果使用了隧道,则使用`postgres:ssh:隧道机器id`注册名
driverName = fmt.Sprintf("postgres:ssh:%d", d.SshTunnelMachineId)
if !collx.ArrayContains(sql.Drivers(), driverName) {
sql.Register(driverName, &PqSqlDialer{sshTunnelMachineId: d.SshTunnelMachineId})
}
sql.Drivers()
}
db := d.Database
var dbParam string
exsitSchema := false
if db != "" {
// postgres database可以使用db/schema表示方便连接指定schema, 若不存在schema则使用默认schema
ss := strings.Split(db, "/")
if len(ss) > 1 {
exsitSchema = true
dbParam = fmt.Sprintf("dbname=%s search_path=%s", ss[0], ss[len(ss)-1])
} else {
dbParam = "dbname=" + db
}
}
dsn := fmt.Sprintf("host=%s port=%d user=%s password=%s %s sslmode=disable connect_timeout=8", d.Host, d.Port, d.Username, d.Password, dbParam)
// 存在额外指定参数,则拼接该连接参数
if d.Params != "" {
// 存在指定的db则需要将dbInstance配置中的parmas排除掉dbname和search_path
if db != "" {
paramArr := strings.Split(d.Params, "&")
paramArr = collx.ArrayRemoveFunc(paramArr, func(param string) bool {
if strings.HasPrefix(param, "dbname=") {
return true
}
if exsitSchema && strings.HasPrefix(param, "search_path") {
return true
}
return false
})
d.Params = strings.Join(paramArr, " ")
}
dsn = fmt.Sprintf("%s %s", dsn, strings.Join(strings.Split(d.Params, "&"), " "))
}
return sql.Open(driverName, dsn)
}
func (md *PostgresMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
return &PgsqlDialect{conn}
}
// pgsql dialer
type PqSqlDialer struct {
sshTunnelMachineId int
}
func (d *PqSqlDialer) Open(name string) (driver.Conn, error) {
return pq.DialOpen(d, name)
}
func (pd *PqSqlDialer) Dial(network, address string) (net.Conn, error) {
sshTunnel, err := machineapp.GetMachineApp().GetSshTunnelMachine(pd.sshTunnelMachineId)
if err != nil {
return nil, err
}
if sshConn, err := sshTunnel.GetDialConn("tcp", address); err == nil {
// 将ssh conn包装否则会返回错误: ssh: tcpChan: deadline not supported
return &netx.WrapSshConn{Conn: sshConn}, nil
} else {
return nil, err
}
}
func (pd *PqSqlDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return pd.Dial(network, address)
}

View File

@@ -90,7 +90,6 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) {
g := rc.GinCtx
fid := GetMachineFileId(g)
readPath := g.Query("path")
readType := g.Query("type")
sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath)
rc.ReqParam = collx.Kvs("machine", mi, "path", readPath)
@@ -99,20 +98,27 @@ func (m *MachineFile) ReadFileContent(rc *req.Ctx) {
fileInfo, _ := sftpFile.Stat()
filesize := fileInfo.Size()
// 如果是读取文件内容,则校验文件大小
if readType != "1" {
biz.IsTrue(filesize < max_read_size, "文件超过1m请使用下载查看")
}
// 如果读取类型为下载,则下载文件,否则获取文件内容
if readType == "1" {
// 截取文件名,如/usr/local/test.java -》 test.java
path := strings.Split(readPath, "/")
rc.Download(sftpFile, path[len(path)-1])
} else {
datas, err := io.ReadAll(sftpFile)
biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s")
rc.ResData = string(datas)
}
biz.IsTrue(filesize < max_read_size, "文件超过1m请使用下载查看")
datas, err := io.ReadAll(sftpFile)
biz.ErrIsNilAppendErr(err, "读取文件内容失败: %s")
rc.ResData = string(datas)
}
func (m *MachineFile) DownloadFile(rc *req.Ctx) {
g := rc.GinCtx
fid := GetMachineFileId(g)
readPath := g.Query("path")
sftpFile, mi, err := m.MachineFileApp.ReadFile(fid, readPath)
rc.ReqParam = collx.Kvs("machine", mi, "path", readPath)
biz.ErrIsNilAppendErr(err, "打开文件失败: %s")
defer sftpFile.Close()
// 截取文件名,如/usr/local/test.java -》 test.java
path := strings.Split(readPath, "/")
rc.Download(sftpFile, path[len(path)-1])
}
func (m *MachineFile) GetDirEntry(rc *req.Ctx) {

View File

@@ -25,7 +25,9 @@ func InitMachineFileRouter(router *gin.RouterGroup) {
req.NewDelete(":machineId/files/:fileId", mf.DeleteFile).Log(req.NewLogSave("机器-删除文件配置")).RequiredPermissionCode("machine:file:del"),
req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-取文件内容")),
req.NewGet(":machineId/files/:fileId/read", mf.ReadFileContent).Log(req.NewLogSave("机器-取文件内容")),
req.NewGet(":machineId/files/:fileId/download", mf.DownloadFile).NoRes().Log(req.NewLogSave("机器-文件下载")),
req.NewGet(":machineId/files/:fileId/read-dir", mf.GetDirEntry),