diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 4e9825f..e00a204 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -85,6 +85,8 @@ type HTTPRequest struct { isAttack bool // 是否是攻击请求 requestBodyData []byte // 读取的Body内容 + isWebsocketResponse bool // 是否为Websocket响应(非请求) + // WAF相关 firewallPolicyId int64 firewallRuleGroupId int64 @@ -410,6 +412,8 @@ func (this *HTTPRequest) doEnd() { var countAttacks int64 = 0 var attackBytes int64 = 0 + var countWebsocketConnections int64 = 0 + if this.isCached { countCached = 1 cachedBytes = totalBytes @@ -421,8 +425,11 @@ func (this *HTTPRequest) doEnd() { attackBytes = totalBytes } } + if this.isWebsocketResponse { + countWebsocketConnections = 1 + } - stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) + stats.SharedTrafficStatManager.Add(this.ReqServer.UserId, this.ReqServer.Id, this.ReqHost, totalBytes, cachedBytes, 1, countCached, countAttacks, attackBytes, countWebsocketConnections, this.ReqServer.ShouldCheckTrafficLimit(), this.ReqServer.PlanId()) // 指标 if metrics.SharedManager.HasHTTPMetrics() { diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index cd40c59..6d58aa3 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -61,6 +61,9 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou } } + // 标记 + this.isWebsocketResponse = true + // 设置指定的来源域 if !this.web.Websocket.RequestSameOrigin && len(this.web.Websocket.RequestOrigin) > 0 { var newRequestOrigin = this.web.Websocket.RequestOrigin @@ -77,7 +80,6 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou } // 连接源站 - // TODO 增加N次错误重试,重试的时候需要尝试不同的源站 originConn, _, err := OriginConnect(this.origin, this.requestServerPort(), this.RawReq.RemoteAddr, requestHost) if err != nil { if isLastRetry { diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index 5db0de6..ddc90d2 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -135,14 +135,14 @@ func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net serverName = tlsConn.ConnectionState().ServerName if len(serverName) > 0 { // 统计 - stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) + stats.SharedTrafficStatManager.Add(server.UserId, server.Id, serverName, 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) recordStat = true } } // 统计 if !recordStat { - stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) + stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) } originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String()) @@ -197,7 +197,7 @@ func (this *TCPListener) handleConn(server *serverconfigs.ServerConfig, conn net // 记录流量 if server != nil { - stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) + stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) } } if err != nil { diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index d293247..c413f05 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -370,7 +370,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener // 统计 if server != nil { - stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) + stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", 0, 0, 1, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) } // 处理ControlMessage @@ -401,7 +401,7 @@ func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyListener // 记录流量和带宽 if server != nil { // 流量 - stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) + stats.SharedTrafficStatManager.Add(server.UserId, server.Id, "", int64(n), 0, 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) // 带宽 var userPlanId int64 diff --git a/internal/stats/bandwidth_stat_manager.go b/internal/stats/bandwidth_stat_manager.go index e4bd7e8..f9e30a4 100644 --- a/internal/stats/bandwidth_stat_manager.go +++ b/internal/stats/bandwidth_stat_manager.go @@ -57,12 +57,13 @@ type BandwidthStat struct { MaxBytes int64 `json:"maxBytes"` TotalBytes int64 `json:"totalBytes"` - CachedBytes int64 `json:"cachedBytes"` - AttackBytes int64 `json:"attackBytes"` - CountRequests int64 `json:"countRequests"` - CountCachedRequests int64 `json:"countCachedRequests"` - CountAttackRequests int64 `json:"countAttackRequests"` - UserPlanId int64 `json:"userPlanId"` + CachedBytes int64 `json:"cachedBytes"` + AttackBytes int64 `json:"attackBytes"` + CountRequests int64 `json:"countRequests"` + CountCachedRequests int64 `json:"countCachedRequests"` + CountAttackRequests int64 `json:"countAttackRequests"` + CountWebsocketConnections int64 `json:"countWebsocketConnections"` + UserPlanId int64 `json:"userPlanId"` } // BandwidthStatManager 服务带宽统计 @@ -142,20 +143,21 @@ func (this *BandwidthStatManager) Loop() error { } pbStats = append(pbStats, &pb.ServerBandwidthStat{ - Id: 0, - UserId: stat.UserId, - ServerId: stat.ServerId, - Day: stat.Day, - TimeAt: stat.TimeAt, - Bytes: stat.MaxBytes / bandwidthTimestampDelim, - TotalBytes: stat.TotalBytes, - CachedBytes: stat.CachedBytes, - AttackBytes: stat.AttackBytes, - CountRequests: stat.CountRequests, - CountCachedRequests: stat.CountCachedRequests, - CountAttackRequests: stat.CountAttackRequests, - UserPlanId: stat.UserPlanId, - NodeRegionId: regionId, + Id: 0, + UserId: stat.UserId, + ServerId: stat.ServerId, + Day: stat.Day, + TimeAt: stat.TimeAt, + Bytes: stat.MaxBytes / bandwidthTimestampDelim, + TotalBytes: stat.TotalBytes, + CachedBytes: stat.CachedBytes, + AttackBytes: stat.AttackBytes, + CountRequests: stat.CountRequests, + CountCachedRequests: stat.CountCachedRequests, + CountAttackRequests: stat.CountAttackRequests, + CountWebsocketConnections: stat.CountWebsocketConnections, + UserPlanId: stat.UserPlanId, + NodeRegionId: regionId, }) delete(this.m, key) } @@ -231,7 +233,7 @@ func (this *BandwidthStatManager) AddBandwidth(userId int64, userPlanId int64, s } // AddTraffic 添加请求数据 -func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) { +func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64) { var now = fasttime.Now() var day = now.Ymd() var timeAt = now.Round5Hi() @@ -245,6 +247,7 @@ func (this *BandwidthStatManager) AddTraffic(serverId int64, cachedBytes int64, stat.CountCachedRequests += countCachedRequests stat.CountAttackRequests += countAttacks stat.AttackBytes += attackBytes + stat.CountWebsocketConnections += countWebsocketConnections } this.locker.Unlock() } diff --git a/internal/stats/bandwidth_stat_manager_test.go b/internal/stats/bandwidth_stat_manager_test.go index b81efe9..91067ed 100644 --- a/internal/stats/bandwidth_stat_manager_test.go +++ b/internal/stats/bandwidth_stat_manager_test.go @@ -53,19 +53,20 @@ func BenchmarkBandwidthStatManager_Slice(b *testing.B) { for j := 0; j < 100; j++ { var stat = &stats.BandwidthStat{} pbStats = append(pbStats, &pb.ServerBandwidthStat{ - Id: 0, - UserId: stat.UserId, - ServerId: stat.ServerId, - Day: stat.Day, - TimeAt: stat.TimeAt, - Bytes: stat.MaxBytes / 2, - TotalBytes: stat.TotalBytes, - CachedBytes: stat.CachedBytes, - AttackBytes: stat.AttackBytes, - CountRequests: stat.CountRequests, - CountCachedRequests: stat.CountCachedRequests, - CountAttackRequests: stat.CountAttackRequests, - NodeRegionId: 1, + Id: 0, + UserId: stat.UserId, + ServerId: stat.ServerId, + Day: stat.Day, + TimeAt: stat.TimeAt, + Bytes: stat.MaxBytes / 2, + TotalBytes: stat.TotalBytes, + CachedBytes: stat.CachedBytes, + AttackBytes: stat.AttackBytes, + CountRequests: stat.CountRequests, + CountCachedRequests: stat.CountCachedRequests, + CountAttackRequests: stat.CountAttackRequests, + CountWebsocketConnections: stat.CountWebsocketConnections, + NodeRegionId: 1, }) } _ = pbStats diff --git a/internal/stats/traffic_stat_manager.go b/internal/stats/traffic_stat_manager.go index 8de74fc..9e3caf1 100644 --- a/internal/stats/traffic_stat_manager.go +++ b/internal/stats/traffic_stat_manager.go @@ -106,13 +106,13 @@ func (this *TrafficStatManager) Start() { } // Add 添加流量 -func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool, planId int64) { +func (this *TrafficStatManager) Add(userId int64, serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, countWebsocketConnections int64, checkingTrafficLimit bool, planId int64) { if serverId == 0 { return } // 添加到带宽 - SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes) + SharedBandwidthStatManager.AddTraffic(serverId, cachedBytes, countRequests, countCachedRequests, countAttacks, attackBytes, countWebsocketConnections) if bytes == 0 && countRequests == 0 { return diff --git a/internal/stats/traffic_stat_manager_test.go b/internal/stats/traffic_stat_manager_test.go index 214af4d..e92ba0a 100644 --- a/internal/stats/traffic_stat_manager_test.go +++ b/internal/stats/traffic_stat_manager_test.go @@ -11,7 +11,7 @@ import ( func TestTrafficStatManager_Add(t *testing.T) { manager := NewTrafficStatManager() for i := 0; i < 100; i++ { - manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, false, 0) + manager.Add(1, 1, "goedge.cn", 1, 0, 0, 0, 0, 0, 0, false, 0) } t.Log(manager.itemMap) } @@ -19,7 +19,7 @@ func TestTrafficStatManager_Add(t *testing.T) { func TestTrafficStatManager_Upload(t *testing.T) { manager := NewTrafficStatManager() for i := 0; i < 100; i++ { - manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, false, 0) + manager.Add(1, 1, "goedge.cn"+types.String(rands.Int(0, 10)), 1, 0, 1, 0, 0, 0, 0, false, 0) } err := manager.Upload() if err != nil { @@ -36,7 +36,7 @@ func BenchmarkTrafficStatManager_Add(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, false, 0) + manager.Add(1, 1, "goedge.cn"+types.String(rand.Int63()%10), 1024, 1, 0, 0, 0, 0, 0, false, 0) } }) }