Files
mayfly-go/server/internal/db/dbm/mysql/dialect.go
zongyangleo 2acc295259 !110 feat: 支持各源数据库导出sql,数据库迁移部分bug修复
* feat: 各源数据库导出
* fix: 数据库迁移 bug修复
2024-03-26 09:05:28 +00:00

126 lines
3.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package mysql
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"strings"
"time"
)
type MysqlDialect struct {
dbi.DefaultDialect
dc *dbi.DbConn
}
// GetDbProgram 获取数据库程序模块,用于数据库备份与恢复
func (md *MysqlDialect) GetDbProgram() (dbi.DbProgram, error) {
return NewDbProgramMysql(md.dc), nil
}
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
// 执行批量insert sqlmysql支持批量insert语法
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
// 重复占位符字符串n遍
repeated = strings.Repeat(placeholder+",", len(values))
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
prefix := "insert into"
if duplicateStrategy == 1 {
prefix = "insert ignore into"
} else if duplicateStrategy == 2 {
prefix = "replace into"
}
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.dc.GetMetaData().QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
// 执行批量insert sql
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
return md.dc.TxExec(tx, sqlStr, args...)
}
func (md *MysqlDialect) GetDataConverter() dbi.DataConverter {
return converter
}
func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
// 生成新表名,为老表明+_copy_时间戳
newTableName := tableName + "_copy_" + time.Now().Format("20060102150405")
// 复制表结构创建表
_, err := md.dc.Exec(fmt.Sprintf("create table %s like %s", newTableName, tableName))
if err != nil {
return err
}
// 复制数据
if copy.CopyData {
go func() {
_, _ = md.dc.Exec(fmt.Sprintf("insert into %s select * from %s", newTableName, tableName))
}()
}
return err
}
func (md *MysqlDialect) ToCommonColumn(column *dbi.Column) {
dataType := column.DataType
t1 := commonColumnTypeMap[string(dataType)]
commonColumnType := dbi.CommonTypeVarchar
if t1 != "" {
commonColumnType = t1
}
column.DataType = commonColumnType
}
func (md *MysqlDialect) ToColumn(column *dbi.Column) {
ctype := mysqlColumnTypeMap[column.DataType]
if ctype == "" {
column.DataType = "varchar"
column.CharMaxLength = 1000
} else {
column.DataType = dbi.ColumnDataType(ctype)
md.dc.GetMetaData().FixColumn(column)
}
}
func (md *MysqlDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) {
sqlArr := md.dc.GetMetaData().GenerateTableDDL(columns, tableInfo, dropOldTable)
for _, sqlStr := range sqlArr {
_, err := md.dc.Exec(sqlStr)
if err != nil {
return 0, err
}
}
return len(sqlArr), nil
}
func (md *MysqlDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error {
meta := md.dc.GetMetaData()
sqlArr := meta.GenerateIndexDDL(indexs, tableInfo)
for _, sqlStr := range sqlArr {
_, err := md.dc.Exec(sqlStr)
if err != nil {
return err
}
}
return nil
}