!130 fix: 数据迁移、数据同步bug修复

* fix: 数据迁移、数据同步bug修复
This commit is contained in:
zongyangleo
2025-01-17 03:53:15 +00:00
committed by Coder慌
parent 8d24c2a4fa
commit 5a6e9d81a7
13 changed files with 139 additions and 42 deletions

View File

@@ -365,9 +365,13 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
if len(rows) > 0 {
beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName)
writer.WriteString(beforeInsert)
insertSql := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone)
if _, err := writer.WriteString(strings.Join(insertSql, ";\n") + ";\n"); err != nil {
return err
sqls := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone)
for _, sqlStr := range sqls {
_, err := writer.WriteString(sqlStr)
if err != nil {
return err
}
}
}

View File

@@ -299,20 +299,20 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
targetData = append(targetData, data)
}
tragetValues := make([][]any, 0)
targetValues := make([][]any, 0)
for _, item := range targetData {
var values = make([]any, 0)
for _, column := range targetInsertColumns {
values = append(values, item[column.ColumnName])
}
tragetValues = append(tragetValues, values)
targetValues = append(targetValues, values)
}
// 执行插入
targetDialect := targetDbConn.GetDialect()
// 生成目标数据库批量插入sql并执行
sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, tragetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone))
sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, targetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone))
// 开启本批次执行事务
targetDbTx, err := targetDbConn.Begin()

View File

@@ -227,6 +227,7 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe
}).WithCategory(progressCategory))
}
tx, _ := dbConn.Begin()
err := sqlparser.SQLSplit(execReader.Reader, func(sql string) error {
if executedStatements%50 == 0 {
if needSendMsg {
@@ -240,15 +241,17 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe
}
executedStatements++
if _, err := dbConn.Exec(sql); err != nil {
if _, err := dbConn.TxExec(tx, sql); err != nil {
return err
}
return nil
})
if err != nil {
_ = tx.Rollback()
return err
}
_ = tx.Commit()
if needSendMsg {
d.msgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), "execution success").WithClientId(clientId))
}

View File

@@ -300,15 +300,14 @@ func (app *dbTransferAppImpl) transfer2Db(ctx context.Context, taskId uint64, lo
}
}()
tx, err := targetConn.Begin()
if err != nil {
pw.CloseWithError(err)
app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil)
return err
}
tx, _ := targetConn.Begin()
err = sqlparser.SQLSplit(pr, func(stmt string) error {
if _, err := targetConn.TxExecContext(ctx, tx, stmt); err != nil {
if _, err := targetConn.TxExec(tx, stmt); err != nil {
app.EndTransfer(ctx, logId, taskId, fmt.Sprintf("执行sql出错: %s", stmt), err, nil)
pw.CloseWithError(err)
return err

View File

@@ -59,7 +59,7 @@ func (c *Column) GetColumnType() string {
}
}
return string(c.DataType)
return c.DataType
}
// 数据库对应的数据类型

View File

@@ -45,38 +45,38 @@ type DbInfo struct {
}
// 获取记录日志的描述
func (d *DbInfo) GetLogDesc() string {
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.CodePath, d.Name, d.Host, d.Port, d.Database)
func (di *DbInfo) GetLogDesc() string {
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", di.Id, di.CodePath, di.Name, di.Host, di.Port, di.Database)
}
// 连接数据库
func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) {
func (di *DbInfo) Conn(meta Meta) (*DbConn, error) {
if meta == nil {
return nil, errorx.NewBiz("the database meta information interface cannot be empty")
}
// 赋值Meta方便后续获取dialect等
dbInfo.Meta = meta
database := dbInfo.Database
di.Meta = meta
database := di.Database
// 如果数据库为空,则使用默认数据库进行连接
if database == "" {
database = meta.GetMetadata(&DbConn{Info: dbInfo}).GetDefaultDb()
dbInfo.Database = database
database = meta.GetMetadata(&DbConn{Info: di}).GetDefaultDb()
di.Database = database
}
conn, err := meta.GetSqlDb(dbInfo)
conn, err := meta.GetSqlDb(di)
if err != nil {
logx.Errorf("db connection failed: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error())
logx.Errorf("db connection failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error())
return nil, errorx.NewBiz("db connection failed: %s", err.Error())
}
err = conn.Ping()
if err != nil {
logx.Errorf("db ping failed: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error())
logx.Errorf("db ping failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error())
return nil, errorx.NewBiz("db connection failed: %s", err.Error())
}
dbc := &DbConn{Id: GetDbConnId(dbInfo.Id, database), Info: dbInfo}
dbc := &DbConn{Id: GetDbConnId(di.Id, database), Info: di}
// 最大连接周期超过时间的连接就close
// conn.SetConnMaxLifetime(100 * time.Second)
@@ -85,7 +85,7 @@ func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) {
// 设置闲置连接数
conn.SetMaxIdleConns(1)
dbc.db = conn
logx.Infof("db connection: %s:%d/%s", dbInfo.Host, dbInfo.Port, database)
logx.Infof("db connection: %s:%d/%s", di.Host, di.Port, database)
return dbc, nil
}

View File

@@ -2,15 +2,17 @@ package dm
import (
"encoding/hex"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"reflect"
"strings"
"gitee.com/chunanyong/dm"
)
var (
CHAR = dbi.NewDbDataType("CHAR", dbi.DTString).WithCT(dbi.CTChar)
CHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar)
VARCHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar)
TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText)
LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText)
@@ -36,6 +38,7 @@ var (
TIME = dbi.NewDbDataType("TIME", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength)
DATE = dbi.NewDbDataType("DATE", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength)
DATETIME = dbi.NewDbDataType("DATETIME", dbi.DTDateTime).WithCT(dbi.CTDateTime).WithFixColumn(dbi.ClearCharMaxLength)
TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", dbi.DTDateTime).WithCT(dbi.CTTimestamp).WithFixColumn(dbi.ClearCharMaxLength)
ST_CURVE = dbi.NewDbDataType("ST_CURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条曲线,可以是圆弧、抛物线等
@@ -50,6 +53,8 @@ var (
ST_POINT = dbi.NewDbDataType("ST_POINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个点
ST_POLYGON = dbi.NewDbDataType("ST_POLYGON", DTDmStruct).WithCT(dbi.CTVarchar) //表示一个多边形
ST_SURFACE = dbi.NewDbDataType("ST_SURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个表面
TABLES = dbi.NewDbDataType("TABLES", DTDmArray).WithCT(dbi.CTVarchar) // 表示一个数组
)
var DTDmStruct = &dbi.DataType{
@@ -104,3 +109,60 @@ func ParseDmStruct(dmStruct *dm.DmStruct) string {
arr = append(arr, ")")
return strings.Join(arr, "")
}
var DTDmArray = &dbi.DataType{
Name: "dm_struct",
Valuer: func() dbi.Valuer {
return &dmArrayValuer{
DefaultValuer: new(dbi.DefaultValuer[dm.DmArray]),
}
},
SQLValue: dbi.SQLValueString,
}
type dmArrayValuer struct {
*dbi.DefaultValuer[dm.DmArray]
}
func (s *dmArrayValuer) Value() any {
if !s.ValuePtr.Valid {
return ""
}
return ParseDmArray(s.ValuePtr)
}
func ParseDmArray(dmArray *dm.DmArray) string {
if !dmArray.Valid {
return ""
}
name, err := dmArray.GetBaseTypeName()
if err != nil {
return err.Error()
}
arr, err := dmArray.GetArray()
if err != nil {
return err.Error()
}
// 获取变量的类型和值
t := reflect.TypeOf(arr)
v := reflect.ValueOf(arr)
// 检查类型是否为数组
if t.Kind() != reflect.Array && t.Kind() != reflect.Slice {
return fmt.Sprintf("%s(%s)", name, anyx.ToString(arr))
}
// 获取数组的长度
length := v.Len()
elements := make([]string, length)
// 遍历数组并将每个元素转换为字符串
for i := 0; i < length; i++ {
element := v.Index(i).Interface()
elements[i] = fmt.Sprintf("%v", element)
}
return fmt.Sprintf("%s(%s)", name, strings.Join(elements, ","))
}

View File

@@ -59,9 +59,10 @@ func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray[*dbi.DbDataType](CHAR, VARCHAR, TEXT, LONG, LONGVARCHAR, IMAGE, LONGVARBINARY, CLOB,
BLOB,
NUMERIC, DECIMAL, NUMBER, INTEGER, INT, BIGINT, TINYINT, BYTE, SMALLINT, BIT, DOUBLE, FLOAT,
TIME, DATE, TIMESTAMP,
TIME, DATE, TIMESTAMP, DATETIME,
ST_CURVE, ST_LINESTRING, ST_GEOMCOLLECTION, ST_GEOMETRY, ST_MULTICURVE, ST_MULTILINESTRING,
ST_MULTIPOINT, ST_MULTIPOLYGON, ST_MULTISURFACE, ST_POINT, ST_POLYGON, ST_SURFACE,
TABLES,
)
}

View File

@@ -88,17 +88,29 @@ func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values
quote := quoter.Quote
if duplicateStrategy == dbi.DuplicateStrategyNone {
identityInsert := ""
var res []string
var hasIdentity = false
identityInsertOn := ""
identityInsertOff := ""
// 有自增列的才加上这个语句
if collx.AnyMatch(columns, func(column dbi.Column) bool { return column.IsIdentity }) {
identityInsert = fmt.Sprintf("set identity_insert %s on;", quote(tableName))
identityInsertOn = fmt.Sprintf("set identity_insert %s on;", quote(tableName))
hasIdentity = true
res = append(res, identityInsertOn)
}
// 达梦数据库只能一条条的执行insert语句所以这里需要将values拆分成多条insert语句
return collx.ArrayMap(values, func(value []any) string {
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, values)
return fmt.Sprintf("%s insert into %s %s values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
sqls := collx.ArrayMap(values, func(value []any) string {
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, [][]any{value})
return fmt.Sprintf("insert into %s %s values %s ;", quote(tableName), columnStr, valuesStrs[0])
})
res = append(res, sqls...)
if hasIdentity {
res = append(res, identityInsertOff)
}
return res
}
// 查询主键字段

View File

@@ -25,18 +25,23 @@ func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return BIT
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return TINYINT
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return SMALLINT
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INTEGER
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return BIGINT
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
@@ -48,28 +53,36 @@ func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return BIGINT
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INT
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INT
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INT
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return DATE
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return TIME
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return TIMESTAMP
clearLength(col)
return DATETIME
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return TIMESTAMP
}
@@ -95,3 +108,8 @@ func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return VARCHAR
}
func clearLength(col *dbi.Column) {
col.CharMaxLength = 0
col.NumPrecision = 0
}

View File

@@ -67,8 +67,8 @@ func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []stri
colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart)
}
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING %s"
sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","), index.IndexType)
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE"
sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","))
comment := dbi.QuoteEscape(index.IndexComment)
if comment != "" {
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
@@ -130,6 +130,8 @@ func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column)
}
}
if mark {
// 去掉单引号
column.ColumnDefault = strings.Trim(column.ColumnDefault, "'")
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)

View File

@@ -85,10 +85,6 @@ func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (expo
return "", 0, err
}
if err != nil {
return "", 0, err
}
localHost := "0.0.0.0"
localAddr := fmt.Sprintf("%s:%d", localHost, localPort)
listener, err := net.Listen("tcp", localAddr)
@@ -223,8 +219,8 @@ func (r *Tunnel) Open(sshClient *ssh.Client) {
r.remoteConnections = append(r.remoteConnections, remoteConn)
logx.Debugf("隧道 %v 连接远程主机成功", r.id)
go copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn)
go r.copyConn(localConn, remoteConn)
go r.copyConn(remoteConn, localConn)
logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
}
@@ -243,6 +239,6 @@ func (r *Tunnel) Close() {
logx.Debugf("隧道 %s 监听器关闭", r.id)
}
func copyConn(writer, reader net.Conn) {
func (r *Tunnel) copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader)
}