重新实现套餐

This commit is contained in:
GoEdgeLab
2023-09-06 16:34:11 +08:00
parent 28bf65e88a
commit ca95ea5c3d
10 changed files with 71 additions and 20 deletions

View File

@@ -198,9 +198,9 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
var cost = time.Since(before).Seconds() var cost = time.Since(before).Seconds()
if cost > 1 { if cost > 1 {
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(float64(n)/cost), int64(n)) stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(float64(n)/cost), int64(n))
} else { } else {
stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(n), int64(n)) stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(n), int64(n))
} }
} }
} }

View File

@@ -16,6 +16,7 @@ type BaseClientConn struct {
isBound bool isBound bool
userId int64 userId int64
userPlanId int64
serverId int64 serverId int64
remoteAddr string remoteAddr string
hasLimit bool hasLimit bool
@@ -106,11 +107,31 @@ func (this *BaseClientConn) SetUserId(userId int64) {
} }
} }
func (this *BaseClientConn) SetUserPlanId(userPlanId int64) {
this.userPlanId = userPlanId
// 设置包装前连接
switch conn := this.rawConn.(type) {
case *tls.Conn:
nativeConn, ok := conn.NetConn().(ClientConnInterface)
if ok {
nativeConn.SetUserPlanId(userPlanId)
}
case *ClientConn:
conn.SetUserPlanId(userPlanId)
}
}
// UserId 获取当前连接所属服务的用户ID // UserId 获取当前连接所属服务的用户ID
func (this *BaseClientConn) UserId() int64 { func (this *BaseClientConn) UserId() int64 {
return this.userId return this.userId
} }
// UserPlanId 用户套餐ID
func (this *BaseClientConn) UserPlanId() int64 {
return this.userPlanId
}
// RawIP 原本IP // RawIP 原本IP
func (this *BaseClientConn) RawIP() string { func (this *BaseClientConn) RawIP() string {
if len(this.rawIP) > 0 { if len(this.rawIP) > 0 {

View File

@@ -18,9 +18,12 @@ type ClientConnInterface interface {
// SetServerId 设置服务ID // SetServerId 设置服务ID
SetServerId(serverId int64) (goNext bool) SetServerId(serverId int64) (goNext bool)
// SetUserId 设置所属服务的用户ID // SetUserId 设置所属网站的用户ID
SetUserId(userId int64) SetUserId(userId int64)
// SetUserPlanId 设置
SetUserPlanId(userPlanId int64)
// UserId 获取当前连接所属服务的用户ID // UserId 获取当前连接所属服务的用户ID
UserId() int64 UserId() int64

View File

@@ -198,7 +198,7 @@ func (this *HTTPRequest) Do() {
} }
// 流量限制 // 流量限制
if this.ReqServer.TrafficLimit != nil && this.ReqServer.TrafficLimit.IsOn && !this.ReqServer.TrafficLimit.IsEmpty() && this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() { if this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() {
this.doTrafficLimit() this.doTrafficLimit()
this.doEnd() this.doEnd()
return return

View File

@@ -8,15 +8,17 @@ import (
// 流量限制 // 流量限制
func (this *HTTPRequest) doTrafficLimit() { func (this *HTTPRequest) doTrafficLimit() {
var config = this.ReqServer.TrafficLimit this.tags = append(this.tags, "trafficLimit")
this.tags = append(this.tags, "bandwidth")
var statusCode = 509 var statusCode = 509
this.writer.statusCode = statusCode
this.ProcessResponseHeaders(this.writer.Header(), statusCode) this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
this.writer.WriteHeader(statusCode) this.writer.WriteHeader(statusCode)
if len(config.NoticePageBody) != 0 {
var config = this.ReqServer.TrafficLimit
if config != nil && len(config.NoticePageBody) != 0 {
_, _ = this.writer.WriteString(this.Format(config.NoticePageBody)) _, _ = this.writer.WriteString(this.Format(config.NoticePageBody))
} else { } else {
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody)) _, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody))

View File

@@ -177,6 +177,12 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
return return
} }
clientConn.SetUserId(server.UserId) clientConn.SetUserId(server.UserId)
var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
clientConn.SetUserPlanId(userPlanId)
} }
} }
} }

View File

@@ -80,6 +80,12 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
return nil return nil
} }
clientConn.SetUserId(server.UserId) clientConn.SetUserId(server.UserId)
var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
clientConn.SetUserPlanId(userPlanId)
} else { } else {
tlsConn, ok := conn.(*tls.Conn) tlsConn, ok := conn.(*tls.Conn)
if ok { if ok {
@@ -92,6 +98,12 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
return nil return nil
} }
clientConn.SetUserId(server.UserId) clientConn.SetUserId(server.UserId)
var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
clientConn.SetUserPlanId(userPlanId)
} }
} }
} }

View File

@@ -404,7 +404,11 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
// 带宽 // 带宽
stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, server.Id, int64(n), int64(n)) var userPlanId int64
if server.UserPlan != nil && server.UserPlan.Id > 0 {
userPlanId = server.UserPlan.Id
}
stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, userPlanId, server.Id, int64(n), int64(n))
} }
} }
if err != nil { if err != nil {

View File

@@ -62,6 +62,7 @@ type BandwidthStat struct {
CountRequests int64 `json:"countRequests"` CountRequests int64 `json:"countRequests"`
CountCachedRequests int64 `json:"countCachedRequests"` CountCachedRequests int64 `json:"countCachedRequests"`
CountAttackRequests int64 `json:"countAttackRequests"` CountAttackRequests int64 `json:"countAttackRequests"`
UserPlanId int64 `json:"userPlanId"`
} }
// BandwidthStatManager 服务带宽统计 // BandwidthStatManager 服务带宽统计
@@ -153,6 +154,7 @@ func (this *BandwidthStatManager) Loop() error {
CountRequests: stat.CountRequests, CountRequests: stat.CountRequests,
CountCachedRequests: stat.CountCachedRequests, CountCachedRequests: stat.CountCachedRequests,
CountAttackRequests: stat.CountAttackRequests, CountAttackRequests: stat.CountAttackRequests,
UserPlanId: stat.UserPlanId,
NodeRegionId: regionId, NodeRegionId: regionId,
}) })
delete(this.m, key) delete(this.m, key)
@@ -178,7 +180,7 @@ func (this *BandwidthStatManager) Loop() error {
} }
// AddBandwidth 添加带宽数据 // AddBandwidth 添加带宽数据
func (this *BandwidthStatManager) AddBandwidth(userId int64, serverId int64, peekBytes int64, totalBytes int64) { func (this *BandwidthStatManager) AddBandwidth(userId int64, userPlanId int64, serverId int64, peekBytes int64, totalBytes int64) {
if serverId <= 0 || (peekBytes == 0 && totalBytes == 0) { if serverId <= 0 || (peekBytes == 0 && totalBytes == 0) {
return return
} }
@@ -217,6 +219,7 @@ func (this *BandwidthStatManager) AddBandwidth(userId int64, serverId int64, pee
Day: day, Day: day,
TimeAt: timeAt, TimeAt: timeAt,
UserId: userId, UserId: userId,
UserPlanId: userPlanId,
ServerId: serverId, ServerId: serverId,
CurrentBytes: peekBytes, CurrentBytes: peekBytes,
MaxBytes: peekBytes, MaxBytes: peekBytes,

View File

@@ -12,22 +12,22 @@ import (
func TestBandwidthStatManager_Add(t *testing.T) { func TestBandwidthStatManager_Add(t *testing.T) {
var manager = stats.NewBandwidthStatManager() var manager = stats.NewBandwidthStatManager()
manager.AddBandwidth(1, 1, 10, 10) manager.AddBandwidth(1, 0, 1, 10, 10)
manager.AddBandwidth(1, 1, 10, 10) manager.AddBandwidth(1, 0, 1, 10, 10)
manager.AddBandwidth(1, 1, 10, 10) manager.AddBandwidth(1, 0, 1, 10, 10)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
manager.AddBandwidth(1, 1, 85, 85) manager.AddBandwidth(1, 0, 1, 85, 85)
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
manager.AddBandwidth(1, 1, 25, 25) manager.AddBandwidth(1, 0, 1, 25, 25)
manager.AddBandwidth(1, 1, 75, 75) manager.AddBandwidth(1, 0, 1, 75, 75)
manager.Inspect() manager.Inspect()
} }
func TestBandwidthStatManager_Loop(t *testing.T) { func TestBandwidthStatManager_Loop(t *testing.T) {
var manager = stats.NewBandwidthStatManager() var manager = stats.NewBandwidthStatManager()
manager.AddBandwidth(1, 1, 10, 10) manager.AddBandwidth(1, 0, 1, 10, 10)
manager.AddBandwidth(1, 1, 10, 10) manager.AddBandwidth(1, 0, 1, 10, 10)
manager.AddBandwidth(1, 1, 10, 10) manager.AddBandwidth(1, 0, 1, 10, 10)
err := manager.Loop() err := manager.Loop()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -40,7 +40,7 @@ func BenchmarkBandwidthStatManager_Add(b *testing.B) {
var i int var i int
for pb.Next() { for pb.Next() {
i++ i++
manager.AddBandwidth(1, int64(i%100), 10, 10) manager.AddBandwidth(1, 0, int64(i%100), 10, 10)
} }
}) })
} }