!129 fix: db 相关bug

* fix: db 相关bug
This commit is contained in:
zongyangleo
2024-12-26 04:11:28 +00:00
committed by Coder慌
parent 68f553f4b0
commit 3f6fb5afef
9 changed files with 240 additions and 202 deletions

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"regexp"
"strings"
)
@@ -87,12 +88,16 @@ func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values
quote := quoter.Quote
if duplicateStrategy == dbi.DuplicateStrategyNone {
identityInsert := fmt.Sprintf("set identity_insert %s on;", quote(tableName))
identityInsert := ""
// 有自增列的才加上这个语句
if collx.AnyMatch(columns, func(column dbi.Column) bool { return column.IsIdentity }) {
identityInsert = fmt.Sprintf("set identity_insert %s on;", quote(tableName))
}
// 达梦数据库只能一条条的执行insert语句所以这里需要将values拆分成多条insert语句
return collx.ArrayMap(values, func(value []any) string {
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, [][]any{value})
return fmt.Sprintf("%s insert into %s (%s) values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(sg.Dialect, DbTypeDM, columns, values)
return fmt.Sprintf("%s insert into %s %s values %s", identityInsert, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n"))
})
}
@@ -156,9 +161,9 @@ func (sg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values
return collx.AsArray(sqlTemp)
}
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
func (sg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
colName := quoter.Quote(column.ColumnName)
dataType := string(column.DataType)
dataType := column.DataType
incr := ""
if column.IsIdentity {
@@ -174,7 +179,10 @@ func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column)
if column.ColumnDefault != "" && !strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
if regexp.MustCompile(`'.*'`).MatchString(column.ColumnDefault) {
// 字符串默认值
mark = false
} else if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
@@ -182,6 +190,10 @@ func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column)
} else {
mark = true
}
// 空
if column.ColumnDefault == "NULL" {
mark = false
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)