mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-03 16:00:25 +08:00
refactor: dbm
This commit is contained in:
594
server/internal/db/dbm/dbi/column.go
Normal file
594
server/internal/db/dbm/dbi/column.go
Normal file
@@ -0,0 +1,594 @@
|
||||
package dbi
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"mayfly-go/pkg/utils/collx"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/may-fly/cast"
|
||||
)
|
||||
|
||||
var (
|
||||
dbDataTypes = make(map[DbType]map[string]*DbDataType) // 列类型
|
||||
)
|
||||
|
||||
// registerColumnDbDataTypes 注册数据库对应的数据类型
|
||||
func registerColumnDbDataTypes(dbType DbType, cts ...*DbDataType) {
|
||||
dbDataTypes[dbType] = collx.ArrayToMap(cts, func(ct *DbDataType) string {
|
||||
return strings.ToLower(string(ct.Name))
|
||||
})
|
||||
}
|
||||
|
||||
func GetDbDataType(dbType DbType, databaseColumnType string) *DbDataType {
|
||||
return cmp.Or(dbDataTypes[dbType][strings.ToLower(databaseColumnType)], DefaultDbDataType)
|
||||
}
|
||||
|
||||
var DefaultDbDataType = NewDbDataType("string", DTString).WithCT(CTVarchar)
|
||||
|
||||
// 表的列信息
|
||||
type Column struct {
|
||||
TableName string `json:"tableName"` // 表名
|
||||
ColumnName string `json:"columnName"` // 列名
|
||||
DataType string `json:"dataType"` // 数据类型
|
||||
ColumnComment string `json:"columnComment"` // 列备注
|
||||
IsPrimaryKey bool `json:"isPrimaryKey"` // 是否为主键
|
||||
IsIdentity bool `json:"isIdentity"` // 是否自增
|
||||
ColumnDefault string `json:"columnDefault"` // 默认值
|
||||
Nullable bool `json:"nullable"` // 是否可为null
|
||||
CharMaxLength int `json:"charMaxLength"` // 字符最大长度
|
||||
NumPrecision int `json:"numPrecision"` // 精度(总数字位数)
|
||||
NumScale int `json:"numScale"` // 小数点位数
|
||||
Extra collx.M `json:"extra"` // 其他额外信息
|
||||
}
|
||||
|
||||
// GetColumnType 获取完整的列类型,拼接数据类型与长度等。如varchar(2000),decimal(20,2)
|
||||
func (c *Column) GetColumnType() string {
|
||||
if c.CharMaxLength > 0 {
|
||||
return fmt.Sprintf("%s(%d)", c.DataType, c.CharMaxLength)
|
||||
}
|
||||
if c.NumPrecision > 0 {
|
||||
if c.NumScale > 0 {
|
||||
return fmt.Sprintf("%s(%d,%d)", c.DataType, c.NumPrecision, c.NumScale)
|
||||
} else {
|
||||
return fmt.Sprintf("%s(%d)", c.DataType, c.NumPrecision)
|
||||
}
|
||||
}
|
||||
|
||||
return string(c.DataType)
|
||||
}
|
||||
|
||||
// 数据库对应的数据类型
|
||||
type DbDataType struct {
|
||||
Name string // 类型名
|
||||
|
||||
DataType *DataType // 数据类型
|
||||
|
||||
fixColumnFunc func(column *Column) // 修复字段长度、精度等, 如mysql text会返回长度,需要将其置为0等
|
||||
|
||||
/** 以下为异构数据迁移同步使用,可不赋值,无值则不支持迁移同步 */
|
||||
|
||||
CommonType CommonDbDataType // 对应的公共类型
|
||||
}
|
||||
|
||||
// WithFixColumn 修复列信息函数,用于修复字段长度、精度等
|
||||
func (ct *DbDataType) WithFixColumn(fixColumnFunc func(column *Column)) *DbDataType {
|
||||
ct.fixColumnFunc = fixColumnFunc
|
||||
return ct
|
||||
}
|
||||
|
||||
// WithCT 对应的公共类型,主要用于异构数据库迁移同步时进行类型转换使用
|
||||
func (ct *DbDataType) WithCT(cct CommonDbDataType) *DbDataType {
|
||||
ct.CommonType = cct
|
||||
return ct
|
||||
}
|
||||
|
||||
// FixColumn 使用修复列信息函数进行列信息修复
|
||||
func (ct *DbDataType) FixColumn(column *Column) {
|
||||
if ct.fixColumnFunc != nil {
|
||||
ct.fixColumnFunc(column)
|
||||
}
|
||||
}
|
||||
|
||||
func NewDbDataType(name string, dataType *DataType) *DbDataType {
|
||||
return &DbDataType{
|
||||
Name: name,
|
||||
DataType: dataType,
|
||||
}
|
||||
}
|
||||
|
||||
func ClearCharMaxLength(column *Column) {
|
||||
column.CharMaxLength = 0
|
||||
column.NumPrecision = 0
|
||||
}
|
||||
|
||||
func ClearNumScale(column *Column) {
|
||||
column.NumScale = 0
|
||||
column.CharMaxLength = 0
|
||||
}
|
||||
|
||||
// DataType 数据类型, 对应于go类型,如int int64等。可自定义其他类型
|
||||
type DataType struct {
|
||||
Name string // 类型名
|
||||
|
||||
Valuer func() Valuer // 获取值对应的处理者,用于sql的scan、解析value等
|
||||
|
||||
SQLValue func(val any) string // 转换为sql字符串值,用于insert等SQL语句的值转换
|
||||
}
|
||||
|
||||
// Copy 拷贝一个同类型的datatype,主要方便用于定制化修改Valuer或ToString
|
||||
func (dt *DataType) Copy() *DataType {
|
||||
return &DataType{
|
||||
Name: dt.Name,
|
||||
Valuer: dt.Valuer,
|
||||
SQLValue: dt.SQLValue,
|
||||
}
|
||||
}
|
||||
|
||||
func (dt *DataType) WithValuer(valuerFunc func() Valuer) *DataType {
|
||||
dt.Valuer = valuerFunc
|
||||
return dt
|
||||
}
|
||||
|
||||
func (dt *DataType) WithSQLValue(sqlvalueFunc func(val any) string) *DataType {
|
||||
dt.SQLValue = sqlvalueFunc
|
||||
return dt
|
||||
}
|
||||
|
||||
const NULL = "NULL"
|
||||
|
||||
// SQLValueDefault 默认使用fmt转string
|
||||
func SQLValueDefault(val any) string {
|
||||
if val == nil {
|
||||
return NULL
|
||||
}
|
||||
return fmt.Sprintf("'%v'", val)
|
||||
}
|
||||
|
||||
// SQLValueNumeric 数字类型转string
|
||||
func SQLValueNumeric(val any) string {
|
||||
if val == nil {
|
||||
return NULL
|
||||
}
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
|
||||
func SQLValueString(val any) string {
|
||||
if val == nil {
|
||||
return NULL
|
||||
}
|
||||
|
||||
strVal, ok := val.(string)
|
||||
if !ok {
|
||||
return fmt.Sprintf("%v", val)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("'%s'", strings.ReplaceAll(strings.ReplaceAll(strVal, "'", "''"), `\`, `\\`))
|
||||
}
|
||||
|
||||
var (
|
||||
DTBit = &DataType{
|
||||
Name: "bit",
|
||||
Valuer: ValuerBit,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTByte = &DataType{
|
||||
Name: "uint8",
|
||||
Valuer: ValuerByte,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt8 = &DataType{
|
||||
Name: "int8",
|
||||
Valuer: ValuerInt16,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt16 = &DataType{
|
||||
Name: "int16",
|
||||
Valuer: ValuerInt16,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt32 = &DataType{
|
||||
Name: "int32",
|
||||
Valuer: ValuerInt32,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTInt64 = &DataType{
|
||||
Name: "int64",
|
||||
Valuer: ValuerInt64,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
// 所有无符号类型,都使用int64存储
|
||||
DTUint64 = &DataType{
|
||||
Name: "uint64",
|
||||
Valuer: ValuerUint64,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTNumeric = &DataType{
|
||||
Name: "numeric",
|
||||
Valuer: ValuerFloat64,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTDecimal = &DataType{
|
||||
Name: "decimal",
|
||||
Valuer: ValuerString,
|
||||
SQLValue: SQLValueNumeric,
|
||||
}
|
||||
|
||||
DTString = &DataType{
|
||||
Name: "string",
|
||||
Valuer: ValuerString,
|
||||
SQLValue: SQLValueString,
|
||||
}
|
||||
|
||||
DTDate = &DataType{
|
||||
Name: "date",
|
||||
Valuer: ValuerDate,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
|
||||
DTTime = &DataType{
|
||||
Name: "time",
|
||||
Valuer: ValuerTime,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
|
||||
DTDateTime = &DataType{
|
||||
Name: "datetime",
|
||||
Valuer: ValuerDatetime,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
|
||||
DTBytes = &DataType{
|
||||
Name: "bytes",
|
||||
Valuer: ValuerBytes,
|
||||
SQLValue: SQLValueDefault,
|
||||
}
|
||||
)
|
||||
|
||||
// Valuer 获取值对应的处理者,用于sql row scan、解析value等
|
||||
type Valuer interface {
|
||||
|
||||
// NewValuePtr 新建值对应的指针,用于sql的row scan
|
||||
NewValuePtr() any
|
||||
|
||||
// Value 获取对应的值(人类可阅读的值),不可原样返回ValuePtr指针类型,需取出具体的值
|
||||
Value() any
|
||||
}
|
||||
|
||||
type DefaultValuer[T any] struct {
|
||||
ValuePtr *T
|
||||
}
|
||||
|
||||
func (s *DefaultValuer[T]) NewValuePtr() any {
|
||||
var t T
|
||||
s.ValuePtr = &t
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
// Valuer工厂函数
|
||||
|
||||
func ValuerString() Valuer {
|
||||
return &stringValuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullString]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerInt64() Valuer {
|
||||
return &int64Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullInt64]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerUint64() Valuer {
|
||||
return &uint64Valuer{
|
||||
DefaultValuer: new(DefaultValuer[[]byte]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerInt32() Valuer {
|
||||
return &int32Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullInt32]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerInt16() Valuer {
|
||||
return &int16Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullInt16]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerByte() Valuer {
|
||||
return &byteValuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullByte]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerBit() Valuer {
|
||||
return &bitValuer{
|
||||
DefaultValuer: new(DefaultValuer[[]byte]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerFloat64() Valuer {
|
||||
return &float64Valuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.NullFloat64]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerDatetime() Valuer {
|
||||
return &datetimeValuer{
|
||||
DefaultValuer: new(DefaultValuer[NullTime]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerDate() Valuer {
|
||||
return &dateValuer{
|
||||
DefaultValuer: new(DefaultValuer[NullTime]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerTime() Valuer {
|
||||
return &timeValuer{
|
||||
DefaultValuer: new(DefaultValuer[NullTime]),
|
||||
}
|
||||
}
|
||||
|
||||
func ValuerBytes() Valuer {
|
||||
return &bytesValuer{
|
||||
DefaultValuer: new(DefaultValuer[sql.RawBytes]),
|
||||
}
|
||||
}
|
||||
|
||||
// 默认 valuer
|
||||
|
||||
// string
|
||||
|
||||
type stringValuer struct {
|
||||
*DefaultValuer[sql.NullString]
|
||||
}
|
||||
|
||||
func (s *stringValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.String
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// uint64
|
||||
|
||||
type uint64Valuer struct {
|
||||
*DefaultValuer[[]byte]
|
||||
}
|
||||
|
||||
func (s *uint64Valuer) Value() any {
|
||||
valBytes := *s.ValuePtr
|
||||
if valBytes == nil {
|
||||
return nil
|
||||
}
|
||||
val := string(valBytes)
|
||||
// 前端超过16位会丢失精度
|
||||
if len(val) > 16 {
|
||||
return val
|
||||
}
|
||||
return cast.ToUint64(val)
|
||||
}
|
||||
|
||||
// int64
|
||||
|
||||
type int64Valuer struct {
|
||||
*DefaultValuer[sql.NullInt64]
|
||||
}
|
||||
|
||||
func (s *int64Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
val := s.ValuePtr.Int64
|
||||
// 前端超过16位会丢失精度
|
||||
if val > 9999999999999999 {
|
||||
return fmt.Sprintf("%d", val)
|
||||
}
|
||||
return val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// int32
|
||||
|
||||
type int32Valuer struct {
|
||||
*DefaultValuer[sql.NullInt32]
|
||||
}
|
||||
|
||||
func (s *int32Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Int32
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// int16
|
||||
|
||||
type int16Valuer struct {
|
||||
*DefaultValuer[sql.NullInt16]
|
||||
}
|
||||
|
||||
func (s *int16Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Int16
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// byte(uint8)
|
||||
|
||||
type byteValuer struct {
|
||||
*DefaultValuer[sql.NullByte]
|
||||
}
|
||||
|
||||
func (s *byteValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Byte
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bit
|
||||
|
||||
type bitValuer struct {
|
||||
*DefaultValuer[[]byte]
|
||||
}
|
||||
|
||||
func (s *bitValuer) Value() any {
|
||||
valBytes := *s.ValuePtr
|
||||
if valBytes == nil {
|
||||
return nil
|
||||
}
|
||||
return valBytes[0]
|
||||
}
|
||||
|
||||
// float64
|
||||
|
||||
type float64Valuer struct {
|
||||
*DefaultValuer[sql.NullFloat64]
|
||||
}
|
||||
|
||||
func (s *float64Valuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Float64
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bytes
|
||||
|
||||
type bytesValuer struct {
|
||||
*DefaultValuer[sql.RawBytes]
|
||||
}
|
||||
|
||||
func (s *bytesValuer) Value() any {
|
||||
val := s.ValuePtr
|
||||
if *val == nil {
|
||||
return nil
|
||||
}
|
||||
return hex.EncodeToString(*val)
|
||||
}
|
||||
|
||||
// datetime
|
||||
|
||||
type datetimeValuer struct {
|
||||
*DefaultValuer[NullTime]
|
||||
}
|
||||
|
||||
func (s *datetimeValuer) NewValuePtr() any {
|
||||
s.ValuePtr = &NullTime{
|
||||
Layout: time.DateTime,
|
||||
}
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
func (s *datetimeValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Time
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// date
|
||||
|
||||
type dateValuer struct {
|
||||
*DefaultValuer[NullTime]
|
||||
}
|
||||
|
||||
func (s *dateValuer) NewValuePtr() any {
|
||||
s.ValuePtr = &NullTime{
|
||||
Layout: time.DateOnly,
|
||||
}
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
func (s *dateValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Time
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// time
|
||||
|
||||
type timeValuer struct {
|
||||
*DefaultValuer[NullTime]
|
||||
}
|
||||
|
||||
func (s *timeValuer) NewValuePtr() any {
|
||||
s.ValuePtr = &NullTime{
|
||||
Layout: time.TimeOnly,
|
||||
}
|
||||
return s.ValuePtr
|
||||
}
|
||||
|
||||
func (s *timeValuer) Value() any {
|
||||
if s.ValuePtr.Valid {
|
||||
return s.ValuePtr.Time
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NullTime represents a time that may be null.
|
||||
// NullTime implements the [Scanner] interface so
|
||||
// it can be used as a scan destination, similar to [NullString].
|
||||
type NullTime struct {
|
||||
Time string
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
Layout string
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = NullTime{}
|
||||
)
|
||||
|
||||
// Scan implements the [Scanner] interface.
|
||||
func (n *NullTime) Scan(value any) error {
|
||||
if value == nil {
|
||||
n.Time, n.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
|
||||
n.Valid = true
|
||||
time, err := convertTime(value, n.Layout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.Time = time
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n NullTime) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.Time, nil
|
||||
}
|
||||
|
||||
func convertTime(src interface{}, layout string) (string, error) {
|
||||
switch s := src.(type) {
|
||||
case string:
|
||||
return s, nil
|
||||
case []uint8:
|
||||
return string(s), nil
|
||||
case time.Time:
|
||||
return s.Format(layout), nil
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user