重新实现套餐相关功能

This commit is contained in:
GoEdgeLab
2023-09-06 16:30:47 +08:00
parent 3d5f0a69a8
commit 04d2678221
26 changed files with 2804 additions and 233 deletions

View File

@@ -61,15 +61,28 @@ func (this *PlanDAO) DisablePlan(tx *dbs.Tx, id int64) error {
} }
// FindEnabledPlan 查找启用中的条目 // FindEnabledPlan 查找启用中的条目
func (this *PlanDAO) FindEnabledPlan(tx *dbs.Tx, id int64) (*Plan, error) { func (this *PlanDAO) FindEnabledPlan(tx *dbs.Tx, planId int64, cacheMap *utils.CacheMap) (*Plan, error) {
var cacheKey = this.Table + ":FindEnabledPlan:" + types.String(planId)
if cacheMap != nil {
cache, _ := cacheMap.Get(cacheKey)
if cache != nil {
return cache.(*Plan), nil
}
}
result, err := this.Query(tx). result, err := this.Query(tx).
Pk(id). Pk(planId).
Attr("state", PlanStateEnabled). Attr("state", PlanStateEnabled).
Find() Find()
if result == nil { if result == nil {
return nil, err return nil, err
} }
return result.(*Plan), err
if cacheMap != nil {
cacheMap.Put(cacheKey, result)
}
return result.(*Plan), nil
} }
// FindPlanName 根据主键查找名称 // FindPlanName 根据主键查找名称

View File

@@ -0,0 +1,6 @@
package models_test
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -2,6 +2,26 @@ package models
import "github.com/iwind/TeaGo/dbs" import "github.com/iwind/TeaGo/dbs"
const (
PlanField_Id dbs.FieldName = "id" // ID
PlanField_IsOn dbs.FieldName = "isOn" // 是否启用
PlanField_Name dbs.FieldName = "name" // 套餐名
PlanField_ClusterId dbs.FieldName = "clusterId" // 集群ID
PlanField_TrafficLimit dbs.FieldName = "trafficLimit" // 流量限制
PlanField_Features dbs.FieldName = "features" // 允许的功能
PlanField_TrafficPrice dbs.FieldName = "trafficPrice" // 流量价格设定
PlanField_BandwidthPrice dbs.FieldName = "bandwidthPrice" // 带宽价格
PlanField_MonthlyPrice dbs.FieldName = "monthlyPrice" // 月付
PlanField_SeasonallyPrice dbs.FieldName = "seasonallyPrice" // 季付
PlanField_YearlyPrice dbs.FieldName = "yearlyPrice" // 年付
PlanField_PriceType dbs.FieldName = "priceType" // 价格类型
PlanField_Order dbs.FieldName = "order" // 排序
PlanField_State dbs.FieldName = "state" // 状态
PlanField_TotalServers dbs.FieldName = "totalServers" // 可以绑定的网站数量
PlanField_TotalServerNamesPerServer dbs.FieldName = "totalServerNamesPerServer" // 每个网站可以绑定的域名数量
PlanField_TotalServerNames dbs.FieldName = "totalServerNames" // 总域名数量
)
// Plan 用户套餐 // Plan 用户套餐
type Plan struct { type Plan struct {
Id uint32 `field:"id"` // ID Id uint32 `field:"id"` // ID
@@ -18,23 +38,29 @@ type Plan struct {
PriceType string `field:"priceType"` // 价格类型 PriceType string `field:"priceType"` // 价格类型
Order uint32 `field:"order"` // 排序 Order uint32 `field:"order"` // 排序
State uint8 `field:"state"` // 状态 State uint8 `field:"state"` // 状态
TotalServers uint32 `field:"totalServers"` // 可以绑定的网站数量
TotalServerNamesPerServer uint32 `field:"totalServerNamesPerServer"` // 每个网站可以绑定的域名数量
TotalServerNames uint32 `field:"totalServerNames"` // 总域名数量
} }
type PlanOperator struct { type PlanOperator struct {
Id interface{} // ID Id any // ID
IsOn interface{} // 是否启用 IsOn any // 是否启用
Name interface{} // 套餐名 Name any // 套餐名
ClusterId interface{} // 集群ID ClusterId any // 集群ID
TrafficLimit interface{} // 流量限制 TrafficLimit any // 流量限制
Features interface{} // 允许的功能 Features any // 允许的功能
TrafficPrice interface{} // 流量价格设定 TrafficPrice any // 流量价格设定
BandwidthPrice interface{} // 带宽价格 BandwidthPrice any // 带宽价格
MonthlyPrice interface{} // 月付 MonthlyPrice any // 月付
SeasonallyPrice interface{} // 季付 SeasonallyPrice any // 季付
YearlyPrice interface{} // 年付 YearlyPrice any // 年付
PriceType interface{} // 价格类型 PriceType any // 价格类型
Order interface{} // 排序 Order any // 排序
State interface{} // 状态 State any // 状态
TotalServers any // 可以绑定的网站数量
TotalServerNamesPerServer any // 每个网站可以绑定的域名数量
TotalServerNames any // 总域名数量
} }
func NewPlanOperator() *PlanOperator { func NewPlanOperator() *PlanOperator {

View File

@@ -25,7 +25,7 @@ import (
type ServerBandwidthStatDAO dbs.DAO type ServerBandwidthStatDAO dbs.DAO
const ( const (
ServerBandwidthStatTablePartials = 20 // 分表数量 ServerBandwidthStatTablePartitions = 20 // 分表数量
) )
func init() { func init() {
@@ -63,15 +63,15 @@ func init() {
} }
// UpdateServerBandwidth 写入数据 // UpdateServerBandwidth 写入数据
// 暂时不使用region区分 // 现在不需要把 userPlanId 加入到数据表unique key中因为只会影响5分钟统计影响非常有限
func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int64, serverId int64, regionId int64, day string, timeAt string, bytes int64, totalBytes int64, cachedBytes int64, attackBytes int64, countRequests int64, countCachedRequests int64, countAttackRequests int64) error { func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int64, serverId int64, regionId int64, userPlanId int64, day string, timeAt string, bandwidthBytes int64, totalBytes int64, cachedBytes int64, attackBytes int64, countRequests int64, countCachedRequests int64, countAttackRequests int64) error {
if serverId <= 0 { if serverId <= 0 {
return errors.New("invalid server id '" + types.String(serverId) + "'") return errors.New("invalid server id '" + types.String(serverId) + "'")
} }
return this.Query(tx). return this.Query(tx).
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Param("bytes", bytes). Param("bytes", bandwidthBytes).
Param("totalBytes", totalBytes). Param("totalBytes", totalBytes).
Param("cachedBytes", cachedBytes). Param("cachedBytes", cachedBytes).
Param("attackBytes", attackBytes). Param("attackBytes", attackBytes).
@@ -84,7 +84,7 @@ func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int
"regionId": regionId, "regionId": regionId,
"day": day, "day": day,
"timeAt": timeAt, "timeAt": timeAt,
"bytes": bytes, "bytes": bandwidthBytes,
"totalBytes": totalBytes, "totalBytes": totalBytes,
"avgBytes": totalBytes / 300, "avgBytes": totalBytes / 300,
"cachedBytes": cachedBytes, "cachedBytes": cachedBytes,
@@ -92,6 +92,7 @@ func (this *ServerBandwidthStatDAO) UpdateServerBandwidth(tx *dbs.Tx, userId int
"countRequests": countRequests, "countRequests": countRequests,
"countCachedRequests": countCachedRequests, "countCachedRequests": countCachedRequests,
"countAttackRequests": countAttackRequests, "countAttackRequests": countAttackRequests,
"userPlanId": userPlanId,
}, maps.Map{ }, maps.Map{
"bytes": dbs.SQL("bytes+:bytes"), "bytes": dbs.SQL("bytes+:bytes"),
"avgBytes": dbs.SQL("(totalBytes+:totalBytes)/300"), // 因为生成SQL语句时会自动将avgBytes排在totalBytes之前所以这里不用担心先后顺序的问题 "avgBytes": dbs.SQL("(totalBytes+:totalBytes)/300"), // 因为生成SQL语句时会自动将avgBytes排在totalBytes之前所以这里不用担心先后顺序的问题
@@ -379,14 +380,18 @@ func (this *ServerBandwidthStatDAO) FindAllServerStatsWithMonth(tx *dbs.Tx, serv
} }
// FindMonthlyPercentile 获取某月内百分位 // FindMonthlyPercentile 获取某月内百分位
func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId int64, month string, percentile int, useAvg bool) (result int64, err error) { func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId int64, month string, percentile int, useAvg bool, noPlan bool) (result int64, err error) {
if percentile <= 0 { if percentile <= 0 {
percentile = 95 percentile = 95
} }
// 如果是100%以上,则快速返回 // 如果是100%以上,则快速返回
if percentile >= 100 { if percentile >= 100 {
result, err = this.Query(tx). var query = this.Query(tx)
if noPlan {
query.Attr("userPlanId", 0)
}
result, err = query.
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Attr("serverId", serverId). Attr("serverId", serverId).
Result(this.bytesField(useAvg)). Result(this.bytesField(useAvg)).
@@ -398,7 +403,11 @@ func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId i
} }
// 总数量 // 总数量
total, err := this.Query(tx). var totalQuery = this.Query(tx)
if noPlan {
totalQuery.Attr("userPlanId", 0)
}
total, err := totalQuery.
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Attr("serverId", serverId). Attr("serverId", serverId).
Between("day", month+"01", month+"31"). Between("day", month+"01", month+"31").
@@ -417,7 +426,11 @@ func (this *ServerBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, serverId i
} }
// 查询 nth 位置 // 查询 nth 位置
result, err = this.Query(tx). var query = this.Query(tx)
if noPlan {
query.Attr("userPlanId", 0)
}
result, err = query.
Table(this.partialTable(serverId)). Table(this.partialTable(serverId)).
Attr("serverId", serverId). Attr("serverId", serverId).
Result(this.bytesField(useAvg)). Result(this.bytesField(useAvg)).
@@ -745,6 +758,74 @@ func (this *ServerBandwidthStatDAO) SumDailyStat(tx *dbs.Tx, serverId int64, reg
return return
} }
// SumMonthlyBytes 统计某个网站单月总流量
func (this *ServerBandwidthStatDAO) SumMonthlyBytes(tx *dbs.Tx, serverId int64, month string, noPlan bool) (int64, error) {
if !regexputils.YYYYMM.MatchString(month) {
return 0, errors.New("invalid month '" + month + "'")
}
// 兼容以往版本
hasFullData, err := this.HasFullData(tx, serverId, month)
if err != nil {
return 0, err
}
if !hasFullData {
return SharedServerDailyStatDAO.SumMonthlyBytes(tx, serverId, month)
}
var query = this.Query(tx)
if noPlan {
query.Attr("userPlanId", 0)
}
return query.
Table(this.partialTable(serverId)).
Between("day", month+"01", month+"31").
Attr("serverId", serverId).
SumInt64("totalBytes", 0)
}
// SumServerMonthlyWithRegion 根据服务计算某月合计
// month 格式为YYYYMM
func (this *ServerBandwidthStatDAO) SumServerMonthlyWithRegion(tx *dbs.Tx, serverId int64, regionId int64, month string, noPlan bool) (int64, error) {
var query = this.Query(tx)
query.Table(this.partialTable(serverId))
if regionId > 0 {
query.Attr("regionId", regionId)
}
if noPlan {
query.Attr("userPlanId", 0)
}
return query.Between("day", month+"01", month+"31").
Attr("serverId", serverId).
SumInt64("totalBytes", 0)
}
// FindDistinctServerIdsWithoutPlanAtPartition 查找没有绑定套餐的有流量网站
func (this *ServerBandwidthStatDAO) FindDistinctServerIdsWithoutPlanAtPartition(tx *dbs.Tx, partitionIndex int, month string) (serverIds []int64, err error) {
ones, err := this.Query(tx).
Table(this.partialTable(int64(partitionIndex))).
Between("day", month+"01", month+"31").
Attr("userPlanId", 0). // 没有绑定套餐
Result("DISTINCT serverId").
FindAll()
if err != nil {
return nil, err
}
for _, one := range ones {
var serverId = int64(one.(*ServerBandwidthStat).ServerId)
if serverId <= 0 {
continue
}
serverIds = append(serverIds, serverId)
}
return
}
// CountPartitions 查看分区数量
func (this *ServerBandwidthStatDAO) CountPartitions() int {
return ServerBandwidthStatTablePartitions
}
// CleanDays 清理过期数据 // CleanDays 清理过期数据
func (this *ServerBandwidthStatDAO) CleanDays(tx *dbs.Tx, days int) error { func (this *ServerBandwidthStatDAO) CleanDays(tx *dbs.Tx, days int) error {
var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days)) // 保留大约3个月的数据 var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days)) // 保留大约3个月的数据
@@ -777,9 +858,9 @@ func (this *ServerBandwidthStatDAO) CleanDefaultDays(tx *dbs.Tx, defaultDays int
func (this *ServerBandwidthStatDAO) runBatch(f func(table string, locker *sync.Mutex) error) error { func (this *ServerBandwidthStatDAO) runBatch(f func(table string, locker *sync.Mutex) error) error {
var locker = &sync.Mutex{} var locker = &sync.Mutex{}
var wg = sync.WaitGroup{} var wg = sync.WaitGroup{}
wg.Add(ServerBandwidthStatTablePartials) wg.Add(ServerBandwidthStatTablePartitions)
var resultErr error var resultErr error
for i := 0; i < ServerBandwidthStatTablePartials; i++ { for i := 0; i < ServerBandwidthStatTablePartitions; i++ {
var table = this.partialTable(int64(i)) var table = this.partialTable(int64(i))
go func(table string) { go func(table string) {
defer wg.Done() defer wg.Done()
@@ -796,7 +877,7 @@ func (this *ServerBandwidthStatDAO) runBatch(f func(table string, locker *sync.M
// 获取分区表 // 获取分区表
func (this *ServerBandwidthStatDAO) partialTable(serverId int64) string { func (this *ServerBandwidthStatDAO) partialTable(serverId int64) string {
return this.Table + "_" + types.String(serverId%int64(ServerBandwidthStatTablePartials)) return this.Table + "_" + types.String(serverId%int64(ServerBandwidthStatTablePartitions))
} }
// 获取字节字段 // 获取字节字段

View File

@@ -16,7 +16,7 @@ import (
func TestServerBandwidthStatDAO_UpdateServerBandwidth(t *testing.T) { func TestServerBandwidthStatDAO_UpdateServerBandwidth(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO() var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx var tx *dbs.Tx
err := dao.UpdateServerBandwidth(tx, 1, 1, 0, timeutil.Format("Ymd"), timeutil.FormatTime("Hi", time.Now().Unix()/300*300), 1024, 300, 0, 0, 0, 0, 0) err := dao.UpdateServerBandwidth(tx, 1, 1, 0, 0, timeutil.Format("Ymd"), timeutil.FormatTime("Hi", time.Now().Unix()/300*300), 1024, 300, 0, 0, 0, 0, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -33,7 +33,7 @@ func TestSeverBandwidthStatDAO_InsertManyStats(t *testing.T) {
} }
var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -rands.Int(0, 200))) var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -rands.Int(0, 200)))
var minute = fmt.Sprintf("%02d%02d", rands.Int(0, 23), rands.Int(0, 59)) var minute = fmt.Sprintf("%02d%02d", rands.Int(0, 23), rands.Int(0, 59))
err := dao.UpdateServerBandwidth(tx, 1, int64(rands.Int(1, 10000)), 0, day, minute, 1024, 300, 0, 0, 0, 0, 0) err := dao.UpdateServerBandwidth(tx, 1, int64(rands.Int(1, 10000)), 0, 0, day, minute, 1024, 300, 0, 0, 0, 0, 0)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -44,8 +44,9 @@ func TestSeverBandwidthStatDAO_InsertManyStats(t *testing.T) {
func TestServerBandwidthStatDAO_FindMonthlyPercentile(t *testing.T) { func TestServerBandwidthStatDAO_FindMonthlyPercentile(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO() var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx var tx *dbs.Tx
t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, false)) t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, false, false))
t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, true)) t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, true, false))
t.Log(dao.FindMonthlyPercentile(tx, 23, timeutil.Format("Ym"), 95, true, true))
} }
func TestServerBandwidthStatDAO_FindAllServerStatsWithMonth(t *testing.T) { func TestServerBandwidthStatDAO_FindAllServerStatsWithMonth(t *testing.T) {
@@ -114,3 +115,32 @@ func TestServerBandwidthStatDAO_FindBandwidthStatsBetweenDays(t *testing.T) {
t.Log(stat.Day, stat.TimeAt, "bytes:", stat.Bytes, "bits:", stat.Bits) t.Log(stat.Day, stat.TimeAt, "bytes:", stat.Bytes, "bits:", stat.Bits)
} }
} }
func TestServerBandwidthStatDAO_SumServerMonthlyWithRegion(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx
{
totalBytes, err := dao.SumServerMonthlyWithRegion(tx, 23, 0, timeutil.Format("Ym"), false)
if err != nil {
t.Fatal(err)
}
t.Log("with plan:", totalBytes)
}
{
totalBytes, err := dao.SumServerMonthlyWithRegion(tx, 23, 0, timeutil.Format("Ym"), true)
if err != nil {
t.Fatal(err)
}
t.Log("without plan:", totalBytes)
}
}
func TestServerBandwidthStatDAO_SumMonthlyBytes(t *testing.T) {
var dao = models.NewServerBandwidthStatDAO()
var tx *dbs.Tx
totalBytes, err := dao.SumMonthlyBytes(tx, 23, timeutil.Format("Ym"), false)
if err != nil {
t.Fatal(err)
}
t.Log("total bytes:", totalBytes)
}

View File

@@ -1,11 +1,32 @@
package models package models
import "github.com/iwind/TeaGo/dbs"
const (
ServerBandwidthStatField_Id dbs.FieldName = "id" // ID
ServerBandwidthStatField_UserId dbs.FieldName = "userId" // 用户ID
ServerBandwidthStatField_ServerId dbs.FieldName = "serverId" // 服务ID
ServerBandwidthStatField_RegionId dbs.FieldName = "regionId" // 区域ID
ServerBandwidthStatField_UserPlanId dbs.FieldName = "userPlanId" // 用户套餐ID
ServerBandwidthStatField_Day dbs.FieldName = "day" // 日期YYYYMMDD
ServerBandwidthStatField_TimeAt dbs.FieldName = "timeAt" // 时间点HHMM
ServerBandwidthStatField_Bytes dbs.FieldName = "bytes" // 带宽字节
ServerBandwidthStatField_AvgBytes dbs.FieldName = "avgBytes" // 平均流量
ServerBandwidthStatField_CachedBytes dbs.FieldName = "cachedBytes" // 缓存的流量
ServerBandwidthStatField_AttackBytes dbs.FieldName = "attackBytes" // 攻击流量
ServerBandwidthStatField_CountRequests dbs.FieldName = "countRequests" // 请求数
ServerBandwidthStatField_CountCachedRequests dbs.FieldName = "countCachedRequests" // 缓存的请求数
ServerBandwidthStatField_CountAttackRequests dbs.FieldName = "countAttackRequests" // 攻击请求数
ServerBandwidthStatField_TotalBytes dbs.FieldName = "totalBytes" // 总流量
)
// ServerBandwidthStat 服务峰值带宽统计 // ServerBandwidthStat 服务峰值带宽统计
type ServerBandwidthStat struct { type ServerBandwidthStat struct {
Id uint64 `field:"id"` // ID Id uint64 `field:"id"` // ID
UserId uint64 `field:"userId"` // 用户ID UserId uint64 `field:"userId"` // 用户ID
ServerId uint64 `field:"serverId"` // 服务ID ServerId uint64 `field:"serverId"` // 服务ID
RegionId uint32 `field:"regionId"` // 区域ID RegionId uint32 `field:"regionId"` // 区域ID
UserPlanId uint64 `field:"userPlanId"` // 用户套餐ID
Day string `field:"day"` // 日期YYYYMMDD Day string `field:"day"` // 日期YYYYMMDD
TimeAt string `field:"timeAt"` // 时间点HHMM TimeAt string `field:"timeAt"` // 时间点HHMM
Bytes uint64 `field:"bytes"` // 带宽字节 Bytes uint64 `field:"bytes"` // 带宽字节
@@ -23,6 +44,7 @@ type ServerBandwidthStatOperator struct {
UserId any // 用户ID UserId any // 用户ID
ServerId any // 服务ID ServerId any // 服务ID
RegionId any // 区域ID RegionId any // 区域ID
UserPlanId any // 用户套餐ID
Day any // 日期YYYYMMDD Day any // 日期YYYYMMDD
TimeAt any // 时间点HHMM TimeAt any // 时间点HHMM
Bytes any // 带宽字节 Bytes any // 带宽字节

View File

@@ -119,7 +119,7 @@ func (this *ServerDailyStatDAO) SaveStats(tx *dbs.Tx, stats []*pb.ServerDailySta
// 更新流量限制状态 // 更新流量限制状态
if stat.CheckTrafficLimiting { if stat.CheckTrafficLimiting {
trafficLimitConfig, err := SharedServerDAO.CalculateServerTrafficLimitConfig(tx, stat.ServerId, cacheMap) trafficLimitConfig, err := SharedServerDAO.FindServerTrafficLimitConfig(tx, stat.ServerId, cacheMap)
if err != nil { if err != nil {
return err return err
} }
@@ -129,7 +129,7 @@ func (this *ServerDailyStatDAO) SaveStats(tx *dbs.Tx, stats []*pb.ServerDailySta
return err return err
} }
err = SharedServerDAO.UpdateServerTrafficLimitStatus(tx, trafficLimitConfig, stat.ServerId, false) err = SharedServerDAO.RenewServerTrafficLimitStatus(tx, trafficLimitConfig, stat.ServerId, false)
if err != nil { if err != nil {
return err return err
} }
@@ -140,6 +140,7 @@ func (this *ServerDailyStatDAO) SaveStats(tx *dbs.Tx, stats []*pb.ServerDailySta
return nil return nil
} }
// SumCurrentDailyStat 查找当前时刻的数据统计 // SumCurrentDailyStat 查找当前时刻的数据统计
func (this *ServerDailyStatDAO) SumCurrentDailyStat(tx *dbs.Tx, serverId int64) (*ServerDailyStat, error) { func (this *ServerDailyStatDAO) SumCurrentDailyStat(tx *dbs.Tx, serverId int64) (*ServerDailyStat, error) {
var day = timeutil.Format("Ymd") var day = timeutil.Format("Ymd")
@@ -164,7 +165,7 @@ func (this *ServerDailyStatDAO) SumServerMonthlyWithRegion(tx *dbs.Tx, serverId
if regionId > 0 { if regionId > 0 {
query.Attr("regionId", regionId) query.Attr("regionId", regionId)
} }
return query.Between("day", month+"01", month+"32"). return query.Between("day", month+"01", month+"31").
Attr("serverId", serverId). Attr("serverId", serverId).
SumInt64("bytes", 0) SumInt64("bytes", 0)
} }
@@ -178,7 +179,7 @@ func (this *ServerDailyStatDAO) SumUserMonthlyWithoutPlan(tx *dbs.Tx, userId int
} }
return query. return query.
Attr("planId", 0). Attr("planId", 0).
Between("day", month+"01", month+"32"). Between("day", month+"01", month+"31").
Attr("userId", userId). Attr("userId", userId).
SumInt64("bytes", 0) SumInt64("bytes", 0)
} }
@@ -190,7 +191,7 @@ func (this *ServerDailyStatDAO) SumUserMonthlyPeek(tx *dbs.Tx, userId int64, reg
if regionId > 0 { if regionId > 0 {
query.Attr("regionId", regionId) query.Attr("regionId", regionId)
} }
max, err := query.Between("day", month+"01", month+"32"). max, err := query.Between("day", month+"01", month+"31").
Attr("userId", userId). Attr("userId", userId).
Max("bytes", 0) Max("bytes", 0)
if err != nil { if err != nil {
@@ -644,7 +645,7 @@ func (this *ServerDailyStatDAO) FindStatsBetweenDays(tx *dbs.Tx, userId int64, s
// month YYYYMM // month YYYYMM
func (this *ServerDailyStatDAO) FindMonthlyStatsWithPlan(tx *dbs.Tx, month string) (result []*ServerDailyStat, err error) { func (this *ServerDailyStatDAO) FindMonthlyStatsWithPlan(tx *dbs.Tx, month string) (result []*ServerDailyStat, err error) {
_, err = this.Query(tx). _, err = this.Query(tx).
Between("day", month+"01", month+"32"). Between("day", month+"01", month+"31").
Gt("planId", 0). Gt("planId", 0).
Slice(&result). Slice(&result).
FindAll() FindAll()

View File

@@ -3,11 +3,13 @@ package models
import ( import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const" teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
dbutils "github.com/TeaOSLab/EdgeAPI/internal/db/utils" dbutils "github.com/TeaOSLab/EdgeAPI/internal/db/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
@@ -782,7 +784,7 @@ func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) {
// 参数: // 参数:
// //
// groupId 分组ID如果为-1则搜索没有分组的服务 // groupId 分组ID如果为-1则搜索没有分组的服务
func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string) (int64, error) { func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string, userPlanId int64) (int64, error) {
query := this.Query(tx). query := this.Query(tx).
State(ServerStateEnabled) State(ServerStateEnabled)
if groupId > 0 { if groupId > 0 {
@@ -829,6 +831,10 @@ func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, ke
query.Where("(" + strings.Join(protocolConds, " OR ") + ")") query.Where("(" + strings.Join(protocolConds, " OR ") + ")")
} }
if userPlanId > 0 {
query.Attr("userPlanId", userPlanId)
}
return query.Count() return query.Count()
} }
@@ -1316,12 +1322,13 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
} }
// 套餐是否依然有效 // 套餐是否依然有效
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId)) plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if plan != nil { if plan != nil {
config.UserPlan = &serverconfigs.UserPlanConfig{ config.UserPlan = &serverconfigs.UserPlanConfig{
Id: int64(userPlan.Id),
DayTo: userPlan.DayTo, DayTo: userPlan.DayTo,
Plan: &serverconfigs.PlanConfig{ Plan: &serverconfigs.PlanConfig{
Id: int64(plan.Id), Id: int64(plan.Id),
@@ -1341,7 +1348,6 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
} }
} }
if config.TrafficLimit != nil && config.TrafficLimit.IsOn && !config.TrafficLimit.IsEmpty() {
if len(server.TrafficLimitStatus) > 0 { if len(server.TrafficLimitStatus) > 0 {
var status = &serverconfigs.TrafficLimitStatus{} var status = &serverconfigs.TrafficLimitStatus{}
err := json.Unmarshal(server.TrafficLimitStatus, status) err := json.Unmarshal(server.TrafficLimitStatus, status)
@@ -1352,7 +1358,6 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
config.TrafficLimitStatus = status config.TrafficLimitStatus = status
} }
} }
}
// UAM // UAM
if !forList { if !forList {
@@ -1794,6 +1799,7 @@ func (this *ServerDAO) FindServerUserId(tx *dbs.Tx, serverId int64) (userId int6
} }
// FindServerUserPlanId 查找服务的套餐ID // FindServerUserPlanId 查找服务的套餐ID
// TODO 需要缓存
func (this *ServerDAO) FindServerUserPlanId(tx *dbs.Tx, serverId int64) (userPlanId int64, err error) { func (this *ServerDAO) FindServerUserPlanId(tx *dbs.Tx, serverId int64) (userPlanId int64, err error) {
return this.Query(tx). return this.Query(tx).
Pk(serverId). Pk(serverId).
@@ -2306,95 +2312,18 @@ func (this *ServerDAO) FindServerTrafficLimitConfig(tx *dbs.Tx, serverId int64,
return nil, err return nil, err
} }
var limit = &serverconfigs.TrafficLimitConfig{}
if serverOne == nil {
return limit, nil
}
var trafficLimit = serverOne.(*Server).TrafficLimit
if len(trafficLimit) > 0 {
err = json.Unmarshal([]byte(trafficLimit), limit)
if err != nil {
return nil, err
}
}
if cacheMap != nil {
cacheMap.Put(cacheKey, limit)
}
return limit, nil
}
// CalculateServerTrafficLimitConfig 计算服务的流量限制
// TODO 优化性能
func (this *ServerDAO) CalculateServerTrafficLimitConfig(tx *dbs.Tx, serverId int64, cacheMap *utils.CacheMap) (*serverconfigs.TrafficLimitConfig, error) {
if cacheMap == nil {
cacheMap = utils.NewCacheMap()
}
var cacheKey = this.Table + ":FindServerTrafficLimitConfig:" + types.String(serverId)
result, ok := cacheMap.Get(cacheKey)
if ok {
return result.(*serverconfigs.TrafficLimitConfig), nil
}
serverOne, err := this.Query(tx).
Pk(serverId).
Result("trafficLimit", "userPlanId").
Find()
if err != nil {
return nil, err
}
var limitConfig = &serverconfigs.TrafficLimitConfig{} var limitConfig = &serverconfigs.TrafficLimitConfig{}
if serverOne == nil { if serverOne == nil {
return limitConfig, nil return limitConfig, nil
} }
var trafficLimit = serverOne.(*Server).TrafficLimit var trafficLimitJSON = serverOne.(*Server).TrafficLimit
var userPlanId = int64(serverOne.(*Server).UserPlanId)
if len(trafficLimit) == 0 { if len(trafficLimitJSON) > 0 {
if userPlanId > 0 { err = json.Unmarshal(trafficLimitJSON, limitConfig)
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, cacheMap)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if userPlan != nil {
planLimit, err := SharedPlanDAO.FindEnabledPlanTrafficLimit(tx, int64(userPlan.PlanId), cacheMap)
if err != nil {
return nil, err
}
if planLimit != nil {
return planLimit, nil
}
}
}
return limitConfig, nil
}
err = json.Unmarshal(trafficLimit, limitConfig)
if err != nil {
return nil, err
}
if !limitConfig.IsOn {
if userPlanId > 0 {
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, cacheMap)
if err != nil {
return nil, err
}
if userPlan != nil {
planLimit, err := SharedPlanDAO.FindEnabledPlanTrafficLimit(tx, int64(userPlan.PlanId), cacheMap)
if err != nil {
return nil, err
}
if planLimit != nil {
return planLimit, nil
}
}
}
} }
if cacheMap != nil { if cacheMap != nil {
@@ -2423,11 +2352,11 @@ func (this *ServerDAO) UpdateServerTrafficLimitConfig(tx *dbs.Tx, serverId int64
} }
// 更新状态 // 更新状态
return this.UpdateServerTrafficLimitStatus(tx, trafficLimitConfig, serverId, true) return this.RenewServerTrafficLimitStatus(tx, trafficLimitConfig, serverId, true)
} }
// UpdateServerTrafficLimitStatus 修改服务的流量限制状态 // RenewServerTrafficLimitStatus 根据限流配置更新网站的流量限制状态
func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitConfig *serverconfigs.TrafficLimitConfig, serverId int64, isUpdatingConfig bool) error { func (this *ServerDAO) RenewServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitConfig *serverconfigs.TrafficLimitConfig, serverId int64, isUpdatingConfig bool) error {
if !trafficLimitConfig.IsOn { if !trafficLimitConfig.IsOn {
if isUpdatingConfig { if isUpdatingConfig {
return this.NotifyUpdate(tx, serverId) return this.NotifyUpdate(tx, serverId)
@@ -2464,9 +2393,11 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
var untilDay = "" var untilDay = ""
// daily // daily
var dateType = ""
if trafficLimitConfig.DailyBytes() > 0 { if trafficLimitConfig.DailyBytes() > 0 {
if server.TrafficDay == timeutil.Format("Ymd") && server.TotalDailyTraffic >= float64(trafficLimitConfig.DailyBytes())/(1<<30) { if server.TrafficDay == timeutil.Format("Ymd") && server.TotalDailyTraffic >= float64(trafficLimitConfig.DailyBytes())/(1<<30) {
untilDay = timeutil.Format("Ymd") untilDay = timeutil.Format("Ymd")
dateType = "day"
} }
} }
@@ -2474,6 +2405,7 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
if server.TrafficMonth == timeutil.Format("Ym") && trafficLimitConfig.MonthlyBytes() > 0 { if server.TrafficMonth == timeutil.Format("Ym") && trafficLimitConfig.MonthlyBytes() > 0 {
if server.TotalMonthlyTraffic >= float64(trafficLimitConfig.MonthlyBytes())/(1<<30) { if server.TotalMonthlyTraffic >= float64(trafficLimitConfig.MonthlyBytes())/(1<<30) {
untilDay = timeutil.Format("Ym32") untilDay = timeutil.Format("Ym32")
dateType = "month"
} }
} }
@@ -2481,12 +2413,16 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
if trafficLimitConfig.TotalBytes() > 0 { if trafficLimitConfig.TotalBytes() > 0 {
if server.TotalTraffic >= float64(trafficLimitConfig.TotalBytes())/(1<<30) { if server.TotalTraffic >= float64(trafficLimitConfig.TotalBytes())/(1<<30) {
untilDay = "30000101" untilDay = "30000101"
dateType = "total"
} }
} }
var isChanged = oldStatus.UntilDay != untilDay var isChanged = oldStatus.UntilDay != untilDay
if isChanged { if isChanged {
statusJSON, err := json.Marshal(&serverconfigs.TrafficLimitStatus{UntilDay: untilDay}) statusJSON, err := json.Marshal(&serverconfigs.TrafficLimitStatus{
UntilDay: untilDay,
DateType: dateType,
})
if err != nil { if err != nil {
return err return err
} }
@@ -2507,6 +2443,90 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
return nil return nil
} }
// UpdateServerTrafficLimitStatus 修改网站的流量限制状态
func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, serverId int64, day string, planId int64, dateType string) error {
if !regexputils.YYYYMMDD.MatchString(day) {
return errors.New("invalid 'day' format")
}
if serverId <= 0 {
return nil
}
// lookup old status
statusJSON, err := this.Query(tx).
Pk(serverId).
Result(ServerField_TrafficLimitStatus).
FindJSONCol()
if err != nil {
return err
}
if IsNotNull(statusJSON) {
var oldStatus = &serverconfigs.TrafficLimitStatus{}
err = json.Unmarshal(statusJSON, oldStatus)
if err != nil {
return err
}
if len(oldStatus.UntilDay) > 0 && oldStatus.UntilDay >= day /** 如果已经限制,且比当前日期长,则无需重复 **/ {
// no need to change
return nil
}
}
var status = &serverconfigs.TrafficLimitStatus{
UntilDay: day,
PlanId: planId,
DateType: dateType,
}
statusJSON, err = json.Marshal(status)
if err != nil {
return err
}
err = this.Query(tx).
Pk(serverId).
Set(ServerField_TrafficLimitStatus, statusJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
// UpdateServersTrafficLimitStatusWithUserPlanId 修改某个套餐下的网站的流量限制状态
func (this *ServerDAO) UpdateServersTrafficLimitStatusWithUserPlanId(tx *dbs.Tx, userPlanId int64, day string, planId int64, dateType string) error {
if userPlanId <= 0 {
return nil
}
servers, err := this.Query(tx).
State(ServerStateEnabled).
Attr("userPlanId", userPlanId).
ResultPk().
FindAll()
if err != nil {
return err
}
for _, server := range servers {
var serverId = int64(server.(*Server).Id)
err = this.UpdateServerTrafficLimitStatus(tx, serverId, day, planId, dateType)
if err != nil {
return err
}
}
return nil
}
// ResetServersTrafficLimitStatusWithPlanId 重置网站限流状态
func (this *ServerDAO) ResetServersTrafficLimitStatusWithPlanId(tx *dbs.Tx, planId int64) error {
return this.Query(tx).
Where("JSON_EXTRACT(trafficLimitStatus, '$.planId')=:planId").
Param("planId", planId).
Set("trafficLimitStatus", dbs.SQL("NULL")).
UpdateQuickly()
}
// IncreaseServerTotalTraffic 增加服务的总流量 // IncreaseServerTotalTraffic 增加服务的总流量
func (this *ServerDAO) IncreaseServerTotalTraffic(tx *dbs.Tx, serverId int64, bytes int64) error { func (this *ServerDAO) IncreaseServerTotalTraffic(tx *dbs.Tx, serverId int64, bytes int64) error {
if serverId <= 0 { if serverId <= 0 {
@@ -2548,17 +2568,16 @@ func (this *ServerDAO) FindEnabledServerIdWithUserPlanId(tx *dbs.Tx, userPlanId
FindInt64Col(0) FindInt64Col(0)
} }
// FindEnabledServerWithUserPlanId 查找使用某个套餐的服务 // FindEnabledServersWithUserPlanId 查找使用某个套餐的网站
func (this *ServerDAO) FindEnabledServerWithUserPlanId(tx *dbs.Tx, userPlanId int64) (*Server, error) { func (this *ServerDAO) FindEnabledServersWithUserPlanId(tx *dbs.Tx, userPlanId int64) (result []*Server, err error) {
one, err := this.Query(tx). _, err = this.Query(tx).
State(ServerStateEnabled). State(ServerStateEnabled).
Attr("userPlanId", userPlanId). Attr("userPlanId", userPlanId).
Result("id", "name", "serverNames", "type"). Result("id", "name", "serverNames", "type").
Find() AscPk().
if err != nil || one == nil { Slice(&result).
return nil, err FindAll()
} return
return one.(*Server), nil
} }
// UpdateServersClusterIdWithPlanId 修改套餐所在集群 // UpdateServersClusterIdWithPlanId 修改套餐所在集群
@@ -2643,7 +2662,7 @@ func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPl
return errors.New("can not find user plan with id '" + types.String(userPlanId) + "'") return errors.New("can not find user plan with id '" + types.String(userPlanId) + "'")
} }
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId)) plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil { if err != nil {
return err return err
} }
@@ -2881,6 +2900,89 @@ func (this *ServerDAO) FindEnabledServersWithIds(tx *dbs.Tx, serverIds []int64)
return return
} }
// CountAllServerNamesWithUserId 计算某个用户下的所有域名数
func (this *ServerDAO) CountAllServerNamesWithUserId(tx *dbs.Tx, userId int64, userPlanId int64) (int64, error) {
if userId <= 0 {
return 0, nil
}
var query = this.Query(tx).
Attr("userId", userId).
State(ServerStateEnabled).
Where("JSON_TYPE(plainServerNames)='ARRAY'")
if userPlanId > 0 {
query.Attr("userPlanId", userPlanId)
}
return query.
SumInt64("JSON_LENGTH(plainServerNames)", 0)
}
// CountServerNames 计算某个网站下的所有域名数
func (this *ServerDAO) CountServerNames(tx *dbs.Tx, serverId int64) (int64, error) {
if serverId <= 0 {
return 0, nil
}
return this.Query(tx).
Result("JSON_LENGTH(plainServerNames)").
Pk(serverId).
State(ServerStateEnabled).
Where("JSON_TYPE(plainServerNames)='ARRAY'").
FindInt64Col(0)
}
// CheckServerPlanQuota 检查网站套餐限制
func (this *ServerDAO) CheckServerPlanQuota(tx *dbs.Tx, serverId int64, countServerNames int) error {
if serverId <= 0 {
return errors.New("invalid 'serverId'")
}
if countServerNames <= 0 {
return nil
}
userPlanId, err := this.FindServerUserPlanId(tx, serverId)
if err != nil {
return err
}
if userPlanId <= 0 {
return nil
}
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, nil)
if err != nil {
return err
}
if userPlan == nil {
return fmt.Errorf("invalid user plan with id %q", types.String(userPlanId))
}
if userPlan.IsExpired() {
return errors.New("the user plan has been expired")
}
if userPlan.UserId == 0 {
return nil
}
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil {
return err
}
if plan == nil {
return fmt.Errorf("invalid plan with id %q", types.String(userPlan.PlanId))
}
if plan.TotalServerNames > 0 {
totalServerNames, err := this.CountAllServerNamesWithUserId(tx, int64(userPlan.UserId), userPlanId)
if err != nil {
return err
}
if totalServerNames+int64(countServerNames) > int64(plan.TotalServerNames) {
return errors.New("server names over plan quota")
}
}
if plan.TotalServerNamesPerServer > 0 {
if countServerNames > types.Int(plan.TotalServerNamesPerServer) {
return errors.New("server names per server over plan quota")
}
}
return nil
}
// NotifyUpdate 同步服务所在的集群 // NotifyUpdate 同步服务所在的集群
func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error { func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error {
if serverId <= 0 { if serverId <= 0 {

View File

@@ -242,7 +242,7 @@ func TestServerDAO_FindEnabledServerWithDomain(t *testing.T) {
} }
} }
func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) { func TestServerDAO_RenewServerTrafficLimitStatus(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var tx *dbs.Tx var tx *dbs.Tx
@@ -250,7 +250,7 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
defer func() { defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms") t.Log(time.Since(before).Seconds()*1000, "ms")
}() }()
err := models.NewServerDAO().UpdateServerTrafficLimitStatus(tx, &serverconfigs.TrafficLimitConfig{ err := models.NewServerDAO().RenewServerTrafficLimitStatus(tx, &serverconfigs.TrafficLimitConfig{
IsOn: true, IsOn: true,
DailySize: &shared.SizeCapacity{Count: 1, Unit: "mb"}, DailySize: &shared.SizeCapacity{Count: 1, Unit: "mb"},
MonthlySize: &shared.SizeCapacity{Count: 10, Unit: "mb"}, MonthlySize: &shared.SizeCapacity{Count: 10, Unit: "mb"},
@@ -263,40 +263,15 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
t.Log("ok") t.Log("ok")
} }
func TestServerDAO_CalculateServerTrafficLimitConfig(t *testing.T) { func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var dao = models.NewServerDAO()
var tx *dbs.Tx var tx *dbs.Tx
before := time.Now() err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day")
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
var cacheMap = utils.NewCacheMap()
config, err := models.SharedServerDAO.CalculateServerTrafficLimitConfig(tx, 23, cacheMap)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
logs.PrintAsJSON(config, t)
}
func TestServerDAO_CalculateServerTrafficLimitConfig_Cache(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
before := time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
var cacheMap = utils.NewCacheMap()
for i := 0; i < 10; i++ {
config, err := models.SharedServerDAO.CalculateServerTrafficLimitConfig(tx, 23, cacheMap)
if err != nil {
t.Fatal(err)
}
_ = config
}
} }
func TestServerDAO_FindBytes(t *testing.T) { func TestServerDAO_FindBytes(t *testing.T) {

View File

@@ -0,0 +1,239 @@
package models
import (
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/goman"
"github.com/TeaOSLab/EdgeAPI/internal/remotelogs"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
timeutil "github.com/iwind/TeaGo/utils/time"
"math"
"sync"
"time"
)
type UserPlanBandwidthStatDAO dbs.DAO
const (
UserPlanBandwidthStatTablePartitions = 20 // 分表数量
)
func init() {
dbs.OnReadyDone(func() {
// 清理数据任务
var ticker = time.NewTicker(time.Duration(rands.Int(24, 48)) * time.Hour)
goman.New(func() {
for range ticker.C {
err := SharedUserPlanBandwidthStatDAO.CleanDefaultDays(nil, 100)
if err != nil {
remotelogs.Error("SharedUserPlanBandwidthStatDAO", "clean expired data failed: "+err.Error())
}
}
})
})
}
func NewUserPlanBandwidthStatDAO() *UserPlanBandwidthStatDAO {
return dbs.NewDAO(&UserPlanBandwidthStatDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeUserPlanBandwidthStats",
Model: new(UserPlanBandwidthStat),
PkName: "id",
},
}).(*UserPlanBandwidthStatDAO)
}
var SharedUserPlanBandwidthStatDAO *UserPlanBandwidthStatDAO
func init() {
dbs.OnReady(func() {
SharedUserPlanBandwidthStatDAO = NewUserPlanBandwidthStatDAO()
})
}
// UpdateUserPlanBandwidth 写入数据
// 暂时不使用region区分
func (this *UserPlanBandwidthStatDAO) UpdateUserPlanBandwidth(tx *dbs.Tx, userId int64, userPlanId int64, regionId int64, day string, timeAt string, bandwidthBytes int64, totalBytes int64, cachedBytes int64, attackBytes int64, countRequests int64, countCachedRequests int64, countAttackRequests int64) error {
if userId <= 0 || userPlanId <= 0 {
return nil
}
return this.Query(tx).
Table(this.partialTable(userPlanId)).
Param("bytes", bandwidthBytes).
Param("totalBytes", totalBytes).
Param("cachedBytes", cachedBytes).
Param("attackBytes", attackBytes).
Param("countRequests", countRequests).
Param("countCachedRequests", countCachedRequests).
Param("countAttackRequests", countAttackRequests).
InsertOrUpdateQuickly(maps.Map{
"userId": userId,
"userPlanId": userPlanId,
"regionId": regionId,
"day": day,
"timeAt": timeAt,
"bytes": bandwidthBytes,
"totalBytes": totalBytes,
"avgBytes": totalBytes / 300,
"cachedBytes": cachedBytes,
"attackBytes": attackBytes,
"countRequests": countRequests,
"countCachedRequests": countCachedRequests,
"countAttackRequests": countAttackRequests,
}, maps.Map{
"bytes": dbs.SQL("bytes+:bytes"),
"avgBytes": dbs.SQL("(totalBytes+:totalBytes)/300"), // 因为生成SQL语句时会自动将avgBytes排在totalBytes之前所以这里不用担心先后顺序的问题
"totalBytes": dbs.SQL("totalBytes+:totalBytes"),
"cachedBytes": dbs.SQL("cachedBytes+:cachedBytes"),
"attackBytes": dbs.SQL("attackBytes+:attackBytes"),
"countRequests": dbs.SQL("countRequests+:countRequests"),
"countCachedRequests": dbs.SQL("countCachedRequests+:countCachedRequests"),
"countAttackRequests": dbs.SQL("countAttackRequests+:countAttackRequests"),
})
}
// FindMonthlyPercentile 获取某月内百分位
func (this *UserPlanBandwidthStatDAO) FindMonthlyPercentile(tx *dbs.Tx, userPlanId int64, month string, percentile int, useAvg bool) (result int64, err error) {
if percentile <= 0 {
percentile = 95
}
// 如果是100%以上,则快速返回
if percentile >= 100 {
result, err = this.Query(tx).
Table(this.partialTable(userPlanId)).
Attr("userPlanId", userPlanId).
Result(this.sumBytesField(useAvg)).
Between("day", month+"01", month+"31").
Group("day").
Group("timeAt").
Desc("bytes").
Limit(1).
FindInt64Col(0)
return
}
// 总数量
total, err := this.Query(tx).
Table(this.partialTable(userPlanId)).
Attr("userPlanId", userPlanId).
Between("day", month+"01", month+"31").
CountAttr("DISTINCT day, timeAt")
if err != nil {
return 0, err
}
if total == 0 {
return 0, nil
}
var offset int64
if total > 1 {
offset = int64(math.Ceil(float64(total) * float64(100-percentile) / 100))
}
// 查询 nth 位置
result, err = this.Query(tx).
Table(this.partialTable(userPlanId)).
Attr("userPlanId", userPlanId).
Result(this.sumBytesField(useAvg)).
Between("day", month+"01", month+"31").
Group("day").
Group("timeAt").
Desc("bytes").
Offset(offset).
Limit(1).
FindInt64Col(0)
return
}
// SumMonthlyBytes 读取单月总流量
func (this *UserPlanBandwidthStatDAO) SumMonthlyBytes(tx *dbs.Tx, userPlanId int64, month string) (int64, error) {
if !regexputils.YYYYMM.MatchString(month) {
return 0, errors.New("invalid ")
}
return this.Query(tx).
Table(this.partialTable(userPlanId)).
Attr("userPlanId", userPlanId).
Between("day", month+"01", month+"31").
SumInt64("totalBytes", 0)
}
// CleanDefaultDays 清理过期数据
func (this *UserPlanBandwidthStatDAO) CleanDefaultDays(tx *dbs.Tx, defaultDays int) error {
databaseConfig, err := SharedSysSettingDAO.ReadDatabaseConfig(tx)
if err != nil {
return err
}
if databaseConfig != nil && databaseConfig.UserPlanBandwidthStat.Clean.Days > 0 {
defaultDays = databaseConfig.UserPlanBandwidthStat.Clean.Days
}
if defaultDays <= 0 {
defaultDays = 100
}
return this.CleanDays(tx, defaultDays)
}
// CleanDays 清理过期数据
func (this *UserPlanBandwidthStatDAO) CleanDays(tx *dbs.Tx, days int) error {
var day = timeutil.Format("Ymd", time.Now().AddDate(0, 0, -days)) // 保留大约3个月的数据
return this.runBatch(func(table string, locker *sync.Mutex) error {
_, err := this.Query(tx).
Table(table).
Lt("day", day).
Delete()
return err
})
}
// 获取字节字段
func (this *UserPlanBandwidthStatDAO) bytesField(useAvg bool) string {
if useAvg {
return "avgBytes AS bytes"
}
return "bytes"
}
func (this *UserPlanBandwidthStatDAO) sumBytesField(useAvg bool) string {
if useAvg {
return "SUM(avgBytes) AS bytes"
}
return "SUM(bytes) AS bytes"
}
// 批量执行
func (this *UserPlanBandwidthStatDAO) runBatch(f func(table string, locker *sync.Mutex) error) error {
var locker = &sync.Mutex{}
var wg = sync.WaitGroup{}
wg.Add(UserPlanBandwidthStatTablePartitions)
var resultErr error
for i := 0; i < UserPlanBandwidthStatTablePartitions; i++ {
var table = this.partialTable(int64(i))
go func(table string) {
defer wg.Done()
err := f(table, locker)
if err != nil {
resultErr = err
}
}(table)
}
wg.Wait()
return resultErr
}
// 获取分区表
func (this *UserPlanBandwidthStatDAO) partialTable(userPlanId int64) string {
return this.Table + "_" + types.String(userPlanId%int64(UserPlanBandwidthStatTablePartitions))
}

View File

@@ -0,0 +1,39 @@
package models_test
import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
"github.com/iwind/TeaGo/dbs"
timeutil "github.com/iwind/TeaGo/utils/time"
"testing"
)
func TestUserPlanBandwidthStatDAO_FindMonthlyPercentile(t *testing.T) {
var dao = models.NewUserPlanBandwidthStatDAO()
var tx *dbs.Tx
{
resultBytes, err := dao.FindMonthlyPercentile(tx, 20, timeutil.Format("Ym"), 100, false)
if err != nil {
t.Fatal(err)
}
t.Log("result bytes0:", resultBytes)
}
{
resultBytes, err := dao.FindMonthlyPercentile(tx, 20, timeutil.Format("Ym"), 95, false)
if err != nil {
t.Fatal(err)
}
t.Log("result bytes1:", resultBytes)
}
{
resultBytes, err := dao.FindMonthlyPercentile(tx, 20, timeutil.Format("Ym"), 95, true)
if err != nil {
t.Fatal(err)
}
t.Log("result bytes2:", resultBytes)
}
}

View File

@@ -0,0 +1,59 @@
package models
import "github.com/iwind/TeaGo/dbs"
const (
UserPlanBandwidthStatField_Id dbs.FieldName = "id" // ID
UserPlanBandwidthStatField_UserId dbs.FieldName = "userId" // 用户ID
UserPlanBandwidthStatField_UserPlanId dbs.FieldName = "userPlanId" // 用户套餐ID
UserPlanBandwidthStatField_Day dbs.FieldName = "day" // 日期YYYYMMDD
UserPlanBandwidthStatField_TimeAt dbs.FieldName = "timeAt" // 时间点HHII
UserPlanBandwidthStatField_Bytes dbs.FieldName = "bytes" // 带宽
UserPlanBandwidthStatField_RegionId dbs.FieldName = "regionId" // 区域ID
UserPlanBandwidthStatField_TotalBytes dbs.FieldName = "totalBytes" // 总流量
UserPlanBandwidthStatField_AvgBytes dbs.FieldName = "avgBytes" // 平均流量
UserPlanBandwidthStatField_CachedBytes dbs.FieldName = "cachedBytes" // 缓存的流量
UserPlanBandwidthStatField_AttackBytes dbs.FieldName = "attackBytes" // 攻击流量
UserPlanBandwidthStatField_CountRequests dbs.FieldName = "countRequests" // 请求数
UserPlanBandwidthStatField_CountCachedRequests dbs.FieldName = "countCachedRequests" // 缓存的请求数
UserPlanBandwidthStatField_CountAttackRequests dbs.FieldName = "countAttackRequests" // 攻击请求数
)
// UserPlanBandwidthStat 用户套餐带宽峰值
type UserPlanBandwidthStat struct {
Id uint64 `field:"id"` // ID
UserId uint64 `field:"userId"` // 用户ID
UserPlanId uint64 `field:"userPlanId"` // 用户套餐ID
Day string `field:"day"` // 日期YYYYMMDD
TimeAt string `field:"timeAt"` // 时间点HHII
Bytes uint64 `field:"bytes"` // 带宽
RegionId uint32 `field:"regionId"` // 区域ID
TotalBytes uint64 `field:"totalBytes"` // 总流量
AvgBytes uint64 `field:"avgBytes"` // 平均流量
CachedBytes uint64 `field:"cachedBytes"` // 缓存的流量
AttackBytes uint64 `field:"attackBytes"` // 攻击流量
CountRequests uint64 `field:"countRequests"` // 请求数
CountCachedRequests uint64 `field:"countCachedRequests"` // 缓存的请求数
CountAttackRequests uint64 `field:"countAttackRequests"` // 攻击请求数
}
type UserPlanBandwidthStatOperator struct {
Id any // ID
UserId any // 用户ID
UserPlanId any // 用户套餐ID
Day any // 日期YYYYMMDD
TimeAt any // 时间点HHII
Bytes any // 带宽
RegionId any // 区域ID
TotalBytes any // 总流量
AvgBytes any // 平均流量
CachedBytes any // 缓存的流量
AttackBytes any // 攻击流量
CountRequests any // 请求数
CountCachedRequests any // 缓存的请求数
CountAttackRequests any // 攻击请求数
}
func NewUserPlanBandwidthStatOperator() *UserPlanBandwidthStatOperator {
return &UserPlanBandwidthStatOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -1 +1,8 @@
package models package models
import timeutil "github.com/iwind/TeaGo/utils/time"
// IsExpired 判断套餐是否过期
func (this *UserPlan) IsExpired() bool {
return len(this.DayTo) == 0 || this.DayTo < timeutil.Format("Y-m-d")
}

View File

@@ -0,0 +1,28 @@
package models
import (
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
)
type UserPlanStatDAO dbs.DAO
func NewUserPlanStatDAO() *UserPlanStatDAO {
return dbs.NewDAO(&UserPlanStatDAO{
DAOObject: dbs.DAOObject{
DB: Tea.Env,
Table: "edgeUserPlanStats",
Model: new(UserPlanStat),
PkName: "id",
},
}).(*UserPlanStatDAO)
}
var SharedUserPlanStatDAO *UserPlanStatDAO
func init() {
dbs.OnReady(func() {
SharedUserPlanStatDAO = NewUserPlanStatDAO()
})
}

View File

@@ -0,0 +1,10 @@
// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
//go:build !plus
package models
import "github.com/iwind/TeaGo/dbs"
func (this *UserPlanStatDAO) IncreaseUserPlanStat(tx *dbs.Tx, userPlanId int64, trafficBytes int64, countRequests int64) error {
return nil
}

View File

@@ -0,0 +1,6 @@
package models_test
import (
_ "github.com/go-sql-driver/mysql"
_ "github.com/iwind/TeaGo/bootstrap"
)

View File

@@ -0,0 +1,38 @@
package models
import "github.com/iwind/TeaGo/dbs"
const (
UserPlanStatField_Id dbs.FieldName = "id" // ID
UserPlanStatField_UserPlanId dbs.FieldName = "userPlanId" // 用户套餐ID
UserPlanStatField_Date dbs.FieldName = "date" // 日期YYYYMMDD或YYYYMM
UserPlanStatField_DateType dbs.FieldName = "dateType" // 日期类型day|month
UserPlanStatField_TrafficBytes dbs.FieldName = "trafficBytes" // 流量
UserPlanStatField_CountRequests dbs.FieldName = "countRequests" // 总请求数
UserPlanStatField_IsProcessed dbs.FieldName = "isProcessed" // 是否已处理
)
// UserPlanStat 用户套餐统计
type UserPlanStat struct {
Id uint64 `field:"id"` // ID
UserPlanId uint64 `field:"userPlanId"` // 用户套餐ID
Date string `field:"date"` // 日期YYYYMMDD或YYYYMM
DateType string `field:"dateType"` // 日期类型day|month
TrafficBytes uint64 `field:"trafficBytes"` // 流量
CountRequests uint64 `field:"countRequests"` // 总请求数
IsProcessed bool `field:"isProcessed"` // 是否已处理
}
type UserPlanStatOperator struct {
Id any // ID
UserPlanId any // 用户套餐ID
Date any // 日期YYYYMMDD或YYYYMM
DateType any // 日期类型day|month
TrafficBytes any // 流量
CountRequests any // 总请求数
IsProcessed any // 是否已处理
}
func NewUserPlanStatOperator() *UserPlanStatOperator {
return &UserPlanStatOperator{}
}

View File

@@ -0,0 +1 @@
package models

View File

@@ -549,7 +549,7 @@ func (this *AdminService) ComposeAdminDashboard(ctx context.Context, req *pb.Com
result.CountServers = countServers result.CountServers = countServers
this.BeginTag(ctx, "SharedServerDAO.CountAllEnabledServersMatch") this.BeginTag(ctx, "SharedServerDAO.CountAllEnabledServersMatch")
countAuditingServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", 0, 0, configutils.BoolStateYes, nil) countAuditingServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", 0, 0, configutils.BoolStateYes, nil, 0)
this.EndTag(ctx, "SharedServerDAO.CountAllEnabledServersMatch") this.EndTag(ctx, "SharedServerDAO.CountAllEnabledServersMatch")
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -10,6 +10,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/domainutils" "github.com/TeaOSLab/EdgeAPI/internal/utils/domainutils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
@@ -147,7 +148,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe
} }
// 套餐 // 套餐
plan, err := models.SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId)) plan, err := models.SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1076,6 +1077,12 @@ func (this *ServerService) UpdateServerNames(ctx context.Context, req *pb.Update
} }
} }
// 套餐额度限制
err = models.SharedServerDAO.CheckServerPlanQuota(tx, req.ServerId, len(serverconfigs.PlainServerNames(serverNameConfigs)))
if err != nil {
return nil, err
}
// 检查用户 // 检查用户
if userId > 0 { if userId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
@@ -1278,7 +1285,7 @@ func (this *ServerService) CountAllEnabledServersMatch(ctx context.Context, req
var tx = this.NullTx() var tx = this.NullTx()
count, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, req.ServerGroupId, req.Keyword, req.UserId, req.NodeClusterId, types.Int8(req.AuditingFlag), utils.SplitStrings(req.ProtocolFamily, ",")) count, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, req.ServerGroupId, req.Keyword, req.UserId, req.NodeClusterId, types.Int8(req.AuditingFlag), utils.SplitStrings(req.ProtocolFamily, ","), req.UserPlanId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2019,6 +2026,52 @@ func (this *ServerService) FindAllEnabledServerNamesWithUserId(ctx context.Conte
return &pb.FindAllEnabledServerNamesWithUserIdResponse{ServerNames: serverNames}, nil return &pb.FindAllEnabledServerNamesWithUserIdResponse{ServerNames: serverNames}, nil
} }
// CountAllServerNamesWithUserId 计算一个用户下的所有域名数量
func (this *ServerService) CountAllServerNamesWithUserId(ctx context.Context, req *pb.CountAllServerNamesWithUserIdRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
count, err := models.SharedServerDAO.CountAllServerNamesWithUserId(tx, req.UserId, req.UserPlanId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// CountServerNames 计算某个网站下的域名数量
func (this *ServerService) CountServerNames(ctx context.Context, req *pb.CountServerNamesRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if req.ServerId <= 0 {
return nil, errors.New("invalid 'serverId'")
}
var tx = this.NullTx()
if userId > 0 {
err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId)
if err != nil {
return nil, err
}
}
count, err := models.SharedServerDAO.CountServerNames(tx, req.ServerId)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// FindAllUserServers 查找一个用户下的所有服务 // FindAllUserServers 查找一个用户下的所有服务
func (this *ServerService) FindAllUserServers(ctx context.Context, req *pb.FindAllUserServersRequest) (*pb.FindAllUserServersResponse, error) { func (this *ServerService) FindAllUserServers(ctx context.Context, req *pb.FindAllUserServersRequest) (*pb.FindAllUserServersResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true) _, userId, err := this.ValidateAdminAndUser(ctx, true)
@@ -2051,6 +2104,26 @@ func (this *ServerService) FindAllUserServers(ctx context.Context, req *pb.FindA
}, nil }, nil
} }
// CountAllUserServers 计算一个用户下的所有网站数量
func (this *ServerService) CountAllUserServers(ctx context.Context, req *pb.CountAllUserServersRequest) (*pb.RPCCountResponse, error) {
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if userId > 0 {
req.UserId = userId
}
var tx = this.NullTx()
countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll, nil, req.UserPlanId)
if err != nil {
return nil, err
}
return this.SuccessCount(countServers)
}
// ComposeAllUserServersConfig 查找某个用户下的服务配置 // ComposeAllUserServersConfig 查找某个用户下的服务配置
func (this *ServerService) ComposeAllUserServersConfig(ctx context.Context, req *pb.ComposeAllUserServersConfigRequest) (*pb.ComposeAllUserServersConfigResponse, error) { func (this *ServerService) ComposeAllUserServersConfig(ctx context.Context, req *pb.ComposeAllUserServersConfigRequest) (*pb.ComposeAllUserServersConfigResponse, error) {
_, err := this.ValidateNode(ctx) _, err := this.ValidateNode(ctx)
@@ -2644,7 +2717,7 @@ func (this *ServerService) UpdateServerUserPlan(ctx context.Context, req *pb.Upd
} }
if req.UserPlanId > 0 { if req.UserPlanId > 0 {
userId, err := models.SharedServerDAO.FindServerUserId(tx, req.ServerId) userId, err = models.SharedServerDAO.FindServerUserId(tx, req.ServerId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2662,14 +2735,47 @@ func (this *ServerService) UpdateServerUserPlan(ctx context.Context, req *pb.Upd
if int64(userPlan.UserId) != userId { if int64(userPlan.UserId) != userId {
return nil, errors.New("can not find user plan with id '" + types.String(req.UserPlanId) + "'") return nil, errors.New("can not find user plan with id '" + types.String(req.UserPlanId) + "'")
} }
if userPlan.IsExpired() {
return nil, fmt.Errorf("the user plan %q has been expired", types.String(req.UserPlanId))
}
// 检查是否已经被别的服务所使用 // 检查限制
serverId, err := models.SharedServerDAO.FindEnabledServerIdWithUserPlanId(tx, req.UserPlanId) plan, err := models.SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if serverId > 0 && serverId != req.ServerId { if plan == nil {
return nil, errors.New("the user plan is used by other server") return nil, errors.New("can not find plan with id '" + types.String(userPlan.PlanId) + "'")
}
if plan.TotalServers > 0 {
countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", userId, 0, configutils.BoolStateAll, nil, req.UserPlanId)
if err != nil {
return nil, err
}
if countServers+1 > int64(plan.TotalServers) {
return nil, errors.New("total servers over quota")
}
}
countServerNames, err := models.SharedServerDAO.CountServerNames(tx, req.ServerId)
if err != nil {
return nil, err
}
if plan.TotalServerNamesPerServer > 0 {
if countServerNames > int64(plan.TotalServerNamesPerServer) {
return nil, errors.New("total server names per server over quota")
}
}
totalServerNames, err := models.SharedServerDAO.CountAllServerNamesWithUserId(tx, userId, req.UserPlanId)
if err != nil {
return nil, err
}
if plan.TotalServerNames > 0 {
if totalServerNames+countServerNames > int64(plan.TotalServerNames) {
return nil, errors.New("total server names over quota")
}
} }
} }
@@ -2714,7 +2820,7 @@ func (this *ServerService) FindServerUserPlan(ctx context.Context, req *pb.FindS
return &pb.FindServerUserPlanResponse{UserPlan: nil}, nil return &pb.FindServerUserPlanResponse{UserPlan: nil}, nil
} }
plan, err := models.SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId)) plan, err := models.SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -2737,6 +2843,9 @@ func (this *ServerService) FindServerUserPlan(ctx context.Context, req *pb.FindS
PriceType: plan.PriceType, PriceType: plan.PriceType,
TrafficPriceJSON: plan.TrafficPrice, TrafficPriceJSON: plan.TrafficPrice,
TrafficLimitJSON: plan.TrafficLimit, TrafficLimitJSON: plan.TrafficLimit,
TotalServers: types.Int32(plan.TotalServers),
TotalServerNames: types.Int32(plan.TotalServerNames),
TotalServerNamesPerServer: types.Int32(plan.TotalServerNamesPerServer),
}, },
}, },
}, nil }, nil

View File

@@ -63,17 +63,31 @@ func init() {
} }
for _, stat := range m { for _, stat := range m {
// 更新服务的带宽峰值 // 更新网站的带宽峰值
if stat.ServerId > 0 { if stat.ServerId > 0 {
err = models.SharedServerBandwidthStatDAO.UpdateServerBandwidth(tx, stat.UserId, stat.ServerId, stat.NodeRegionId, stat.Day, stat.TimeAt, stat.Bytes, stat.TotalBytes, stat.CachedBytes, stat.AttackBytes, stat.CountRequests, stat.CountCachedRequests, stat.CountAttackRequests) // 更新带宽统计
err = models.SharedServerBandwidthStatDAO.UpdateServerBandwidth(tx, stat.UserId, stat.ServerId, stat.NodeRegionId, stat.UserPlanId, stat.Day, stat.TimeAt, stat.Bytes, stat.TotalBytes, stat.CachedBytes, stat.AttackBytes, stat.CountRequests, stat.CountCachedRequests, stat.CountAttackRequests)
if err != nil { if err != nil {
remotelogs.Error("ServerBandwidthStatService", "dump bandwidth stats failed: "+err.Error()) remotelogs.Error("ServerBandwidthStatService", "dump bandwidth stats failed: "+err.Error())
} }
// 更新网站的bandwidth字段方便快速排序
err = models.SharedServerDAO.UpdateServerBandwidth(tx, stat.ServerId, stat.Day+stat.TimeAt, stat.Bytes, stat.CountRequests, stat.CountAttackRequests) err = models.SharedServerDAO.UpdateServerBandwidth(tx, stat.ServerId, stat.Day+stat.TimeAt, stat.Bytes, stat.CountRequests, stat.CountAttackRequests)
if err != nil { if err != nil {
remotelogs.Error("ServerBandwidthStatService", "update server bandwidth failed: "+err.Error()) remotelogs.Error("ServerBandwidthStatService", "update server bandwidth failed: "+err.Error())
} }
// 套餐统计
if stat.UserPlanId > 0 {
// 总体统计
err = models.SharedUserPlanStatDAO.IncreaseUserPlanStat(tx, stat.UserPlanId, stat.TotalBytes, stat.CountRequests)
if err != nil {
remotelogs.Error("ServerBandwidthStatService", "IncreaseUserPlanStat: "+err.Error())
}
// 分时统计
err = models.SharedUserPlanBandwidthStatDAO.UpdateUserPlanBandwidth(tx, stat.UserId, stat.UserPlanId, stat.NodeRegionId, stat.Day, stat.TimeAt, stat.Bytes, stat.TotalBytes, stat.CachedBytes, stat.AttackBytes, stat.CountRequests, stat.CountCachedRequests, stat.CountAttackRequests)
}
} }
// 更新用户的带宽峰值 // 更新用户的带宽峰值
@@ -147,6 +161,7 @@ func (this *ServerBandwidthStatService) UploadServerBandwidthStats(ctx context.C
CountRequests: stat.CountRequests, CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests, CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests, CountAttackRequests: stat.CountAttackRequests,
UserPlanId: stat.UserPlanId,
} }
} }
serverBandwidthStatsLocker.Unlock() serverBandwidthStatsLocker.Unlock()

View File

@@ -412,7 +412,7 @@ func (this *UserService) ComposeUserDashboard(ctx context.Context, req *pb.Compo
} }
// 网站数量 // 网站数量
countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll, []string{}) countServers, err := models.SharedServerDAO.CountAllEnabledServersMatch(tx, 0, "", req.UserId, 0, configutils.BoolStateAll, []string{}, 0)
if err != nil { if err != nil {
return nil, err return nil, err
} }

File diff suppressed because it is too large Load Diff

View File

@@ -95,6 +95,9 @@ var upgradeFuncs = []*upgradeVersion{
{ {
"1.2.1", upgradeV1_2_1, "1.2.1", upgradeV1_2_1,
}, },
{
"1.2.9", upgradeV1_2_9,
},
} }
// UpgradeSQLData 升级SQL数据 // UpgradeSQLData 升级SQL数据
@@ -746,3 +749,24 @@ func upgradeV1_2_1(db *dbs.DB) error {
} }
return nil return nil
} }
// v1.2.9
func upgradeV1_2_9(db *dbs.DB) error {
// 升级套餐网站数限制
{
_, err := db.Exec("UPDATE edgePlans SET totalServers=1 WHERE totalServers=0")
if err != nil {
return err
}
}
// 升级网站流量限制状态
{
_, err := db.Exec("UPDATE edgeServers SET trafficLimitStatus=NULL WHERE trafficLimitStatus IS NOT NULL")
if err != nil {
return err
}
}
return nil
}

View File

@@ -271,4 +271,21 @@ func TestUpgradeSQLData_v1_2_1(t *testing.T) {
t.Log("ok") t.Log("ok")
} }
func TestUpgradeSQLData_v1_2_9(t *testing.T) {
db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{
Driver: "mysql",
Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s",
Prefix: "edge",
})
if err != nil {
t.Fatal(err)
}
defer func() {
_ = db.Close()
}()
err = upgradeV1_2_9(db)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}