mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 07:40:56 +08:00 
			
		
		
		
	重新实现套餐
This commit is contained in:
		@@ -198,9 +198,9 @@ func (this *ClientConn) Write(b []byte) (n int, err error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
				var cost = time.Since(before).Seconds()
 | 
									var cost = time.Since(before).Seconds()
 | 
				
			||||||
				if cost > 1 {
 | 
									if cost > 1 {
 | 
				
			||||||
					stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(float64(n)/cost), int64(n))
 | 
										stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(float64(n)/cost), int64(n))
 | 
				
			||||||
				} else {
 | 
									} else {
 | 
				
			||||||
					stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(n), int64(n))
 | 
										stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(n), int64(n))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -16,6 +16,7 @@ type BaseClientConn struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	isBound    bool
 | 
						isBound    bool
 | 
				
			||||||
	userId     int64
 | 
						userId     int64
 | 
				
			||||||
 | 
						userPlanId int64
 | 
				
			||||||
	serverId   int64
 | 
						serverId   int64
 | 
				
			||||||
	remoteAddr string
 | 
						remoteAddr string
 | 
				
			||||||
	hasLimit   bool
 | 
						hasLimit   bool
 | 
				
			||||||
@@ -106,11 +107,31 @@ func (this *BaseClientConn) SetUserId(userId int64) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *BaseClientConn) SetUserPlanId(userPlanId int64) {
 | 
				
			||||||
 | 
						this.userPlanId = userPlanId
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 设置包装前连接
 | 
				
			||||||
 | 
						switch conn := this.rawConn.(type) {
 | 
				
			||||||
 | 
						case *tls.Conn:
 | 
				
			||||||
 | 
							nativeConn, ok := conn.NetConn().(ClientConnInterface)
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								nativeConn.SetUserPlanId(userPlanId)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						case *ClientConn:
 | 
				
			||||||
 | 
							conn.SetUserPlanId(userPlanId)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// UserId 获取当前连接所属服务的用户ID
 | 
					// UserId 获取当前连接所属服务的用户ID
 | 
				
			||||||
func (this *BaseClientConn) UserId() int64 {
 | 
					func (this *BaseClientConn) UserId() int64 {
 | 
				
			||||||
	return this.userId
 | 
						return this.userId
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// UserPlanId 用户套餐ID
 | 
				
			||||||
 | 
					func (this *BaseClientConn) UserPlanId() int64 {
 | 
				
			||||||
 | 
						return this.userPlanId
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// RawIP 原本IP
 | 
					// RawIP 原本IP
 | 
				
			||||||
func (this *BaseClientConn) RawIP() string {
 | 
					func (this *BaseClientConn) RawIP() string {
 | 
				
			||||||
	if len(this.rawIP) > 0 {
 | 
						if len(this.rawIP) > 0 {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -18,9 +18,12 @@ type ClientConnInterface interface {
 | 
				
			|||||||
	// SetServerId 设置服务ID
 | 
						// SetServerId 设置服务ID
 | 
				
			||||||
	SetServerId(serverId int64) (goNext bool)
 | 
						SetServerId(serverId int64) (goNext bool)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// SetUserId 设置所属服务的用户ID
 | 
						// SetUserId 设置所属网站的用户ID
 | 
				
			||||||
	SetUserId(userId int64)
 | 
						SetUserId(userId int64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// SetUserPlanId 设置
 | 
				
			||||||
 | 
						SetUserPlanId(userPlanId int64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// UserId 获取当前连接所属服务的用户ID
 | 
						// UserId 获取当前连接所属服务的用户ID
 | 
				
			||||||
	UserId() int64
 | 
						UserId() int64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -198,7 +198,7 @@ func (this *HTTPRequest) Do() {
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 流量限制
 | 
							// 流量限制
 | 
				
			||||||
		if this.ReqServer.TrafficLimit != nil && this.ReqServer.TrafficLimit.IsOn && !this.ReqServer.TrafficLimit.IsEmpty() && this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() {
 | 
							if this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() {
 | 
				
			||||||
			this.doTrafficLimit()
 | 
								this.doTrafficLimit()
 | 
				
			||||||
			this.doEnd()
 | 
								this.doEnd()
 | 
				
			||||||
			return
 | 
								return
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,15 +8,17 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// 流量限制
 | 
					// 流量限制
 | 
				
			||||||
func (this *HTTPRequest) doTrafficLimit() {
 | 
					func (this *HTTPRequest) doTrafficLimit() {
 | 
				
			||||||
	var config = this.ReqServer.TrafficLimit
 | 
						this.tags = append(this.tags, "trafficLimit")
 | 
				
			||||||
 | 
					 | 
				
			||||||
	this.tags = append(this.tags, "bandwidth")
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var statusCode = 509
 | 
						var statusCode = 509
 | 
				
			||||||
 | 
						this.writer.statusCode = statusCode
 | 
				
			||||||
	this.ProcessResponseHeaders(this.writer.Header(), statusCode)
 | 
						this.ProcessResponseHeaders(this.writer.Header(), statusCode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.writer.Header().Set("Content-Type", "text/html; charset=utf-8")
 | 
				
			||||||
	this.writer.WriteHeader(statusCode)
 | 
						this.writer.WriteHeader(statusCode)
 | 
				
			||||||
	if len(config.NoticePageBody) != 0 {
 | 
					
 | 
				
			||||||
 | 
						var config = this.ReqServer.TrafficLimit
 | 
				
			||||||
 | 
						if config != nil && len(config.NoticePageBody) != 0 {
 | 
				
			||||||
		_, _ = this.writer.WriteString(this.Format(config.NoticePageBody))
 | 
							_, _ = this.writer.WriteString(this.Format(config.NoticePageBody))
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody))
 | 
							_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody))
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -177,6 +177,12 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
 | 
				
			|||||||
					return
 | 
										return
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
				clientConn.SetUserId(server.UserId)
 | 
									clientConn.SetUserId(server.UserId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
									var userPlanId int64
 | 
				
			||||||
 | 
									if server.UserPlan != nil && server.UserPlan.Id > 0 {
 | 
				
			||||||
 | 
										userPlanId = server.UserPlan.Id
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									clientConn.SetUserPlanId(userPlanId)
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -80,6 +80,12 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
 | 
				
			|||||||
			return nil
 | 
								return nil
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		clientConn.SetUserId(server.UserId)
 | 
							clientConn.SetUserId(server.UserId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							var userPlanId int64
 | 
				
			||||||
 | 
							if server.UserPlan != nil && server.UserPlan.Id > 0 {
 | 
				
			||||||
 | 
								userPlanId = server.UserPlan.Id
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							clientConn.SetUserPlanId(userPlanId)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		tlsConn, ok := conn.(*tls.Conn)
 | 
							tlsConn, ok := conn.(*tls.Conn)
 | 
				
			||||||
		if ok {
 | 
							if ok {
 | 
				
			||||||
@@ -92,6 +98,12 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
 | 
				
			|||||||
						return nil
 | 
											return nil
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
					clientConn.SetUserId(server.UserId)
 | 
										clientConn.SetUserId(server.UserId)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										var userPlanId int64
 | 
				
			||||||
 | 
										if server.UserPlan != nil && server.UserPlan.Id > 0 {
 | 
				
			||||||
 | 
											userPlanId = server.UserPlan.Id
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										clientConn.SetUserPlanId(userPlanId)
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -404,7 +404,11 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener
 | 
				
			|||||||
					stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
 | 
										stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
					// 带宽
 | 
										// 带宽
 | 
				
			||||||
					stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, server.Id, int64(n), int64(n))
 | 
										var userPlanId int64
 | 
				
			||||||
 | 
										if server.UserPlan != nil && server.UserPlan.Id > 0 {
 | 
				
			||||||
 | 
											userPlanId = server.UserPlan.Id
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
										stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, userPlanId, server.Id, int64(n), int64(n))
 | 
				
			||||||
				}
 | 
									}
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
			if err != nil {
 | 
								if err != nil {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -62,6 +62,7 @@ type BandwidthStat struct {
 | 
				
			|||||||
	CountRequests       int64 `json:"countRequests"`
 | 
						CountRequests       int64 `json:"countRequests"`
 | 
				
			||||||
	CountCachedRequests int64 `json:"countCachedRequests"`
 | 
						CountCachedRequests int64 `json:"countCachedRequests"`
 | 
				
			||||||
	CountAttackRequests int64 `json:"countAttackRequests"`
 | 
						CountAttackRequests int64 `json:"countAttackRequests"`
 | 
				
			||||||
 | 
						UserPlanId          int64 `json:"userPlanId"`
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// BandwidthStatManager 服务带宽统计
 | 
					// BandwidthStatManager 服务带宽统计
 | 
				
			||||||
@@ -153,6 +154,7 @@ func (this *BandwidthStatManager) Loop() error {
 | 
				
			|||||||
				CountRequests:       stat.CountRequests,
 | 
									CountRequests:       stat.CountRequests,
 | 
				
			||||||
				CountCachedRequests: stat.CountCachedRequests,
 | 
									CountCachedRequests: stat.CountCachedRequests,
 | 
				
			||||||
				CountAttackRequests: stat.CountAttackRequests,
 | 
									CountAttackRequests: stat.CountAttackRequests,
 | 
				
			||||||
 | 
									UserPlanId:          stat.UserPlanId,
 | 
				
			||||||
				NodeRegionId:        regionId,
 | 
									NodeRegionId:        regionId,
 | 
				
			||||||
			})
 | 
								})
 | 
				
			||||||
			delete(this.m, key)
 | 
								delete(this.m, key)
 | 
				
			||||||
@@ -178,7 +180,7 @@ func (this *BandwidthStatManager) Loop() error {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// AddBandwidth 添加带宽数据
 | 
					// AddBandwidth 添加带宽数据
 | 
				
			||||||
func (this *BandwidthStatManager) AddBandwidth(userId int64, serverId int64, peekBytes int64, totalBytes int64) {
 | 
					func (this *BandwidthStatManager) AddBandwidth(userId int64, userPlanId int64, serverId int64, peekBytes int64, totalBytes int64) {
 | 
				
			||||||
	if serverId <= 0 || (peekBytes == 0 && totalBytes == 0) {
 | 
						if serverId <= 0 || (peekBytes == 0 && totalBytes == 0) {
 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -217,6 +219,7 @@ func (this *BandwidthStatManager) AddBandwidth(userId int64, serverId int64, pee
 | 
				
			|||||||
			Day:              day,
 | 
								Day:              day,
 | 
				
			||||||
			TimeAt:           timeAt,
 | 
								TimeAt:           timeAt,
 | 
				
			||||||
			UserId:           userId,
 | 
								UserId:           userId,
 | 
				
			||||||
 | 
								UserPlanId:       userPlanId,
 | 
				
			||||||
			ServerId:         serverId,
 | 
								ServerId:         serverId,
 | 
				
			||||||
			CurrentBytes:     peekBytes,
 | 
								CurrentBytes:     peekBytes,
 | 
				
			||||||
			MaxBytes:         peekBytes,
 | 
								MaxBytes:         peekBytes,
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,22 +12,22 @@ import (
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func TestBandwidthStatManager_Add(t *testing.T) {
 | 
					func TestBandwidthStatManager_Add(t *testing.T) {
 | 
				
			||||||
	var manager = stats.NewBandwidthStatManager()
 | 
						var manager = stats.NewBandwidthStatManager()
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 10, 10)
 | 
						manager.AddBandwidth(1, 0, 1, 10, 10)
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 10, 10)
 | 
						manager.AddBandwidth(1, 0, 1, 10, 10)
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 10, 10)
 | 
						manager.AddBandwidth(1, 0, 1, 10, 10)
 | 
				
			||||||
	time.Sleep(1 * time.Second)
 | 
						time.Sleep(1 * time.Second)
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 85, 85)
 | 
						manager.AddBandwidth(1, 0, 1, 85, 85)
 | 
				
			||||||
	time.Sleep(1 * time.Second)
 | 
						time.Sleep(1 * time.Second)
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 25, 25)
 | 
						manager.AddBandwidth(1, 0, 1, 25, 25)
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 75, 75)
 | 
						manager.AddBandwidth(1, 0, 1, 75, 75)
 | 
				
			||||||
	manager.Inspect()
 | 
						manager.Inspect()
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestBandwidthStatManager_Loop(t *testing.T) {
 | 
					func TestBandwidthStatManager_Loop(t *testing.T) {
 | 
				
			||||||
	var manager = stats.NewBandwidthStatManager()
 | 
						var manager = stats.NewBandwidthStatManager()
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 10, 10)
 | 
						manager.AddBandwidth(1, 0, 1, 10, 10)
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 10, 10)
 | 
						manager.AddBandwidth(1, 0, 1, 10, 10)
 | 
				
			||||||
	manager.AddBandwidth(1, 1, 10, 10)
 | 
						manager.AddBandwidth(1, 0, 1, 10, 10)
 | 
				
			||||||
	err := manager.Loop()
 | 
						err := manager.Loop()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
@@ -40,7 +40,7 @@ func BenchmarkBandwidthStatManager_Add(b *testing.B) {
 | 
				
			|||||||
		var i int
 | 
							var i int
 | 
				
			||||||
		for pb.Next() {
 | 
							for pb.Next() {
 | 
				
			||||||
			i++
 | 
								i++
 | 
				
			||||||
			manager.AddBandwidth(1, int64(i%100), 10, 10)
 | 
								manager.AddBandwidth(1, 0, int64(i%100), 10, 10)
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	})
 | 
						})
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user