diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index a53606a4..171f0f73 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -175,8 +175,6 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* }() srcDialect := srcConn.GetDialect() - // 记录更新字段最新值 - targetDialect := targetConn.GetDialect() // task.FieldMap为json数组字符串 [{"src":"id","target":"id"}],转为map var fieldMap []map[string]string @@ -209,7 +207,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* total++ result = append(result, row) if total%batchSize == 0 { - if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetDialect, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { return err } @@ -231,7 +229,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* // 处理剩余的数据 if len(result) > 0 { - if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetDialect, targetDbTx); err != nil { + if err := app.srcData2TargetDb(result, fieldMap, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil { targetDbTx.Rollback() return syncLog, err } @@ -251,7 +249,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (* return syncLog, nil } -func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbm.DataType, task *entity.DataSyncTask, srcDialect dbm.DbDialect, targetDialect dbm.DbDialect, targetDbTx *sql.Tx) error { +func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, updFieldType dbm.DataType, task *entity.DataSyncTask, srcDialect dbm.DbDialect, targetDbConn *dbm.DbConn, targetDbTx *sql.Tx) error { var data = make([]map[string]any, 0) // 遍历res,组装插入sql @@ -279,7 +277,7 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ for _, item := range fieldMap { targetField := item["target"] srcField := item["target"] - targetWrapColumns = append(targetWrapColumns, targetDialect.WrapName(targetField)) + targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.WrapName(targetField)) srcColumns = append(srcColumns, srcField) } @@ -294,7 +292,7 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ } // 目标数据库执行sql批量插入 - _, err := targetDialect.BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values) + _, err := targetDbConn.GetDialect().BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values) if err != nil { return err } diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 77ff23b1..ced35eb6 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -2,7 +2,6 @@ package application import ( "context" - "encoding/json" "fmt" "mayfly-go/internal/db/config" "mayfly-go/internal/db/dbm" @@ -11,6 +10,7 @@ import ( "mayfly-go/pkg/contextx" "mayfly-go/pkg/errorx" "mayfly-go/pkg/model" + "mayfly-go/pkg/utils/jsonx" "strconv" "strings" @@ -226,8 +226,7 @@ func doUpdate(ctx context.Context, update *sqlparser.Update, execSqlReq *DbSqlEx selectSql := fmt.Sprintf("SELECT %s FROM %s %s LIMIT 200", updateColumnsAndPrimaryKey, tableStr, where) _, res, err := dbConn.QueryContext(ctx, selectSql) if err == nil { - bytes, _ := json.Marshal(res) - dbSqlExec.OldValue = string(bytes) + dbSqlExec.OldValue = jsonx.ToStr(res) } else { dbSqlExec.OldValue = err.Error() } @@ -253,8 +252,7 @@ func doDelete(ctx context.Context, delete *sqlparser.Delete, execSqlReq *DbSqlEx selectSql := fmt.Sprintf("SELECT * FROM %s %s LIMIT 200", tableStr, where) _, res, _ := dbConn.QueryContext(ctx, selectSql) - bytes, _ := json.Marshal(res) - dbSqlExec.OldValue = string(bytes) + dbSqlExec.OldValue = jsonx.ToStr(res) dbSqlExec.Table = table dbSqlExec.Type = entity.DbSqlExecTypeDelete diff --git a/server/internal/db/dbm/db_type.go b/server/internal/db/dbm/db_type.go index 1aa1cf96..f31d8079 100644 --- a/server/internal/db/dbm/db_type.go +++ b/server/internal/db/dbm/db_type.go @@ -25,6 +25,17 @@ func (dbType DbType) Equal(typ string) bool { return ToDbType(typ) == dbType } +func (dbType DbType) QuoteIdentifier(name string) string { + switch dbType { + case DbTypeMysql, DbTypeMariadb: + return quoteIdentifier(name, "`") + case DbTypePostgres: + return pq.QuoteIdentifier(name) + default: + panic(fmt.Sprintf("invalid database type: %s", dbType)) + } +} + func (dbType DbType) MetaDbName() string { switch dbType { case DbTypeMysql, DbTypeMariadb: @@ -38,14 +49,13 @@ func (dbType DbType) MetaDbName() string { } } -func (dbType DbType) QuoteIdentifier(name string) string { +// 包装字段名,防止使用了数据库保留关键字 +func (dbType DbType) WrapName(name string) string { switch dbType { case DbTypeMysql, DbTypeMariadb: - return quoteIdentifier(name, "`") - case DbTypePostgres: - return pq.QuoteIdentifier(name) + return fmt.Sprintf("`%s`", name) default: - panic(fmt.Sprintf("invalid database type: %s", dbType)) + return fmt.Sprintf(`"%s"`, name) } } diff --git a/server/internal/db/dbm/dialect.go b/server/internal/db/dbm/dialect.go index 657020f9..020c16f9 100644 --- a/server/internal/db/dbm/dialect.go +++ b/server/internal/db/dbm/dialect.go @@ -90,9 +90,6 @@ type DbDialect interface { // GetDbProgram 获取数据库程序模块,用于数据库备份与恢复 GetDbProgram() DbProgram - // 封装名字,如,mysql: `table_name`, dm: "table_name" - WrapName(name string) string - // 批量保存数据 BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) diff --git a/server/internal/db/dbm/dialect_dm.go b/server/internal/db/dbm/dialect_dm.go index eb044fe3..998d66d9 100644 --- a/server/internal/db/dbm/dialect_dm.go +++ b/server/internal/db/dbm/dialect_dm.go @@ -278,10 +278,6 @@ func (dd *DMDialect) GetDbProgram() DbProgram { panic("implement me") } -func (dd *DMDialect) WrapName(name string) string { - return "\"" + name + "\"" -} - func (dd *DMDialect) GetDataType(dbColumnType string) DataType { if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { return DataTypeNumber @@ -311,7 +307,7 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, // 去除最后一个逗号,占位符由括号包裹 placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ",")) - sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.WrapName(tableName), strings.Join(columns, ","), placeholder) + sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), placeholder) effRows := 0 for _, value := range values { // 达梦数据库只能一条条的执行insert diff --git a/server/internal/db/dbm/dialect_mysql.go b/server/internal/db/dbm/dialect_mysql.go index ffaf8200..56a6d2fe 100644 --- a/server/internal/db/dbm/dialect_mysql.go +++ b/server/internal/db/dbm/dialect_mysql.go @@ -202,10 +202,6 @@ func (md *MysqlDialect) GetDbProgram() DbProgram { return NewDbProgramMysql(md.dc) } -func (md *MysqlDialect) WrapName(name string) string { - return "`" + name + "`" -} - func (md *MysqlDialect) GetDataType(dbColumnType string) DataType { if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { return DataTypeNumber @@ -240,7 +236,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri // 去除最后一个逗号 placeholder = strings.TrimSuffix(repeated, ",") - sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.WrapName(tableName), strings.Join(columns, ","), placeholder) + sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), placeholder) // 执行批量insert sql // 把二维数组转为一维数组 var args []any diff --git a/server/internal/db/dbm/dialect_pgsql.go b/server/internal/db/dbm/dialect_pgsql.go index ca8abcc4..8d19197a 100644 --- a/server/internal/db/dbm/dialect_pgsql.go +++ b/server/internal/db/dbm/dialect_pgsql.go @@ -280,10 +280,6 @@ func (pd *PgsqlDialect) GetDbProgram() DbProgram { panic("implement me") } -func (pd *PgsqlDialect) WrapName(name string) string { - return fmt.Sprintf(`"%s"`, name) -} - func (pd *PgsqlDialect) GetDataType(dbColumnType string) DataType { if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) { return DataTypeNumber @@ -323,7 +319,7 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")") } - sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.WrapName(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", ")) + sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", ")) // 执行批量insert sql return pd.dc.TxExec(tx, sqlStr, args...) diff --git a/server/internal/machine/domain/entity/machine_cronjob.go b/server/internal/machine/domain/entity/machine_cronjob.go index c14cd49d..6a028996 100644 --- a/server/internal/machine/domain/entity/machine_cronjob.go +++ b/server/internal/machine/domain/entity/machine_cronjob.go @@ -21,20 +21,10 @@ type MachineCronJob struct { // 计划任务与机器关联信息 type MachineCronJobRelate struct { - model.DeletedModel + model.CreateModel - CronJobId uint64 - MachineId uint64 - Creator string - CreatorId uint64 - CreateTime *time.Time -} - -func (m *MachineCronJobRelate) SetBaseInfo(gt model.IdGenType, la *model.LoginAccount) { - now := time.Now() - m.CreateTime = &now - m.Creator = la.Username - m.CreatorId = la.Id + CronJobId uint64 + MachineId uint64 } // 机器任务执行记录 diff --git a/server/internal/sys/domain/entity/resource.go b/server/internal/sys/domain/entity/resource.go index e1babe2e..66748d0a 100644 --- a/server/internal/sys/domain/entity/resource.go +++ b/server/internal/sys/domain/entity/resource.go @@ -18,9 +18,9 @@ func (a *Resource) TableName() string { return "t_sys_resource" } -func (m *Resource) SetBaseInfo(idGenType model.IdGenType, la *model.LoginAccount) { +func (m *Resource) FillBaseInfo(idGenType model.IdGenType, la *model.LoginAccount) { // id使用时间戳,减少id冲突概率 - m.Model.SetBaseInfo(model.IdGenTypeTimestamp, la) + m.Model.FillBaseInfo(model.IdGenTypeTimestamp, la) } const ( diff --git a/server/pkg/base/repo.go b/server/pkg/base/repo.go index b7b3a9ff..feb6ee51 100644 --- a/server/pkg/base/repo.go +++ b/server/pkg/base/repo.go @@ -199,7 +199,7 @@ func (br *RepoImpl[T]) GetModel() T { // 从上下文获取登录账号信息,并赋值至实体 func (br *RepoImpl[T]) fillBaseInfo(ctx context.Context, e T) T { if la := contextx.GetLoginAccount(ctx); la != nil { - // 默认使用数据库id策略, 若要改变则实体结构体自行覆盖SetBaseInfo方法。可参考 sys/entity.Resource + // 默认使用数据库id策略, 若要改变则实体结构体自行覆盖FillBaseInfo方法。可参考 sys/entity.Resource e.FillBaseInfo(model.IdGenTypeNone, la) } return e diff --git a/server/pkg/model/model.go b/server/pkg/model/model.go index dcce3ca2..9995bd13 100644 --- a/server/pkg/model/model.go +++ b/server/pkg/model/model.go @@ -68,7 +68,7 @@ type CreateModel struct { Creator string `json:"creator"` } -func (m *CreateModel) SetBaseInfo(idGenType IdGenType, account *LoginAccount) { +func (m *CreateModel) FillBaseInfo(idGenType IdGenType, account *LoginAccount) { if !m.IsCreate() { return } @@ -95,7 +95,7 @@ type Model struct { } // 设置基础信息. 如创建时间,修改时间,创建者,修改者信息 -func (m *Model) SetBaseInfo(idGenType IdGenType, account *LoginAccount) { +func (m *Model) FillBaseInfo(idGenType IdGenType, account *LoginAccount) { nowTime := time.Now() isCreate := m.IsCreate() if isCreate {