mirror of
				https://gitee.com/dromara/mayfly-go
				synced 2025-11-04 00:10:25 +08:00 
			
		
		
		
	refactor: dbm重构、调整metadata与dialect接口
This commit is contained in:
		@@ -38,7 +38,7 @@ func (od *OracleDialect) BatchInsert(tx *sql.Tx, tableName string, columns []str
 | 
			
		||||
 | 
			
		||||
// 简单批量插入sql,无需判断键冲突策略
 | 
			
		||||
func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, values [][]any, duplicateStrategy int, tx *sql.Tx) (int64, error) {
 | 
			
		||||
	metadata := od.dc.GetMetaData()
 | 
			
		||||
	metadata := od.dc.GetMetadata()
 | 
			
		||||
	// 忽略键冲突策略
 | 
			
		||||
	ignore := ""
 | 
			
		||||
	if duplicateStrategy == dbi.DuplicateStrategyIgnore {
 | 
			
		||||
@@ -66,7 +66,7 @@ func (od *OracleDialect) batchInsertSimple(tableName string, columns []string, v
 | 
			
		||||
		for i := 0; i < len(value); i++ {
 | 
			
		||||
			placeholder = append(placeholder, fmt.Sprintf(":%d", i+1))
 | 
			
		||||
		}
 | 
			
		||||
		sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, metadata.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))
 | 
			
		||||
		sqlTemp := fmt.Sprintf("INSERT %s INTO %s (%s) VALUES (%s)", ignore, od.QuoteIdentifier(tableName), strings.Join(columns, ","), strings.Join(placeholder, ","))
 | 
			
		||||
 | 
			
		||||
		// oracle数据库为了兼容ignore主键冲突,只能一条条的执行insert
 | 
			
		||||
		res, err := od.dc.TxExec(tx, sqlTemp, value...)
 | 
			
		||||
@@ -82,7 +82,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
 | 
			
		||||
	// 查询主键字段
 | 
			
		||||
	uniqueCols := make([]string, 0)
 | 
			
		||||
	caseSqls := make([]string, 0)
 | 
			
		||||
	metadata := od.dc.GetMetaData()
 | 
			
		||||
	metadata := od.dc.GetMetadata()
 | 
			
		||||
	// 查询唯一索引涉及到的字段,并组装到match条件内
 | 
			
		||||
	indexs, _ := metadata.GetTableIndex(tableName)
 | 
			
		||||
	if indexs != nil {
 | 
			
		||||
@@ -94,7 +94,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
 | 
			
		||||
					if !collx.ArrayContains(uniqueCols, col) {
 | 
			
		||||
						uniqueCols = append(uniqueCols, col)
 | 
			
		||||
					}
 | 
			
		||||
					tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", metadata.QuoteIdentifier(col), metadata.QuoteIdentifier(col)))
 | 
			
		||||
					tmp = append(tmp, fmt.Sprintf(" T1.%s = T2.%s ", od.QuoteIdentifier(col), od.QuoteIdentifier(col)))
 | 
			
		||||
				}
 | 
			
		||||
				caseSqls = append(caseSqls, fmt.Sprintf("( %s )", strings.Join(tmp, " AND ")))
 | 
			
		||||
			}
 | 
			
		||||
@@ -111,7 +111,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
 | 
			
		||||
	upds := make([]string, 0)
 | 
			
		||||
	insertCols := make([]string, 0)
 | 
			
		||||
	for _, column := range columns {
 | 
			
		||||
		if !collx.ArrayContains(uniqueCols, metadata.RemoveQuote(column)) {
 | 
			
		||||
		if !collx.ArrayContains(uniqueCols, od.RemoveQuote(column)) {
 | 
			
		||||
			upds = append(upds, fmt.Sprintf("T1.%s = T2.%s", column, column))
 | 
			
		||||
		}
 | 
			
		||||
		insertCols = append(insertCols, fmt.Sprintf("T1.%s", column))
 | 
			
		||||
@@ -132,7 +132,7 @@ func (od *OracleDialect) batchInsertMergeSql(tableName string, columns []string,
 | 
			
		||||
 | 
			
		||||
	t2 := strings.Join(t2s, " UNION ALL ")
 | 
			
		||||
 | 
			
		||||
	sqlTemp := "MERGE INTO " + metadata.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") "
 | 
			
		||||
	sqlTemp := "MERGE INTO " + od.QuoteIdentifier(tableName) + " T1 USING (" + t2 + ") T2 ON (" + strings.Join(caseSqls, " OR ") + ") "
 | 
			
		||||
	sqlTemp += "WHEN NOT MATCHED THEN INSERT (" + strings.Join(insertCols, ",") + ") VALUES (" + strings.Join(insertVals, ",") + ") "
 | 
			
		||||
	sqlTemp += "WHEN MATCHED THEN UPDATE SET " + strings.Join(upds, ",")
 | 
			
		||||
 | 
			
		||||
@@ -155,22 +155,132 @@ func (od *OracleDialect) CopyTable(copy *dbi.DbCopyTable) error {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (od *OracleDialect) CreateTable(commonColumns []dbi.Column, tableInfo dbi.Table, dropOldTable bool) (int, error) {
 | 
			
		||||
	meta := od.dc.GetMetaData()
 | 
			
		||||
	sqlArr := meta.GenerateTableDDL(commonColumns, tableInfo, dropOldTable)
 | 
			
		||||
	// 需要分开执行sql
 | 
			
		||||
	for _, sqlStr := range sqlArr {
 | 
			
		||||
		_, err := od.dc.Exec(sqlStr)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return 0, err
 | 
			
		||||
// 获取建表ddl
 | 
			
		||||
func (od *OracleDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
 | 
			
		||||
	quoteTableName := od.QuoteIdentifier(tableInfo.TableName)
 | 
			
		||||
	sqlArr := make([]string, 0)
 | 
			
		||||
 | 
			
		||||
	if dropBeforeCreate {
 | 
			
		||||
		dropSqlTmp := `
 | 
			
		||||
declare
 | 
			
		||||
      num number;
 | 
			
		||||
begin
 | 
			
		||||
    select count(1) into num from user_tables where table_name = '%s' and owner = (SELECT sys_context('USERENV', 'CURRENT_SCHEMA') FROM dual) ;
 | 
			
		||||
    if num > 0 then
 | 
			
		||||
        execute immediate 'drop table "%s"' ;
 | 
			
		||||
    end if;
 | 
			
		||||
end`
 | 
			
		||||
		sqlArr = append(sqlArr, fmt.Sprintf(dropSqlTmp, tableInfo.TableName, tableInfo.TableName))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 组装建表语句
 | 
			
		||||
	createSql := fmt.Sprintf("CREATE TABLE %s ( \n", quoteTableName)
 | 
			
		||||
	fields := make([]string, 0)
 | 
			
		||||
	pks := make([]string, 0)
 | 
			
		||||
	columnComments := make([]string, 0)
 | 
			
		||||
	// 把通用类型转换为达梦类型
 | 
			
		||||
	for _, column := range columns {
 | 
			
		||||
		if column.IsPrimaryKey {
 | 
			
		||||
			pks = append(pks, od.QuoteIdentifier(column.ColumnName))
 | 
			
		||||
		}
 | 
			
		||||
		fields = append(fields, od.genColumnBasicSql(column))
 | 
			
		||||
		// 防止注释内含有特殊字符串导致sql出错
 | 
			
		||||
		if column.ColumnComment != "" {
 | 
			
		||||
			comment := od.QuoteEscape(column.ColumnComment)
 | 
			
		||||
			columnComments = append(columnComments, fmt.Sprintf("COMMENT ON COLUMN %s.%s IS '%s'", quoteTableName, od.QuoteIdentifier(column.ColumnName), comment))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return len(sqlArr), nil
 | 
			
		||||
 | 
			
		||||
	// 建表
 | 
			
		||||
	createSql += strings.Join(fields, ",\n")
 | 
			
		||||
	if len(pks) > 0 {
 | 
			
		||||
		createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
 | 
			
		||||
	}
 | 
			
		||||
	createSql += "\n)"
 | 
			
		||||
	sqlArr = append(sqlArr, createSql)
 | 
			
		||||
 | 
			
		||||
	// 表注释
 | 
			
		||||
	tableCommentSql := ""
 | 
			
		||||
	if tableInfo.TableComment != "" {
 | 
			
		||||
		tableCommentSql = fmt.Sprintf("COMMENT ON TABLE %s is '%s'", od.QuoteIdentifier(tableInfo.TableName), od.QuoteEscape(tableInfo.TableComment))
 | 
			
		||||
		sqlArr = append(sqlArr, tableCommentSql)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 列注释
 | 
			
		||||
	if len(columnComments) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, columnComments...)
 | 
			
		||||
	}
 | 
			
		||||
	otherSql := od.GenerateTableOtherDDL(tableInfo, quoteTableName, columns)
 | 
			
		||||
	if len(otherSql) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, otherSql...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sqlArr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (od *OracleDialect) CreateIndex(tableInfo dbi.Table, indexs []dbi.Index) error {
 | 
			
		||||
	meta := od.dc.GetMetaData()
 | 
			
		||||
	sqlArr := meta.GenerateIndexDDL(indexs, tableInfo)
 | 
			
		||||
	_, err := od.dc.Exec(strings.Join(sqlArr, ";"))
 | 
			
		||||
	return err
 | 
			
		||||
// 获取建索引ddl
 | 
			
		||||
func (od *OracleDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
 | 
			
		||||
	sqls := make([]string, 0)
 | 
			
		||||
	comments := make([]string, 0)
 | 
			
		||||
 | 
			
		||||
	for _, index := range indexs {
 | 
			
		||||
		unique := ""
 | 
			
		||||
		if index.IsUnique {
 | 
			
		||||
			unique = "unique"
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 取出列名,添加引号
 | 
			
		||||
		cols := strings.Split(index.ColumnName, ",")
 | 
			
		||||
		colNames := make([]string, len(cols))
 | 
			
		||||
		for i, name := range cols {
 | 
			
		||||
			colNames[i] = od.QuoteIdentifier(name)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		sqls = append(sqls, fmt.Sprintf("CREATE %s INDEX %s ON %s(%s)", unique, od.QuoteIdentifier(index.IndexName), od.QuoteIdentifier(tableInfo.TableName), strings.Join(colNames, ",")))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sqlArr := make([]string, 0)
 | 
			
		||||
 | 
			
		||||
	sqlArr = append(sqlArr, sqls...)
 | 
			
		||||
 | 
			
		||||
	if len(comments) > 0 {
 | 
			
		||||
		sqlArr = append(sqlArr, comments...)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return sqlArr
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 11g及以下版本会设置自增序列
 | 
			
		||||
func (od *OracleDialect) GenerateTableOtherDDL(tableInfo dbi.Table, quoteTableName string, columns []dbi.Column) []string {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (od *OracleDialect) GetDataHelper() dbi.DataHelper {
 | 
			
		||||
	return new(DataHelper)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (od *OracleDialect) GetColumnHelper() dbi.ColumnHelper {
 | 
			
		||||
	return new(ColumnHelper)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (od *OracleDialect) genColumnBasicSql(column dbi.Column) string {
 | 
			
		||||
	colName := od.QuoteIdentifier(column.ColumnName)
 | 
			
		||||
 | 
			
		||||
	if column.IsIdentity {
 | 
			
		||||
		// 如果是自增,不需要设置默认值和空值,自增列数据类型必须是number
 | 
			
		||||
		return fmt.Sprintf(" %s NUMBER generated by default as IDENTITY", colName)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	nullAble := ""
 | 
			
		||||
	if !column.Nullable {
 | 
			
		||||
		nullAble = " NOT NULL"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defVal := ""
 | 
			
		||||
	if column.ColumnDefault != "" {
 | 
			
		||||
		defVal = fmt.Sprintf(" DEFAULT %v", column.ColumnDefault)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	columnSql := fmt.Sprintf(" %s %s%s%s", colName, column.GetColumnType(), defVal, nullAble)
 | 
			
		||||
	return columnSql
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user