refactor: code review

This commit is contained in:
meilin.huang
2024-01-06 22:36:50 +08:00
parent e158422091
commit eea759e10e
12 changed files with 288 additions and 248 deletions

View File

@@ -189,12 +189,8 @@ func (md *MysqlDialect) GetTableDDL(tableName string) (string, error) {
return res[0]["Create Table"].(string) + ";", nil
}
func (md *MysqlDialect) GetTableRecord(tableName string, pageNum, pageSize int) ([]*QueryColumn, []map[string]any, error) {
return md.dc.Query(fmt.Sprintf("SELECT * FROM %s LIMIT %d, %d", tableName, (pageNum-1)*pageSize, pageSize))
}
func (md *MysqlDialect) WalkTableRecord(tableName string, walk func(record map[string]any, columns []*QueryColumn)) error {
return md.dc.WalkTableRecord(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walk)
func (md *MysqlDialect) WalkTableRecord(tableName string, walkFn WalkQueryRowsFunc) error {
return md.dc.WalkQueryRows(context.Background(), fmt.Sprintf("SELECT * FROM %s", tableName), walkFn)
}
func (md *MysqlDialect) GetSchemas() ([]string, error) {
@@ -210,10 +206,6 @@ func (pd *MysqlDialect) WrapName(name string) string {
return "`" + name + "`"
}
func (pd *MysqlDialect) PageSql(pageNum int, pageSize int) string {
return fmt.Sprintf("limit %d, %d", (pageNum-1)*pageSize, pageSize)
}
func (pd *MysqlDialect) GetDataType(dbColumnType string) DataType {
if regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`).MatchString(dbColumnType) {
return DataTypeNumber
@@ -233,24 +225,29 @@ func (pd *MysqlDialect) GetDataType(dbColumnType string) DataType {
return DataTypeString
}
func (pd *MysqlDialect) SaveBatch(conn *DbConn, tableName string, columns string, placeholder string, values [][]any) error {
func (pd *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any) (int64, error) {
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
// 执行批量insert sqlmysql支持批量insert语法
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
// 重复占位符字符串n遍
repeated := strings.Repeat(placeholder+",", len(values))
repeated = strings.Repeat(placeholder+",", len(values))
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.WrapName(tableName), columns, placeholder)
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", pd.WrapName(tableName), strings.Join(columns, ","), placeholder)
// 执行批量insert sql
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
_, err := conn.Exec(sqlStr, args...)
return err
return pd.dc.TxExec(tx, sqlStr, args...)
}
func (pd *MysqlDialect) FormatStrData(dbColumnValue string, dataType DataType) string {