refactor: dbm

This commit is contained in:
meilin.huang
2024-12-08 13:04:23 +08:00
parent ebc89e056f
commit e56788af3e
152 changed files with 4273 additions and 3715 deletions

View File

@@ -0,0 +1,48 @@
package mysql
import (
"mayfly-go/internal/db/dbm/dbi"
)
const (
IndexSubPartKey = "subPart"
)
var (
Bit = dbi.NewDbDataType("bit", dbi.DTBit).WithCT(dbi.CTBit)
Tinyint = dbi.NewDbDataType("tinyint", dbi.DTInt8).WithCT(dbi.CTInt1).WithFixColumn(dbi.ClearNumScale)
Smallint = dbi.NewDbDataType("smallint", dbi.DTInt16).WithCT(dbi.CTInt2).WithFixColumn(dbi.ClearNumScale)
Mediumint = dbi.NewDbDataType("mediumint", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
Int = dbi.NewDbDataType("int", dbi.DTInt32).WithCT(dbi.CTInt4).WithFixColumn(dbi.ClearNumScale)
Bigint = dbi.NewDbDataType("bigint", dbi.DTInt64).WithCT(dbi.CTInt8).WithFixColumn(dbi.ClearNumScale)
UnsignedBigint = dbi.NewDbDataType("unsigned bigint", dbi.DTUint64).WithCT(dbi.CTUnsignedInt8).WithFixColumn(dbi.ClearNumScale)
UnsignedInt = dbi.NewDbDataType("unsigned int", dbi.DTUint64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
UnsignedSmallint = dbi.NewDbDataType("unsigned smallint", dbi.DTInt32).WithCT(dbi.CTUnsignedInt2).WithFixColumn(dbi.ClearNumScale)
UnsignedMediumint = dbi.NewDbDataType("unsigned mediumint", dbi.DTInt64).WithCT(dbi.CTUnsignedInt4).WithFixColumn(dbi.ClearNumScale)
Decimal = dbi.NewDbDataType("decimal", dbi.DTDecimal).WithCT(dbi.CTDecimal)
Double = dbi.NewDbDataType("double", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Float = dbi.NewDbDataType("float", dbi.DTNumeric).WithCT(dbi.CTNumeric)
Varchar = dbi.NewDbDataType("varchar", dbi.DTString).WithCT(dbi.CTVarchar)
Char = dbi.NewDbDataType("char", dbi.DTString).WithCT(dbi.CTChar)
Text = dbi.NewDbDataType("text", dbi.DTString).WithCT(dbi.CTText).WithFixColumn(dbi.ClearCharMaxLength)
Mediumtext = dbi.NewDbDataType("mediumtext", dbi.DTString).WithCT(dbi.CTMediumtext).WithFixColumn(dbi.ClearCharMaxLength)
Longtext = dbi.NewDbDataType("longtext", dbi.DTString).WithCT(dbi.CTLongtext).WithFixColumn(dbi.ClearCharMaxLength)
JSON = dbi.NewDbDataType("json", dbi.DTString).WithCT(dbi.CTJSON).WithFixColumn(dbi.ClearCharMaxLength)
Datetime = dbi.NewDbDataType("datetime", dbi.DTDateTime).WithCT(dbi.CTDateTime)
Date = dbi.NewDbDataType("date", dbi.DTDate).WithCT(dbi.CTDate)
Time = dbi.NewDbDataType("time", dbi.DTTime).WithCT(dbi.CTTime)
Timestamp = dbi.NewDbDataType("timestamp", dbi.DTDateTime).WithCT(dbi.CTTimestamp)
Enum = dbi.NewDbDataType("enum", dbi.DTString).WithCT(dbi.CTEnum)
Set = dbi.NewDbDataType("set", dbi.DTString).WithCT(dbi.CTVarchar)
Blob = dbi.NewDbDataType("blob", dbi.DTBytes).WithCT(dbi.CTBlob)
Mediumblob = dbi.NewDbDataType("mediumblob", dbi.DTBytes).WithCT(dbi.CTMediumblob)
Longblob = dbi.NewDbDataType("longblob", dbi.DTBytes).WithCT(dbi.CTLongblob)
Binary = dbi.NewDbDataType("binary", dbi.DTBytes).WithCT(dbi.CTBinary)
Varbinary = dbi.NewDbDataType("varbinary", dbi.DTBytes).WithCT(dbi.CTVarbinary)
)

View File

@@ -1,17 +1,20 @@
package mysql
import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/internal/db/dbm/sqlparser"
"mayfly-go/internal/db/dbm/sqlparser/mysql"
"mayfly-go/pkg/utils/collx"
"strings"
"time"
)
const Quoter = "`"
var (
mysqlQuoter = dbi.Quoter{
Prefix: '`',
Suffix: '`',
IsReserved: dbi.AlwaysReserve,
}
)
type MysqlDialect struct {
dbi.DefaultDialect
@@ -24,38 +27,6 @@ func (md *MysqlDialect) GetDbProgram() (dbi.DbProgram, error) {
return NewDbProgramMysql(md.dc), nil
}
func (md *MysqlDialect) BatchInsert(tx *sql.Tx, tableName string, columns []string, values [][]any, duplicateStrategy int) (int64, error) {
// 生成占位符字符串:如:(?,?)
// 重复字符串并用逗号连接
repeated := strings.Repeat("?,", len(columns))
// 去除最后一个逗号,占位符由括号包裹
placeholder := fmt.Sprintf("(%s)", strings.TrimSuffix(repeated, ","))
// 执行批量insert sqlmysql支持批量insert语法
// insert into table_name (column1, column2, ...) values (value1, value2, ...), (value1, value2, ...), ...
// 重复占位符字符串n遍
repeated = strings.Repeat(placeholder+",", len(values))
// 去除最后一个逗号
placeholder = strings.TrimSuffix(repeated, ",")
prefix := "insert into"
if duplicateStrategy == 1 {
prefix = "insert ignore into"
} else if duplicateStrategy == 2 {
prefix = "replace into"
}
sqlStr := fmt.Sprintf("%s %s (%s) values %s", prefix, md.QuoteIdentifier(tableName), strings.Join(columns, ","), placeholder)
// 执行批量insert sql
// 把二维数组转为一维数组
var args []any
for _, v := range values {
args = append(args, v...)
}
return md.dc.TxExec(tx, sqlStr, args...)
}
func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
tableName := copy.TableName
@@ -77,145 +48,14 @@ func (md *MysqlDialect) CopyTable(copy *dbi.DbCopyTable) error {
return err
}
// 获取建表ddl
func (md *MysqlDialect) GenerateTableDDL(columns []dbi.Column, tableInfo dbi.Table, dropBeforeCreate bool) []string {
sqlArr := make([]string, 0)
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", md.QuoteIdentifier(tableInfo.TableName)))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", md.QuoteIdentifier(tableInfo.TableName))
fields := make([]string, 0)
pks := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, column.ColumnName)
}
fields = append(fields, md.genColumnBasicSql(column))
}
// 建表ddl
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
// 表注释
if tableInfo.TableComment != "" {
createSql += fmt.Sprintf(" COMMENT '%s'", md.QuoteEscape(tableInfo.TableComment))
}
sqlArr = append(sqlArr, createSql)
return sqlArr
}
// 获取建索引ddl
func (md *MysqlDialect) GenerateIndexDDL(indexs []dbi.Index, tableInfo dbi.Table) []string {
sqlArr := make([]string, 0)
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
cols := strings.Split(index.ColumnName, ",")
colNames := make([]string, len(cols))
for i, name := range cols {
colNames[i] = md.QuoteIdentifier(name)
}
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING BTREE"
sqlStr := fmt.Sprintf(sqlTmp, md.QuoteIdentifier(tableInfo.TableName), unique, md.QuoteIdentifier(index.IndexName), strings.Join(colNames, ","))
comment := md.QuoteEscape(index.IndexComment)
if comment != "" {
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
}
sqlArr = append(sqlArr, sqlStr)
}
return sqlArr
}
func (md *MysqlDialect) genColumnBasicSql(column dbi.Column) string {
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " AUTO_INCREMENT"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
columnType := column.GetColumnType()
if nullAble == "" && strings.Contains(columnType, "timestamp") {
nullAble = " NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
if column.ColumnDefault != "" &&
// 当默认值是字符串'NULL'时,不需要设置默认值
column.ColumnDefault != "NULL" &&
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
!strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
comment := ""
if column.ColumnComment != "" {
// 防止注释内含有特殊字符串导致sql出错
commentStr := md.QuoteEscape(column.ColumnComment)
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
}
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", md.QuoteIdentifier(column.ColumnName), columnType, nullAble, incr, defVal, comment)
return columnSql
}
func (dx *MysqlDialect) QuoteIdentifier(name string) string {
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
}
return Quoter + strings.Replace(name, Quoter, Quoter+Quoter, -1) + Quoter
}
func (dx *MysqlDialect) RemoveQuote(name string) string {
return strings.ReplaceAll(name, Quoter, "")
}
func (md *MysqlDialect) QuoteLiteral(literal string) string {
literal = strings.ReplaceAll(literal, `\`, `\\`)
literal = strings.ReplaceAll(literal, `'`, `''`)
return "'" + literal + "'"
}
func (md *MysqlDialect) GetDataHelper() dbi.DataHelper {
return dataHelper
}
func (md *MysqlDialect) GetColumnHelper() dbi.ColumnHelper {
return columnHelper
func (md *MysqlDialect) Quoter() dbi.Quoter {
return mysqlQuoter
}
func (pd *MysqlDialect) GetSQLParser() sqlparser.SqlParser {
return new(mysql.MysqlParser)
}
func (md *MysqlDialect) GetSQLGenerator() dbi.SQLGenerator {
return &SQLGenerator{Dialect: md}
}

View File

@@ -1,218 +1 @@
package mysql
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/anyx"
"regexp"
"strings"
"time"
)
var (
// 数字类型
numberRegexp = regexp.MustCompile(`(?i)int|double|float|number|decimal|byte|bit`)
// 日期时间类型
datetimeRegexp = regexp.MustCompile(`(?i)datetime|timestamp`)
// 日期类型
dateRegexp = regexp.MustCompile(`(?i)date`)
// 时间类型
timeRegexp = regexp.MustCompile(`(?i)time`)
blobRegexp = regexp.MustCompile(`(?i)blob`)
// mysql数据类型 映射 公共数据类型
commonColumnTypeMap = map[string]dbi.ColumnDataType{
"bigint": dbi.CommonTypeBigint,
"binary": dbi.CommonTypeBinary,
"blob": dbi.CommonTypeBlob,
"char": dbi.CommonTypeChar,
"datetime": dbi.CommonTypeDatetime,
"date": dbi.CommonTypeDate,
"decimal": dbi.CommonTypeNumber,
"double": dbi.CommonTypeNumber,
"enum": dbi.CommonTypeEnum,
"float": dbi.CommonTypeNumber,
"int": dbi.CommonTypeInt,
"json": dbi.CommonTypeJSON,
"longblob": dbi.CommonTypeLongblob,
"longtext": dbi.CommonTypeLongtext,
"mediumblob": dbi.CommonTypeBlob,
"mediumtext": dbi.CommonTypeMediumtext,
"bit": dbi.CommonTypeBit,
"set": dbi.CommonTypeVarchar,
"smallint": dbi.CommonTypeSmallint,
"text": dbi.CommonTypeText,
"time": dbi.CommonTypeTime,
"timestamp": dbi.CommonTypeTimestamp,
"tinyint": dbi.CommonTypeTinyint,
"varbinary": dbi.CommonTypeVarbinary,
"varchar": dbi.CommonTypeVarchar,
}
// 公共数据类型 映射 mysql数据类型
mysqlColumnTypeMap = map[dbi.ColumnDataType]string{
dbi.CommonTypeVarchar: "varchar",
dbi.CommonTypeChar: "char",
dbi.CommonTypeText: "text",
dbi.CommonTypeBlob: "blob",
dbi.CommonTypeLongblob: "longblob",
dbi.CommonTypeLongtext: "longtext",
dbi.CommonTypeBinary: "binary",
dbi.CommonTypeMediumblob: "blob",
dbi.CommonTypeMediumtext: "mediumtext",
dbi.CommonTypeVarbinary: "varbinary",
dbi.CommonTypeInt: "int",
dbi.CommonTypeBit: "bit",
dbi.CommonTypeSmallint: "smallint",
dbi.CommonTypeTinyint: "tinyint",
dbi.CommonTypeNumber: "decimal",
dbi.CommonTypeBigint: "bigint",
dbi.CommonTypeDatetime: "datetime",
dbi.CommonTypeDate: "date",
dbi.CommonTypeTime: "time",
dbi.CommonTypeTimestamp: "timestamp",
dbi.CommonTypeEnum: "enum",
dbi.CommonTypeJSON: "json",
}
dataHelper = &DataHelper{}
columnHelper = &ColumnHelper{}
)
func GetDataHelper() *DataHelper {
return dataHelper
}
type DataHelper struct {
}
func (dc *DataHelper) GetDataType(dbColumnType string) dbi.DataType {
if numberRegexp.MatchString(dbColumnType) {
return dbi.DataTypeNumber
}
// 日期时间类型
if datetimeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDateTime
}
// 日期类型
if dateRegexp.MatchString(dbColumnType) {
return dbi.DataTypeDate
}
// 时间类型
if timeRegexp.MatchString(dbColumnType) {
return dbi.DataTypeTime
}
// blob类型
if blobRegexp.MatchString(dbColumnType) {
return dbi.DataTypeBlob
}
return dbi.DataTypeString
}
func (dc *DataHelper) FormatData(dbColumnValue any, dataType dbi.DataType) string {
// 如果dataType是datetime而dbColumnValue是string类型则需要根据类型格式化
str, ok := dbColumnValue.(string)
if dataType == dbi.DataTypeDateTime && ok {
// 尝试用时间格式解析
res, err := time.Parse(time.DateTime, str)
if err == nil {
return str
}
res, _ = time.Parse(time.RFC3339, str)
return res.Format(time.DateTime)
}
if dataType == dbi.DataTypeDate && ok {
res, _ := time.Parse(time.DateOnly, str)
return res.Format(time.DateOnly)
}
if dataType == dbi.DataTypeTime && ok {
res, _ := time.Parse(time.TimeOnly, str)
return res.Format(time.TimeOnly)
}
return anyx.ToString(dbColumnValue)
}
func (dc *DataHelper) ParseData(dbColumnValue any, dataType dbi.DataType) any {
// 如果dataType是datetime而dbColumnValue是string类型则需要转换为time.Time类型
_, ok := dbColumnValue.(string)
if ok {
if dataType == dbi.DataTypeDateTime {
res, _ := time.Parse(time.DateTime, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeDate {
res, _ := time.Parse(time.DateOnly, anyx.ToString(dbColumnValue))
return res
}
if dataType == dbi.DataTypeTime {
res, _ := time.Parse(time.TimeOnly, anyx.ToString(dbColumnValue))
return res
}
}
return dbColumnValue
}
func (dc *DataHelper) WrapValue(dbColumnValue any, dataType dbi.DataType) string {
if dbColumnValue == nil {
return "NULL"
}
switch dataType {
case dbi.DataTypeNumber:
return fmt.Sprintf("%v", dbColumnValue)
case dbi.DataTypeString:
val := fmt.Sprintf("%v", dbColumnValue)
// 转义单引号
val = strings.Replace(val, `'`, `''`, -1)
val = strings.Replace(val, `\''`, `\'`, -1)
// 转义换行符
val = strings.Replace(val, "\n", "\\n", -1)
return fmt.Sprintf("'%s'", val)
case dbi.DataTypeDate, dbi.DataTypeDateTime, dbi.DataTypeTime:
// mysql时间类型无需格式化
return fmt.Sprintf("'%s'", dbColumnValue)
case dbi.DataTypeBlob:
return fmt.Sprintf("unhex('%s')", dbColumnValue)
}
return fmt.Sprintf("'%s'", dbColumnValue)
}
type ColumnHelper struct {
dbi.DefaultColumnHelper
}
func (ch *ColumnHelper) ToCommonColumn(dialectColumn *dbi.Column) {
dataType := dialectColumn.DataType
t1 := commonColumnTypeMap[string(dataType)]
commonColumnType := dbi.CommonTypeVarchar
if t1 != "" {
commonColumnType = t1
}
dialectColumn.DataType = commonColumnType
}
func (ch *ColumnHelper) ToColumn(column *dbi.Column) {
ctype := mysqlColumnTypeMap[column.DataType]
if ctype == "" {
column.DataType = "varchar"
column.CharMaxLength = 1000
} else {
column.DataType = dbi.ColumnDataType(ctype)
ch.FixColumn(column)
}
}
func (ch *ColumnHelper) FixColumn(column *dbi.Column) {
// 如果是int整型删除精度
if strings.Contains(strings.ToLower(string(column.DataType)), "int") {
column.NumScale = 0
column.CharMaxLength = 0
} else
// 如果是text删除长度
if strings.Contains(strings.ToLower(string(column.DataType)), "text") {
column.CharMaxLength = 0
column.NumPrecision = 0
}
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"net"
"github.com/go-sql-driver/mysql"
@@ -12,10 +13,15 @@ import (
func init() {
meta := new(Meta)
dbi.Register(dbi.DbTypeMysql, meta)
dbi.Register(dbi.DbTypeMariadb, meta)
dbi.Register(DbTypeMysql, meta)
dbi.Register(DbTypeMariadb, meta)
}
const (
DbTypeMysql dbi.DbType = "mysql"
DbTypeMariadb dbi.DbType = "mariadb"
)
type Meta struct {
}
@@ -31,7 +37,7 @@ func (mm *Meta) GetSqlDb(d *dbi.DbInfo) (*sql.DB, error) {
})
}
// 设置dataSourceName -> 更多参数参考https://github.com/go-sql-driver/mysql#dsn-data-source-name
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
dsn := fmt.Sprintf("%s:%s@%s(%s:%d)/%s?parseTime=true&timeout=8s", d.Username, d.Password, d.Network, d.Host, d.Port, d.Database)
if d.Params != "" {
dsn = fmt.Sprintf("%s&%s", dsn, d.Params)
}
@@ -46,3 +52,17 @@ func (mm *Meta) GetDialect(conn *dbi.DbConn) dbi.Dialect {
func (mm *Meta) GetMetadata(conn *dbi.DbConn) dbi.Metadata {
return &MysqlMetadata{dc: conn}
}
func (mm *Meta) GetDbDataTypes() []*dbi.DbDataType {
return collx.AsArray(
UnsignedBigint, Bigint, Tinyint, Smallint, Int, Bit, Float, Double, Decimal,
Varchar, Char, Text, Longtext, Mediumtext,
Datetime, Date, Time, Timestamp,
Enum, JSON, Set,
Binary, Blob, Longblob, Mediumblob, Varbinary,
)
}
func (mm *Meta) GetCommonTypeConverter() dbi.CommonTypeConverter {
return &commonTypeConverter{}
}

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/errorx"
"mayfly-go/pkg/logx"
"mayfly-go/pkg/utils/collx"
"mayfly-go/pkg/utils/stringx"
"strings"
@@ -54,7 +53,7 @@ func (md *MysqlMetadata) GetDbNames() ([]string, error) {
func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
dialect := md.dc.GetDialect()
names := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
var res []map[string]any
@@ -87,9 +86,8 @@ func (md *MysqlMetadata) GetTables(tableNames ...string) ([]dbi.Table, error) {
// 获取列元信息, 如列名等
func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error) {
dialect := md.dc.GetDialect()
columnHelper := dialect.GetColumnHelper()
tableName := strings.Join(collx.ArrayMap[string, string](tableNames, func(val string) string {
return fmt.Sprintf("'%s'", dialect.RemoveQuote(val))
return fmt.Sprintf("'%s'", dialect.Quoter().Trim(val))
}), ",")
_, res, err := md.dc.Query(fmt.Sprintf(dbi.GetLocalSql(MYSQL_META_FILE, MYSQL_COLUMN_MA_KEY), tableName))
@@ -103,7 +101,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
column := dbi.Column{
TableName: cast.ToString(re["tableName"]),
ColumnName: cast.ToString(re["columnName"]),
DataType: dbi.ColumnDataType(cast.ToString(re["dataType"])),
DataType: cast.ToString(re["dataType"]),
ColumnComment: cast.ToString(re["columnComment"]),
Nullable: cast.ToString(re["nullable"]) == "YES",
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
@@ -114,7 +112,7 @@ func (md *MysqlMetadata) GetColumns(tableNames ...string) ([]dbi.Column, error)
NumScale: cast.ToInt(re["numScale"]),
}
columnHelper.FixColumn(&column)
md.dc.GetDbDataType(column.DataType).FixColumn(&column)
columns = append(columns, column)
}
return columns, nil
@@ -156,6 +154,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
IsUnique: cast.ToInt(re["isUnique"]) == 1,
SeqInIndex: cast.ToInt(re["seqInIndex"]),
IsPrimaryKey: cast.ToInt(re["isPrimaryKey"]) == 1,
Extra: collx.Kvs(IndexSubPartKey, cast.ToInt(re[IndexSubPartKey])),
})
}
// 把查询结果以索引名分组,索引字段以逗号连接
@@ -179,34 +178,7 @@ func (md *MysqlMetadata) GetTableIndex(tableName string) ([]dbi.Index, error) {
// 获取建表ddl
func (md *MysqlMetadata) GetTableDDL(tableName string, dropBeforeCreate bool) (string, error) {
// 1.获取表信息
tbs, err := md.GetTables(tableName)
tableInfo := &dbi.Table{}
if err != nil || tbs == nil || len(tbs) <= 0 {
logx.Errorf("获取表信息失败, %s", tableName)
return "", err
}
tableInfo.TableName = tbs[0].TableName
tableInfo.TableComment = tbs[0].TableComment
// 2.获取列信息
columns, err := md.GetColumns(tableName)
if err != nil {
logx.Errorf("获取列信息失败, %s", tableName)
return "", err
}
dialect := md.dc.GetDialect()
tableDDLArr := dialect.GenerateTableDDL(columns, *tableInfo, dropBeforeCreate)
// 3.获取索引信息
indexs, err := md.GetTableIndex(tableName)
if err != nil {
logx.Errorf("获取索引信息失败, %s", tableName)
return "", err
}
// 组装返回
tableDDLArr = append(tableDDLArr, dialect.GenerateIndexDDL(indexs, *tableInfo)...)
return strings.Join(tableDDLArr, ";\n"), nil
return dbi.GenTableDDL(md.dc.GetDialect(), md, tableName, dropBeforeCreate)
}
func (md *MysqlMetadata) GetSchemas() ([]string, error) {

View File

@@ -56,9 +56,9 @@ func (svc *DbProgramMysql) getMysqlBin() *config.MysqlBin {
dbInfo := svc.dbInfo()
var mysqlBin *config.MysqlBin
switch dbInfo.Type {
case dbi.DbTypeMariadb:
case DbTypeMariadb:
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMariadbBin)
case dbi.DbTypeMysql:
case DbTypeMysql:
mysqlBin = config.GetMysqlBin(config.ConfigKeyDbMysqlBin)
default:
panic(fmt.Sprintf("不兼容 MySQL 的数据库类型: %v", dbInfo.Type))

View File

@@ -0,0 +1,147 @@
package mysql
import (
"fmt"
"mayfly-go/internal/db/dbm/dbi"
"mayfly-go/pkg/utils/collx"
"strings"
"github.com/may-fly/cast"
)
type SQLGenerator struct {
Dialect dbi.Dialect
}
func (msg *SQLGenerator) GenTableDDL(table dbi.Table, columns []dbi.Column, dropBeforeCreate bool) []string {
sqlArr := make([]string, 0)
quoter := msg.Dialect.Quoter()
if dropBeforeCreate {
sqlArr = append(sqlArr, fmt.Sprintf("DROP TABLE IF EXISTS %s", quoter.Quote(table.TableName)))
}
// 组装建表语句
createSql := fmt.Sprintf("CREATE TABLE %s (\n", quoter.Quote(table.TableName))
fields := make([]string, 0)
pks := make([]string, 0)
for _, column := range columns {
if column.IsPrimaryKey {
pks = append(pks, column.ColumnName)
}
fields = append(fields, msg.genColumnBasicSql(quoter, column))
}
// 建表ddl
createSql += strings.Join(fields, ",\n")
if len(pks) > 0 {
createSql += fmt.Sprintf(", \nPRIMARY KEY (%s)", strings.Join(pks, ","))
}
createSql += "\n)"
// 表注释
if table.TableComment != "" {
createSql += fmt.Sprintf(" COMMENT '%s'", dbi.QuoteEscape(table.TableComment))
}
sqlArr = append(sqlArr, createSql)
return sqlArr
}
func (msg *SQLGenerator) GenIndexDDL(table dbi.Table, indexs []dbi.Index) []string {
sqlArr := make([]string, 0)
quoter := msg.Dialect.Quoter()
for _, index := range indexs {
unique := ""
if index.IsUnique {
unique = "unique"
}
// 取出列名,添加引号
colNames := quoter.Quotes(strings.Split(index.ColumnName, ","))
// 暂时先处理单个索引的情况,多个涉及获取索引时的合并等,以及前端调整等,后续完善
if subPart := cast.ToInt(index.Extra[IndexSubPartKey]); subPart > 0 && len(colNames) == 1 {
colNames[0] = fmt.Sprintf("%s(%d)", colNames[0], subPart)
}
sqlTmp := "ALTER TABLE %s ADD %s INDEX %s(%s) USING %s"
sqlStr := fmt.Sprintf(sqlTmp, quoter.Quote(table.TableName), unique, quoter.Quote(index.IndexName), strings.Join(colNames, ","), index.IndexType)
comment := dbi.QuoteEscape(index.IndexComment)
if comment != "" {
sqlStr += fmt.Sprintf(" COMMENT '%s'", comment)
}
sqlArr = append(sqlArr, sqlStr)
}
return sqlArr
}
func (msg *SQLGenerator) GenInsert(tableName string, columns []dbi.Column, values [][]any, duplicateStrategy int) []string {
if duplicateStrategy == dbi.DuplicateStrategyNone {
return collx.AsArray(dbi.GenCommonInsert(msg.Dialect, DbTypeMysql, tableName, columns, values))
}
prefix := "insert ignore into"
if duplicateStrategy == dbi.DuplicateStrategyUpdate {
prefix = "replace into"
}
quote := msg.Dialect.Quoter().Quote
columnStr, valuesStrs := dbi.GenInsertSqlColumnAndValues(msg.Dialect, DbTypeMysql, columns, values)
return collx.AsArray[string](fmt.Sprintf("%s %s %s VALUES \n%s", prefix, quote(tableName), columnStr, strings.Join(valuesStrs, ",\n")))
}
func (msg *SQLGenerator) genColumnBasicSql(quoter dbi.Quoter, column dbi.Column) string {
dataType := string(column.DataType)
incr := ""
if column.IsIdentity {
incr = " AUTO_INCREMENT"
}
nullAble := ""
if !column.Nullable {
nullAble = " NOT NULL"
}
columnType := column.GetColumnType()
if nullAble == "" && strings.Contains(columnType, "timestamp") {
nullAble = " NULL"
}
defVal := "" // 默认值需要判断引号,如函数是不需要引号的
if column.ColumnDefault != "" &&
// 当默认值是字符串'NULL'时,不需要设置默认值
column.ColumnDefault != "NULL" &&
// 为了防止跨源函数不支持 当默认值是函数时,不需要设置默认值
!strings.Contains(column.ColumnDefault, "(") {
// 哪些字段类型默认值需要加引号
mark := false
if collx.ArrayAnyMatches([]string{"char", "text", "date", "time", "lob"}, strings.ToLower(dataType)) {
// 当数据类型是日期时间,默认值是日期时间函数时,默认值不需要引号
if collx.ArrayAnyMatches([]string{"date", "time"}, strings.ToLower(dataType)) &&
collx.ArrayAnyMatches([]string{"DATE", "TIME"}, strings.ToUpper(column.ColumnDefault)) {
mark = false
} else {
mark = true
}
}
if mark {
defVal = fmt.Sprintf(" DEFAULT '%s'", column.ColumnDefault)
} else {
defVal = fmt.Sprintf(" DEFAULT %s", column.ColumnDefault)
}
}
comment := ""
if column.ColumnComment != "" {
// 防止注释内含有特殊字符串导致sql出错
commentStr := dbi.QuoteEscape(column.ColumnComment)
comment = fmt.Sprintf(" COMMENT '%s'", commentStr)
}
columnSql := fmt.Sprintf(" %s %s%s%s%s%s", quoter.Quote(column.ColumnName), columnType, nullAble, incr, defVal, comment)
return columnSql
}

View File

@@ -0,0 +1,108 @@
package mysql
import "mayfly-go/internal/db/dbm/dbi"
var _ dbi.CommonTypeConverter = (*commonTypeConverter)(nil)
type commonTypeConverter struct {
}
func (c *commonTypeConverter) Varchar(col *dbi.Column) *dbi.DbDataType {
// 如果字符长度大于16383则转为text类型
if col.CharMaxLength > 16383 {
col.CharMaxLength = 0
return Text
}
return Varchar
}
func (c *commonTypeConverter) Char(col *dbi.Column) *dbi.DbDataType {
return Char
}
func (c *commonTypeConverter) Text(col *dbi.Column) *dbi.DbDataType {
col.CharMaxLength = 0
col.NumPrecision = 0
return Text
}
func (c *commonTypeConverter) Mediumtext(col *dbi.Column) *dbi.DbDataType {
col.CharMaxLength = 0
col.NumPrecision = 0
return Mediumtext
}
func (c *commonTypeConverter) Longtext(col *dbi.Column) *dbi.DbDataType {
col.CharMaxLength = 0
col.NumPrecision = 0
return Longtext
}
func (c *commonTypeConverter) Bit(col *dbi.Column) *dbi.DbDataType {
return Bit
}
func (c *commonTypeConverter) Int1(col *dbi.Column) *dbi.DbDataType {
return Tinyint
}
func (c *commonTypeConverter) Int2(col *dbi.Column) *dbi.DbDataType {
return Smallint
}
func (c *commonTypeConverter) Int4(col *dbi.Column) *dbi.DbDataType {
return Int
}
func (c *commonTypeConverter) Int8(col *dbi.Column) *dbi.DbDataType {
return Bigint
}
func (c *commonTypeConverter) Numeric(col *dbi.Column) *dbi.DbDataType {
return Double
}
func (c *commonTypeConverter) Decimal(col *dbi.Column) *dbi.DbDataType {
return Decimal
}
func (c *commonTypeConverter) UnsignedInt8(col *dbi.Column) *dbi.DbDataType {
return UnsignedBigint
}
func (c *commonTypeConverter) UnsignedInt4(col *dbi.Column) *dbi.DbDataType {
return UnsignedInt
}
func (c *commonTypeConverter) UnsignedInt2(col *dbi.Column) *dbi.DbDataType {
return UnsignedMediumint
}
func (c *commonTypeConverter) UnsignedInt1(col *dbi.Column) *dbi.DbDataType {
return UnsignedSmallint
}
func (c *commonTypeConverter) Date(col *dbi.Column) *dbi.DbDataType {
return Date
}
func (c *commonTypeConverter) Time(col *dbi.Column) *dbi.DbDataType {
return Time
}
func (c *commonTypeConverter) Datetime(col *dbi.Column) *dbi.DbDataType {
return Datetime
}
func (c *commonTypeConverter) Timestamp(col *dbi.Column) *dbi.DbDataType {
return Timestamp
}
func (c *commonTypeConverter) Binary(col *dbi.Column) *dbi.DbDataType {
return Binary
}
func (c *commonTypeConverter) Varbinary(col *dbi.Column) *dbi.DbDataType {
return Varbinary
}
func (c *commonTypeConverter) Mediumblob(col *dbi.Column) *dbi.DbDataType {
return Mediumblob
}
func (c *commonTypeConverter) Blob(col *dbi.Column) *dbi.DbDataType {
return Blob
}
func (c *commonTypeConverter) Longblob(col *dbi.Column) *dbi.DbDataType {
return Longblob
}
func (c *commonTypeConverter) Enum(col *dbi.Column) *dbi.DbDataType {
return Enum
}
func (c *commonTypeConverter) JSON(col *dbi.Column) *dbi.DbDataType {
return JSON
}