mirror of
https://gitee.com/dromara/mayfly-go
synced 2025-11-03 16:00:25 +08:00
refactor: dbm
This commit is contained in:
@@ -161,39 +161,39 @@ func (ai *AppImpl[T, R]) CountByCond(cond any) int64 {
|
||||
|
||||
// Tx 执行事务操作
|
||||
func (ai *AppImpl[T, R]) Tx(ctx context.Context, funcs ...func(context.Context) error) (err error) {
|
||||
tx := GetTxFromCtx(ctx)
|
||||
dbCtx := ctx
|
||||
var txDb *gorm.DB
|
||||
isCreateTx := false
|
||||
txDb := GetDbFromCtx(ctx)
|
||||
|
||||
if tx == nil {
|
||||
if txDb == nil {
|
||||
txDb = global.Db.Begin()
|
||||
dbCtx, tx = NewCtxWithTxDb(ctx, txDb)
|
||||
} else {
|
||||
txDb = tx.DB
|
||||
tx.Count++
|
||||
if txDb.Error != nil {
|
||||
return txDb.Error
|
||||
}
|
||||
|
||||
dbCtx = NewCtxWithDb(ctx, txDb)
|
||||
// 只有创建事务的方法,才允许其提交或回滚
|
||||
isCreateTx = true
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Count = 0
|
||||
txDb.Rollback()
|
||||
err = fmt.Errorf("%v", r)
|
||||
return
|
||||
}
|
||||
|
||||
tx.Count--
|
||||
// Make sure to rollback when panic, Block error or Commit error
|
||||
if isCreateTx && err != nil {
|
||||
txDb.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
for _, f := range funcs {
|
||||
err = f(dbCtx)
|
||||
if err != nil && tx.Count > 0 {
|
||||
tx.Count = 0
|
||||
txDb.Rollback()
|
||||
if err = f(dbCtx); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if tx.Count == 1 {
|
||||
if isCreateTx {
|
||||
err = txDb.Commit().Error
|
||||
}
|
||||
return
|
||||
|
||||
@@ -12,34 +12,19 @@ const (
|
||||
DbKey CtxKey = "db"
|
||||
)
|
||||
|
||||
// Tx 事务上下文信息
|
||||
type Tx struct {
|
||||
Count int
|
||||
DB *gorm.DB
|
||||
}
|
||||
|
||||
// NewCtxWithTxDb 将事务db放置context中
|
||||
func NewCtxWithTxDb(ctx context.Context, db *gorm.DB) (context.Context, *Tx) {
|
||||
if tx := GetTxFromCtx(ctx); tx != nil {
|
||||
return ctx, tx
|
||||
// NewCtxWithDb 将事务db放置context中,若已存在,则直接返回ctx
|
||||
func NewCtxWithDb(ctx context.Context, db *gorm.DB) context.Context {
|
||||
if tx := GetDbFromCtx(ctx); tx != nil {
|
||||
return ctx
|
||||
}
|
||||
|
||||
tx := &Tx{Count: 1, DB: db}
|
||||
return context.WithValue(ctx, DbKey, tx), tx
|
||||
return context.WithValue(ctx, DbKey, db)
|
||||
}
|
||||
|
||||
// GetDbFromCtx 获取ctx中的事务db
|
||||
func GetDbFromCtx(ctx context.Context) *gorm.DB {
|
||||
if tx := GetTxFromCtx(ctx); tx != nil {
|
||||
return tx.DB
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTxFromCtx 获取当前ctx事务
|
||||
func GetTxFromCtx(ctx context.Context) *Tx {
|
||||
if tx, ok := ctx.Value(DbKey).(*Tx); ok {
|
||||
return tx
|
||||
if txdb, ok := ctx.Value(DbKey).(*gorm.DB); ok {
|
||||
return txdb
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -15,16 +15,16 @@ import (
|
||||
// 基础repo接口
|
||||
type Repo[T model.ModelI] interface {
|
||||
|
||||
// 新增一个实体
|
||||
// Insert 新增一个实体
|
||||
Insert(ctx context.Context, e T) error
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
// InsertWithDb 使用指定gorm db执行,主要用于事务执行
|
||||
InsertWithDb(ctx context.Context, db *gorm.DB, e T) error
|
||||
|
||||
// 批量新增实体
|
||||
// BatchInsert 批量新增实体
|
||||
BatchInsert(ctx context.Context, models []T) error
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
// BatchInsertWithDb 使用指定gorm db执行,主要用于事务执行
|
||||
BatchInsertWithDb(ctx context.Context, db *gorm.DB, models []T) error
|
||||
|
||||
// 根据实体id更新实体信息
|
||||
@@ -42,31 +42,32 @@ type Repo[T model.ModelI] interface {
|
||||
// @param values 需要模型结构体或map
|
||||
UpdateByCondWithDb(ctx context.Context, db *gorm.DB, values any, cond any) error
|
||||
|
||||
// 保存实体,实体IsCreate返回true则新增,否则更新
|
||||
// Save 保存实体,实体IsCreate返回true则新增,否则更新
|
||||
Save(ctx context.Context, e T) error
|
||||
|
||||
// 保存实体,实体IsCreate返回true则新增,否则更新。
|
||||
// SaveWithDb 保存实体,实体IsCreate返回true则新增,否则更新。
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
SaveWithDb(ctx context.Context, db *gorm.DB, e T) error
|
||||
|
||||
// 根据实体主键删除实体
|
||||
// DeleteById 根据实体主键删除实体
|
||||
DeleteById(ctx context.Context, id ...uint64) error
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
// DeleteByIdWithDb 使用指定gorm db执行,主要用于事务执行
|
||||
DeleteByIdWithDb(ctx context.Context, db *gorm.DB, id ...uint64) error
|
||||
|
||||
// 根据实体条件删除实体
|
||||
// DeleteByCond 根据实体条件删除实体
|
||||
DeleteByCond(ctx context.Context, cond any) error
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
// DeleteByCondWithDb 使用指定gorm db执行,主要用于事务执行
|
||||
DeleteByCondWithDb(ctx context.Context, db *gorm.DB, cond any) error
|
||||
|
||||
// ExecBySql 执行原生sql
|
||||
ExecBySql(sql string, params ...any) error
|
||||
|
||||
// 根据实体id查询
|
||||
// GetById 根据实体id查询
|
||||
GetById(id uint64, cols ...string) (T, error)
|
||||
|
||||
// GetByIds 根据实体ids查询
|
||||
GetByIds(ids []uint64, cols ...string) ([]T, error)
|
||||
|
||||
// GetByCond 根据实体条件查询实体信息(单个结果集)
|
||||
@@ -88,13 +89,13 @@ type Repo[T model.ModelI] interface {
|
||||
// SelectBySql 根据sql语句查询数据
|
||||
SelectBySql(sql string, res any, params ...any) error
|
||||
|
||||
// 根据指定条件统计model表的数量
|
||||
// CountByCond 根据指定条件统计model表的数量
|
||||
CountByCond(cond any) int64
|
||||
}
|
||||
|
||||
// 基础repo接口
|
||||
type RepoImpl[T model.ModelI] struct {
|
||||
M T // 模型实例
|
||||
model any // 模型实例
|
||||
|
||||
modelType reflect.Type // 模型类型
|
||||
}
|
||||
@@ -120,7 +121,6 @@ func (br *RepoImpl[T]) BatchInsert(ctx context.Context, es []T) error {
|
||||
return gormx.BatchInsert[T](es)
|
||||
}
|
||||
|
||||
// 使用指定gorm db执行,主要用于事务执行
|
||||
func (br *RepoImpl[T]) BatchInsertWithDb(ctx context.Context, db *gorm.DB, es []T) error {
|
||||
for _, e := range es {
|
||||
br.fillBaseInfo(ctx, e)
|
||||
@@ -281,7 +281,12 @@ func (br *RepoImpl[T]) NewModel() T {
|
||||
|
||||
// getModel 获取表的模型实例
|
||||
func (br *RepoImpl[T]) getModel() T {
|
||||
return br.M
|
||||
if br.model != nil {
|
||||
return br.model.(T)
|
||||
}
|
||||
|
||||
br.model = br.NewModel()
|
||||
return br.model.(T)
|
||||
}
|
||||
|
||||
// getModelType 获取模型类型(非指针模型)
|
||||
@@ -290,7 +295,8 @@ func (br *RepoImpl[T]) getModelType() reflect.Type {
|
||||
return br.modelType
|
||||
}
|
||||
|
||||
modelType := reflect.TypeOf(br.M)
|
||||
var model T
|
||||
modelType := reflect.TypeOf(model)
|
||||
// 检查 model 是否为指针类型
|
||||
if modelType.Kind() == reflect.Ptr {
|
||||
// 获取指针指向的类型
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SQLStatement 结构体用于存储解析后的 SQL 语句及其注释
|
||||
type SQLStatement struct {
|
||||
Comment string
|
||||
SQL string
|
||||
}
|
||||
|
||||
var sqlMap = make(map[string]string)
|
||||
|
||||
func RegisterSql(fs embed.FS) error {
|
||||
return walkDir(fs, ".", func(fp string, data []byte) error {
|
||||
if filepath.Ext(fp) != ".sql" {
|
||||
return nil
|
||||
}
|
||||
|
||||
fileNameWithExt := path.Base(fp)
|
||||
sqls, err := parseSQL(string(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filename := strings.TrimSuffix(fileNameWithExt, path.Ext(fileNameWithExt))
|
||||
for _, sql := range sqls {
|
||||
sqlMap[filename+"."+strings.TrimSpace(sql.Comment)] = strings.TrimSpace(sql.SQL)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func GetSQL(filename, stmt string) string {
|
||||
return sqlMap[filename+"."+stmt]
|
||||
}
|
||||
|
||||
// walkDir 递归遍历目录
|
||||
func walkDir(fsys fs.FS, path string, callback func(filePath string, data []byte) error) error {
|
||||
entries, err := fs.ReadDir(fsys, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, entry := range entries {
|
||||
entryPath := filepath.Join(path, entry.Name())
|
||||
if entry.IsDir() {
|
||||
// 递归遍历子目录
|
||||
if err := walkDir(fsys, entryPath, callback); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// 读取文件内容
|
||||
data, err := fs.ReadFile(fsys, entryPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := callback(entryPath, data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseSQL 解析带有注释的 SQL 语句
|
||||
func parseSQL(sql string) ([]SQLStatement, error) {
|
||||
var statements []SQLStatement
|
||||
lines := strings.Split(sql, "\n")
|
||||
|
||||
var currentComment string
|
||||
var currentSQL string
|
||||
|
||||
for _, line := range lines {
|
||||
trimmedLine := strings.TrimSpace(line)
|
||||
if strings.HasPrefix(trimmedLine, "--") {
|
||||
// 处理单行注释
|
||||
if currentSQL != "" {
|
||||
statements = append(statements, SQLStatement{Comment: currentComment, SQL: strings.TrimRight(currentSQL, " ")})
|
||||
currentComment = ""
|
||||
currentSQL = ""
|
||||
}
|
||||
currentComment += strings.TrimPrefix(trimmedLine, "--") + "\n"
|
||||
continue
|
||||
}
|
||||
|
||||
if trimmedLine == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
currentSQL += line + " "
|
||||
if strings.HasSuffix(trimmedLine, ";") {
|
||||
statements = append(statements, SQLStatement{Comment: currentComment, SQL: strings.TrimRight(currentSQL, " ")})
|
||||
currentComment = ""
|
||||
currentSQL = ""
|
||||
}
|
||||
}
|
||||
|
||||
// 处理最后一段未结束的 SQL 语句
|
||||
if currentSQL != "" {
|
||||
statements = append(statements, SQLStatement{Comment: currentComment, SQL: strings.TrimRight(currentSQL, " ")})
|
||||
}
|
||||
|
||||
return statements, nil
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package base
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParserSql(t *testing.T) {
|
||||
sql := `-- selectByCond
|
||||
Select * from tdb where id > 10;
|
||||
-- another comment
|
||||
Select * from another_table where name = 'test'
|
||||
and age = ?;
|
||||
-- multi-line comment
|
||||
-- continues here
|
||||
Select * from yet_another_table
|
||||
Where id = ?;`
|
||||
|
||||
statements, err := parseSQL(sql)
|
||||
if err != nil {
|
||||
fmt.Println("Error:", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, stmt := range statements {
|
||||
fmt.Printf("Comment: %s\nSQL: %s\n\n", stmt.Comment, stmt.SQL)
|
||||
}
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
package captcha
|
||||
|
||||
import (
|
||||
"mayfly-go/pkg/rediscli"
|
||||
"time"
|
||||
|
||||
"github.com/mojocn/base64Captcha"
|
||||
)
|
||||
|
||||
var store base64Captcha.Store
|
||||
var driver base64Captcha.Driver = base64Captcha.DefaultDriverDigit
|
||||
|
||||
// 生成验证码
|
||||
func Generate() (string, string, error) {
|
||||
if store == nil {
|
||||
if rediscli.GetCli() != nil {
|
||||
store = new(RedisStore)
|
||||
} else {
|
||||
store = base64Captcha.DefaultMemStore
|
||||
}
|
||||
}
|
||||
|
||||
c := base64Captcha.NewCaptcha(driver, store)
|
||||
// 获取
|
||||
id, b64s, _, err := c.Generate()
|
||||
return id, b64s, err
|
||||
}
|
||||
|
||||
// 验证验证码
|
||||
func Verify(id string, val string) bool {
|
||||
if store == nil || id == "" || val == "" {
|
||||
return false
|
||||
}
|
||||
// 同时清理掉这个图片
|
||||
return store.Verify(id, val, true)
|
||||
}
|
||||
|
||||
type RedisStore struct {
|
||||
}
|
||||
|
||||
const CAPTCHA = "mayfly:captcha:"
|
||||
|
||||
// 实现设置captcha的方法
|
||||
func (r RedisStore) Set(id string, value string) error {
|
||||
//time.Minute*2:有效时间2分钟
|
||||
rediscli.Set(CAPTCHA+id, value, time.Minute*2)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 实现获取captcha的方法
|
||||
func (r RedisStore) Get(id string, clear bool) string {
|
||||
key := CAPTCHA + id
|
||||
val, err := rediscli.Get(key)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if clear {
|
||||
//clear为true,验证通过,删除这个验证码
|
||||
rediscli.Del(key)
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
// 实现验证captcha的方法
|
||||
func (r RedisStore) Verify(id, answer string, clear bool) bool {
|
||||
return r.Get(id, clear) == answer
|
||||
}
|
||||
@@ -32,7 +32,7 @@ func (e *Enum[T]) Valid(value T) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
errMsg := fmt.Sprintf("%s可选值为: ", e.name)
|
||||
errMsg := fmt.Sprintf("%s the optional value is: ", e.name)
|
||||
for val, desc := range e.values {
|
||||
errMsg = fmt.Sprintf("%s [%v->%s]", errMsg, val, desc)
|
||||
}
|
||||
|
||||
@@ -109,7 +109,7 @@ func UpdateByIdWithDb(db *gorm.DB, model model.ModelI, columns ...string) error
|
||||
return db.Model(model).Select(columns).Updates(model).Error
|
||||
}
|
||||
|
||||
// UpdateByCondWithDb 使用指定gorm.DB更新满足条件的数据(model的主键值需为空,否则会带上主键条件)
|
||||
// UpdateByCond 使用默认global.Dd更新满足条件的数据(model的主键值需为空,否则会带上主键条件)
|
||||
func UpdateByCond(dbModel model.ModelI, values any, cond *model.QueryCond) error {
|
||||
return UpdateByCondWithDb(global.Db, dbModel, values, cond)
|
||||
}
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
package otp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
otp_t "github.com/pquerna/otp"
|
||||
totp_t "github.com/pquerna/otp/totp"
|
||||
)
|
||||
|
||||
type GenerateOpts totp_t.GenerateOpts
|
||||
|
||||
func NewTOTP(opt GenerateOpts) (*otp_t.Key, error) {
|
||||
return totp_t.Generate(totp_t.GenerateOpts(opt))
|
||||
}
|
||||
|
||||
func Validate(code string, secret string) bool {
|
||||
if secret == "" {
|
||||
return true
|
||||
}
|
||||
return totp_t.Validate(code, secret)
|
||||
}
|
||||
|
||||
func GenTotpCode(code string, secret string) (string, error) {
|
||||
return totp_t.GenerateCode(secret, time.Now())
|
||||
}
|
||||
@@ -71,7 +71,7 @@ func ArrayToMap[T any, K comparable](arr []T, keyFunc func(val T) K) map[K]T {
|
||||
}
|
||||
|
||||
// 数组映射,即将一数组元素通过映射函数转换为另一数组
|
||||
func ArrayMap[T any, K comparable](arr []T, mapFunc func(val T) K) []K {
|
||||
func ArrayMap[T any, K any](arr []T, mapFunc func(val T) K) []K {
|
||||
res := make([]K, len(arr))
|
||||
for i, val := range arr {
|
||||
res[i] = mapFunc(val)
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
|
||||
// json字符串转map
|
||||
func ToMap(jsonStr string) map[string]any {
|
||||
if jsonStr == "" {
|
||||
return map[string]any{}
|
||||
}
|
||||
return ToMapByBytes([]byte(jsonStr))
|
||||
}
|
||||
|
||||
@@ -22,7 +25,7 @@ func ToMapByBytes(bytes []byte) map[string]any {
|
||||
var res map[string]any
|
||||
err := json.Unmarshal(bytes, &res)
|
||||
if err != nil {
|
||||
logx.Errorf("json字符串转map失败: %s", err.Error())
|
||||
logx.ErrorTrace("json字符串转map失败", err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"strings"
|
||||
"text/template"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
// 逻辑空字符串(由于gorm更新结构体只更新非零值,所以使用该值最为逻辑空字符串,方便更新结构体)
|
||||
@@ -124,16 +125,36 @@ func ReverStrTemplate(temp, str string, res map[string]any) {
|
||||
}
|
||||
}
|
||||
|
||||
func TruncateStr(s string, length int) string {
|
||||
if length >= len(s) {
|
||||
// Truncate 截断字符串并在中间部分显示指定的替换字符串
|
||||
func Truncate(s string, length int, prefixLen int, replace string) string {
|
||||
totalRunes := utf8.RuneCountInString(s)
|
||||
|
||||
// 如果字符串长度小于或等于指定的 length,直接返回原字符串
|
||||
if totalRunes <= length {
|
||||
return s
|
||||
}
|
||||
var last int
|
||||
for i := range s {
|
||||
if i > length {
|
||||
break
|
||||
}
|
||||
last = i
|
||||
|
||||
// 如果字符串长度小于或等于 prefixLen,直接返回原字符串
|
||||
if totalRunes <= prefixLen {
|
||||
return s
|
||||
}
|
||||
return s[:last]
|
||||
|
||||
// 计算 suffixLen
|
||||
suffixLen := length - prefixLen
|
||||
|
||||
// 确保 suffixLen 不会越界
|
||||
if suffixLen <= 0 {
|
||||
runes := []rune(s)
|
||||
return string(runes[:length]) + replace
|
||||
}
|
||||
|
||||
// 获取前 prefixLen 个字符
|
||||
runes := []rune(s)
|
||||
prefix := string(runes[:prefixLen])
|
||||
|
||||
// 获取后 suffixLen 个字符
|
||||
suffix := string(runes[len(runes)-suffixLen:])
|
||||
|
||||
// 返回格式化后的字符串
|
||||
return prefix + replace + suffix
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
package stringx
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestTruncateStr(t *testing.T) {
|
||||
@@ -12,20 +13,12 @@ func TestTruncateStr(t *testing.T) {
|
||||
length int
|
||||
want string
|
||||
}{
|
||||
{"123一二三", 0, ""},
|
||||
{"123一二三", 1, "1"},
|
||||
{"123一二三", 3, "123"},
|
||||
{"123一二三", 4, "123"},
|
||||
{"123一二三", 5, "123"},
|
||||
{"123一二三", 6, "123一"},
|
||||
{"123一二三", 7, "123一"},
|
||||
{"123一二三", 11, "123一二"},
|
||||
{"123一二三", 12, "123一二三"},
|
||||
{"123一二三", 13, "123一二三"},
|
||||
{"123一二三", 4, "123...三"},
|
||||
{"123一二三", 5, "123...二三"},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
t.Run(strconv.Itoa(tc.length), func(t *testing.T) {
|
||||
got := TruncateStr(tc.data, tc.length)
|
||||
got := Truncate(tc.data, tc.length, 3, "...")
|
||||
require.Equal(t, tc.want, got)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -8,12 +8,18 @@ import (
|
||||
)
|
||||
|
||||
const DefaultDateTimeFormat = "2006-01-02 15:04:05"
|
||||
const DefaultDateFormat = "2006-01-02"
|
||||
|
||||
// DefaultFormat 使用默认格式进行格式化: 2006-01-02 15:04:05
|
||||
func DefaultFormat(time time.Time) string {
|
||||
return time.Format(DefaultDateTimeFormat)
|
||||
}
|
||||
|
||||
// DefaultFormatDate 使用默认格式进行格式化: 2006-01-02
|
||||
func DefaultFormatDate(time time.Time) string {
|
||||
return time.Format(DefaultDateFormat)
|
||||
}
|
||||
|
||||
// TimeNo 获取当前时间编号,格式为20060102150405
|
||||
func TimeNo() string {
|
||||
return time.Now().Format("20060102150405")
|
||||
|
||||
Reference in New Issue
Block a user