Files
mayfly-go/base/model/model.go
2021-01-08 15:37:32 +08:00

302 lines
8.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package model
import (
"errors"
"mayfly-go/base/ctx"
"mayfly-go/base/utils"
"reflect"
"strconv"
"strings"
"time"
"github.com/beego/beego/v2/client/orm"
"github.com/siddontang/go/log"
)
type Model struct {
Id uint64 `orm:"column(id);auto" json:"id"`
CreateTime *time.Time `orm:"column(create_time);type(datetime);null" json:"createTime"`
CreatorId uint64 `orm:"column(creator_id)" json:"creatorId"`
Creator string `orm:"column(creator)" json:"creator"`
UpdateTime *time.Time `orm:"column(update_time);type(datetime);null" json:"updateTime"`
ModifierId uint64 `orm:"column(modifier_id)" json:"modifierId"`
Modifier string `orm:"column(modifier)" json:"modifier"`
}
// 设置基础信息. 如创建时间,修改时间,创建者,修改者信息
func (m *Model) SetBaseInfo(account *ctx.LoginAccount) {
nowTime := time.Now()
isCreate := m.Id == 0
if isCreate {
m.CreateTime = &nowTime
}
m.UpdateTime = &nowTime
if account == nil {
return
}
id := account.Id
name := account.Username
if isCreate {
m.CreatorId = id
m.Creator = name
}
m.Modifier = name
m.ModifierId = id
}
// 获取orm querySeter
func QuerySetter(table interface{}) orm.QuerySeter {
return getOrm().QueryTable(table)
}
// 根据id获取实体对象。model需为指针类型需要将查询出来的值赋值给model
//
// 若error不为nil则为不存在该记录
func GetById(model interface{}, id uint64, cols ...string) error {
return QuerySetter(model).Filter("Id", id).One(model, cols...)
}
// 根据id更新model更新字段为model中不为空的值即int类型不为0ptr类型不为nil这类字段值
func UpdateById(model interface{}) (int64, error) {
var id uint64
params := orm.Params{}
err := utils.DoWithFields(model, func(ft reflect.StructField, fv reflect.Value) error {
if utils.IsBlank(fv) {
return nil
}
if ft.Name == "Id" {
if id = fv.Uint(); id == 0 {
return errors.New("根据id更新model时Id不能为0")
}
return nil
}
params[ft.Name] = fv.Interface()
return nil
})
if err != nil {
return 0, err
}
return QuerySetter(model).Filter("Id", id).Update(params)
}
// 根据id删除model
func DeleteById(model interface{}, id uint64) (int64, error) {
return QuerySetter(model).Filter("Id", id).Delete()
}
// 插入model
func Insert(model interface{}) (int64, error) {
return getOrm().Insert(model)
}
// 获取满足model中不为空的字段值条件的所有数据.
//
// @param list为数组类型 如 var users []*User
func ListByCondition(model interface{}, list interface{}) {
qs := QuerySetter(model)
utils.DoWithFields(model, func(ft reflect.StructField, fv reflect.Value) error {
if !utils.IsBlank(fv) {
qs = qs.Filter(ft.Name, fv.Interface())
}
return nil
})
qs.All(list)
}
// 获取满足model中不为空的字段值条件的单个对象。model需为指针类型需要将查询出来的值赋值给model
//
// 若 error不为nil则为不存在该记录
func GetByCondition(model interface{}, cols ...string) error {
qs := QuerySetter(model)
utils.DoWithFields(model, func(ft reflect.StructField, fv reflect.Value) error {
if !utils.IsBlank(fv) {
qs = qs.Filter(ft.Name, fv.Interface())
}
return nil
})
return qs.One(model, cols...)
}
// 获取分页结果
func GetPage(seter orm.QuerySeter, pageParam *PageParam, models interface{}, toModels interface{}) PageResult {
count, _ := seter.Count()
if count == 0 {
return PageResult{Total: 0, List: nil}
}
_, qerr := seter.Limit(pageParam.PageSize, pageParam.PageNum-1).All(models, getFieldNames(toModels)...)
BizErrIsNil(qerr, "查询错误")
err := utils.Copy(toModels, models)
BizErrIsNil(err, "实体转换错误")
return PageResult{Total: count, List: toModels}
}
// 根据sql获取分页对象
func GetPageBySql(sql string, toModel interface{}, param *PageParam, args ...interface{}) PageResult {
selectIndex := strings.Index(sql, "SELECT ") + 7
fromIndex := strings.Index(sql, " FROM")
selectCol := sql[selectIndex:fromIndex]
countSql := strings.Replace(sql, selectCol, "COUNT(*) AS total ", 1)
// 查询count
o := getOrm()
type TotalRes struct {
Total int64
}
var totalRes TotalRes
_ = o.Raw(countSql, args).QueryRow(&totalRes)
total := totalRes.Total
if total == 0 {
return PageResult{Total: 0, List: nil}
}
// 分页查询
limitSql := sql + " LIMIT " + strconv.Itoa(param.PageNum-1) + ", " + strconv.Itoa(param.PageSize)
var maps []orm.Params
_, err := o.Raw(limitSql, args).Values(&maps)
if err != nil {
panic(errors.New("查询错误 : " + err.Error()))
}
e := ormParams2Struct(maps, toModel)
if e != nil {
panic(e)
}
return PageResult{Total: total, List: toModel}
}
func GetListBySql(sql string, params ...interface{}) *[]orm.Params {
var maps []orm.Params
_, err := getOrm().Raw(sql, params).Values(&maps)
if err != nil {
log.Error("根据sql查询数据列表失败%s", err.Error())
}
return &maps
}
// 获取所有列表数据
// model为数组类型 如 var users []*User
func GetList(seter orm.QuerySeter, model interface{}, toModel interface{}) {
_, _ = seter.All(model, getFieldNames(toModel)...)
err := utils.Copy(toModel, model)
BizErrIsNil(err, "实体转换错误")
}
// 根据toModel结构体字段查询单条记录并将值赋值给toModel
func GetOne(seter orm.QuerySeter, model interface{}, toModel interface{}) error {
err := seter.One(model, getFieldNames(toModel)...)
if err != nil {
return err
}
cerr := utils.Copy(toModel, model)
BizErrIsNil(cerr, "实体转换错误")
return nil
}
// 根据实体以及指定字段值查询实体若字段数组为空则默认用id查
func GetBy(model interface{}, fs ...string) error {
err := getOrm().Read(model, fs...)
if err != nil {
if err == orm.ErrNoRows {
return errors.New("该数据不存在")
} else {
return errors.New("查询失败")
}
}
return nil
}
func Update(model interface{}, fs ...string) error {
_, err := getOrm().Update(model, fs...)
if err != nil {
return errors.New("数据更新失败")
}
return nil
}
func Delete(model interface{}, fs ...string) error {
_, err := getOrm().Delete(model, fs...)
if err != nil {
return errors.New("数据删除失败")
}
return nil
}
func getOrm() orm.Ormer {
return orm.NewOrm()
}
// 结果模型缓存
var resultModelCache = make(map[string][]string)
// 获取实体对象的字段名
func getFieldNames(obj interface{}) []string {
objType := indirectType(reflect.TypeOf(obj))
cacheKey := objType.PkgPath() + "." + objType.Name()
cache := resultModelCache[cacheKey]
if cache != nil {
return cache
}
cache = getFieldNamesByType("", reflect.TypeOf(obj))
resultModelCache[cacheKey] = cache
return cache
}
func indirectType(reflectType reflect.Type) reflect.Type {
for reflectType.Kind() == reflect.Ptr || reflectType.Kind() == reflect.Slice {
reflectType = reflectType.Elem()
}
return reflectType
}
func getFieldNamesByType(namePrefix string, reflectType reflect.Type) []string {
var fieldNames []string
if reflectType = indirectType(reflectType); reflectType.Kind() == reflect.Struct {
for i := 0; i < reflectType.NumField(); i++ {
t := reflectType.Field(i)
tName := t.Name
// 判断结构体字段是否为结构体,是的话则跳过
it := indirectType(t.Type)
if it.Kind() == reflect.Struct {
itName := it.Name()
// 如果包含Time或time则表示为time类型无需递归该结构体字段
if !strings.Contains(itName, "BaseModel") && !strings.Contains(itName, "Time") &&
!strings.Contains(itName, "time") {
fieldNames = append(fieldNames, getFieldNamesByType(tName+"__", it)...)
continue
}
}
if t.Anonymous {
fieldNames = append(fieldNames, getFieldNamesByType("", t.Type)...)
} else {
fieldNames = append(fieldNames, namePrefix+tName)
}
}
}
return fieldNames
}
func ormParams2Struct(maps []orm.Params, structs interface{}) error {
structsV := reflect.Indirect(reflect.ValueOf(structs))
valType := structsV.Type()
valElemType := valType.Elem()
sliceType := reflect.SliceOf(valElemType)
length := len(maps)
valSlice := structsV
if valSlice.IsNil() {
// Make a new slice to hold our result, same size as the original data.
valSlice = reflect.MakeSlice(sliceType, length, length)
}
for i := 0; i < length; i++ {
err := utils.Map2Struct(maps[i], valSlice.Index(i).Addr().Interface())
if err != nil {
return err
}
}
structsV.Set(valSlice)
return nil
}