!106 feat:数据同步支持唯一键冲突策略

* refactor:sql同步
* fix: 表格右键导出菜单换行符修复
* feat:数据同步支持唯一键冲突策略
This commit is contained in:
zongyangleo
2024-03-01 04:03:03 +00:00
committed by Coder慌
parent f93231da61
commit 76475e807e
38 changed files with 1117 additions and 257 deletions

View File

@@ -15,11 +15,12 @@ type DataSyncTaskForm struct {
UpdField string `binding:"required" json:"updField"`
UpdFieldVal string `binding:"required" json:"updFieldVal"`
TargetDbId int64 `binding:"required" json:"targetDbId"`
TargetDbName string `binding:"required" json:"targetDbName"`
TargetTagPath string `binding:"required" json:"targetTagPath"`
TargetTableName string `binding:"required" json:"targetTableName"`
FieldMap string `binding:"required" json:"fieldMap"`
TargetDbId int64 `binding:"required" json:"targetDbId"`
TargetDbName string `binding:"required" json:"targetDbName"`
TargetTagPath string `binding:"required" json:"targetTagPath"`
TargetTableName string `binding:"required" json:"targetTableName"`
FieldMap string `binding:"required" json:"fieldMap"`
DuplicateStrategy int `json:"duplicateStrategy"`
}
type DataSyncTaskStatusForm struct {

View File

@@ -53,6 +53,7 @@ type dataSyncAppImpl struct {
var (
dateTimeReg = regexp.MustCompile(`^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$`)
whereReg = regexp.MustCompile(`(?i)where`)
)
func (app *dataSyncAppImpl) InjectDbDataSyncTaskRepo(repo repository.DataSyncTask) {
@@ -143,11 +144,15 @@ func (app *dataSyncAppImpl) RunCronJob(id uint64) error {
updSql := ""
orderSql := ""
if task.UpdFieldVal != "0" && task.UpdFieldVal != "" && task.UpdField != "" {
srcConn, _ := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
srcConn, err := app.dbApp.GetDbConn(uint64(task.SrcDbId), task.SrcDbName)
if err != nil {
logx.Errorf("数据源连接不可用, %s", err.Error())
return
}
task.UpdFieldVal = strings.Trim(task.UpdFieldVal, " ")
// 把UpdFieldVal尝试转为int如果可以转为int则不添加引号否则添加引号
if _, err := strconv.Atoi(task.UpdFieldVal); err != nil {
if _, err = strconv.Atoi(task.UpdFieldVal); err != nil {
updSql = fmt.Sprintf("and %s > '%s'", task.UpdField, task.UpdFieldVal)
} else {
updSql = fmt.Sprintf("and %s > %s", task.UpdField, task.UpdFieldVal)
@@ -162,8 +167,14 @@ func (app *dataSyncAppImpl) RunCronJob(id uint64) error {
}
orderSql = "order by " + task.UpdField + " asc "
}
// 正则判断DataSql是否以where .*结尾如果是则不添加where 1 = 1
var where = "where 1=1"
if whereReg.MatchString(task.DataSql) {
where = ""
}
// 组装查询sql
sql := fmt.Sprintf("select * from (%s) t where 1 = 1 %s %s", task.DataSql, updSql, orderSql)
sql := fmt.Sprintf("%s %s %s %s", task.DataSql, where, updSql, orderSql)
log, err := app.doDataSync(sql, task)
if err != nil {
@@ -223,6 +234,12 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
result := make([]map[string]any, 0)
var queryColumns []*dbi.QueryColumn
// 如果有数据库别名则从UpdField中去掉数据库别名, 如a.id => id用于获取字段具体名称
updFieldName := task.UpdField
if task.UpdField != "" && strings.Contains(task.UpdField, ".") {
updFieldName = strings.Split(task.UpdField, ".")[1]
}
err = srcConn.WalkQueryRows(context.Background(), sql, func(row map[string]any, columns []*dbi.QueryColumn) error {
if len(queryColumns) == 0 {
queryColumns = columns
@@ -230,7 +247,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
// 遍历columns 取task.UpdField的字段类型
updFieldType = dbi.DataTypeString
for _, column := range columns {
if strings.EqualFold(strings.ToLower(column.Name), strings.ToLower(task.UpdField)) {
if strings.EqualFold(column.Name, updFieldName) {
updFieldType = srcDialect.GetDataConverter().GetDataType(column.Type)
break
}
@@ -240,7 +257,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
total++
result = append(result, row)
if total%batchSize == 0 {
if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil {
if err := app.srcData2TargetDb(result, fieldMap, columns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil {
return err
}
@@ -262,15 +279,19 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
// 处理剩余的数据
if len(result) > 0 {
if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, task, srcDialect, targetConn, targetDbTx); err != nil {
if err := app.srcData2TargetDb(result, fieldMap, queryColumns, updFieldType, updFieldName, task, srcDialect, targetConn, targetDbTx); err != nil {
targetDbTx.Rollback()
return syncLog, err
}
}
// 如果是mssql暂不手动提交事务否则报错 mssql: The COMMIT TRANSACTION request has no corresponding BEGIN TRANSACTION.
if err := targetDbTx.Commit(); err != nil {
return syncLog, errorx.NewBiz("数据同步-目标数据库事务提交失败: %s", err.Error())
if targetConn.Info.Type != dbi.DbTypeMssql {
return syncLog, errorx.NewBiz("数据同步-目标数据库事务提交失败: %s", err.Error())
}
}
logx.Infof("同步任务:[%s],执行完毕,保存记录成功:[%d]条", task.TaskName, total)
// 保存执行成功日志
@@ -282,7 +303,7 @@ func (app *dataSyncAppImpl) doDataSync(sql string, task *entity.DataSyncTask) (*
return syncLog, nil
}
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap []map[string]string, columns []*dbi.QueryColumn, updFieldType dbi.DataType, updFieldName string, task *entity.DataSyncTask, srcDialect dbi.Dialect, targetDbConn *dbi.DbConn, targetDbTx *sql.Tx) error {
// 遍历src字段列表取出字段对应的类型
var srcColumnTypes = make(map[string]string)
@@ -305,9 +326,9 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
data = append(data, rowData)
}
// 解决字段大小写问题
updFieldVal := srcRes[len(srcRes)-1][strings.ToUpper(task.UpdField)]
updFieldVal := srcRes[len(srcRes)-1][strings.ToUpper(updFieldName)]
if updFieldVal == "" || updFieldVal == nil {
updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(task.UpdField)]
updFieldVal = srcRes[len(srcRes)-1][strings.ToLower(updFieldName)]
}
task.UpdFieldVal = srcDialect.GetDataConverter().FormatData(updFieldVal, updFieldType)
@@ -338,7 +359,7 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
}
// 目标数据库执行sql批量插入
_, err := targetDbConn.GetDialect().BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values)
_, err := targetDbConn.GetDialect().BatchInsert(targetDbTx, task.TargetTableName, targetWrapColumns, values, task.DuplicateStrategy)
if err != nil {
return err
}

View File

@@ -61,6 +61,8 @@ func (dbType DbType) RemoveQuote(name string) string {
return removeQuote(name, "`")
case DbTypePostgres, DbTypeGauss, DbTypeKingbaseEs, DbTypeVastbase:
return removeQuote(name, `"`)
case DbTypeMssql:
return strings.Trim(name, "[]")
default:
return removeQuote(name, `"`)
}

View File

@@ -20,6 +20,15 @@ const (
DataTypeDateTime DataType = "datetime"
)
const (
// -1. 无操作
DuplicateStrategyNone = -1
// 1. 忽略
DuplicateStrategyIgnore = 1
// 2. 更新
DuplicateStrategyUpdate = 2
)
// 数据库服务实例信息
type DbServer struct {
Version string `json:"version"` // 版本信息
@@ -111,7 +120,7 @@ type Dialect interface {
GetDbProgram() (DbProgram, error)
// 批量保存数据
BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error)
BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error)
// 获取数据转换器用于解析格式化列数据等
GetDataConverter() DataConverter

View File

@@ -3,7 +3,7 @@ package dbi
import "database/sql"
var (
metas map[DbType]Meta = make(map[DbType]Meta)
metas = make(map[DbType]Meta)
)
// 注册数据库类型与dbmeta

View File

@@ -261,35 +261,114 @@ var (
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
converter = new(DataConverter)
)
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 执行批量insert sql
// insert into "table_name" ("column1", "column2", ...) values (value1, value2, ...)
dbType := dd.dc.Info.Type
// 无需处理重复数据直接执行批量insert
if duplicateStrategy == dbi.DuplicateStrategyNone || duplicateStrategy == 0 {
return dd.batchInsertSimple(dbType, tx, tableName, columns, values)
} else { // 执行MERGE INTO语句
return dd.batchInsertMerge(dbType, tx, tableName, columns, values)
}
}
func (dd *DMDialect) batchInsertSimple(dbType dbi.DbType, tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dbType.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
effRows := 0
for _, value := range values {
// 达梦数据库只能一条条的执行insert
er, err := dd.dc.TxExec(tx, sqlTemp, value...)
res, err := dd.dc.TxExec(tx, sqlTemp, value...)
if err != nil {
logx.Errorf("执行sql失败%s", err.Error())
return int64(effRows), err
logx.Errorf("执行sql失败%s, sql: [ %s ]", err.Error(), sqlTemp)
}
effRows += int(er)
effRows += int(res)
}
// 执行批量insert sql
return int64(effRows), nil
}
func (dd *DMDialect) batchInsertMerge(dbType dbi.DbType, tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 查询主键字段
uniqueCols := make([]string, 0)
caseSqls := make([]string, 0)
tableCols, _ := dd.GetColumns(tableName)
identityCols := make([]string, 0)
for _, col := range tableCols {
if col.IsPrimaryKey {
uniqueCols = append(uniqueCols, col.ColumnName)
caseSqls = append(caseSqls, fmt.Sprintf("( T1.%s = T2.%s )", dbType.QuoteIdentifier(col.ColumnName), dbType.QuoteIdentifier(col.ColumnName)))
}
if col.IsIdentity {
// 自增字段不放入insert内即使是设置了identity_insert on也不起作用
identityCols = append(identityCols, dbType.QuoteIdentifier(col.ColumnName))
}
}
// 查询唯一索引涉及到的字段并组装到match条件内
indexs, _ := dd.GetTableIndex(tableName)
if indexs != nil {
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
tmp := make([]string, 0)
for _, col := range cols {
uniqueCols = append(uniqueCols, col)
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", dbType.QuoteIdentifier(col), dbType.QuoteIdentifier(col)))
}
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
}
}
}
// 重复数据处理策略
phs := make([]string, 0)
insertVals := make([]string, 0)
upds := make([]string, 0)
insertCols := make([]string, 0)
for _, column := range columns {
phs = append(phs, fmt.Sprintf("? %s", column))
if !collx.ArrayContains(uniqueCols, dbType.RemoveQuote(column)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
}
if !collx.ArrayContains(identityCols, column) {
insertCols = append(insertCols, fmt.Sprintf("%s", column))
insertVals = append(insertVals, fmt.Sprintf("T2.%s", column))
}
}
t2s := make([]string, 0)
for i := 0; i < len(values); i++ {
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(phs, ",")))
}
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + dbType.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " OR ")
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ")"
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
return dd.dc.TxExec(tx, sqlTemp, args...)
}
func (dd *DMDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
return converter
}
type DataConverter struct {
@@ -328,6 +407,22 @@ func (dd *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
}
func (dd *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if ok {
if dataType == dbi.DataTypeDateTime {
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
}
return dbColumnValue
}

View File

@@ -135,13 +135,11 @@ func (md *MssqlDialect) GetPrimaryKey(tablename string) (string, error) {
return columns[0].ColumnName, nil
}
// 获取表索引信息
func (md *MssqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
func (md *MssqlDialect) getTableIndexWithPK(tableName string) ([]dbi.Index, error) {
_, res, err := md.dc.Query(dbi.GetLocalSql(MSSQL_META_FILE, MSSQL_INDEX_INFO_KEY), md.currentSchema(), tableName)
if err != nil {
return nil, err
}
indexs := make([]dbi.Index, 0)
for _, re := range res {
indexs = append(indexs, dbi.Index{
@@ -153,18 +151,14 @@ func (md *MssqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
SeqInIndex: anyx.ConvInt(re["seqInIndex"]),
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
// 把查询结果以索引名分组,多个索引字段以逗号连接
result := make([]dbi.Index, 0)
key := ""
for _, v := range indexs {
// 当前的索引名
in := v.IndexName
// 过滤掉主键索引,主键索引名为PK__开头的
if strings.HasPrefix(in, "PK__") {
continue
}
if key == in {
// 索引字段已根据名称和顺序排序,故取最后一个即可
// 索引字段已根据名称和字段顺序排序,故取最后一个即可
i := len(result) - 1
// 同索引字段以逗号连接
result[i].ColumnName = result[i].ColumnName + "," + v.ColumnName
@@ -173,6 +167,20 @@ func (md *MssqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
result = append(result, v)
}
}
return indexs, nil
}
// 获取表索引信息
func (md *MssqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
indexs, _ := md.getTableIndexWithPK(tableName)
result := make([]dbi.Index, 0)
// 过滤掉主键索引,主键索引名为PK__开头的
for _, v := range indexs {
in := v.IndexName
if strings.HasPrefix(in, "PK__") {
continue
}
}
return result, nil
}
@@ -289,8 +297,39 @@ func (md *MssqlDialect) GetDbProgram() (dbi.DbProgram, error) {
return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", md.dc.Info.Type)
}
func (md *MssqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
func (md *MssqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
if duplicateStrategy == dbi.DuplicateStrategyUpdate {
return md.batchInsertMerge(tx, tableName, columns, values, duplicateStrategy)
}
return md.batchInsertSimple(tx, tableName, columns, values, duplicateStrategy)
}
func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
schema := md.currentSchema()
ignoreDupSql := ""
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
// ALTER TABLE dbo.TEST ADD CONSTRAINT uniqueRows UNIQUE (ColA, ColB, ColC, ColD) WITH (IGNORE_DUP_KEY = ON)
indexs, _ := md.getTableIndexWithPK(tableName)
// 收集唯一索引涉及到的字段
uniqueColumns := make([]string, 0)
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
for _, col := range cols {
if !collx.ArrayContains(uniqueColumns, col) {
uniqueColumns = append(uniqueColumns, col)
}
}
}
}
if len(uniqueColumns) > 0 {
// 设置忽略重复键
ignoreDupSql = fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT uniqueRows UNIQUE (%s) WITH (IGNORE_DUP_KEY = {sign})", schema, tableName, strings.Join(uniqueColumns, ","))
_, _ = md.dc.TxExec(tx, strings.ReplaceAll(ignoreDupSql, "{sign}", "ON"))
}
}
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
@@ -312,18 +351,86 @@ func (md *MssqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
for _, v := range values {
args = append(args, v...)
}
// 设置允许填充自增列之后,显示指定列名可以插入自增列
_, _ = md.dc.Exec(fmt.Sprintf("set identity_insert \"%s\" on", tableName))
identityInsertOn := fmt.Sprintf("SET IDENTITY_INSERT [%s].[%s] ON", schema, tableName)
exec, err := md.dc.TxExec(tx, sqlStr, args...)
res, err := md.dc.TxExec(tx, fmt.Sprintf("%s %s", identityInsertOn, sqlStr), args...)
_, _ = md.dc.Exec(fmt.Sprintf("set identity_insert \"%s\" off", tableName))
// 执行完之后,设置忽略重复键
if ignoreDupSql != "" {
_, _ = md.dc.TxExec(tx, strings.ReplaceAll(ignoreDupSql, "{sign}", "OFF"))
}
return res, err
}
return exec, err
func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
schema := md.currentSchema()
dbType := md.dc.Info.Type
// 收集MERGE 语句的 ON 子句条件
caseSqls := make([]string, 0)
pkCols := make([]string, 0)
// 查询取出自增列字段, merge update不能修改自增列
identityCols := make([]string, 0)
cols, err := md.GetColumns(tableName)
for _, col := range cols {
if col.IsIdentity {
identityCols = append(identityCols, col.ColumnName)
}
if col.IsPrimaryKey {
pkCols = append(pkCols, col.ColumnName)
name := dbType.QuoteIdentifier(col.ColumnName)
caseSqls = append(caseSqls, fmt.Sprintf(" T1.%s = T2.%s ", name, name))
}
}
if len(pkCols) == 0 {
return md.batchInsertSimple(tx, tableName, columns, values, duplicateStrategy)
}
// 重复数据处理策略
insertVals := make([]string, 0)
upds := make([]string, 0)
insertCols := make([]string, 0)
// 源数据占位sql
phs := make([]string, 0)
for _, column := range columns {
if !collx.ArrayContains(identityCols, dbType.RemoveQuote(column)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
}
insertCols = append(insertCols, fmt.Sprintf("%s", column))
insertVals = append(insertVals, fmt.Sprintf("T2.%s", column))
phs = append(phs, fmt.Sprintf("? %s", column))
}
// 把二维数组转为一维数组
var args []any
tmp := fmt.Sprintf("select %s", strings.Join(phs, ","))
t2s := make([]string, 0)
for _, v := range values {
args = append(args, v...)
t2s = append(t2s, tmp)
}
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + dbType.QuoteIdentifier(schema) + "." + dbType.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ")
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
// 设置允许填充自增列之后,显示指定列名可以插入自增列
identityInsertOn := fmt.Sprintf("SET IDENTITY_INSERT [%s].[%s] ON", schema, tableName)
// 执行merge sql,必须要以分号结尾
res, err := md.dc.TxExec(tx, fmt.Sprintf("%s %s;", identityInsertOn, sqlTemp), args...)
if err != nil {
logx.Errorf("执行sql失败%s, sql: [ %s ]", err.Error(), sqlTemp)
}
return res, err
}
func (md *MssqlDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
return converter
}
var (
@@ -335,6 +442,8 @@ var (
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
converter = new(DataConverter)
)
type DataConverter struct {
@@ -364,6 +473,20 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
}
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate && ok {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
return dbColumnValue
}
@@ -409,12 +532,10 @@ func (md *MssqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
// 复制数据
// 设置允许填充自增列之后,显示指定列名可以插入自增列
identityInsertOn := ""
identityInsertOff := ""
if hasIdentity {
identityInsertOn = fmt.Sprintf("SET IDENTITY_INSERT [%s].[%s] ON", schema, newTableName)
identityInsertOff = fmt.Sprintf("SET IDENTITY_INSERT [%s].[%s] OFF", schema, newTableName)
}
_, err = md.dc.Exec(fmt.Sprintf(" %s INSERT INTO [%s].[%s] (%s) SELECT * FROM [%s].[%s] %s", identityInsertOn, schema, newTableName, columnsSql, schema, copy.TableName, identityInsertOff))
_, err = md.dc.Exec(fmt.Sprintf(" %s INSERT INTO [%s].[%s] (%s) SELECT * FROM [%s].[%s]", identityInsertOn, schema, newTableName, columnsSql, schema, copy.TableName))
if err != nil {
logx.Warnf("复制表[%s]数据失败: %s", copy.TableName, err.Error())
}

View File

@@ -10,14 +10,14 @@ import (
)
func init() {
meta := new(Meta)
meta := new(MssqlMeta)
dbi.Register(dbi.DbTypeMssql, meta)
}
type Meta struct {
type MssqlMeta struct {
}
func (md *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
func (md *MssqlMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
err := d.IfUseSshTunnelChangeIpPort()
if err != nil {
return nil, err
@@ -52,6 +52,6 @@ func (md *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
return sql.Open(driverName, dsn)
}
func (md *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
func (md *MssqlMeta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
return &MssqlDialect{conn}
}

View File

@@ -45,7 +45,6 @@ func (md *MysqlDialect) GetDbNames() ([]string, error) {
for _, re := range res {
databases = append(databases, anyx.ConvString(re["dbname"]))
}
return databases, nil
}
@@ -173,7 +172,7 @@ func (md *MysqlDialect) GetDbProgram() (dbi.DbProgram, error) {
return NewDbProgramMysql(md.dc), nil
}
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
@@ -188,7 +187,14 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
prefix := "insert into"
if duplicateStrategy == 1 {
prefix = "insert ignore into"
} else if duplicateStrategy == 2 {
prefix = "replace into"
}
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
// 执行批量insert sql
// 把二维数组转为一维数组
var args []any
@@ -199,7 +205,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
}
func (md *MysqlDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
return converter
}
var (
@@ -211,6 +217,8 @@ var (
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
converter = new(DataConverter)
)
type DataConverter struct {
@@ -240,6 +248,22 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
}
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if ok {
if dataType == dbi.DataTypeDateTime {
res, _ := time.Parse(time.DateTime, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
}
return dbColumnValue
}

View File

@@ -5,6 +5,7 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/anyx"
"mayfly-go/pkg/utils/collx"
"regexp"
@@ -252,13 +253,7 @@ func (od *OracleDialect) GetDbProgram() (dbi.DbProgram, error) {
return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", od.dc.Info.Type)
}
func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
//INSERT ALL
//INTO my_table(field_1,field_2) VALUES (value_1,value_2)
//INTO my_table(field_1,field_2) VALUES (value_3,value_4)
//INTO my_table(field_1,field_2) VALUES (value_5,value_6)
//SELECT 1 FROM DUAL;
func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
if len(values) <= 0 {
return 0, nil
}
@@ -269,27 +264,122 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str
args = append(args, v...)
}
// 拼接oracle批量插入语句
sqlArr := make([]string, 0)
sqlArr = append(sqlArr, "INSERT ALL")
if duplicateStrategy == dbi.DuplicateStrategyNone || duplicateStrategy == 0 || duplicateStrategy == dbi.DuplicateStrategyIgnore {
return od.batchInsertSimple(od.dc.Info.Type, tableName, columns, values, duplicateStrategy, tx)
} else {
return od.batchInsertMergeSql(od.dc.Info.Type, tableName, columns, values, args, tx)
}
}
// 简单批量插入sql无需判断键冲突策略
func (od *OracleDialect) batchInsertSimple(dbType dbi.DbType, tableName string, columns []string, values [][]any, duplicateStrategy int, tx *sql.Tx) (int64, error) {
// 忽略键冲突策略
ignore := ""
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
// 查出唯一索引涉及的字段
indexs, _ := od.GetTableIndex(tableName)
if indexs != nil {
arr := make([]string, 0)
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
for _, col := range cols {
if !collx.ArrayContains(arr, col) {
arr = append(arr, col)
}
}
}
}
ignore = fmt.Sprintf("/*+ IGNORE_ROW_ON_DUPKEY_INDEX(%s(%s)) */", tableName, strings.Join(arr, ","))
}
}
effRows := 0
for _, value := range values {
// 拼接带占位符的sql oracle的占位符是:1,:2,:3....
var placeholder []string
for i := 0; i < len(value); i++ {
placeholder = append(placeholder, fmt.Sprintf(":%d", i+1))
}
sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, dbType.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))
// oracle数据库为了兼容ignore主键冲突只能一条条的执行insert
res, err := od.dc.TxExec(tx, sqlTemp, value...)
if err != nil {
logx.Errorf("执行sql失败%s, sql: [ %s ]", err.Error(), sqlTemp)
}
effRows += int(res)
}
return int64(effRows), nil
}
func (od *OracleDialect) batchInsertMergeSql(dbType dbi.DbType, tableName string, columns []string, values [][]any, args []any, tx *sql.Tx) (int64, error) {
// 查询主键字段
uniqueCols := make([]string, 0)
caseSqls := make([]string, 0)
// 查询唯一索引涉及到的字段并组装到match条件内
indexs, _ := od.GetTableIndex(tableName)
if indexs != nil {
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
tmp := make([]string, 0)
for _, col := range cols {
if !collx.ArrayContains(uniqueCols, col) {
uniqueCols = append(uniqueCols, col)
}
tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", dbType.QuoteIdentifier(col), dbType.QuoteIdentifier(col)))
}
caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
}
}
}
// 如果caseSqls为空则说明没有唯一键直接使用简单批量插入
if len(caseSqls) == 0 {
return od.batchInsertSimple(dbType, tableName, columns, values, dbi.DuplicateStrategyNone, tx)
}
// 重复数据处理策略
insertVals := make([]string, 0)
upds := make([]string, 0)
insertCols := make([]string, 0)
for _, column := range columns {
if !collx.ArrayContains(uniqueCols, dbType.RemoveQuote(column)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
}
insertCols = append(insertCols, fmt.Sprintf("T1.%s", column))
insertVals = append(insertVals, fmt.Sprintf("T2.%s", column))
}
// 生成源数据占位sql
t2s := make([]string, 0)
// 拼接带占位符的sql oracle的占位符是:1,:2,:3....
for i := 0; i < len(args); i += len(columns) {
var placeholder []string
for j := 0; j < len(columns); j++ {
placeholder = append(placeholder, fmt.Sprintf(":%d", i+j+1))
col := columns[j]
placeholder = append(placeholder, fmt.Sprintf(":%d %s", i+j+1, col))
}
sqlArr = append(sqlArr, fmt.Sprintf("INTO %s (%s) VALUES (%s)", od.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ",")))
t2s = append(t2s, fmt.Sprintf("SELECT %s FROM dual", strings.Join(placeholder, ", ")))
}
sqlArr = append(sqlArr, "SELECT 1 FROM DUAL")
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + dbType.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") "
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
// 执行批量insert sql
res, err := od.dc.TxExec(tx, strings.Join(sqlArr, " "), args...)
res, err := od.dc.TxExec(tx, sqlTemp, args...)
if err != nil {
logx.Errorf("执行sql失败%s, sql: [ %s ]", err.Error(), sqlTemp)
}
return res, err
}
func (od *OracleDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
return converter
}
var (
@@ -297,6 +387,8 @@ var (
numberTypeRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeTypeRegexp = regexp.MustCompile(`(?i)date|timestamp`)
converter = new(DataConverter)
)
type DataConverter struct {

View File

@@ -50,7 +50,7 @@ func (md *OraMeta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
}
}
}
urlOptions["TIMEOUT"] = "10"
urlOptions["TIMEOUT"] = "1000"
connStr := go_ora.BuildUrl(d.Host, d.Port, d.Sid, d.Username, d.Password, urlOptions)
conn, err := sql.Open(driverName, connStr)
if err != nil {

View File

@@ -25,8 +25,8 @@ type PgsqlDialect struct {
dc *dbi.DbConn
}
func (md *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
_, res, err := md.dc.Query("SELECT version() as server_version")
func (pd *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
_, res, err := pd.dc.Query("SELECT version() as server_version")
if err != nil {
return nil, err
}
@@ -36,8 +36,8 @@ func (md *PgsqlDialect) GetDbServer() (*dbi.DbServer, error) {
return ds, nil
}
func (md *PgsqlDialect) GetDbNames() ([]string, error) {
_, res, err := md.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')")
func (pd *PgsqlDialect) GetDbNames() ([]string, error) {
_, res, err := pd.dc.Query("SELECT datname AS dbname FROM pg_database WHERE datistemplate = false AND has_database_privilege(datname, 'CONNECT')")
if err != nil {
return nil, err
}
@@ -51,8 +51,8 @@ func (md *PgsqlDialect) GetDbNames() ([]string, error) {
}
// 获取表基础元信息, 如表名等
func (md *PgsqlDialect) GetTables() ([]dbi.Table, error) {
_, res, err := md.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
func (pd *PgsqlDialect) GetTables() ([]dbi.Table, error) {
_, res, err := pd.dc.Query(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_INFO_KEY))
if err != nil {
return nil, err
}
@@ -72,13 +72,13 @@ func (md *PgsqlDialect) GetTables() ([]dbi.Table, error) {
}
// 获取列元信息, 如列名等
func (md *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dbType := md.dc.Info.Type
func (pd *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dbType := pd.dc.Info.Type
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dbType.RemoveQuote(val))
}), ",")
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_COLUMN_MA_KEY), tableName))
if err != nil {
return nil, err
}
@@ -100,8 +100,8 @@ func (md *PgsqlDialect) GetColumns(tableNames ...string) ([]dbi.Column, error) {
return columns, nil
}
func (md *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
columns, err := md.GetColumns(tablename)
func (pd *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
columns, err := pd.GetColumns(tablename)
if err != nil {
return "", err
}
@@ -118,8 +118,8 @@ func (md *PgsqlDialect) GetPrimaryKey(tablename string) (string, error) {
}
// 获取表索引信息
func (md *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
func (pd *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
_, res, err := pd.dc.Query(fmt.Sprintf(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_INDEX_INFO_KEY), tableName))
if err != nil {
return nil, err
}
@@ -155,17 +155,14 @@ func (md *PgsqlDialect) GetTableIndex(tableName string) ([]dbi.Index, error) {
}
// 获取建表ddl
func (md *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
_, err := md.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
func (pd *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
_, err := pd.dc.Exec(dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_TABLE_DDL_KEY))
if err != nil {
return "", err
}
_, schemaRes, _ := md.dc.Query("select current_schema() as schema")
schemaName := schemaRes[0]["schema"].(string)
ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", schemaName, tableName)
_, res, err := md.dc.Query(ddlSql)
ddlSql := fmt.Sprintf("select showcreatetable('%s','%s') as sql", pd.currentSchema(), tableName)
_, res, err := pd.dc.Query(ddlSql)
if err != nil {
return "", err
}
@@ -174,9 +171,9 @@ func (md *PgsqlDialect) GetTableDDL(tableName string) (string, error) {
}
// 获取pgsql当前连接的库可访问的schemaNames
func (md *PgsqlDialect) GetSchemas() ([]string, error) {
func (pd *PgsqlDialect) GetSchemas() ([]string, error) {
sql := dbi.GetLocalSql(PGSQL_META_FILE, PGSQL_DB_SCHEMAS)
_, res, err := md.dc.Query(sql)
_, res, err := pd.dc.Query(sql)
if err != nil {
return nil, err
}
@@ -188,11 +185,11 @@ func (md *PgsqlDialect) GetSchemas() ([]string, error) {
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (md *PgsqlDialect) GetDbProgram() (dbi.DbProgram, error) {
return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", md.dc.Info.Type)
func (pd *PgsqlDialect) GetDbProgram() (dbi.DbProgram, error) {
return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", pd.dc.Info.Type)
}
func (md *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 执行批量insert sql跟mysql一样 pg或高斯支持批量insert语法
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
@@ -212,14 +209,99 @@ func (md *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")")
}
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "))
// 根据冲突策略生成后缀
suffix := ""
if pd.dc.Info.Type == dbi.DbTypeGauss {
// 高斯db使用ON DUPLICATE KEY UPDATE 语法参考 https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
suffix = pd.gaussOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
} else {
// pgsql 默认使用 on conflict 语法参考 http://www.postgres.cn/docs/12/sql-insert.html
// vastbase语法参考 https://docs.vastdata.com.cn/zh/docs/VastbaseE100Ver3.0.0/doc/SQL%E8%AF%AD%E6%B3%95/INSERT.html
// kingbase语法参考 https://help.kingbase.com.cn/v8/development/sql-plsql/sql/SQL_Statements_9.html#insert
suffix = pd.pgsqlOnDuplicateStrategySql(duplicateStrategy, tableName, columns)
}
sqlStr := fmt.Sprintf("insert into %s (%s) values %s %s", pd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "), suffix)
// 执行批量insert sql
return md.dc.TxExec(tx, sqlStr, args...)
return pd.dc.TxExec(tx, sqlStr, args...)
}
func (md *PgsqlDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
// pgsql默认唯一键冲突策略
func (pd PgsqlDialect) pgsqlOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string {
suffix := ""
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
suffix = " \n on conflict do nothing"
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
// 生成 on conflict () do update set column1 = excluded.column1, column2 = excluded.column2, ...
var updateColumns []string
for _, col := range columns {
updateColumns = append(updateColumns, fmt.Sprintf("%s = excluded.%s", col, col))
}
// 查询唯一键名,拼接冲突sql
_, keyRes, _ := pd.dc.Query("SELECT constraint_name FROM information_schema.table_constraints WHERE constraint_schema = $1 AND table_name = $2 AND constraint_type in ('PRIMARY KEY', 'UNIQUE') ", pd.currentSchema(), tableName)
if len(keyRes) > 0 {
for _, re := range keyRes {
key := anyx.ToString(re["constraint_name"])
if key != "" {
suffix += fmt.Sprintf(" \n on conflict on constraint %s do update set %s \n", key, strings.Join(updateColumns, ", "))
}
}
}
}
return suffix
}
// 高斯db唯一键冲突策略,使用ON DUPLICATE KEY UPDATE 参考https://support.huaweicloud.com/distributed-devg-v3-gaussdb/gaussdb-12-0607.html#ZH-CN_TOPIC_0000001633948138
func (pd PgsqlDialect) gaussOnDuplicateStrategySql(duplicateStrategy int, tableName string, columns []string) string {
suffix := ""
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
suffix = " \n ON DUPLICATE KEY UPDATE NOTHING"
} else if duplicateStrategy == dbi.DuplicateStrategyUpdate {
// 查出表里的唯一键涉及的字段
var uniqueColumns []string
indexs, err := pd.GetTableIndex(tableName)
if err == nil {
for _, index := range indexs {
if index.IsUnique {
cols := strings.Split(index.ColumnName, ",")
for _, col := range cols {
if !collx.ArrayContains(uniqueColumns, strings.ToLower(col)) {
uniqueColumns = append(uniqueColumns, strings.ToLower(col))
}
}
}
}
}
suffix = " \n ON DUPLICATE KEY UPDATE "
for i, col := range columns {
// ON DUPLICATE KEY UPDATE语句不支持更新唯一键字段所以得去掉
if !collx.ArrayContains(uniqueColumns, pd.dc.Info.Type.RemoveQuote(strings.ToLower(col))) {
suffix += fmt.Sprintf("%s = excluded.%s", col, col)
if i < len(columns)-1 {
suffix += ", "
}
}
}
}
return suffix
}
// 从连接信息中获取数据库和schema信息
func (pd *PgsqlDialect) currentSchema() string {
dbName := pd.dc.Info.Database
schema := ""
arr := strings.Split(dbName, "/")
if len(arr) == 2 {
schema = arr[1]
}
return schema
}
func (pd *PgsqlDialect) GetDataConverter() dbi.DataConverter {
return converter
}
var (
@@ -231,6 +313,8 @@ var (
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
converter = new(DataConverter)
)
type DataConverter struct {
@@ -272,19 +356,33 @@ func (dc *DataConverter) FormatData(dbColumnValue any, dataType dbi.DataType) st
}
func (dc *DataConverter) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
res, _ := time.Parse(time.RFC3339, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate && ok {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
return dbColumnValue
}
func (md *PgsqlDialect) IsGauss() bool {
return strings.Contains(md.dc.Info.Params, "gauss")
func (pd *PgsqlDialect) IsGauss() bool {
return strings.Contains(pd.dc.Info.Params, "gauss")
}
func (md *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
func (pd *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
// 生成新表名,为老表明+_copy_时间戳
newTableName := tableName + "_copy_" + time.Now().Format("20060102150405")
// 执行根据旧表创建新表
_, err := md.dc.Exec(fmt.Sprintf("create table %s (like %s)", newTableName, tableName))
_, err := pd.dc.Exec(fmt.Sprintf("create table %s (like %s)", newTableName, tableName))
if err != nil {
return err
}
@@ -292,12 +390,12 @@ func (md *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
// 复制数据
if copy.CopyData {
go func() {
_, _ = md.dc.Exec(fmt.Sprintf("insert into %s select * from %s", newTableName, tableName))
_, _ = pd.dc.Exec(fmt.Sprintf("insert into %s select * from %s", newTableName, tableName))
}()
}
// 查询旧表的自增字段名 重新设置新表的序列序列器
_, res, err := md.dc.Query(fmt.Sprintf("select column_name from information_schema.columns where table_name = '%s' and column_default like 'nextval%%'", tableName))
_, res, err := pd.dc.Query(fmt.Sprintf("select column_name from information_schema.columns where table_name = '%s' and column_default like 'nextval%%'", tableName))
if err != nil {
return err
}
@@ -307,7 +405,7 @@ func (md *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
if colName != "" {
// 查询自增列当前最大值
_, maxRes, err := md.dc.Query(fmt.Sprintf("select max(%s) max_val from %s", colName, tableName))
_, maxRes, err := pd.dc.Query(fmt.Sprintf("select max(%s) max_val from %s", colName, tableName))
if err != nil {
return err
}
@@ -323,12 +421,12 @@ func (md *PgsqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
newSeqName := fmt.Sprintf("%s_%s_copy_seq", newTableName, colName)
// 创建自增序列,当前最大值为旧表最大值
_, err = md.dc.Exec(fmt.Sprintf("CREATE SEQUENCE %s START %d INCREMENT 1", newSeqName, maxVal))
_, err = pd.dc.Exec(fmt.Sprintf("CREATE SEQUENCE %s START %d INCREMENT 1", newSeqName, maxVal))
if err != nil {
return err
}
// 将新表的自增主键序列与主键列相关联
_, err = md.dc.Exec(fmt.Sprintf("alter table %s alter column %s set default nextval('%s')", newTableName, colName, newSeqName))
_, err = pd.dc.Exec(fmt.Sprintf("alter table %s alter column %s set default nextval('%s')", newTableName, colName, newSeqName))
if err != nil {
return err
}

View File

@@ -20,8 +20,9 @@ func init() {
dbi.Register(dbi.DbTypeKingbaseEs, meta)
dbi.Register(dbi.DbTypeVastbase, meta)
gauss := new(PostgresMeta)
gauss.Param = "dbtype=gauss"
gauss := &PostgresMeta{
Param: "dbtype=gauss",
}
dbi.Register(dbi.DbTypeGauss, gauss)
}

View File

@@ -184,7 +184,7 @@ func (sd *SqliteDialect) GetDbProgram() (dbi.DbProgram, error) {
return nil, fmt.Errorf("该数据库类型不支持数据库备份与恢复: %v", sd.dc.Info.Type)
}
func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 执行批量insert sql跟mysql一样 支持批量insert语法
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
@@ -197,7 +197,14 @@ func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", sd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
prefix := "insert into"
if duplicateStrategy == 1 {
prefix = "insert or ignore into"
} else if duplicateStrategy == 2 {
prefix = "insert or replace into"
}
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, sd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
// 把二维数组转为一维数组
var args []any
@@ -210,7 +217,7 @@ func (sd *SqliteDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str
}
func (sd *SqliteDialect) GetDataConverter() dbi.DataConverter {
return new(DataConverter)
return converter
}
var (
@@ -218,6 +225,8 @@ var (
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit|real`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime`)
converter = new(DataConverter)
)
type DataConverter struct {

View File

@@ -26,11 +26,12 @@ type DataSyncTask struct {
UpdFieldVal string `orm:"column(upd_field_val)" json:"updFieldVal"` // 更新字段当前值
// 目标数据库信息
TargetDbId int64 `orm:"column(target_db_id)" json:"targetDbId"`
TargetDbName string `orm:"column(target_db_name)" json:"targetDbName"`
TargetTagPath string `orm:"column(target_tag_path)" json:"targetTagPath"`
TargetTableName string `orm:"column(target_table_name)" json:"targetTableName"`
FieldMap string `orm:"column(field_map)" json:"fieldMap"` // 字段映射json
TargetDbId int64 `orm:"column(target_db_id)" json:"targetDbId"`
TargetDbName string `orm:"column(target_db_name)" json:"targetDbName"`
TargetTagPath string `orm:"column(target_tag_path)" json:"targetTagPath"`
TargetTableName string `orm:"column(target_table_name)" json:"targetTableName"`
FieldMap string `orm:"column(field_map)" json:"fieldMap"` // 字段映射json
DuplicateStrategy int `orm:"column(duplicate_strategy)" json:"duplicateStrategy"` // 冲突策略 -11忽略2覆盖
}
func (d *DataSyncTask) TableName() string {