refactor: 数据同步编辑页优化等

This commit is contained in:
meilin.huang
2024-01-11 12:35:44 +08:00
parent bbec3eca0d
commit bc811cbd49
15 changed files with 178 additions and 112 deletions

View File

@@ -277,7 +277,7 @@ func (app *dataSyncAppImpl) srcData2TargetDb(srcRes []map[string]any, fieldMap [
for _, item := range fieldMap {
targetField := item["target"]
srcField := item["target"]
targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.WrapName(targetField))
targetWrapColumns = append(targetWrapColumns, targetDbConn.Info.Type.QuoteIdentifier(targetField))
srcColumns = append(srcColumns, srcField)
}

View File

@@ -25,37 +25,25 @@ func (dbType DbType) Equal(typ string) bool {
return ToDbType(typ) == dbType
}
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// quoted := quoteIdentifier(tblname, '"')
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
func (dbType DbType) QuoteIdentifier(name string) string {
switch dbType {
case DbTypeMysql, DbTypeMariadb:
return quoteIdentifier(name, "`")
case DbTypePostgres:
return pq.QuoteIdentifier(name)
return quoteIdentifier(name, `"`)
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}
func (dbType DbType) MetaDbName() string {
switch dbType {
case DbTypeMysql, DbTypeMariadb:
return ""
case DbTypePostgres:
return "postgres"
case DbTypeDM:
return ""
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
}
}
// 包装字段名,防止使用了数据库保留关键字
func (dbType DbType) WrapName(name string) string {
switch dbType {
case DbTypeMysql, DbTypeMariadb:
return fmt.Sprintf("`%s`", name)
default:
return fmt.Sprintf(`"%s"`, name)
return quoteIdentifier(name, `"`)
}
}
@@ -68,7 +56,20 @@ func (dbType DbType) QuoteLiteral(literal string) string {
case DbTypePostgres:
return pq.QuoteLiteral(literal)
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
return pq.QuoteLiteral(literal)
}
}
func (dbType DbType) MetaDbName() string {
switch dbType {
case DbTypeMysql, DbTypeMariadb:
return ""
case DbTypePostgres:
return "postgres"
case DbTypeDM:
return ""
default:
return ""
}
}
@@ -78,24 +79,11 @@ func (dbType DbType) Dialect() sqlparser.Dialect {
return sqlparser.MysqlDialect{}
case DbTypePostgres:
return sqlparser.PostgresDialect{}
case DbTypeDM:
return sqlparser.PostgresDialect{}
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
return sqlparser.PostgresDialect{}
}
}
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
// used as part of an SQL statement. For example:
//
// tblname := "my_table"
// data := "my_data"
// quoted := pq.QuoteIdentifier(tblname)
// err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
//
// Any double quotes in name will be escaped. The quoted identifier will be
// case sensitive when used in a query. If the input string contains a zero
// byte, the result will be truncated immediately before it.
func quoteIdentifier(name, quoter string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
@@ -116,7 +104,7 @@ func (dbType DbType) StmtSetForeignKeyChecks(check bool) string {
// not currently supported postgres
return ""
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
return ""
}
}
@@ -128,6 +116,6 @@ func (dbType DbType) StmtUseDatabase(dbName string) string {
// not currently supported postgres
return ""
default:
panic(fmt.Sprintf("invalid database type: %s", dbType))
return ""
}
}

View File

@@ -50,3 +50,34 @@ func Test_QuoteLiteral(t *testing.T) {
})
}
}
func Test_quoteIdentifier(t *testing.T) {
tests := []struct {
dbType DbType
sql string
want string
}{
{
dbType: DbTypeMysql,
sql: "`a`",
},
{
dbType: DbTypeMysql,
sql: "select table",
},
{
dbType: DbTypePostgres,
sql: "a",
},
{
dbType: DbTypePostgres,
sql: "table",
},
}
for _, tt := range tests {
t.Run(string(tt.dbType)+"_"+tt.sql, func(t *testing.T) {
got := tt.dbType.QuoteIdentifier(tt.sql)
require.Equal(t, tt.want, got)
})
}
}

View File

@@ -307,7 +307,7 @@ func (dd *DMDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string,
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), placeholder)
sqlTemp := fmt.Sprintf("insert into %s (%s) values %s", dd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
effRows := 0
for _, value := range values {
// 达梦数据库只能一条条的执行insert

View File

@@ -236,7 +236,7 @@ func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), placeholder)
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", md.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
// 执行批量insert sql
// 把二维数组转为一维数组
var args []any

View File

@@ -319,7 +319,7 @@ func (pd *PgsqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []stri
placeholders = append(placeholders, "("+strings.Join(placeholder, ", ")+")")
}
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.dc.Info.Type.WrapName(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "))
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.dc.Info.Type.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholders, ", "))
// 执行批量insert sql
return pd.dc.TxExec(tx, sqlStr, args...)