refactor: dbm

This commit is contained in:
meilin.huang
2024-12-08 13:04:23 +08:00
parent ebc89e056f
commit e56788af3e
152 changed files with 4273 additions and 3715 deletions

View File

@@ -161,39 +161,39 @@ func (ai *AppImpl[T, R]) CountByCond(cond any) int64 {
// Tx 执行事务操作
func (ai *AppImpl[T, R]) Tx(ctx context.Context, funcs ...func(context.Context) error) (err error) {
tx := GetTxFromCtx(ctx)
dbCtx := ctx
var txDb *gorm.DB
isCreateTx := false
txDb := GetDbFromCtx(ctx)
if tx == nil {
if txDb == nil {
txDb = global.Db.Begin()
dbCtx, tx = NewCtxWithTxDb(ctx, txDb)
} else {
txDb = tx.DB
tx.Count++
if txDb.Error != nil {
return txDb.Error
}
dbCtx = NewCtxWithDb(ctx, txDb)
// 只有创建事务的方法,才允许其提交或回滚
isCreateTx = true
}
defer func() {
if r := recover(); r != nil {
tx.Count = 0
txDb.Rollback()
err = fmt.Errorf("%v", r)
return
}
tx.Count--
// Make sure to rollback when panic, Block error or Commit error
if isCreateTx && err != nil {
txDb.Rollback()
}
}()
for _, f := range funcs {
err = f(dbCtx)
if err != nil && tx.Count > 0 {
tx.Count = 0
txDb.Rollback()
if err = f(dbCtx); err != nil {
return
}
}
if tx.Count == 1 {
if isCreateTx {
err = txDb.Commit().Error
}
return

View File

@@ -12,34 +12,19 @@ const (
DbKey CtxKey = "db"
)
// Tx 事务上下文信息
type Tx struct {
Count int
DB *gorm.DB
}
// NewCtxWithTxDb 将事务db放置context中
func NewCtxWithTxDb(ctx context.Context, db *gorm.DB) (context.Context, *Tx) {
if tx := GetTxFromCtx(ctx); tx != nil {
return ctx, tx
// NewCtxWithDb 将事务db放置context中若已存在则直接返回ctx
func NewCtxWithDb(ctx context.Context, db *gorm.DB) context.Context {
if tx := GetDbFromCtx(ctx); tx != nil {
return ctx
}
tx := &Tx{Count: 1, DB: db}
return context.WithValue(ctx, DbKey, tx), tx
return context.WithValue(ctx, DbKey, db)
}
// GetDbFromCtx 获取ctx中的事务db
func GetDbFromCtx(ctx context.Context) *gorm.DB {
if tx := GetTxFromCtx(ctx); tx != nil {
return tx.DB
}
return nil
}
// GetTxFromCtx 获取当前ctx事务
func GetTxFromCtx(ctx context.Context) *Tx {
if tx, ok := ctx.Value(DbKey).(*Tx); ok {
return tx
if txdb, ok := ctx.Value(DbKey).(*gorm.DB); ok {
return txdb
}
return nil
}

View File

@@ -15,16 +15,16 @@ import (
// 基础repo接口
type Repo[T model.ModelI] interface {
// 新增一个实体
// Insert 新增一个实体
Insert(ctx context.Context, e T) error
// 使用指定gorm db执行主要用于事务执行
// InsertWithDb 使用指定gorm db执行主要用于事务执行
InsertWithDb(ctx context.Context, db *gorm.DB, e T) error
// 批量新增实体
// BatchInsert 批量新增实体
BatchInsert(ctx context.Context, models []T) error
// 使用指定gorm db执行主要用于事务执行
// BatchInsertWithDb 使用指定gorm db执行主要用于事务执行
BatchInsertWithDb(ctx context.Context, db *gorm.DB, models []T) error
// 根据实体id更新实体信息
@@ -42,31 +42,32 @@ type Repo[T model.ModelI] interface {
// @param values 需要模型结构体或map
UpdateByCondWithDb(ctx context.Context, db *gorm.DB, values any, cond any) error
// 保存实体实体IsCreate返回true则新增否则更新
// Save 保存实体实体IsCreate返回true则新增否则更新
Save(ctx context.Context, e T) error
// 保存实体实体IsCreate返回true则新增否则更新。
// SaveWithDb 保存实体实体IsCreate返回true则新增否则更新。
// 使用指定gorm db执行主要用于事务执行
SaveWithDb(ctx context.Context, db *gorm.DB, e T) error
// 根据实体主键删除实体
// DeleteById 根据实体主键删除实体
DeleteById(ctx context.Context, id ...uint64) error
// 使用指定gorm db执行主要用于事务执行
// DeleteByIdWithDb 使用指定gorm db执行主要用于事务执行
DeleteByIdWithDb(ctx context.Context, db *gorm.DB, id ...uint64) error
// 根据实体条件删除实体
// DeleteByCond 根据实体条件删除实体
DeleteByCond(ctx context.Context, cond any) error
// 使用指定gorm db执行主要用于事务执行
// DeleteByCondWithDb 使用指定gorm db执行主要用于事务执行
DeleteByCondWithDb(ctx context.Context, db *gorm.DB, cond any) error
// ExecBySql 执行原生sql
ExecBySql(sql string, params ...any) error
// 根据实体id查询
// GetById 根据实体id查询
GetById(id uint64, cols ...string) (T, error)
// GetByIds 根据实体ids查询
GetByIds(ids []uint64, cols ...string) ([]T, error)
// GetByCond 根据实体条件查询实体信息(单个结果集)
@@ -88,13 +89,13 @@ type Repo[T model.ModelI] interface {
// SelectBySql 根据sql语句查询数据
SelectBySql(sql string, res any, params ...any) error
// 根据指定条件统计model表的数量
// CountByCond 根据指定条件统计model表的数量
CountByCond(cond any) int64
}
// 基础repo接口
type RepoImpl[T model.ModelI] struct {
M T // 模型实例
model any // 模型实例
modelType reflect.Type // 模型类型
}
@@ -120,7 +121,6 @@ func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es []T) error {
return gormx.BatchInsert[T](es)
}
// 使用指定gorm db执行主要用于事务执行
func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error {
for _, e := range es {
br.fillBaseInfo(ctx, e)
@@ -281,7 +281,12 @@ func (br *RepoImpl[T]) NewModel() T {
// getModel 获取表的模型实例
func (br *RepoImpl[T]) getModel() T {
return br.M
if br.model != nil {
return br.model.(T)
}
br.model = br.NewModel()
return br.model.(T)
}
// getModelType 获取模型类型(非指针模型)
@@ -290,7 +295,8 @@ func (br *RepoImpl[T]) getModelType() reflect.Type {
return br.modelType
}
modelType := reflect.TypeOf(br.M)
var model T
modelType := reflect.TypeOf(model)
// 检查 model 是否为指针类型
if modelType.Kind() == reflect.Ptr {
// 获取指针指向的类型

View File

@@ -1,111 +0,0 @@
package base
import (
"embed"
"io/fs"
"path"
"path/filepath"
"strings"
)
// SQLStatement 结构体用于存储解析后的 SQL 语句及其注释
type SQLStatement struct {
Comment string
SQL string
}
var sqlMap = make(map[string]string)
func RegisterSql(fs embed.FS) error {
return walkDir(fs, ".", func(fp string, data []byte) error {
if filepath.Ext(fp) != ".sql" {
return nil
}
fileNameWithExt := path.Base(fp)
sqls, err := parseSQL(string(data))
if err != nil {
return err
}
filename := strings.TrimSuffix(fileNameWithExt, path.Ext(fileNameWithExt))
for _, sql := range sqls {
sqlMap[filename+"."+strings.TrimSpace(sql.Comment)] = strings.TrimSpace(sql.SQL)
}
return nil
})
}
func GetSQL(filename, stmt string) string {
return sqlMap[filename+"."+stmt]
}
// walkDir 递归遍历目录
func walkDir(fsys fs.FS, path string, callback func(filePath string, data []byte) error) error {
entries, err := fs.ReadDir(fsys, path)
if err != nil {
return err
}
for _, entry := range entries {
entryPath := filepath.Join(path, entry.Name())
if entry.IsDir() {
// 递归遍历子目录
if err := walkDir(fsys, entryPath, callback); err != nil {
return err
}
} else {
// 读取文件内容
data, err := fs.ReadFile(fsys, entryPath)
if err != nil {
return err
}
if err := callback(entryPath, data); err != nil {
return err
}
}
}
return nil
}
// parseSQL 解析带有注释的 SQL 语句
func parseSQL(sql string) ([]SQLStatement, error) {
var statements []SQLStatement
lines := strings.Split(sql, "\n")
var currentComment string
var currentSQL string
for _, line := range lines {
trimmedLine := strings.TrimSpace(line)
if strings.HasPrefix(trimmedLine, "--") {
// 处理单行注释
if currentSQL != "" {
statements = append(statements, SQLStatement{Comment: currentComment, SQL: strings.TrimRight(currentSQL, " ")})
currentComment = ""
currentSQL = ""
}
currentComment += strings.TrimPrefix(trimmedLine, "--") + "\n"
continue
}
if trimmedLine == "" {
continue
}
currentSQL += line + " "
if strings.HasSuffix(trimmedLine, ";") {
statements = append(statements, SQLStatement{Comment: currentComment, SQL: strings.TrimRight(currentSQL, " ")})
currentComment = ""
currentSQL = ""
}
}
// 处理最后一段未结束的 SQL 语句
if currentSQL != "" {
statements = append(statements, SQLStatement{Comment: currentComment, SQL: strings.TrimRight(currentSQL, " ")})
}
return statements, nil
}

View File

@@ -1,28 +0,0 @@
package base
import (
"fmt"
"testing"
)
func TestParserSql(t *testing.T) {
sql := `-- selectByCond
Select * from tdb where id > 10;
-- another comment
Select * from another_table where name = 'test'
and age = ?;
-- multi-line comment
-- continues here
Select * from yet_another_table
Where id = ?;`
statements, err := parseSQL(sql)
if err != nil {
fmt.Println("Error:", err)
return
}
for _, stmt := range statements {
fmt.Printf("Comment: %s\nSQL: %s\n\n", stmt.Comment, stmt.SQL)
}
}

View File

@@ -1,67 +0,0 @@
package captcha
import (
"mayfly-go/pkg/rediscli"
"time"
"github.com/mojocn/base64Captcha"
)
var store base64Captcha.Store
var driver base64Captcha.Driver = base64Captcha.DefaultDriverDigit
// 生成验证码
func Generate() (string, string, error) {
if store == nil {
if rediscli.GetCli() != nil {
store = new(RedisStore)
} else {
store = base64Captcha.DefaultMemStore
}
}
c := base64Captcha.NewCaptcha(driver, store)
// 获取
id, b64s, _, err := c.Generate()
return id, b64s, err
}
// 验证验证码
func Verify(id string, val string) bool {
if store == nil || id == "" || val == "" {
return false
}
// 同时清理掉这个图片
return store.Verify(id, val, true)
}
type RedisStore struct {
}
const CAPTCHA = "mayfly:captcha:"
// 实现设置captcha的方法
func (r RedisStore) Set(id string, value string) error {
//time.Minute*2有效时间2分钟
rediscli.Set(CAPTCHA+id, value, time.Minute*2)
return nil
}
// 实现获取captcha的方法
func (r RedisStore) Get(id string, clear bool) string {
key := CAPTCHA + id
val, err := rediscli.Get(key)
if err != nil {
return ""
}
if clear {
//clear为true验证通过删除这个验证码
rediscli.Del(key)
}
return val
}
// 实现验证captcha的方法
func (r RedisStore) Verify(id, answer string, clear bool) bool {
return r.Get(id, clear) == answer
}

View File

@@ -32,7 +32,7 @@ func (e *Enum[T]) Valid(value T) error {
return nil
}
errMsg := fmt.Sprintf("%s可选值为: ", e.name)
errMsg := fmt.Sprintf("%s the optional value is: ", e.name)
for val, desc := range e.values {
errMsg = fmt.Sprintf("%s [%v->%s]", errMsg, val, desc)
}

View File

@@ -109,7 +109,7 @@ func UpdateByIdWithDb(db *gorm.DB, model model.ModelI, columns ...string) error
return db.Model(model).Select(columns).Updates(model).Error
}
// UpdateByCondWithDb 使用指定gorm.DB更新满足条件的数据(model的主键值需为空否则会带上主键条件)
// UpdateByCond 使用默认global.Dd更新满足条件的数据(model的主键值需为空否则会带上主键条件)
func UpdateByCond(dbModel model.ModelI, values any, cond *model.QueryCond) error {
return UpdateByCondWithDb(global.Db, dbModel, values, cond)
}

View File

@@ -1,25 +0,0 @@
package otp
import (
"time"
otp_t "github.com/pquerna/otp"
totp_t "github.com/pquerna/otp/totp"
)
type GenerateOpts totp_t.GenerateOpts
func NewTOTP(opt GenerateOpts) (*otp_t.Key, error) {
return totp_t.Generate(totp_t.GenerateOpts(opt))
}
func Validate(code string, secret string) bool {
if secret == "" {
return true
}
return totp_t.Validate(code, secret)
}
func GenTotpCode(code string, secret string) (string, error) {
return totp_t.GenerateCode(secret, time.Now())
}

View File

@@ -71,7 +71,7 @@ func ArrayToMap[T any, K comparable](arr []T, keyFunc func(val T) K) map[K]T {
}
// 数组映射,即将一数组元素通过映射函数转换为另一数组
func ArrayMap[T any, K comparable](arr []T, mapFunc func(val T) K) []K {
func ArrayMap[T any, K any](arr []T, mapFunc func(val T) K) []K {
res := make([]K, len(arr))
for i, val := range arr {
res[i] = mapFunc(val)

View File

@@ -9,6 +9,9 @@ import (
// json字符串转map
func ToMap(jsonStr string) map[string]any {
if jsonStr == "" {
return map[string]any{}
}
return ToMapByBytes([]byte(jsonStr))
}
@@ -22,7 +25,7 @@ func ToMapByBytes(bytes []byte) map[string]any {
var res map[string]any
err := json.Unmarshal(bytes, &res)
if err != nil {
logx.Errorf("json字符串转map失败: %s", err.Error())
logx.ErrorTrace("json字符串转map失败", err)
}
return res
}

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"strings"
"text/template"
"unicode/utf8"
)
// 逻辑空字符串由于gorm更新结构体只更新非零值所以使用该值最为逻辑空字符串方便更新结构体
@@ -124,16 +125,36 @@ func ReverStrTemplate(temp, str string, res map[string]any) {
}
}
func TruncateStr(s string, length int) string {
if length >= len(s) {
// Truncate 截断字符串并在中间部分显示指定的替换字符串
func Truncate(s string, length int, prefixLen int, replace string) string {
totalRunes := utf8.RuneCountInString(s)
// 如果字符串长度小于或等于指定的 length直接返回原字符串
if totalRunes <= length {
return s
}
var last int
for i := range s {
if i > length {
break
}
last = i
// 如果字符串长度小于或等于 prefixLen直接返回原字符串
if totalRunes <= prefixLen {
return s
}
return s[:last]
// 计算 suffixLen
suffixLen := length - prefixLen
// 确保 suffixLen 不会越界
if suffixLen <= 0 {
runes := []rune(s)
return string(runes[:length]) + replace
}
// 获取前 prefixLen 个字符
runes := []rune(s)
prefix := string(runes[:prefixLen])
// 获取后 suffixLen 个字符
suffix := string(runes[len(runes)-suffixLen:])
// 返回格式化后的字符串
return prefix + replace + suffix
}

View File

@@ -1,9 +1,10 @@
package stringx
import (
"github.com/stretchr/testify/require"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
func TestTruncateStr(t *testing.T) {
@@ -12,20 +13,12 @@ func TestTruncateStr(t *testing.T) {
length int
want string
}{
{"123一二三", 0, ""},
{"123一二三", 1, "1"},
{"123一二三", 3, "123"},
{"123一二三", 4, "123"},
{"123一二三", 5, "123"},
{"123一二三", 6, "123一"},
{"123一二三", 7, "123一"},
{"123一二三", 11, "123一二"},
{"123一二三", 12, "123一二三"},
{"123一二三", 13, "123一二三"},
{"123一二三", 4, "123...三"},
{"123一二三", 5, "123...二三"},
}
for _, tc := range testCases {
t.Run(strconv.Itoa(tc.length), func(t *testing.T) {
got := TruncateStr(tc.data, tc.length)
got := Truncate(tc.data, tc.length, 3, "...")
require.Equal(t, tc.want, got)
})
}

View File

@@ -8,12 +8,18 @@ import (
)
const DefaultDateTimeFormat = "2006-01-02 15:04:05"
const DefaultDateFormat = "2006-01-02"
// DefaultFormat 使用默认格式进行格式化: 2006-01-02 15:04:05
func DefaultFormat(time time.Time) string {
return time.Format(DefaultDateTimeFormat)
}
// DefaultFormatDate 使用默认格式进行格式化: 2006-01-02
func DefaultFormatDate(time time.Time) string {
return time.Format(DefaultDateFormat)
}
// TimeNo 获取当前时间编号格式为20060102150405
func TimeNo() string {
return time.Now().Format("20060102150405")