mirror of
https://gitee.com/dromara/mayfly-go
synced 2026-02-16 09:15:39 +08:00
refactor: dbm
This commit is contained in:
48
server/internal/db/dbm/mysql/column.go
Normal file
48
server/internal/db/dbm/mysql/column.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
)
|
||||
|
||||
const (
|
||||
IndexSubPartKey = "subPart"
|
||||
)
|
||||
|
||||
var (
|
||||
Bit = dbi.NewDbDataType("bit", dbi.DTBit).WithCT(dbi.CTBit)
|
||||
Tinyint = dbi.NewDbDataType("tinyint", dbi.DTInt8).WithCT(dbi.CTInt1).WithFixColumn(dbi.ClearNumScale)
|
||||
Smallint = dbi.NewDbDataType("smallint", dbi.DTInt16).WithCT(dbi.CTInt2).WithFixColumn(dbi.ClearNumScale)
|
||||
Mediumint = dbi.NewDbDataType("mediumint", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
Int = dbi.NewDbDataType("int", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
Bigint = dbi.NewDbDataType("bigint", dbi.DTInt64).WithCT(dbi.CTInt8).WithFixColumn(dbi.ClearNumScale)
|
||||
|
||||
UnsignedBigint = dbi.NewDbDataType("unsigned bigint", dbi.DTUint64).WithCT(dbi.CTUnsignedInt8).WithFixColumn(dbi.ClearNumScale)
|
||||
UnsignedInt = dbi.NewDbDataType("unsigned int", dbi.DTUint64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
UnsignedSmallint = dbi.NewDbDataType("unsigned smallint", dbi.DTInt32).WithCT(dbi.CTUnsignedInt2).WithFixColumn(dbi.ClearNumScale)
|
||||
UnsignedMediumint = dbi.NewDbDataType("unsigned mediumint", dbi.DTInt64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
|
||||
|
||||
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
|
||||
Double = dbi.NewDbDataType("double", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
Float = dbi.NewDbDataType("float", dbi.DTNumeric).WithCT(dbi.CTNumeric)
|
||||
|
||||
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTChar)
|
||||
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Mediumtext = dbi.NewDbDataType("mediumtext", dbi.DTString).WithCT(dbi.CTMediumtext).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
Longtext = dbi.NewDbDataType("longtext", dbi.DTString).WithCT(dbi.CTLongtext).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
JSON = dbi.NewDbDataType("json", dbi.DTString).WithCT(dbi.CTJSON).WithFixColumn(dbi.ClearCharMaxLength)
|
||||
|
||||
Datetime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
|
||||
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
|
||||
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
|
||||
Timestamp = dbi.NewDbDataType("timestamp", dbi.DTDateTime).WithCT(dbi.CTTimestamp)
|
||||
|
||||
Enum = dbi.NewDbDataType("enum", dbi.DTString).WithCT(dbi.CTEnum)
|
||||
Set = dbi.NewDbDataType("set", dbi.DTString).WithCT(dbi.CTVarchar)
|
||||
|
||||
Blob = dbi.NewDbDataType("blob", dbi.DTBytes).WithCT(dbi.CTBlob)
|
||||
Mediumblob = dbi.NewDbDataType("mediumblob", dbi.DTBytes).WithCT(dbi.CTMediumblob)
|
||||
Longblob = dbi.NewDbDataType("longblob", dbi.DTBytes).WithCT(dbi.CTLongblob)
|
||||
Binary = dbi.NewDbDataType("binary", dbi.DTBytes).WithCT(dbi.CTBinary)
|
||||
Varbinary = dbi.NewDbDataType("varbinary", dbi.DTBytes).WithCT(dbi.CTVarbinary)
|
||||
)
|
||||
@@ -1,17 +1,20 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/internal/db/dbm/sqlparser"
|
||||
"mayfly-go/internal/db/dbm/sqlparser/mysql"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const Quoter = "`"
|
||||
var (
|
||||
mysqlQuoter = dbi.Quoter{
|
||||
Prefix: '`',
|
||||
Suffix: '`',
|
||||
IsReserved: dbi.AlwaysReserve,
|
||||
}
|
||||
)
|
||||
|
||||
type MysqlDialect struct {
|
||||
dbi.DefaultDialect
|
||||
@@ -24,38 +27,6 @@ func (md *MysqlDialect) GetDbProgram() (dbi.DbProgram, error) {
|
||||
return NewDbProgramMysql(md.dc), nil
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
|
||||
// 生成占位符字符串:如:(?,?)
|
||||
// 重复字符串并用逗号连接
|
||||
repeated := strings.Repeat("?,", len(columns))
|
||||
// 去除最后一个逗号,占位符由括号包裹
|
||||
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
|
||||
|
||||
// 执行批量insert sql,mysql支持批量insert语法
|
||||
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
|
||||
|
||||
// 重复占位符字符串n遍
|
||||
repeated = strings.Repeat(placeholder+",", len(values))
|
||||
// 去除最后一个逗号
|
||||
placeholder = strings.TrimSuffix(repeated, ",")
|
||||
|
||||
prefix := "insert into"
|
||||
if duplicateStrategy == 1 {
|
||||
prefix = "insert ignore into"
|
||||
} else if duplicateStrategy == 2 {
|
||||
prefix = "replace into"
|
||||
}
|
||||
|
||||
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
|
||||
// 执行批量insert sql
|
||||
// 把二维数组转为一维数组
|
||||
var args []any
|
||||
for _, v := range values {
|
||||
args = append(args, v...)
|
||||
}
|
||||
return md.dc.TxExec(tx, sqlStr, args...)
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
tableName := copy.TableName
|
||||
|
||||
@@ -77,145 +48,14 @@ func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取建表ddl
|
||||
func (md *MysqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", md.QuoteIdentifier(tableInfo.TableName)))
|
||||
}
|
||||
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", md.QuoteIdentifier(tableInfo.TableName))
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, column.ColumnName)
|
||||
}
|
||||
fields = append(fields, md.genColumnBasicSql(column))
|
||||
}
|
||||
|
||||
// 建表ddl
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
// 表注释
|
||||
if tableInfo.TableComment != "" {
|
||||
createSql += fmt.Sprintf(" COMMENT '%s'", md.QuoteEscape(tableInfo.TableComment))
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
// 获取建索引ddl
|
||||
func (md *MysqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
cols := strings.Split(index.ColumnName, ",")
|
||||
colNames := make([]string, len(cols))
|
||||
for i, name := range cols {
|
||||
colNames[i] = md.QuoteIdentifier(name)
|
||||
}
|
||||
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE"
|
||||
sqlStr := fmt.Sprintf(sqlTmp, md.QuoteIdentifier(tableInfo.TableName), unique, md.QuoteIdentifier(index.IndexName), strings.Join(colNames, ","))
|
||||
comment := md.QuoteEscape(index.IndexComment)
|
||||
if comment != "" {
|
||||
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
|
||||
}
|
||||
sqlArr = append(sqlArr, sqlStr)
|
||||
}
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) genColumnBasicSql(column dbi.Column) string {
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " AUTO_INCREMENT"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
columnType := column.GetColumnType()
|
||||
if nullAble == "" && strings.Contains(columnType, "timestamp") {
|
||||
nullAble = " NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
|
||||
if column.ColumnDefault != "" &&
|
||||
// 当默认值是字符串'NULL'时,不需要设置默认值
|
||||
column.ColumnDefault != "NULL" &&
|
||||
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
!strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
comment := ""
|
||||
if column.ColumnComment != "" {
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
commentStr := md.QuoteEscape(column.ColumnComment)
|
||||
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.QuoteIdentifier(column.ColumnName), columnType, nullAble, incr, defVal, comment)
|
||||
return columnSql
|
||||
}
|
||||
|
||||
func (dx *MysqlDialect) QuoteIdentifier(name string) string {
|
||||
end := strings.IndexRune(name, 0)
|
||||
if end > -1 {
|
||||
name = name[:end]
|
||||
}
|
||||
return Quoter + strings.Replace(name, Quoter, Quoter+Quoter, -1) + Quoter
|
||||
}
|
||||
|
||||
func (dx *MysqlDialect) RemoveQuote(name string) string {
|
||||
return strings.ReplaceAll(name, Quoter, "")
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) QuoteLiteral(literal string) string {
|
||||
literal = strings.ReplaceAll(literal, `\`, `\\`)
|
||||
literal = strings.ReplaceAll(literal, `'`, `''`)
|
||||
return "'" + literal + "'"
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetDataHelper() dbi.DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetColumnHelper() dbi.ColumnHelper {
|
||||
return columnHelper
|
||||
func (md *MysqlDialect) Quoter() dbi.Quoter {
|
||||
return mysqlQuoter
|
||||
}
|
||||
|
||||
func (pd *MysqlDialect) GetSQLParser() sqlparser.SqlParser {
|
||||
return new(mysql.MysqlParser)
|
||||
}
|
||||
|
||||
func (md *MysqlDialect) GetSQLGenerator() dbi.SQLGenerator {
|
||||
return &SQLGenerator{Dialect: md}
|
||||
}
|
||||
|
||||
@@ -1,218 +1 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/anyx"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
// 数字类型
|
||||
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
|
||||
// 日期时间类型
|
||||
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
|
||||
// 日期类型
|
||||
dateRegexp = regexp.MustCompile(`(?i)date`)
|
||||
// 时间类型
|
||||
timeRegexp = regexp.MustCompile(`(?i)time`)
|
||||
|
||||
blobRegexp = regexp.MustCompile(`(?i)blob`)
|
||||
|
||||
// mysql数据类型 映射 公共数据类型
|
||||
commonColumnTypeMap = map[string]dbi.ColumnDataType{
|
||||
"bigint": dbi.CommonTypeBigint,
|
||||
"binary": dbi.CommonTypeBinary,
|
||||
"blob": dbi.CommonTypeBlob,
|
||||
"char": dbi.CommonTypeChar,
|
||||
"datetime": dbi.CommonTypeDatetime,
|
||||
"date": dbi.CommonTypeDate,
|
||||
"decimal": dbi.CommonTypeNumber,
|
||||
"double": dbi.CommonTypeNumber,
|
||||
"enum": dbi.CommonTypeEnum,
|
||||
"float": dbi.CommonTypeNumber,
|
||||
"int": dbi.CommonTypeInt,
|
||||
"json": dbi.CommonTypeJSON,
|
||||
"longblob": dbi.CommonTypeLongblob,
|
||||
"longtext": dbi.CommonTypeLongtext,
|
||||
"mediumblob": dbi.CommonTypeBlob,
|
||||
"mediumtext": dbi.CommonTypeMediumtext,
|
||||
"bit": dbi.CommonTypeBit,
|
||||
"set": dbi.CommonTypeVarchar,
|
||||
"smallint": dbi.CommonTypeSmallint,
|
||||
"text": dbi.CommonTypeText,
|
||||
"time": dbi.CommonTypeTime,
|
||||
"timestamp": dbi.CommonTypeTimestamp,
|
||||
"tinyint": dbi.CommonTypeTinyint,
|
||||
"varbinary": dbi.CommonTypeVarbinary,
|
||||
"varchar": dbi.CommonTypeVarchar,
|
||||
}
|
||||
|
||||
// 公共数据类型 映射 mysql数据类型
|
||||
mysqlColumnTypeMap = map[dbi.ColumnDataType]string{
|
||||
dbi.CommonTypeVarchar: "varchar",
|
||||
dbi.CommonTypeChar: "char",
|
||||
dbi.CommonTypeText: "text",
|
||||
dbi.CommonTypeBlob: "blob",
|
||||
dbi.CommonTypeLongblob: "longblob",
|
||||
dbi.CommonTypeLongtext: "longtext",
|
||||
dbi.CommonTypeBinary: "binary",
|
||||
dbi.CommonTypeMediumblob: "blob",
|
||||
dbi.CommonTypeMediumtext: "mediumtext",
|
||||
dbi.CommonTypeVarbinary: "varbinary",
|
||||
dbi.CommonTypeInt: "int",
|
||||
dbi.CommonTypeBit: "bit",
|
||||
dbi.CommonTypeSmallint: "smallint",
|
||||
dbi.CommonTypeTinyint: "tinyint",
|
||||
dbi.CommonTypeNumber: "decimal",
|
||||
dbi.CommonTypeBigint: "bigint",
|
||||
dbi.CommonTypeDatetime: "datetime",
|
||||
dbi.CommonTypeDate: "date",
|
||||
dbi.CommonTypeTime: "time",
|
||||
dbi.CommonTypeTimestamp: "timestamp",
|
||||
dbi.CommonTypeEnum: "enum",
|
||||
dbi.CommonTypeJSON: "json",
|
||||
}
|
||||
dataHelper = &DataHelper{}
|
||||
columnHelper = &ColumnHelper{}
|
||||
)
|
||||
|
||||
func GetDataHelper() *DataHelper {
|
||||
return dataHelper
|
||||
}
|
||||
|
||||
type DataHelper struct {
|
||||
}
|
||||
|
||||
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
|
||||
if numberRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeNumber
|
||||
}
|
||||
// 日期时间类型
|
||||
if datetimeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDateTime
|
||||
}
|
||||
// 日期类型
|
||||
if dateRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeDate
|
||||
}
|
||||
// 时间类型
|
||||
if timeRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeTime
|
||||
}
|
||||
// blob类型
|
||||
if blobRegexp.MatchString(dbColumnType) {
|
||||
return dbi.DataTypeBlob
|
||||
}
|
||||
return dbi.DataTypeString
|
||||
}
|
||||
|
||||
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要根据类型格式化
|
||||
str, ok := dbColumnValue.(string)
|
||||
if dataType == dbi.DataTypeDateTime && ok {
|
||||
// 尝试用时间格式解析
|
||||
res, err := time.Parse(time.DateTime, str)
|
||||
if err == nil {
|
||||
return str
|
||||
}
|
||||
res, _ = time.Parse(time.RFC3339, str)
|
||||
return res.Format(time.DateTime)
|
||||
}
|
||||
if dataType == dbi.DataTypeDate && ok {
|
||||
res, _ := time.Parse(time.DateOnly, str)
|
||||
return res.Format(time.DateOnly)
|
||||
}
|
||||
if dataType == dbi.DataTypeTime && ok {
|
||||
res, _ := time.Parse(time.TimeOnly, str)
|
||||
return res.Format(time.TimeOnly)
|
||||
}
|
||||
return anyx.ToString(dbColumnValue)
|
||||
}
|
||||
|
||||
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
|
||||
// 如果dataType是datetime而dbColumnValue是string类型,则需要转换为time.Time类型
|
||||
_, ok := dbColumnValue.(string)
|
||||
if ok {
|
||||
if dataType == dbi.DataTypeDateTime {
|
||||
res, _ := time.Parse(time.DateTime, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeDate {
|
||||
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
if dataType == dbi.DataTypeTime {
|
||||
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
|
||||
return res
|
||||
}
|
||||
}
|
||||
return dbColumnValue
|
||||
}
|
||||
|
||||
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
|
||||
if dbColumnValue == nil {
|
||||
return "NULL"
|
||||
}
|
||||
switch dataType {
|
||||
case dbi.DataTypeNumber:
|
||||
return fmt.Sprintf("%v", dbColumnValue)
|
||||
case dbi.DataTypeString:
|
||||
val := fmt.Sprintf("%v", dbColumnValue)
|
||||
// 转义单引号
|
||||
val = strings.Replace(val, `'`, `''`, -1)
|
||||
val = strings.Replace(val, `\''`, `\'`, -1)
|
||||
// 转义换行符
|
||||
val = strings.Replace(val, "\n", "\\n", -1)
|
||||
return fmt.Sprintf("'%s'", val)
|
||||
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
|
||||
// mysql时间类型无需格式化
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
case dbi.DataTypeBlob:
|
||||
return fmt.Sprintf("unhex('%s')", dbColumnValue)
|
||||
}
|
||||
return fmt.Sprintf("'%s'", dbColumnValue)
|
||||
}
|
||||
|
||||
type ColumnHelper struct {
|
||||
dbi.DefaultColumnHelper
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
|
||||
dataType := dialectColumn.DataType
|
||||
|
||||
t1 := commonColumnTypeMap[string(dataType)]
|
||||
commonColumnType := dbi.CommonTypeVarchar
|
||||
|
||||
if t1 != "" {
|
||||
commonColumnType = t1
|
||||
}
|
||||
|
||||
dialectColumn.DataType = commonColumnType
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) ToColumn(column *dbi.Column) {
|
||||
ctype := mysqlColumnTypeMap[column.DataType]
|
||||
if ctype == "" {
|
||||
column.DataType = "varchar"
|
||||
column.CharMaxLength = 1000
|
||||
} else {
|
||||
column.DataType = dbi.ColumnDataType(ctype)
|
||||
ch.FixColumn(column)
|
||||
}
|
||||
}
|
||||
|
||||
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
|
||||
// 如果是int整型,删除精度
|
||||
if strings.Contains(strings.ToLower(string(column.DataType)), "int") {
|
||||
column.NumScale = 0
|
||||
column.CharMaxLength = 0
|
||||
} else
|
||||
// 如果是text,删除长度
|
||||
if strings.Contains(strings.ToLower(string(column.DataType)), "text") {
|
||||
column.CharMaxLength = 0
|
||||
column.NumPrecision = 0
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"net"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
@@ -12,10 +13,15 @@ import (
|
||||
|
||||
func init() {
|
||||
meta := new(Meta)
|
||||
dbi.Register(dbi.DbTypeMysql, meta)
|
||||
dbi.Register(dbi.DbTypeMariadb, meta)
|
||||
dbi.Register(DbTypeMysql, meta)
|
||||
dbi.Register(DbTypeMariadb, meta)
|
||||
}
|
||||
|
||||
const (
|
||||
DbTypeMysql dbi.DbType = "mysql"
|
||||
DbTypeMariadb dbi.DbType = "mariadb"
|
||||
)
|
||||
|
||||
type Meta struct {
|
||||
}
|
||||
|
||||
@@ -31,7 +37,7 @@ func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
|
||||
})
|
||||
}
|
||||
// 设置dataSourceName -> 更多参数参考:https://github.com/go-sql-driver/mysql#dsn-data-source-name
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
|
||||
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?parseTime=true&timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
|
||||
if d.Params != "" {
|
||||
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
|
||||
}
|
||||
@@ -46,3 +52,17 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
|
||||
func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
|
||||
return &MysqlMetadata{dc: conn}
|
||||
}
|
||||
|
||||
func (mm *Meta) GetDbDataTypes() []*dbi.DbDataType {
|
||||
return collx.AsArray(
|
||||
UnsignedBigint, Bigint, Tinyint, Smallint, Int, Bit, Float, Double, Decimal,
|
||||
Varchar, Char, Text, Longtext, Mediumtext,
|
||||
Datetime, Date, Time, Timestamp,
|
||||
Enum, JSON, Set,
|
||||
Binary, Blob, Longblob, Mediumblob, Varbinary,
|
||||
)
|
||||
}
|
||||
|
||||
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
|
||||
return &commonTypeConverter{}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/errorx"
|
||||
"mayfly-go/pkg/logx"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"mayfly-go/pkg/utils/stringx"
|
||||
"strings"
|
||||
@@ -54,7 +53,7 @@ func (md *MysqlMetadata) GetDbNames() ([]string, error) {
|
||||
func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
dialect := md.dc.GetDialect()
|
||||
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
var res []map[string]any
|
||||
@@ -87,9 +86,8 @@ func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
|
||||
// 获取列元信息, 如列名等
|
||||
func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
|
||||
dialect := md.dc.GetDialect()
|
||||
columnHelper := dialect.GetColumnHelper()
|
||||
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
|
||||
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
|
||||
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
|
||||
}), ",")
|
||||
|
||||
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
|
||||
@@ -103,7 +101,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
column := dbi.Column{
|
||||
TableName: cast.ToString(re["tableName"]),
|
||||
ColumnName: cast.ToString(re["columnName"]),
|
||||
DataType: dbi.ColumnDataType(cast.ToString(re["dataType"])),
|
||||
DataType: cast.ToString(re["dataType"]),
|
||||
ColumnComment: cast.ToString(re["columnComment"]),
|
||||
Nullable: cast.ToString(re["nullable"]) == "YES",
|
||||
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
|
||||
@@ -114,7 +112,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
|
||||
NumScale: cast.ToInt(re["numScale"]),
|
||||
}
|
||||
|
||||
columnHelper.FixColumn(&column)
|
||||
md.dc.GetDbDataType(column.DataType).FixColumn(&column)
|
||||
columns = append(columns, column)
|
||||
}
|
||||
return columns, nil
|
||||
@@ -156,6 +154,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
IsUnique: cast.ToInt(re["isUnique"]) == 1,
|
||||
SeqInIndex: cast.ToInt(re["seqInIndex"]),
|
||||
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
|
||||
Extra: collx.Kvs(IndexSubPartKey, cast.ToInt(re[IndexSubPartKey])),
|
||||
})
|
||||
}
|
||||
// 把查询结果以索引名分组,索引字段以逗号连接
|
||||
@@ -179,34 +178,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
|
||||
|
||||
// 获取建表ddl
|
||||
func (md *MysqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
|
||||
// 1.获取表信息
|
||||
tbs, err := md.GetTables(tableName)
|
||||
tableInfo := &dbi.Table{}
|
||||
if err != nil || tbs == nil || len(tbs) <= 0 {
|
||||
logx.Errorf("获取表信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
tableInfo.TableName = tbs[0].TableName
|
||||
tableInfo.TableComment = tbs[0].TableComment
|
||||
|
||||
// 2.获取列信息
|
||||
columns, err := md.GetColumns(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取列信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
|
||||
dialect := md.dc.GetDialect()
|
||||
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
|
||||
// 3.获取索引信息
|
||||
indexs, err := md.GetTableIndex(tableName)
|
||||
if err != nil {
|
||||
logx.Errorf("获取索引信息失败, %s", tableName)
|
||||
return "", err
|
||||
}
|
||||
// 组装返回
|
||||
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
|
||||
return strings.Join(tableDDLArr, ";\n"), nil
|
||||
return dbi.GenTableDDL(md.dc.GetDialect(), md, tableName, dropBeforeCreate)
|
||||
}
|
||||
|
||||
func (md *MysqlMetadata) GetSchemas() ([]string, error) {
|
||||
|
||||
@@ -56,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
|
||||
dbInfo := svc.dbInfo()
|
||||
var mysqlBin *config.MysqlBin
|
||||
switch dbInfo.Type {
|
||||
case dbi.DbTypeMariadb:
|
||||
case DbTypeMariadb:
|
||||
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin)
|
||||
case dbi.DbTypeMysql:
|
||||
case DbTypeMysql:
|
||||
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
|
||||
default:
|
||||
panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type))
|
||||
|
||||
147
server/internal/db/dbm/mysql/sqlgen.go
Normal file
147
server/internal/db/dbm/mysql/sqlgen.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"mayfly-go/internal/db/dbm/dbi"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
|
||||
type SQLGenerator struct {
|
||||
Dialect dbi.Dialect
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
quoter := msg.Dialect.Quoter()
|
||||
|
||||
if dropBeforeCreate {
|
||||
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoter.Quote(table.TableName)))
|
||||
}
|
||||
|
||||
// 组装建表语句
|
||||
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoter.Quote(table.TableName))
|
||||
fields := make([]string, 0)
|
||||
pks := make([]string, 0)
|
||||
|
||||
for _, column := range columns {
|
||||
if column.IsPrimaryKey {
|
||||
pks = append(pks, column.ColumnName)
|
||||
}
|
||||
fields = append(fields, msg.genColumnBasicSql(quoter, column))
|
||||
}
|
||||
|
||||
// 建表ddl
|
||||
createSql += strings.Join(fields, ",\n")
|
||||
if len(pks) > 0 {
|
||||
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
|
||||
}
|
||||
createSql += "\n)"
|
||||
|
||||
// 表注释
|
||||
if table.TableComment != "" {
|
||||
createSql += fmt.Sprintf(" COMMENT '%s'", dbi.QuoteEscape(table.TableComment))
|
||||
}
|
||||
|
||||
sqlArr = append(sqlArr, createSql)
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
|
||||
sqlArr := make([]string, 0)
|
||||
quoter := msg.Dialect.Quoter()
|
||||
|
||||
for _, index := range indexs {
|
||||
unique := ""
|
||||
if index.IsUnique {
|
||||
unique = "unique"
|
||||
}
|
||||
// 取出列名,添加引号
|
||||
colNames := quoter.Quotes(strings.Split(index.ColumnName, ","))
|
||||
|
||||
// 暂时先处理单个索引的情况,多个涉及获取索引时的合并等,以及前端调整等,后续完善
|
||||
if subPart := cast.ToInt(index.Extra[IndexSubPartKey]); subPart > 0 && len(colNames) == 1 {
|
||||
colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart)
|
||||
}
|
||||
|
||||
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING %s"
|
||||
sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","), index.IndexType)
|
||||
comment := dbi.QuoteEscape(index.IndexComment)
|
||||
if comment != "" {
|
||||
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
|
||||
}
|
||||
sqlArr = append(sqlArr, sqlStr)
|
||||
}
|
||||
|
||||
return sqlArr
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
|
||||
if duplicateStrategy == dbi.DuplicateStrategyNone {
|
||||
return collx.AsArray(dbi.GenCommonInsert(msg.Dialect, DbTypeMysql, tableName, columns, values))
|
||||
}
|
||||
|
||||
prefix := "insert ignore into"
|
||||
if duplicateStrategy == dbi.DuplicateStrategyUpdate {
|
||||
prefix = "replace into"
|
||||
}
|
||||
|
||||
quote := msg.Dialect.Quoter().Quote
|
||||
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(msg.Dialect, DbTypeMysql, columns, values)
|
||||
|
||||
return collx.AsArray[string](fmt.Sprintf("%s %s %s VALUES \n%s", prefix, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")))
|
||||
}
|
||||
|
||||
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
|
||||
dataType := string(column.DataType)
|
||||
|
||||
incr := ""
|
||||
if column.IsIdentity {
|
||||
incr = " AUTO_INCREMENT"
|
||||
}
|
||||
|
||||
nullAble := ""
|
||||
if !column.Nullable {
|
||||
nullAble = " NOT NULL"
|
||||
}
|
||||
columnType := column.GetColumnType()
|
||||
if nullAble == "" && strings.Contains(columnType, "timestamp") {
|
||||
nullAble = " NULL"
|
||||
}
|
||||
|
||||
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
|
||||
if column.ColumnDefault != "" &&
|
||||
// 当默认值是字符串'NULL'时,不需要设置默认值
|
||||
column.ColumnDefault != "NULL" &&
|
||||
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
|
||||
!strings.Contains(column.ColumnDefault, "(") {
|
||||
// 哪些字段类型默认值需要加引号
|
||||
mark := false
|
||||
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
|
||||
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
|
||||
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
|
||||
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
|
||||
mark = false
|
||||
} else {
|
||||
mark = true
|
||||
}
|
||||
}
|
||||
if mark {
|
||||
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
|
||||
} else {
|
||||
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
|
||||
}
|
||||
}
|
||||
comment := ""
|
||||
if column.ColumnComment != "" {
|
||||
// 防止注释内含有特殊字符串导致sql出错
|
||||
commentStr := dbi.QuoteEscape(column.ColumnComment)
|
||||
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
|
||||
}
|
||||
|
||||
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", quoter.Quote(column.ColumnName), columnType, nullAble, incr, defVal, comment)
|
||||
return columnSql
|
||||
}
|
||||
108
server/internal/db/dbm/mysql/transfer.go
Normal file
108
server/internal/db/dbm/mysql/transfer.go
Normal file
@@ -0,0 +1,108 @@
|
||||
package mysql
|
||||
|
||||
import "mayfly-go/internal/db/dbm/dbi"
|
||||
|
||||
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
|
||||
|
||||
type commonTypeConverter struct {
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
|
||||
// 如果字符长度大于16383,则转为text类型
|
||||
if col.CharMaxLength > 16383 {
|
||||
col.CharMaxLength = 0
|
||||
return Text
|
||||
}
|
||||
return Varchar
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
|
||||
return Char
|
||||
}
|
||||
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
|
||||
col.CharMaxLength = 0
|
||||
col.NumPrecision = 0
|
||||
return Text
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
|
||||
col.CharMaxLength = 0
|
||||
col.NumPrecision = 0
|
||||
return Mediumtext
|
||||
}
|
||||
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
|
||||
col.CharMaxLength = 0
|
||||
col.NumPrecision = 0
|
||||
return Longtext
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bit
|
||||
}
|
||||
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
|
||||
return Tinyint
|
||||
}
|
||||
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
|
||||
return Smallint
|
||||
}
|
||||
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
|
||||
return Int
|
||||
}
|
||||
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
|
||||
return Bigint
|
||||
}
|
||||
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
|
||||
return Double
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
|
||||
return Decimal
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedBigint
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedInt
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedMediumint
|
||||
}
|
||||
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
|
||||
return UnsignedSmallint
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
|
||||
return Date
|
||||
}
|
||||
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
|
||||
return Time
|
||||
}
|
||||
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
|
||||
return Datetime
|
||||
}
|
||||
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
|
||||
return Timestamp
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Binary
|
||||
}
|
||||
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
|
||||
return Varbinary
|
||||
}
|
||||
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Mediumblob
|
||||
}
|
||||
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Blob
|
||||
}
|
||||
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
|
||||
return Longblob
|
||||
}
|
||||
|
||||
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
|
||||
return Enum
|
||||
}
|
||||
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
|
||||
return JSON
|
||||
}
|
||||
Reference in New Issue
Block a user