diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 1569221..7620b39 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -271,12 +271,12 @@ func (this *HTTPRequest) doEnd() { // TODO 增加是否开启开关 if this.Server != nil { if this.isCached { - stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0) + stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, this.writer.sentBodyBytes, 1, 1, 0, 0, this.Server.ShouldCheckTrafficLimit()) } else { if this.isAttack { - stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes) + stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 1, this.writer.sentBodyBytes, this.Server.ShouldCheckTrafficLimit()) } else { - stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0) + stats.SharedTrafficStatManager.Add(this.Server.Id, this.Host, this.writer.sentBodyBytes, 0, 1, 0, 0, 0, this.Server.ShouldCheckTrafficLimit()) } } } diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index d7a3ed0..77a0214 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -65,14 +65,14 @@ func (this *TCPListener) handleConn(conn net.Conn) error { var serverName = tlsConn.ConnectionState().ServerName if len(serverName) > 0 { // 统计 - stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0) + stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit()) recordStat = true } } // 统计 if !recordStat { - stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0) + stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit()) } originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String()) @@ -125,7 +125,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error { } // 记录流量 - stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0) + if firstServer != nil { + stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0, firstServer.ShouldCheckTrafficLimit()) + } } if err != nil { closer() diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index 963e774..cf1c19e 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -64,7 +64,7 @@ func (this *UDPListener) Serve() error { remotelogs.Error("UDP_LISTENER", "unable to find a origin server") continue } - conn = NewUDPConn(firstServer.Id, addr, this.Listener, originConn.(*net.UDPConn)) + conn = NewUDPConn(firstServer, addr, this.Listener, originConn.(*net.UDPConn)) this.connLocker.Lock() this.connMap[addr.String()] = conn this.connLocker.Unlock() @@ -174,7 +174,7 @@ type UDPConn struct { isClosed bool } -func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn { +func NewUDPConn(server *serverconfigs.ServerConfig, addr net.Addr, proxyConn *net.UDPConn, serverConn *net.UDPConn) *UDPConn { conn := &UDPConn{ addr: addr, proxyConn: proxyConn, @@ -184,7 +184,9 @@ func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverCon } // 统计 - stats.SharedTrafficStatManager.Add(serverId, "", 0, 0, 1, 0, 0, 0) + if server != nil { + stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit()) + } go func() { buffer := bytePool32k.Get() @@ -203,7 +205,9 @@ func NewUDPConn(serverId int64, addr net.Addr, proxyConn *net.UDPConn, serverCon } // 记录流量 - stats.SharedTrafficStatManager.Add(serverId, "", int64(n), 0, 0, 0, 0, 0) + if server != nil { + stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit()) + } } if err != nil { conn.isOk = false diff --git a/internal/stats/traffic_stat_manager.go b/internal/stats/traffic_stat_manager.go index dfdedc2..53e389f 100644 --- a/internal/stats/traffic_stat_manager.go +++ b/internal/stats/traffic_stat_manager.go @@ -20,12 +20,13 @@ import ( var SharedTrafficStatManager = NewTrafficStatManager() type TrafficItem struct { - Bytes int64 - CachedBytes int64 - CountRequests int64 - CountCachedRequests int64 - CountAttackRequests int64 - AttackBytes int64 + Bytes int64 + CachedBytes int64 + CountRequests int64 + CountCachedRequests int64 + CountAttackRequests int64 + AttackBytes int64 + CheckingTrafficLimit bool } // TrafficStatManager 区域流量统计 @@ -86,7 +87,7 @@ func (this *TrafficStatManager) Start(configFunc func() *nodeconfigs.NodeConfig) } // Add 添加流量 -func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64) { +func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, cachedBytes int64, countRequests int64, countCachedRequests int64, countAttacks int64, attackBytes int64, checkingTrafficLimit bool) { if bytes == 0 && countRequests == 0 { return } @@ -110,6 +111,7 @@ func (this *TrafficStatManager) Add(serverId int64, domain string, bytes int64, item.CountCachedRequests += countCachedRequests item.CountAttackRequests += countAttacks item.AttackBytes += attackBytes + item.CheckingTrafficLimit = checkingTrafficLimit // 单个域名流量 var domainKey = strconv.FormatInt(timestamp, 10) + "@" + strconv.FormatInt(serverId, 10) + "@" + domain @@ -160,15 +162,16 @@ func (this *TrafficStatManager) Upload() error { } pbServerStats = append(pbServerStats, &pb.ServerDailyStat{ - ServerId: serverId, - RegionId: config.RegionId, - Bytes: item.Bytes, - CachedBytes: item.CachedBytes, - CountRequests: item.CountRequests, - CountCachedRequests: item.CountCachedRequests, - CountAttackRequests: item.CountAttackRequests, - AttackBytes: item.AttackBytes, - CreatedAt: timestamp, + ServerId: serverId, + RegionId: config.RegionId, + Bytes: item.Bytes, + CachedBytes: item.CachedBytes, + CountRequests: item.CountRequests, + CountCachedRequests: item.CountCachedRequests, + CountAttackRequests: item.CountAttackRequests, + AttackBytes: item.AttackBytes, + CheckTrafficLimiting: item.CheckingTrafficLimit, + CreatedAt: timestamp, }) } if len(pbServerStats) == 0 {