!130 fix: 数据迁移、数据同步bug修复

* fix: 数据迁移、数据同步bug修复
This commit is contained in:
zongyangleo
2025-01-17 03:53:15 +00:00
committed by Coder慌
parent 8d24c2a4fa
commit 5a6e9d81a7
13 changed files with 139 additions and 42 deletions

View File

@@ -199,7 +199,7 @@ const onClearFontIcon = () => {
// 获取 input 的宽度 // 获取 input 的宽度
const getInputWidth = () => { const getInputWidth = () => {
nextTick(() => { nextTick(() => {
state.fontIconWidth = inputWidthRef.value.$el.offsetWidth; state.fontIconWidth = inputWidthRef.value?.$el.offsetWidth;
}); });
}; };
// 监听页面宽度改变 // 监听页面宽度改变

View File

@@ -365,11 +365,15 @@ func (d *dbAppImpl) DumpDb(ctx context.Context, reqParam *dto.DumpDb) error {
if len(rows) > 0 { if len(rows) > 0 {
beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName) beforeInsert := dumpHelper.BeforeInsertSql(quoteSchema, quoteTableName)
writer.WriteString(beforeInsert) writer.WriteString(beforeInsert)
insertSql := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone) sqls := targetSqlGenerator.GenInsert(tableName, columns, rows, dbi.DuplicateStrategyNone)
if _, err := writer.WriteString(strings.Join(insertSql, ";\n") + ";\n"); err != nil {
for _, sqlStr := range sqls {
_, err := writer.WriteString(sqlStr)
if err != nil {
return err return err
} }
} }
}
dumpHelper.AfterInsert(writer, tableName, columns) dumpHelper.AfterInsert(writer, tableName, columns)
progress(tableName, dbi.StmtTypeInsert, dataCount, true) progress(tableName, dbi.StmtTypeInsert, dataCount, true)

View File

@@ -299,20 +299,20 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
targetData = append(targetData, data) targetData = append(targetData, data)
} }
tragetValues := make([][]any, 0) targetValues := make([][]any, 0)
for _, item := range targetData { for _, item := range targetData {
var values = make([]any, 0) var values = make([]any, 0)
for _, column := range targetInsertColumns { for _, column := range targetInsertColumns {
values = append(values, item[column.ColumnName]) values = append(values, item[column.ColumnName])
} }
tragetValues = append(tragetValues, values) targetValues = append(targetValues, values)
} }
// 执行插入 // 执行插入
targetDialect := targetDbConn.GetDialect() targetDialect := targetDbConn.GetDialect()
// 生成目标数据库批量插入sql并执行 // 生成目标数据库批量插入sql并执行
sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, tragetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone)) sqls := targetDialect.GetSQLGenerator().GenInsert(task.TargetTableName, targetInsertColumns, targetValues, cmp.Or(task.DuplicateStrategy, dbi.DuplicateStrategyNone))
// 开启本批次执行事务 // 开启本批次执行事务
targetDbTx, err := targetDbConn.Begin() targetDbTx, err := targetDbConn.Begin()

View File

@@ -227,6 +227,7 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe
}).WithCategory(progressCategory)) }).WithCategory(progressCategory))
} }
tx, _ := dbConn.Begin()
err := sqlparser.SQLSplit(execReader.Reader, func(sql string) error { err := sqlparser.SQLSplit(execReader.Reader, func(sql string) error {
if executedStatements%50 == 0 { if executedStatements%50 == 0 {
if needSendMsg { if needSendMsg {
@@ -240,15 +241,17 @@ func (d *dbSqlExecAppImpl) ExecReader(ctx context.Context, execReader *dto.SqlRe
} }
executedStatements++ executedStatements++
if _, err := dbConn.Exec(sql); err != nil { if _, err := dbConn.TxExec(tx, sql); err != nil {
return err return err
} }
return nil return nil
}) })
if err != nil { if err != nil {
_ = tx.Rollback()
return err return err
} }
_ = tx.Commit()
if needSendMsg { if needSendMsg {
d.msgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), "execution success").WithClientId(clientId)) d.msgApp.CreateAndSend(la, msgdto.SuccessSysMsg(i18n.T(imsg.SqlScriptRunSuccess), "execution success").WithClientId(clientId))
} }

View File

@@ -300,15 +300,14 @@ func (app *dbTransferAppImpl) transfer2Db(ctx context.Context, taskId uint64, lo
} }
}() }()
tx, err := targetConn.Begin()
if err != nil { if err != nil {
pw.CloseWithError(err) pw.CloseWithError(err)
app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil) app.EndTransfer(ctx, logId, taskId, "transfer table failed", err, nil)
return err return err
} }
tx, _ := targetConn.Begin()
err = sqlparser.SQLSplit(pr, func(stmt string) error { err = sqlparser.SQLSplit(pr, func(stmt string) error {
if _, err := targetConn.TxExecContext(ctx, tx, stmt); err != nil { if _, err := targetConn.TxExec(tx, stmt); err != nil {
app.EndTransfer(ctx, logId, taskId, fmt.Sprintf("执行sql出错: %s", stmt), err, nil) app.EndTransfer(ctx, logId, taskId, fmt.Sprintf("执行sql出错: %s", stmt), err, nil)
pw.CloseWithError(err) pw.CloseWithError(err)
return err return err

View File

@@ -59,7 +59,7 @@ func (c *Column) GetColumnType() string {
} }
} }
return string(c.DataType) return c.DataType
} }
// 数据库对应的数据类型 // 数据库对应的数据类型

View File

@@ -45,38 +45,38 @@ type DbInfo struct {
} }
// 获取记录日志的描述 // 获取记录日志的描述
func (d *DbInfo) GetLogDesc() string { func (di *DbInfo) GetLogDesc() string {
return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", d.Id, d.CodePath, d.Name, d.Host, d.Port, d.Database) return fmt.Sprintf("DB[id=%d, tag=%s, name=%s, ip=%s:%d, database=%s]", di.Id, di.CodePath, di.Name, di.Host, di.Port, di.Database)
} }
// 连接数据库 // 连接数据库
func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) { func (di *DbInfo) Conn(meta Meta) (*DbConn, error) {
if meta == nil { if meta == nil {
return nil, errorx.NewBiz("the database meta information interface cannot be empty") return nil, errorx.NewBiz("the database meta information interface cannot be empty")
} }
// 赋值Meta方便后续获取dialect等 // 赋值Meta方便后续获取dialect等
dbInfo.Meta = meta di.Meta = meta
database := dbInfo.Database database := di.Database
// 如果数据库为空,则使用默认数据库进行连接 // 如果数据库为空,则使用默认数据库进行连接
if database == "" { if database == "" {
database = meta.GetMetadata(&DbConn{Info: dbInfo}).GetDefaultDb() database = meta.GetMetadata(&DbConn{Info: di}).GetDefaultDb()
dbInfo.Database = database di.Database = database
} }
conn, err := meta.GetSqlDb(dbInfo) conn, err := meta.GetSqlDb(di)
if err != nil { if err != nil {
logx.Errorf("db connection failed: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error()) logx.Errorf("db connection failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error())
return nil, errorx.NewBiz("db connection failed: %s", err.Error()) return nil, errorx.NewBiz("db connection failed: %s", err.Error())
} }
err = conn.Ping() err = conn.Ping()
if err != nil { if err != nil {
logx.Errorf("db ping failed: %s:%d/%s, err:%s", dbInfo.Host, dbInfo.Port, database, err.Error()) logx.Errorf("db ping failed: %s:%d/%s, err:%s", di.Host, di.Port, database, err.Error())
return nil, errorx.NewBiz("db connection failed: %s", err.Error()) return nil, errorx.NewBiz("db connection failed: %s", err.Error())
} }
dbc := &DbConn{Id: GetDbConnId(dbInfo.Id, database), Info: dbInfo} dbc := &DbConn{Id: GetDbConnId(di.Id, database), Info: di}
// 最大连接周期超过时间的连接就close // 最大连接周期超过时间的连接就close
// conn.SetConnMaxLifetime(100 * time.Second) // conn.SetConnMaxLifetime(100 * time.Second)
@@ -85,7 +85,7 @@ func (dbInfo *DbInfo) Conn(meta Meta) (*DbConn, error) {
// 设置闲置连接数 // 设置闲置连接数
conn.SetMaxIdleConns(1) conn.SetMaxIdleConns(1)
dbc.db = conn dbc.db = conn
logx.Infof("db connection: %s:%d/%s", dbInfo.Host, dbInfo.Port, database) logx.Infof("db connection: %s:%d/%s", di.Host, di.Port, database)
return dbc, nil return dbc, nil
} }

View File

@@ -2,15 +2,17 @@ package dm
import ( import (
"encoding/hex" "encoding/hex"
"fmt"
"mayfly-go/internal/db/dbm/dbi" "mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx" "mayfly-go/pkg/utils/anyx"
"reflect"
"strings" "strings"
"gitee.com/chunanyong/dm" "gitee.com/chunanyong/dm"
) )
var ( var (
CHAR = dbi.NewDbDataType("CHAR", dbi.DTString).WithCT(dbi.CTChar) CHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar)
VARCHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar) VARCHAR = dbi.NewDbDataType("VARCHAR", dbi.DTString).WithCT(dbi.CTVarchar)
TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText) TEXT = dbi.NewDbDataType("TEXT", dbi.DTString).WithCT(dbi.CTText)
LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText) LONG = dbi.NewDbDataType("LONG", dbi.DTString).WithCT(dbi.CTText)
@@ -36,6 +38,7 @@ var (
TIME = dbi.NewDbDataType("TIME", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength) TIME = dbi.NewDbDataType("TIME", dbi.DTTime).WithCT(dbi.CTTime).WithFixColumn(dbi.ClearCharMaxLength)
DATE = dbi.NewDbDataType("DATE", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength) DATE = dbi.NewDbDataType("DATE", dbi.DTDate).WithCT(dbi.CTDate).WithFixColumn(dbi.ClearCharMaxLength)
DATETIME = dbi.NewDbDataType("DATETIME", dbi.DTDateTime).WithCT(dbi.CTDateTime).WithFixColumn(dbi.ClearCharMaxLength)
TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", dbi.DTDateTime).WithCT(dbi.CTTimestamp).WithFixColumn(dbi.ClearCharMaxLength) TIMESTAMP = dbi.NewDbDataType("TIMESTAMP", dbi.DTDateTime).WithCT(dbi.CTTimestamp).WithFixColumn(dbi.ClearCharMaxLength)
ST_CURVE = dbi.NewDbDataType("ST_CURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条曲线,可以是圆弧、抛物线等 ST_CURVE = dbi.NewDbDataType("ST_CURVE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一条曲线,可以是圆弧、抛物线等
@@ -50,6 +53,8 @@ var (
ST_POINT = dbi.NewDbDataType("ST_POINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个点 ST_POINT = dbi.NewDbDataType("ST_POINT", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个点
ST_POLYGON = dbi.NewDbDataType("ST_POLYGON", DTDmStruct).WithCT(dbi.CTVarchar) //表示一个多边形 ST_POLYGON = dbi.NewDbDataType("ST_POLYGON", DTDmStruct).WithCT(dbi.CTVarchar) //表示一个多边形
ST_SURFACE = dbi.NewDbDataType("ST_SURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个表面 ST_SURFACE = dbi.NewDbDataType("ST_SURFACE", DTDmStruct).WithCT(dbi.CTVarchar) // 表示一个表面
TABLES = dbi.NewDbDataType("TABLES", DTDmArray).WithCT(dbi.CTVarchar) // 表示一个数组
) )
var DTDmStruct = &dbi.DataType{ var DTDmStruct = &dbi.DataType{
@@ -104,3 +109,60 @@ func ParseDmStruct(dmStruct *dm.DmStruct) string {
arr = append(arr, ")") arr = append(arr, ")")
return strings.Join(arr, "") return strings.Join(arr, "")
} }
var DTDmArray = &dbi.DataType{
Name: "dm_struct",
Valuer: func() dbi.Valuer {
return &dmArrayValuer{
DefaultValuer: new(dbi.DefaultValuer[dm.DmArray]),
}
},
SQLValue: dbi.SQLValueString,
}
type dmArrayValuer struct {
*dbi.DefaultValuer[dm.DmArray]
}
func (s *dmArrayValuer) Value() any {
if !s.ValuePtr.Valid {
return ""
}
return ParseDmArray(s.ValuePtr)
}
func ParseDmArray(dmArray *dm.DmArray) string {
if !dmArray.Valid {
return ""
}
name, err := dmArray.GetBaseTypeName()
if err != nil {
return err.Error()
}
arr, err := dmArray.GetArray()
if err != nil {
return err.Error()
}
// 获取变量的类型和值
t := reflect.TypeOf(arr)
v := reflect.ValueOf(arr)
// 检查类型是否为数组
if t.Kind() != reflect.Array && t.Kind() != reflect.Slice {
return fmt.Sprintf("%s(%s)", name, anyx.ToString(arr))
}
// 获取数组的长度
length := v.Len()
elements := make([]string, length)
// 遍历数组并将每个元素转换为字符串
for i := 0; i < length; i++ {
element := v.Index(i).Interface()
elements[i] = fmt.Sprintf("%v", element)
}
return fmt.Sprintf("%s(%s)", name, strings.Join(elements, ","))
}

View File

@@ -59,9 +59,10 @@ func (sm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray[*dbi.DbDataType](CHAR, VARCHAR, TEXT, LONG, LONGVARCHAR, IMAGE, LONGVARBINARY, CLOB, return collx.AsArray[*dbi.DbDataType](CHAR, VARCHAR, TEXT, LONG, LONGVARCHAR, IMAGE, LONGVARBINARY, CLOB,
BLOB, BLOB,
NUMERIC, DECIMAL, NUMBER, INTEGER, INT, BIGINT, TINYINT, BYTE, SMALLINT, BIT, DOUBLE, FLOAT, NUMERIC, DECIMAL, NUMBER, INTEGER, INT, BIGINT, TINYINT, BYTE, SMALLINT, BIT, DOUBLE, FLOAT,
TIME, DATE, TIMESTAMP, TIME, DATE, TIMESTAMP, DATETIME,
ST_CURVE, ST_LINESTRING, ST_GEOMCOLLECTION, ST_GEOMETRY, ST_MULTICURVE, ST_MULTILINESTRING, ST_CURVE, ST_LINESTRING, ST_GEOMCOLLECTION, ST_GEOMETRY, ST_MULTICURVE, ST_MULTILINESTRING,
ST_MULTIPOINT, ST_MULTIPOLYGON, ST_MULTISURFACE, ST_POINT, ST_POLYGON, ST_SURFACE, ST_MULTIPOINT, ST_MULTIPOLYGON, ST_MULTISURFACE, ST_POINT, ST_POLYGON, ST_SURFACE,
TABLES,
) )
} }

View File

@@ -88,17 +88,29 @@ func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values
quote := quoter.Quote quote := quoter.Quote
if duplicateStrategy == dbi.DuplicateStrategyNone { if duplicateStrategy == dbi.DuplicateStrategyNone {
identityInsert := "" var res []string
var hasIdentity = false
identityInsertOn := ""
identityInsertOff := ""
// 有自增列的才加上这个语句 // 有自增列的才加上这个语句
if collx.AnyMatch(columns, func(column dbi.Column) bool { return column.IsIdentity }) { if collx.AnyMatch(columns, func(column dbi.Column) bool { return column.IsIdentity }) {
identityInsert = fmt.Sprintf("set identity_insert %s on;", quote(tableName)) identityInsertOn = fmt.Sprintf("set identity_insert %s on;", quote(tableName))
hasIdentity = true
res = append(res, identityInsertOn)
} }
// 达梦数据库只能一条条的执行insert语句所以这里需要将values拆分成多条insert语句 // 达梦数据库只能一条条的执行insert语句所以这里需要将values拆分成多条insert语句
return collx.ArrayMap(values, func(value []any) string { sqls := collx.ArrayMap(values, func(value []any) string {
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, values) columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, [][]any{value})
return fmt.Sprintf("%s insert into %s %s values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")) return fmt.Sprintf("insert into %s %s values %s ;", quote(tableName), columnStr, valuesStrs[0])
}) })
res = append(res, sqls...)
if hasIdentity {
res = append(res, identityInsertOff)
}
return res
} }
// 查询主键字段 // 查询主键字段

View File

@@ -25,18 +25,23 @@ func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
} }
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return BIT return BIT
} }
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return TINYINT return TINYINT
} }
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return SMALLINT return SMALLINT
} }
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INTEGER return INTEGER
} }
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return BIGINT return BIGINT
} }
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
@@ -48,28 +53,36 @@ func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
} }
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return BIGINT return BIGINT
} }
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INT return INT
} }
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INT return INT
} }
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return INT return INT
} }
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return DATE return DATE
} }
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return TIME return TIME
} }
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return TIMESTAMP clearLength(col)
return DATETIME
} }
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
clearLength(col)
return TIMESTAMP return TIMESTAMP
} }
@@ -95,3 +108,8 @@ func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType { func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return VARCHAR return VARCHAR
} }
func clearLength(col *dbi.Column) {
col.CharMaxLength = 0
col.NumPrecision = 0
}

View File

@@ -67,8 +67,8 @@ func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []stri
colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart) colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart)
} }
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING %s" sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE"
sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","), index.IndexType) sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","))
comment := dbi.QuoteEscape(index.IndexComment) comment := dbi.QuoteEscape(index.IndexComment)
if comment != "" { if comment != "" {
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment) sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
@@ -130,6 +130,8 @@ func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column)
} }
} }
if mark { if mark {
// 去掉单引号
column.ColumnDefault = strings.Trim(column.ColumnDefault, "'")
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault) defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else { } else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault) defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)

View File

@@ -85,10 +85,6 @@ func (stm *SshTunnelMachine) OpenSshTunnel(id string, ip string, port int) (expo
return "", 0, err return "", 0, err
} }
if err != nil {
return "", 0, err
}
localHost := "0.0.0.0" localHost := "0.0.0.0"
localAddr := fmt.Sprintf("%s:%d", localHost, localPort) localAddr := fmt.Sprintf("%s:%d", localHost, localPort)
listener, err := net.Listen("tcp", localAddr) listener, err := net.Listen("tcp", localAddr)
@@ -223,8 +219,8 @@ func (r *Tunnel) Open(sshClient *ssh.Client) {
r.remoteConnections = append(r.remoteConnections, remoteConn) r.remoteConnections = append(r.remoteConnections, remoteConn)
logx.Debugf("隧道 %v 连接远程主机成功", r.id) logx.Debugf("隧道 %v 连接远程主机成功", r.id)
go copyConn(localConn, remoteConn) go r.copyConn(localConn, remoteConn)
go copyConn(remoteConn, localConn) go r.copyConn(remoteConn, localConn)
logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr) logx.Debugf("隧道 %v 开始转发数据 [%v]->[%v]", r.id, localAddr, remoteAddr)
logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~") logx.Debug("~~~~~~~~~~~~~~~~~~~~分割线~~~~~~~~~~~~~~~~~~~~~~~~")
} }
@@ -243,6 +239,6 @@ func (r *Tunnel) Close() {
logx.Debugf("隧道 %s 监听器关闭", r.id) logx.Debugf("隧道 %s 监听器关闭", r.id)
} }
func copyConn(writer, reader net.Conn) { func (r *Tunnel) copyConn(writer, reader net.Conn) {
_, _ = io.Copy(writer, reader) _, _ = io.Copy(writer, reader)
} }