diff --git a/frontend/src/components/iconSelector/index.vue b/frontend/src/components/iconSelector/index.vue index 8593e7b3..6864bdc1 100644 --- a/frontend/src/components/iconSelector/index.vue +++ b/frontend/src/components/iconSelector/index.vue @@ -199,7 +199,7 @@ const onClearFontIcon = () => { // 获取 input 的宽度 const getInputWidth = () => { nextTick(() => { - state.fontIconWidth = inputWidthRef.value.$el.offsetWidth; + state.fontIconWidth = inputWidthRef.value?.$el.offsetWidth; }); }; // 监听页面宽度改变 diff --git a/server/internal/db/application/db.go b/server/internal/db/application/db.go index 71f140fa..33309833 100644 --- a/server/internal/db/application/db.go +++ b/server/internal/db/application/db.go @@ -365,9 +365,13 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error { if len(rows) > 0 { beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName) writer.WriteString(beforeInsert) - insertSql := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone) - if _, err := writer.WriteString(strings.Join(insertSql, ";\n") + ";\n"); err != nil { - return err + sqls := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone) + + for _, sqlStr := range sqls { + _, err := writer.WriteString(sqlStr) + if err != nil { + return err + } } } diff --git a/server/internal/db/application/db_data_sync.go b/server/internal/db/application/db_data_sync.go index a32fbbea..789ba265 100644 --- a/server/internal/db/application/db_data_sync.go +++ b/server/internal/db/application/db_data_sync.go @@ -299,20 +299,20 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [ targetData = append(targetData, data) } - tragetValues := make([][]any, 0) + targetValues := make([][]any, 0) for _, item := range targetData { var values = make([]any, 0) for _, column := range targetInsertColumns { values = append(values, item[column.ColumnName]) } - tragetValues = append(tragetValues, values) + targetValues = append(targetValues, values) } // 执行插入 targetDialect := targetDbConn.GetDialect() // 生成目标数据库批量插入sql,并执行 - sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, tragetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone)) + sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, targetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone)) // 开启本批次执行事务 targetDbTx, err := targetDbConn.Begin() diff --git a/server/internal/db/application/db_sql_exec.go b/server/internal/db/application/db_sql_exec.go index 3b50e2e1..62e14884 100644 --- a/server/internal/db/application/db_sql_exec.go +++ b/server/internal/db/application/db_sql_exec.go @@ -227,6 +227,7 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe }).WithCategory(progressCategory)) } + tx, _ := dbConn.Begin() err := sqlparser.SQLSplit(execReader.Reader, func(sql string) error { if executedStatements%50 == 0 { if needSendMsg { @@ -240,15 +241,17 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe } executedStatements++ - if _, err := dbConn.Exec(sql); err != nil { + if _, err := dbConn.TxExec(tx, sql); err != nil { return err } return nil }) - if err != nil { + _ = tx.Rollback() return err } + _ = tx.Commit() + if needSendMsg { d.msgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), "execution success").WithClientId(clientId)) } diff --git a/server/internal/db/application/db_transfer.go b/server/internal/db/application/db_transfer.go index 1541c714..9af648d2 100644 --- a/server/internal/db/application/db_transfer.go +++ b/server/internal/db/application/db_transfer.go @@ -300,15 +300,14 @@ func (app *dbTransferAppImpl) transfer2Db(ctx context.Context, taskId uint64, lo } }() - tx, err := targetConn.Begin() if err != nil { pw.CloseWithError(err) app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil) return err } - + tx, _ := targetConn.Begin() err = sqlparser.SQLSplit(pr, func(stmt string) error { - if _, err := targetConn.TxExecContext(ctx, tx, stmt); err != nil { + if _, err := targetConn.TxExec(tx, stmt); err != nil { app.EndTransfer(ctx, logId, taskId, fmt.Sprintf("执行sql出错: %s", stmt), err, nil) pw.CloseWithError(err) return err diff --git a/server/internal/db/dbm/dbi/column.go b/server/internal/db/dbm/dbi/column.go index 419c657c..469ec6a4 100644 --- a/server/internal/db/dbm/dbi/column.go +++ b/server/internal/db/dbm/dbi/column.go @@ -59,7 +59,7 @@ func (c *Column) GetColumnType() string { } } - return string(c.DataType) + return c.DataType } // 数据库对应的数据类型 diff --git a/server/internal/db/dbm/dbi/db_info.go b/server/internal/db/dbm/dbi/db_info.go index 88a2f501..5733ec5c 100644 --- a/server/internal/db/dbm/dbi/db_info.go +++ b/server/internal/db/dbm/dbi/db_info.go @@ -45,38 +45,38 @@ type DbInfo struct { } // 获取记录日志的描述 -func (d *DbInfo) GetLogDesc() string { - return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.CodePath, d.Name, d.Host, d.Port, d.Database) +func (di *DbInfo) GetLogDesc() string { + return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", di.Id, di.CodePath, di.Name, di.Host, di.Port, di.Database) } // 连接数据库 -func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) { +func (di *DbInfo) Conn(meta Meta) (*DbConn, error) { if meta == nil { return nil, errorx.NewBiz("the database meta information interface cannot be empty") } // 赋值Meta,方便后续获取dialect等 - dbInfo.Meta = meta - database := dbInfo.Database + di.Meta = meta + database := di.Database // 如果数据库为空,则使用默认数据库进行连接 if database == "" { - database = meta.GetMetadata(&DbConn{Info: dbInfo}).GetDefaultDb() - dbInfo.Database = database + database = meta.GetMetadata(&DbConn{Info: di}).GetDefaultDb() + di.Database = database } - conn, err := meta.GetSqlDb(dbInfo) + conn, err := meta.GetSqlDb(di) if err != nil { - logx.Errorf("db connection failed: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error()) + logx.Errorf("db connection failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error()) return nil, errorx.NewBiz("db connection failed: %s", err.Error()) } err = conn.Ping() if err != nil { - logx.Errorf("db ping failed: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error()) + logx.Errorf("db ping failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error()) return nil, errorx.NewBiz("db connection failed: %s", err.Error()) } - dbc := &DbConn{Id: GetDbConnId(dbInfo.Id, database), Info: dbInfo} + dbc := &DbConn{Id: GetDbConnId(di.Id, database), Info: di} // 最大连接周期,超过时间的连接就close // conn.SetConnMaxLifetime(100 * time.Second) @@ -85,7 +85,7 @@ func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) { // 设置闲置连接数 conn.SetMaxIdleConns(1) dbc.db = conn - logx.Infof("db connection: %s:%d/%s", dbInfo.Host, dbInfo.Port, database) + logx.Infof("db connection: %s:%d/%s", di.Host, di.Port, database) return dbc, nil } diff --git a/server/internal/db/dbm/dm/column.go b/server/internal/db/dbm/dm/column.go index 38e6ba9b..6c39535d 100644 --- a/server/internal/db/dbm/dm/column.go +++ b/server/internal/db/dbm/dm/column.go @@ -2,15 +2,17 @@ package dm import ( "encoding/hex" + "fmt" "mayfly-go/internal/db/dbm/dbi" "mayfly-go/pkg/utils/anyx" + "reflect" "strings" "gitee.com/chunanyong/dm" ) var ( - CHAR = dbi.NewDbDataType("CHAR", dbi.DTString).WithCT(dbi.CTChar) + CHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar) VARCHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar) TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText) LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText) @@ -36,6 +38,7 @@ var ( TIME = dbi.NewDbDataType("TIME", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength) DATE = dbi.NewDbDataType("DATE", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength) + DATETIME = dbi.NewDbDataType("DATETIME", dbi.DTDateTime).WithCT(dbi.CTDateTime).WithFixColumn(dbi.ClearCharMaxLength) TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", dbi.DTDateTime).WithCT(dbi.CTTimestamp).WithFixColumn(dbi.ClearCharMaxLength) ST_CURVE = dbi.NewDbDataType("ST_CURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条曲线,可以是圆弧、抛物线等 @@ -50,6 +53,8 @@ var ( ST_POINT = dbi.NewDbDataType("ST_POINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个点 ST_POLYGON = dbi.NewDbDataType("ST_POLYGON", DTDmStruct).WithCT(dbi.CTVarchar) //表示一个多边形 ST_SURFACE = dbi.NewDbDataType("ST_SURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个表面 + + TABLES = dbi.NewDbDataType("TABLES", DTDmArray).WithCT(dbi.CTVarchar) // 表示一个数组 ) var DTDmStruct = &dbi.DataType{ @@ -104,3 +109,60 @@ func ParseDmStruct(dmStruct *dm.DmStruct) string { arr = append(arr, ")") return strings.Join(arr, "") } + +var DTDmArray = &dbi.DataType{ + Name: "dm_struct", + Valuer: func() dbi.Valuer { + return &dmArrayValuer{ + DefaultValuer: new(dbi.DefaultValuer[dm.DmArray]), + } + }, + SQLValue: dbi.SQLValueString, +} + +type dmArrayValuer struct { + *dbi.DefaultValuer[dm.DmArray] +} + +func (s *dmArrayValuer) Value() any { + if !s.ValuePtr.Valid { + return "" + } + return ParseDmArray(s.ValuePtr) +} + +func ParseDmArray(dmArray *dm.DmArray) string { + if !dmArray.Valid { + return "" + } + + name, err := dmArray.GetBaseTypeName() + if err != nil { + return err.Error() + } + + arr, err := dmArray.GetArray() + if err != nil { + return err.Error() + } + + // 获取变量的类型和值 + t := reflect.TypeOf(arr) + v := reflect.ValueOf(arr) + + // 检查类型是否为数组 + if t.Kind() != reflect.Array && t.Kind() != reflect.Slice { + return fmt.Sprintf("%s(%s)", name, anyx.ToString(arr)) + } + // 获取数组的长度 + length := v.Len() + elements := make([]string, length) + + // 遍历数组并将每个元素转换为字符串 + for i := 0; i < length; i++ { + element := v.Index(i).Interface() + elements[i] = fmt.Sprintf("%v", element) + } + + return fmt.Sprintf("%s(%s)", name, strings.Join(elements, ",")) +} diff --git a/server/internal/db/dbm/dm/meta.go b/server/internal/db/dbm/dm/meta.go index 3b63401a..0442715f 100644 --- a/server/internal/db/dbm/dm/meta.go +++ b/server/internal/db/dbm/dm/meta.go @@ -59,9 +59,10 @@ func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType { return collx.AsArray[*dbi.DbDataType](CHAR, VARCHAR, TEXT, LONG, LONGVARCHAR, IMAGE, LONGVARBINARY, CLOB, BLOB, NUMERIC, DECIMAL, NUMBER, INTEGER, INT, BIGINT, TINYINT, BYTE, SMALLINT, BIT, DOUBLE, FLOAT, - TIME, DATE, TIMESTAMP, + TIME, DATE, TIMESTAMP, DATETIME, ST_CURVE, ST_LINESTRING, ST_GEOMCOLLECTION, ST_GEOMETRY, ST_MULTICURVE, ST_MULTILINESTRING, ST_MULTIPOINT, ST_MULTIPOLYGON, ST_MULTISURFACE, ST_POINT, ST_POLYGON, ST_SURFACE, + TABLES, ) } diff --git a/server/internal/db/dbm/dm/sqlgen.go b/server/internal/db/dbm/dm/sqlgen.go index 02714a30..e8457d2a 100644 --- a/server/internal/db/dbm/dm/sqlgen.go +++ b/server/internal/db/dbm/dm/sqlgen.go @@ -88,17 +88,29 @@ func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values quote := quoter.Quote if duplicateStrategy == dbi.DuplicateStrategyNone { - identityInsert := "" + var res []string + var hasIdentity = false + identityInsertOn := "" + identityInsertOff := "" // 有自增列的才加上这个语句 if collx.AnyMatch(columns, func(column dbi.Column) bool { return column.IsIdentity }) { - identityInsert = fmt.Sprintf("set identity_insert %s on;", quote(tableName)) + identityInsertOn = fmt.Sprintf("set identity_insert %s on;", quote(tableName)) + hasIdentity = true + res = append(res, identityInsertOn) } // 达梦数据库只能一条条的执行insert语句,所以这里需要将values拆分成多条insert语句 - return collx.ArrayMap(values, func(value []any) string { - columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, values) - return fmt.Sprintf("%s insert into %s %s values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")) + sqls := collx.ArrayMap(values, func(value []any) string { + columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, [][]any{value}) + return fmt.Sprintf("insert into %s %s values %s ;", quote(tableName), columnStr, valuesStrs[0]) }) + + res = append(res, sqls...) + + if hasIdentity { + res = append(res, identityInsertOff) + } + return res } // 查询主键字段 diff --git a/server/internal/db/dbm/dm/transfer.go b/server/internal/db/dbm/dm/transfer.go index f780bb99..954d0e7d 100644 --- a/server/internal/db/dbm/dm/transfer.go +++ b/server/internal/db/dbm/dm/transfer.go @@ -25,18 +25,23 @@ func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType { } func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return BIT } func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return TINYINT } func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return SMALLINT } func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return INTEGER } func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return BIGINT } func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType { @@ -48,28 +53,36 @@ func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType { } func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return BIGINT } func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return INT } func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return INT } func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return INT } func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return DATE } func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return TIME } func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType { - return TIMESTAMP + clearLength(col) + return DATETIME } func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType { + clearLength(col) return TIMESTAMP } @@ -95,3 +108,8 @@ func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType { return VARCHAR } + +func clearLength(col *dbi.Column) { + col.CharMaxLength = 0 + col.NumPrecision = 0 +} diff --git a/server/internal/db/dbm/mysql/sqlgen.go b/server/internal/db/dbm/mysql/sqlgen.go index bf7d674f..bbef7743 100644 --- a/server/internal/db/dbm/mysql/sqlgen.go +++ b/server/internal/db/dbm/mysql/sqlgen.go @@ -67,8 +67,8 @@ func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []stri colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart) } - sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING %s" - sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","), index.IndexType) + sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE" + sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ",")) comment := dbi.QuoteEscape(index.IndexComment) if comment != "" { sqlStr += fmt.Sprintf(" COMMENT '%s'", comment) @@ -130,6 +130,8 @@ func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) } } if mark { + // 去掉单引号 + column.ColumnDefault = strings.Trim(column.ColumnDefault, "'") defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) } else { defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) diff --git a/server/internal/machine/mcm/sshtunnel.go b/server/internal/machine/mcm/sshtunnel.go index bebe5bc0..3a232900 100644 --- a/server/internal/machine/mcm/sshtunnel.go +++ b/server/internal/machine/mcm/sshtunnel.go @@ -85,10 +85,6 @@ func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (expo return "", 0, err } - if err != nil { - return "", 0, err - } - localHost := "0.0.0.0" localAddr := fmt.Sprintf("%s:%d", localHost, localPort) listener, err := net.Listen("tcp", localAddr) @@ -223,8 +219,8 @@ func (r *Tunnel) Open(sshClient *ssh.Client) { r.remoteConnections = append(r.remoteConnections, remoteConn) logx.Debugf("隧道 %v 连接远程主机成功", r.id) - go copyConn(localConn, remoteConn) - go copyConn(remoteConn, localConn) + go r.copyConn(localConn, remoteConn) + go r.copyConn(remoteConn, localConn) logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr) logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~") } @@ -243,6 +239,6 @@ func (r *Tunnel) Close() { logx.Debugf("隧道 %s 监听器关闭", r.id) } -func copyConn(writer, reader net.Conn) { +func (r *Tunnel) copyConn(writer, reader net.Conn) { _, _ = io.Copy(writer, reader) }