refactor: dbm重构、调整metadata与dialect接口

This commit is contained in:
meilin.huang
2024-11-01 17:27:22 +08:00
parent af14be9801
commit 74ae031853
36 changed files with 1216 additions and 1384 deletions

View File

@@ -63,12 +63,12 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns
return count, nil
}
msMetadata := md.dc.GetMetaData()
msMetadata := md.dc.GetMetadata()
schema := md.dc.Info.CurrentSchema()
ignoreDupSql := ""
if duplicateStrategy == dbi.DuplicateStrategyIgnore {
// ALTER TABLE dbo.TEST ADD CONSTRAINT uniqueRows UNIQUE (ColA, ColB, ColC, ColD) WITH (IGNORE_DUP_KEY = ON)
indexs, _ := msMetadata.MetaData.(*MssqlMetaData).getTableIndexWithPK(tableName)
indexs, _ := msMetadata.(*MssqlMetadata).getTableIndexWithPK(tableName)
// 收集唯一索引涉及到的字段
uniqueColumns := make([]string, 0)
for _, index := range indexs {
@@ -99,7 +99,7 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
baseTable := fmt.Sprintf("%s.%s", msMetadata.QuoteIdentifier(schema), msMetadata.QuoteIdentifier(tableName))
baseTable := fmt.Sprintf("%s.%s", md.QuoteIdentifier(schema), md.QuoteIdentifier(tableName))
sqlStr := fmt.Sprintf("insert into %s (%s) values %s", baseTable, strings.Join(columns, ","), placeholder)
// 执行批量insert sql
@@ -117,7 +117,7 @@ func (md *MssqlDialect) batchInsertSimple(tx *sql.Tx, tableName string, columns
}
func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
msMetadata := md.dc.GetMetaData()
msMetadata := md.dc.GetMetadata()
schema := md.dc.Info.CurrentSchema()
// 收集MERGE 语句的 ON 子句条件
@@ -136,7 +136,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
}
if col.IsPrimaryKey {
pkCols = append(pkCols, col.ColumnName)
name := msMetadata.QuoteIdentifier(col.ColumnName)
name := md.QuoteIdentifier(col.ColumnName)
caseSqls = append(caseSqls, fmt.Sprintf(" T1.%s = T2.%s ", name, name))
}
}
@@ -150,7 +150,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
// 源数据占位sql
phs := make([]string, 0)
for _, column := range columns {
if !collx.ArrayContains(identityCols, msMetadata.RemoveQuote(column)) {
if !collx.ArrayContains(identityCols, md.RemoveQuote(column)) {
upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
}
insertCols = append(insertCols, column)
@@ -168,7 +168,7 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
}
t2 := strings.Join(t2s, " UNION ALL ")
sqlTemp := "MERGE INTO " + msMetadata.QuoteIdentifier(schema) + "." + msMetadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ")
sqlTemp := "MERGE INTO " + md.QuoteIdentifier(schema) + "." + md.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON " + strings.Join(caseSqls, " AND ")
sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
@@ -185,14 +185,14 @@ func (md *MssqlDialect) batchInsertMerge(tx *sql.Tx, tableName string, columns [
}
func (md *MssqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
msMetadata := md.dc.GetMetaData().MetaData.(*MssqlMetaData)
msMetadata := md.dc.GetMetadata()
schema := md.dc.Info.CurrentSchema()
// 生成新表名,为老表明+_copy_时间戳
newTableName := copy.TableName + "_copy_" + time.Now().Format("20060102150405")
// 复制建表语句
ddl, err := msMetadata.CopyTableDDL(copy.TableName, newTableName)
ddl, err := md.CopyTableDDL(copy.TableName, newTableName)
if err != nil {
return err
}
@@ -239,14 +239,180 @@ func (md *MssqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
return err
}
func (md *MssqlDialect) CreateTable(columns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) {
sqlArr := md.dc.GetMetaData().GenerateTableDDL(columns, tableInfo, dropOldTable)
_, err := md.dc.Exec(strings.Join(sqlArr, ";"))
return len(sqlArr), err
func (md *MssqlDialect) CopyTableDDL(tableName string, newTableName string) (string, error) {
if newTableName == "" {
newTableName = tableName
}
metadata := md.dc.GetMetadata()
// 查询表名和表注释, 设置表注释
tbs, err := metadata.GetTables(tableName)
if err != nil || len(tbs) < 1 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
tabInfo := &dbi.Table{
TableName: newTableName,
TableComment: tbs[0].TableComment,
}
// 查询列信息
columns, err := metadata.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
sqlArr := md.GenerateTableDDL(columns, *tabInfo, true)
// 设置索引
indexs, err := metadata.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
return strings.Join(sqlArr, ";"), err
}
sqlArr = append(sqlArr, md.GenerateIndexDDL(indexs, *tabInfo)...)
return strings.Join(sqlArr, ";"), nil
}
func (md *MssqlDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error {
sqlArr := md.dc.GetMetaData().GenerateIndexDDL(indexs, tableInfo)
_, err := md.dc.Exec(strings.Join(sqlArr, ";"))
return err
// 获取建表ddl
func (md *MssqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
tbName := tableInfo.TableName
schemaName := md.dc.Info.CurrentSchema()
sqlArr := make([]string, 0)
// 删除表
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s.%s", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName)))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s.%s (\n", md.QuoteIdentifier(schemaName), md.QuoteIdentifier(tbName))
fields := make([]string, 0)
pks := make([]string, 0)
columnComments := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, md.QuoteIdentifier(column.ColumnName))
}
fields = append(fields, md.genColumnBasicSql(column))
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'COLUMN', N'%s'"
// 防止注释内含有特殊字符串导致sql出错
if column.ColumnComment != "" {
comment := md.QuoteEscape(column.ColumnComment)
columnComments = append(columnComments, fmt.Sprintf(commentTmp, comment, md.dc.Info.CurrentSchema(), tbName, column.ColumnName))
}
}
// create
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \n PRIMARY KEY CLUSTERED (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
// comment
tableCommentSql := ""
if tableInfo.TableComment != "" {
commentTmp := "EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s'"
tableCommentSql = fmt.Sprintf(commentTmp, md.QuoteEscape(tableInfo.TableComment), md.dc.Info.CurrentSchema(), tbName)
}
sqlArr = append(sqlArr, createSql)
if tableCommentSql != "" {
sqlArr = append(sqlArr, tableCommentSql)
}
if len(columnComments) > 0 {
sqlArr = append(sqlArr, columnComments...)
}
return sqlArr
}
// 获取建索引ddl
func (md *MssqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
tbName := tableInfo.TableName
sqls := make([]string, 0)
comments := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = md.QuoteIdentifier(name)
}
sqls = append(sqls, fmt.Sprintf("create %s NONCLUSTERED index %s on %s.%s(%s)", unique, md.QuoteIdentifier(index.IndexName), md.QuoteIdentifier(md.dc.Info.CurrentSchema()), md.QuoteIdentifier(tbName), strings.Join(colNames, ",")))
if index.IndexComment != "" {
comment := md.QuoteEscape(index.IndexComment)
comments = append(comments, fmt.Sprintf("EXECUTE sp_addextendedproperty N'MS_Description', N'%s', N'SCHEMA', N'%s', N'TABLE', N'%s', N'INDEX', N'%s'", comment, md.dc.Info.CurrentSchema(), tbName, index.IndexName))
}
}
if len(comments) > 0 {
sqls = append(sqls, comments...)
}
return sqls
}
func (md *MssqlDialect) GetIdentifierQuoteString() string {
return "["
}
func (md *MssqlDialect) GetDataHelper() dbi.DataHelper {
return new(DataHelper)
}
func (md *MssqlDialect) GetColumnHelper() dbi.ColumnHelper {
return new(ColumnHelper)
}
func (md *MssqlDialect) GetDumpHelper() dbi.DumpHelper {
return new(DumpHelper)
}
func (md *MssqlDialect) genColumnBasicSql(column dbi.Column) string {
colName := md.QuoteIdentifier(column.ColumnName)
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " IDENTITY(1,1)"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的 // 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, dataType) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
columnSql := fmt.Sprintf(" %s %s%s%s%s", colName, column.GetColumnType(), incr, nullAble, defVal)
return columnSql
}