重新实现套餐相关功能

This commit is contained in:
GoEdgeLab
2023-09-06 16:30:47 +08:00
parent 3d5f0a69a8
commit 04d2678221
26 changed files with 2804 additions and 233 deletions

View File

@@ -3,11 +3,13 @@ package models
import (
"encoding/json"
"errors"
"fmt"
teaconst "github.com/TeaOSLab/EdgeAPI/internal/const"
"github.com/TeaOSLab/EdgeAPI/internal/db/models/dns"
dbutils "github.com/TeaOSLab/EdgeAPI/internal/db/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils"
"github.com/TeaOSLab/EdgeAPI/internal/utils/regexputils"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
@@ -782,7 +784,7 @@ func (this *ServerDAO) CountAllEnabledServers(tx *dbs.Tx) (int64, error) {
// 参数:
//
// groupId 分组ID如果为-1则搜索没有分组的服务
func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string) (int64, error) {
func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, keyword string, userId int64, clusterId int64, auditingFlag configutils.BoolState, protocolFamilies []string, userPlanId int64) (int64, error) {
query := this.Query(tx).
State(ServerStateEnabled)
if groupId > 0 {
@@ -829,6 +831,10 @@ func (this *ServerDAO) CountAllEnabledServersMatch(tx *dbs.Tx, groupId int64, ke
query.Where("(" + strings.Join(protocolConds, " OR ") + ")")
}
if userPlanId > 0 {
query.Attr("userPlanId", userPlanId)
}
return query.Count()
}
@@ -1316,12 +1322,13 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
}
// 套餐是否依然有效
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId))
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), cacheMap)
if err != nil {
return nil, err
}
if plan != nil {
config.UserPlan = &serverconfigs.UserPlanConfig{
Id: int64(userPlan.Id),
DayTo: userPlan.DayTo,
Plan: &serverconfigs.PlanConfig{
Id: int64(plan.Id),
@@ -1341,16 +1348,14 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, server *Server, ignoreCer
}
}
if config.TrafficLimit != nil && config.TrafficLimit.IsOn && !config.TrafficLimit.IsEmpty() {
if len(server.TrafficLimitStatus) > 0 {
var status = &serverconfigs.TrafficLimitStatus{}
err := json.Unmarshal(server.TrafficLimitStatus, status)
if err != nil {
return nil, err
}
if status.IsValid() {
config.TrafficLimitStatus = status
}
if len(server.TrafficLimitStatus) > 0 {
var status = &serverconfigs.TrafficLimitStatus{}
err := json.Unmarshal(server.TrafficLimitStatus, status)
if err != nil {
return nil, err
}
if status.IsValid() {
config.TrafficLimitStatus = status
}
}
@@ -1794,6 +1799,7 @@ func (this *ServerDAO) FindServerUserId(tx *dbs.Tx, serverId int64) (userId int6
}
// FindServerUserPlanId 查找服务的套餐ID
// TODO 需要缓存
func (this *ServerDAO) FindServerUserPlanId(tx *dbs.Tx, serverId int64) (userPlanId int64, err error) {
return this.Query(tx).
Pk(serverId).
@@ -2306,94 +2312,17 @@ func (this *ServerDAO) FindServerTrafficLimitConfig(tx *dbs.Tx, serverId int64,
return nil, err
}
var limit = &serverconfigs.TrafficLimitConfig{}
if serverOne == nil {
return limit, nil
}
var trafficLimit = serverOne.(*Server).TrafficLimit
if len(trafficLimit) > 0 {
err = json.Unmarshal([]byte(trafficLimit), limit)
if err != nil {
return nil, err
}
}
if cacheMap != nil {
cacheMap.Put(cacheKey, limit)
}
return limit, nil
}
// CalculateServerTrafficLimitConfig 计算服务的流量限制
// TODO 优化性能
func (this *ServerDAO) CalculateServerTrafficLimitConfig(tx *dbs.Tx, serverId int64, cacheMap *utils.CacheMap) (*serverconfigs.TrafficLimitConfig, error) {
if cacheMap == nil {
cacheMap = utils.NewCacheMap()
}
var cacheKey = this.Table + ":FindServerTrafficLimitConfig:" + types.String(serverId)
result, ok := cacheMap.Get(cacheKey)
if ok {
return result.(*serverconfigs.TrafficLimitConfig), nil
}
serverOne, err := this.Query(tx).
Pk(serverId).
Result("trafficLimit", "userPlanId").
Find()
if err != nil {
return nil, err
}
var limitConfig = &serverconfigs.TrafficLimitConfig{}
if serverOne == nil {
return limitConfig, nil
}
var trafficLimit = serverOne.(*Server).TrafficLimit
var userPlanId = int64(serverOne.(*Server).UserPlanId)
var trafficLimitJSON = serverOne.(*Server).TrafficLimit
if len(trafficLimit) == 0 {
if userPlanId > 0 {
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, cacheMap)
if err != nil {
return nil, err
}
if userPlan != nil {
planLimit, err := SharedPlanDAO.FindEnabledPlanTrafficLimit(tx, int64(userPlan.PlanId), cacheMap)
if err != nil {
return nil, err
}
if planLimit != nil {
return planLimit, nil
}
}
}
return limitConfig, nil
}
err = json.Unmarshal(trafficLimit, limitConfig)
if err != nil {
return nil, err
}
if !limitConfig.IsOn {
if userPlanId > 0 {
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, cacheMap)
if err != nil {
return nil, err
}
if userPlan != nil {
planLimit, err := SharedPlanDAO.FindEnabledPlanTrafficLimit(tx, int64(userPlan.PlanId), cacheMap)
if err != nil {
return nil, err
}
if planLimit != nil {
return planLimit, nil
}
}
if len(trafficLimitJSON) > 0 {
err = json.Unmarshal(trafficLimitJSON, limitConfig)
if err != nil {
return nil, err
}
}
@@ -2423,11 +2352,11 @@ func (this *ServerDAO) UpdateServerTrafficLimitConfig(tx *dbs.Tx, serverId int64
}
// 更新状态
return this.UpdateServerTrafficLimitStatus(tx, trafficLimitConfig, serverId, true)
return this.RenewServerTrafficLimitStatus(tx, trafficLimitConfig, serverId, true)
}
// UpdateServerTrafficLimitStatus 修改服务的流量限制状态
func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitConfig *serverconfigs.TrafficLimitConfig, serverId int64, isUpdatingConfig bool) error {
// RenewServerTrafficLimitStatus 根据限流配置更新网站的流量限制状态
func (this *ServerDAO) RenewServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitConfig *serverconfigs.TrafficLimitConfig, serverId int64, isUpdatingConfig bool) error {
if !trafficLimitConfig.IsOn {
if isUpdatingConfig {
return this.NotifyUpdate(tx, serverId)
@@ -2464,9 +2393,11 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
var untilDay = ""
// daily
var dateType = ""
if trafficLimitConfig.DailyBytes() > 0 {
if server.TrafficDay == timeutil.Format("Ymd") && server.TotalDailyTraffic >= float64(trafficLimitConfig.DailyBytes())/(1<<30) {
untilDay = timeutil.Format("Ymd")
dateType = "day"
}
}
@@ -2474,6 +2405,7 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
if server.TrafficMonth == timeutil.Format("Ym") && trafficLimitConfig.MonthlyBytes() > 0 {
if server.TotalMonthlyTraffic >= float64(trafficLimitConfig.MonthlyBytes())/(1<<30) {
untilDay = timeutil.Format("Ym32")
dateType = "month"
}
}
@@ -2481,12 +2413,16 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
if trafficLimitConfig.TotalBytes() > 0 {
if server.TotalTraffic >= float64(trafficLimitConfig.TotalBytes())/(1<<30) {
untilDay = "30000101"
dateType = "total"
}
}
var isChanged = oldStatus.UntilDay != untilDay
if isChanged {
statusJSON, err := json.Marshal(&serverconfigs.TrafficLimitStatus{UntilDay: untilDay})
statusJSON, err := json.Marshal(&serverconfigs.TrafficLimitStatus{
UntilDay: untilDay,
DateType: dateType,
})
if err != nil {
return err
}
@@ -2507,6 +2443,90 @@ func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, trafficLimitCo
return nil
}
// UpdateServerTrafficLimitStatus 修改网站的流量限制状态
func (this *ServerDAO) UpdateServerTrafficLimitStatus(tx *dbs.Tx, serverId int64, day string, planId int64, dateType string) error {
if !regexputils.YYYYMMDD.MatchString(day) {
return errors.New("invalid 'day' format")
}
if serverId <= 0 {
return nil
}
// lookup old status
statusJSON, err := this.Query(tx).
Pk(serverId).
Result(ServerField_TrafficLimitStatus).
FindJSONCol()
if err != nil {
return err
}
if IsNotNull(statusJSON) {
var oldStatus = &serverconfigs.TrafficLimitStatus{}
err = json.Unmarshal(statusJSON, oldStatus)
if err != nil {
return err
}
if len(oldStatus.UntilDay) > 0 && oldStatus.UntilDay >= day /** 如果已经限制,且比当前日期长,则无需重复 **/ {
// no need to change
return nil
}
}
var status = &serverconfigs.TrafficLimitStatus{
UntilDay: day,
PlanId: planId,
DateType: dateType,
}
statusJSON, err = json.Marshal(status)
if err != nil {
return err
}
err = this.Query(tx).
Pk(serverId).
Set(ServerField_TrafficLimitStatus, statusJSON).
UpdateQuickly()
if err != nil {
return err
}
return this.NotifyUpdate(tx, serverId)
}
// UpdateServersTrafficLimitStatusWithUserPlanId 修改某个套餐下的网站的流量限制状态
func (this *ServerDAO) UpdateServersTrafficLimitStatusWithUserPlanId(tx *dbs.Tx, userPlanId int64, day string, planId int64, dateType string) error {
if userPlanId <= 0 {
return nil
}
servers, err := this.Query(tx).
State(ServerStateEnabled).
Attr("userPlanId", userPlanId).
ResultPk().
FindAll()
if err != nil {
return err
}
for _, server := range servers {
var serverId = int64(server.(*Server).Id)
err = this.UpdateServerTrafficLimitStatus(tx, serverId, day, planId, dateType)
if err != nil {
return err
}
}
return nil
}
// ResetServersTrafficLimitStatusWithPlanId 重置网站限流状态
func (this *ServerDAO) ResetServersTrafficLimitStatusWithPlanId(tx *dbs.Tx, planId int64) error {
return this.Query(tx).
Where("JSON_EXTRACT(trafficLimitStatus, '$.planId')=:planId").
Param("planId", planId).
Set("trafficLimitStatus", dbs.SQL("NULL")).
UpdateQuickly()
}
// IncreaseServerTotalTraffic 增加服务的总流量
func (this *ServerDAO) IncreaseServerTotalTraffic(tx *dbs.Tx, serverId int64, bytes int64) error {
if serverId <= 0 {
@@ -2548,17 +2568,16 @@ func (this *ServerDAO) FindEnabledServerIdWithUserPlanId(tx *dbs.Tx, userPlanId
FindInt64Col(0)
}
// FindEnabledServerWithUserPlanId 查找使用某个套餐的服务
func (this *ServerDAO) FindEnabledServerWithUserPlanId(tx *dbs.Tx, userPlanId int64) (*Server, error) {
one, err := this.Query(tx).
// FindEnabledServersWithUserPlanId 查找使用某个套餐的网站
func (this *ServerDAO) FindEnabledServersWithUserPlanId(tx *dbs.Tx, userPlanId int64) (result []*Server, err error) {
_, err = this.Query(tx).
State(ServerStateEnabled).
Attr("userPlanId", userPlanId).
Result("id", "name", "serverNames", "type").
Find()
if err != nil || one == nil {
return nil, err
}
return one.(*Server), nil
AscPk().
Slice(&result).
FindAll()
return
}
// UpdateServersClusterIdWithPlanId 修改套餐所在集群
@@ -2643,7 +2662,7 @@ func (this *ServerDAO) UpdateServerUserPlanId(tx *dbs.Tx, serverId int64, userPl
return errors.New("can not find user plan with id '" + types.String(userPlanId) + "'")
}
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId))
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil {
return err
}
@@ -2881,6 +2900,89 @@ func (this *ServerDAO) FindEnabledServersWithIds(tx *dbs.Tx, serverIds []int64)
return
}
// CountAllServerNamesWithUserId 计算某个用户下的所有域名数
func (this *ServerDAO) CountAllServerNamesWithUserId(tx *dbs.Tx, userId int64, userPlanId int64) (int64, error) {
if userId <= 0 {
return 0, nil
}
var query = this.Query(tx).
Attr("userId", userId).
State(ServerStateEnabled).
Where("JSON_TYPE(plainServerNames)='ARRAY'")
if userPlanId > 0 {
query.Attr("userPlanId", userPlanId)
}
return query.
SumInt64("JSON_LENGTH(plainServerNames)", 0)
}
// CountServerNames 计算某个网站下的所有域名数
func (this *ServerDAO) CountServerNames(tx *dbs.Tx, serverId int64) (int64, error) {
if serverId <= 0 {
return 0, nil
}
return this.Query(tx).
Result("JSON_LENGTH(plainServerNames)").
Pk(serverId).
State(ServerStateEnabled).
Where("JSON_TYPE(plainServerNames)='ARRAY'").
FindInt64Col(0)
}
// CheckServerPlanQuota 检查网站套餐限制
func (this *ServerDAO) CheckServerPlanQuota(tx *dbs.Tx, serverId int64, countServerNames int) error {
if serverId <= 0 {
return errors.New("invalid 'serverId'")
}
if countServerNames <= 0 {
return nil
}
userPlanId, err := this.FindServerUserPlanId(tx, serverId)
if err != nil {
return err
}
if userPlanId <= 0 {
return nil
}
userPlan, err := SharedUserPlanDAO.FindEnabledUserPlan(tx, userPlanId, nil)
if err != nil {
return err
}
if userPlan == nil {
return fmt.Errorf("invalid user plan with id %q", types.String(userPlanId))
}
if userPlan.IsExpired() {
return errors.New("the user plan has been expired")
}
if userPlan.UserId == 0 {
return nil
}
plan, err := SharedPlanDAO.FindEnabledPlan(tx, int64(userPlan.PlanId), nil)
if err != nil {
return err
}
if plan == nil {
return fmt.Errorf("invalid plan with id %q", types.String(userPlan.PlanId))
}
if plan.TotalServerNames > 0 {
totalServerNames, err := this.CountAllServerNamesWithUserId(tx, int64(userPlan.UserId), userPlanId)
if err != nil {
return err
}
if totalServerNames+int64(countServerNames) > int64(plan.TotalServerNames) {
return errors.New("server names over plan quota")
}
}
if plan.TotalServerNamesPerServer > 0 {
if countServerNames > types.Int(plan.TotalServerNamesPerServer) {
return errors.New("server names per server over plan quota")
}
}
return nil
}
// NotifyUpdate 同步服务所在的集群
func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error {
if serverId <= 0 {