diff --git a/internal/nodes/conn_linger.go b/internal/nodes/conn_linger.go new file mode 100644 index 0000000..ccf59fe --- /dev/null +++ b/internal/nodes/conn_linger.go @@ -0,0 +1,7 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package nodes + +type LingerConn interface { + SetLinger(sec int) error +} diff --git a/internal/nodes/http_cache_task_manager.go b/internal/nodes/http_cache_task_manager.go index 1b3c084..3b0d837 100644 --- a/internal/nodes/http_cache_task_manager.go +++ b/internal/nodes/http_cache_task_manager.go @@ -210,6 +210,7 @@ func (this *HTTPCacheTaskManager) processKey(key *pb.HTTPCacheTaskKey) error { } // TODO 增加失败重试 +// TODO 使用并发操作 func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error { var fullKey = key.Key if !this.protocolReg.MatchString(fullKey) { diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index 089060b..8ed6a89 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -21,7 +21,7 @@ type TCPListener struct { } func (this *TCPListener) Serve() error { - listener := this.Listener + var listener = this.Listener if this.Group.IsTLS() { listener = tls.NewListener(listener, this.buildTLSConfig()) } @@ -52,14 +52,29 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) { } func (this *TCPListener) handleConn(conn net.Conn) error { - firstServer := this.Group.FirstServer() - if firstServer == nil { + var server = this.Group.FirstServer() + if server == nil { return errors.New("no server available") } - if firstServer.ReverseProxy == nil { + if server.ReverseProxy == nil { return errors.New("no ReverseProxy configured for the server") } + // 是否已达到流量限制 + if this.reachedTrafficLimit() { + // 关闭连接 + tcpConn, ok := conn.(LingerConn) + if ok { + _ = tcpConn.SetLinger(0) + } + _ = conn.Close() + + // TODO 使用系统防火墙drop当前端口的数据包一段时间(1分钟) + // 不能使用阻止IP的方法,因为边缘节点只上有可能还有别的代理服务 + + return nil + } + // 记录域名排行 tlsConn, ok := conn.(*tls.Conn) var recordStat = false @@ -67,17 +82,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error { var serverName = tlsConn.ConnectionState().ServerName if len(serverName) > 0 { // 统计 - stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) + stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) recordStat = true } } // 统计 if !recordStat { - stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) + stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) } - originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String()) + originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String()) if err != nil { return err } @@ -88,17 +103,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error { } // PROXY Protocol - if firstServer.ReverseProxy != nil && - firstServer.ReverseProxy.ProxyProtocol != nil && - firstServer.ReverseProxy.ProxyProtocol.IsOn && - (firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) { + if server.ReverseProxy != nil && + server.ReverseProxy.ProxyProtocol != nil && + server.ReverseProxy.ProxyProtocol.IsOn && + (server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) { var remoteAddr = conn.RemoteAddr() var transportProtocol = proxyproto.TCPv4 if strings.Contains(remoteAddr.String(), "[") { transportProtocol = proxyproto.TCPv6 } - header := proxyproto.Header{ - Version: byte(firstServer.ReverseProxy.ProxyProtocol.Version), + var header = proxyproto.Header{ + Version: byte(server.ReverseProxy.ProxyProtocol.Version), Command: proxyproto.PROXY, TransportProtocol: transportProtocol, SourceAddr: remoteAddr, @@ -113,7 +128,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error { // 从源站读取 goman.New(func() { - originBuffer := utils.BytePool16k.Get() + var originBuffer = utils.BytePool16k.Get() defer func() { utils.BytePool16k.Put(originBuffer) }() @@ -127,8 +142,8 @@ func (this *TCPListener) handleConn(conn net.Conn) error { } // 记录流量 - if firstServer != nil { - stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) + if server != nil { + stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) } } if err != nil { @@ -139,11 +154,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error { }) // 从客户端读取 - clientBuffer := utils.BytePool16k.Get() + var clientBuffer = utils.BytePool16k.Get() defer func() { utils.BytePool16k.Put(clientBuffer) }() for { + // 是否已达到流量限制 + if this.reachedTrafficLimit() { + closer() + return nil + } + n, err := conn.Read(clientBuffer) if n > 0 { _, err = originConn.Write(clientBuffer[:n]) @@ -188,3 +209,12 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi err = errors.New("no origin can be used") return } + +// 检查是否已经达到流量限制 +func (this *TCPListener) reachedTrafficLimit() bool { + var server = this.Group.FirstServer() + if server == nil { + return true + } + return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid() +}