From ca95ea5c3de10dfaab5e2aee205ffba996ccb53b Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Wed, 6 Sep 2023 16:34:11 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=96=B0=E5=AE=9E=E7=8E=B0=E5=A5=97?= =?UTF-8?q?=E9=A4=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/client_conn.go | 4 ++-- internal/nodes/client_conn_base.go | 21 +++++++++++++++++++ internal/nodes/client_conn_interface.go | 5 ++++- internal/nodes/http_request.go | 2 +- internal/nodes/http_request_traffic_limit.go | 10 +++++---- internal/nodes/listener_http.go | 6 ++++++ internal/nodes/listener_tcp.go | 12 +++++++++++ internal/nodes/listener_udp.go | 6 +++++- internal/stats/bandwidth_stat_manager.go | 5 ++++- internal/stats/bandwidth_stat_manager_test.go | 20 +++++++++--------- 10 files changed, 71 insertions(+), 20 deletions(-) diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index e6d88ec..95a3154 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -198,9 +198,9 @@ func (this *ClientConn) Write(b []byte) (n int, err error) { var cost = time.Since(before).Seconds() if cost > 1 { - stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(float64(n)/cost), int64(n)) + stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(float64(n)/cost), int64(n)) } else { - stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.serverId, int64(n), int64(n)) + stats.SharedBandwidthStatManager.AddBandwidth(this.userId, this.userPlanId, this.serverId, int64(n), int64(n)) } } } diff --git a/internal/nodes/client_conn_base.go b/internal/nodes/client_conn_base.go index 0dc85b4..0c9d33a 100644 --- a/internal/nodes/client_conn_base.go +++ b/internal/nodes/client_conn_base.go @@ -16,6 +16,7 @@ type BaseClientConn struct { isBound bool userId int64 + userPlanId int64 serverId int64 remoteAddr string hasLimit bool @@ -106,11 +107,31 @@ func (this *BaseClientConn) SetUserId(userId int64) { } } +func (this *BaseClientConn) SetUserPlanId(userPlanId int64) { + this.userPlanId = userPlanId + + // 设置包装前连接 + switch conn := this.rawConn.(type) { + case *tls.Conn: + nativeConn, ok := conn.NetConn().(ClientConnInterface) + if ok { + nativeConn.SetUserPlanId(userPlanId) + } + case *ClientConn: + conn.SetUserPlanId(userPlanId) + } +} + // UserId 获取当前连接所属服务的用户ID func (this *BaseClientConn) UserId() int64 { return this.userId } +// UserPlanId 用户套餐ID +func (this *BaseClientConn) UserPlanId() int64 { + return this.userPlanId +} + // RawIP 原本IP func (this *BaseClientConn) RawIP() string { if len(this.rawIP) > 0 { diff --git a/internal/nodes/client_conn_interface.go b/internal/nodes/client_conn_interface.go index 651f4b0..3554c40 100644 --- a/internal/nodes/client_conn_interface.go +++ b/internal/nodes/client_conn_interface.go @@ -18,9 +18,12 @@ type ClientConnInterface interface { // SetServerId 设置服务ID SetServerId(serverId int64) (goNext bool) - // SetUserId 设置所属服务的用户ID + // SetUserId 设置所属网站的用户ID SetUserId(userId int64) + // SetUserPlanId 设置 + SetUserPlanId(userPlanId int64) + // UserId 获取当前连接所属服务的用户ID UserId() int64 diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 4ceb41d..650e191 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -198,7 +198,7 @@ func (this *HTTPRequest) Do() { } // 流量限制 - if this.ReqServer.TrafficLimit != nil && this.ReqServer.TrafficLimit.IsOn && !this.ReqServer.TrafficLimit.IsEmpty() && this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() { + if this.ReqServer.TrafficLimitStatus != nil && this.ReqServer.TrafficLimitStatus.IsValid() { this.doTrafficLimit() this.doEnd() return diff --git a/internal/nodes/http_request_traffic_limit.go b/internal/nodes/http_request_traffic_limit.go index 7a5e90b..95fb15b 100644 --- a/internal/nodes/http_request_traffic_limit.go +++ b/internal/nodes/http_request_traffic_limit.go @@ -8,15 +8,17 @@ import ( // 流量限制 func (this *HTTPRequest) doTrafficLimit() { - var config = this.ReqServer.TrafficLimit - - this.tags = append(this.tags, "bandwidth") + this.tags = append(this.tags, "trafficLimit") var statusCode = 509 + this.writer.statusCode = statusCode this.ProcessResponseHeaders(this.writer.Header(), statusCode) + this.writer.Header().Set("Content-Type", "text/html; charset=utf-8") this.writer.WriteHeader(statusCode) - if len(config.NoticePageBody) != 0 { + + var config = this.ReqServer.TrafficLimit + if config != nil && len(config.NoticePageBody) != 0 { _, _ = this.writer.WriteString(this.Format(config.NoticePageBody)) } else { _, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultTrafficLimitNoticePageBody)) diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index e22f1eb..3a8f366 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -177,6 +177,12 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http. return } clientConn.SetUserId(server.UserId) + + var userPlanId int64 + if server.UserPlan != nil && server.UserPlan.Id > 0 { + userPlanId = server.UserPlan.Id + } + clientConn.SetUserPlanId(userPlanId) } } } diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index 306a14a..7ea7dee 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -80,6 +80,12 @@ func (this *TCPListener) handleConn(conn net.Conn) error { return nil } clientConn.SetUserId(server.UserId) + + var userPlanId int64 + if server.UserPlan != nil && server.UserPlan.Id > 0 { + userPlanId = server.UserPlan.Id + } + clientConn.SetUserPlanId(userPlanId) } else { tlsConn, ok := conn.(*tls.Conn) if ok { @@ -92,6 +98,12 @@ func (this *TCPListener) handleConn(conn net.Conn) error { return nil } clientConn.SetUserId(server.UserId) + + var userPlanId int64 + if server.UserPlan != nil && server.UserPlan.Id > 0 { + userPlanId = server.UserPlan.Id + } + clientConn.SetUserPlanId(userPlanId) } } } diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index 3a8a408..d293247 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -404,7 +404,11 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) // 带宽 - stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, server.Id, int64(n), int64(n)) + var userPlanId int64 + if server.UserPlan != nil && server.UserPlan.Id > 0 { + userPlanId = server.UserPlan.Id + } + stats.SharedBandwidthStatManager.AddBandwidth(server.UserId, userPlanId, server.Id, int64(n), int64(n)) } } if err != nil { diff --git a/internal/stats/bandwidth_stat_manager.go b/internal/stats/bandwidth_stat_manager.go index b385963..e4bd7e8 100644 --- a/internal/stats/bandwidth_stat_manager.go +++ b/internal/stats/bandwidth_stat_manager.go @@ -62,6 +62,7 @@ type BandwidthStat struct { CountRequests int64 `json:"countRequests"` CountCachedRequests int64 `json:"countCachedRequests"` CountAttackRequests int64 `json:"countAttackRequests"` + UserPlanId int64 `json:"userPlanId"` } // BandwidthStatManager 服务带宽统计 @@ -153,6 +154,7 @@ func (this *BandwidthStatManager) Loop() error { CountRequests: stat.CountRequests, CountCachedRequests: stat.CountCachedRequests, CountAttackRequests: stat.CountAttackRequests, + UserPlanId: stat.UserPlanId, NodeRegionId: regionId, }) delete(this.m, key) @@ -178,7 +180,7 @@ func (this *BandwidthStatManager) Loop() error { } // AddBandwidth 添加带宽数据 -func (this *BandwidthStatManager) AddBandwidth(userId int64, serverId int64, peekBytes int64, totalBytes int64) { +func (this *BandwidthStatManager) AddBandwidth(userId int64, userPlanId int64, serverId int64, peekBytes int64, totalBytes int64) { if serverId <= 0 || (peekBytes == 0 && totalBytes == 0) { return } @@ -217,6 +219,7 @@ func (this *BandwidthStatManager) AddBandwidth(userId int64, serverId int64, pee Day: day, TimeAt: timeAt, UserId: userId, + UserPlanId: userPlanId, ServerId: serverId, CurrentBytes: peekBytes, MaxBytes: peekBytes, diff --git a/internal/stats/bandwidth_stat_manager_test.go b/internal/stats/bandwidth_stat_manager_test.go index 665d6c0..b81efe9 100644 --- a/internal/stats/bandwidth_stat_manager_test.go +++ b/internal/stats/bandwidth_stat_manager_test.go @@ -12,22 +12,22 @@ import ( func TestBandwidthStatManager_Add(t *testing.T) { var manager = stats.NewBandwidthStatManager() - manager.AddBandwidth(1, 1, 10, 10) - manager.AddBandwidth(1, 1, 10, 10) - manager.AddBandwidth(1, 1, 10, 10) + manager.AddBandwidth(1, 0, 1, 10, 10) + manager.AddBandwidth(1, 0, 1, 10, 10) + manager.AddBandwidth(1, 0, 1, 10, 10) time.Sleep(1 * time.Second) - manager.AddBandwidth(1, 1, 85, 85) + manager.AddBandwidth(1, 0, 1, 85, 85) time.Sleep(1 * time.Second) - manager.AddBandwidth(1, 1, 25, 25) - manager.AddBandwidth(1, 1, 75, 75) + manager.AddBandwidth(1, 0, 1, 25, 25) + manager.AddBandwidth(1, 0, 1, 75, 75) manager.Inspect() } func TestBandwidthStatManager_Loop(t *testing.T) { var manager = stats.NewBandwidthStatManager() - manager.AddBandwidth(1, 1, 10, 10) - manager.AddBandwidth(1, 1, 10, 10) - manager.AddBandwidth(1, 1, 10, 10) + manager.AddBandwidth(1, 0, 1, 10, 10) + manager.AddBandwidth(1, 0, 1, 10, 10) + manager.AddBandwidth(1, 0, 1, 10, 10) err := manager.Loop() if err != nil { t.Fatal(err) @@ -40,7 +40,7 @@ func BenchmarkBandwidthStatManager_Add(b *testing.B) { var i int for pb.Next() { i++ - manager.AddBandwidth(1, int64(i%100), 10, 10) + manager.AddBandwidth(1, 0, int64(i%100), 10, 10) } }) }