TCP负载均衡实现流量限制,达到限制后,关闭连接

This commit is contained in:
GoEdgeLab
2022-06-17 21:49:15 +08:00
parent 8c06d648bf
commit 3dc540d8a8
3 changed files with 55 additions and 17 deletions

View File

@@ -0,0 +1,7 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
type LingerConn interface {
SetLinger(sec int) error
}

View File

@@ -210,6 +210,7 @@ func (this *HTTPCacheTaskManager) processKey(key *pb.HTTPCacheTaskKey) error {
}
// TODO 增加失败重试
// TODO 使用并发操作
func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error {
var fullKey = key.Key
if !this.protocolReg.MatchString(fullKey) {

View File

@@ -21,7 +21,7 @@ type TCPListener struct {
}
func (this *TCPListener) Serve() error {
listener := this.Listener
var listener = this.Listener
if this.Group.IsTLS() {
listener = tls.NewListener(listener, this.buildTLSConfig())
}
@@ -52,14 +52,29 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
}
func (this *TCPListener) handleConn(conn net.Conn) error {
firstServer := this.Group.FirstServer()
if firstServer == nil {
var server = this.Group.FirstServer()
if server == nil {
return errors.New("no server available")
}
if firstServer.ReverseProxy == nil {
if server.ReverseProxy == nil {
return errors.New("no ReverseProxy configured for the server")
}
// 是否已达到流量限制
if this.reachedTrafficLimit() {
// 关闭连接
tcpConn, ok := conn.(LingerConn)
if ok {
_ = tcpConn.SetLinger(0)
}
_ = conn.Close()
// TODO 使用系统防火墙drop当前端口的数据包一段时间1分钟
// 不能使用阻止IP的方法因为边缘节点只上有可能还有别的代理服务
return nil
}
// 记录域名排行
tlsConn, ok := conn.(*tls.Conn)
var recordStat = false
@@ -67,17 +82,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
var serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 {
// 统计
stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
recordStat = true
}
}
// 统计
if !recordStat {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String())
if err != nil {
return err
}
@@ -88,17 +103,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}
// PROXY Protocol
if firstServer.ReverseProxy != nil &&
firstServer.ReverseProxy.ProxyProtocol != nil &&
firstServer.ReverseProxy.ProxyProtocol.IsOn &&
(firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
if server.ReverseProxy != nil &&
server.ReverseProxy.ProxyProtocol != nil &&
server.ReverseProxy.ProxyProtocol.IsOn &&
(server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = conn.RemoteAddr()
var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr.String(), "[") {
transportProtocol = proxyproto.TCPv6
}
header := proxyproto.Header{
Version: byte(firstServer.ReverseProxy.ProxyProtocol.Version),
var header = proxyproto.Header{
Version: byte(server.ReverseProxy.ProxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: remoteAddr,
@@ -113,7 +128,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 从源站读取
goman.New(func() {
originBuffer := utils.BytePool16k.Get()
var originBuffer = utils.BytePool16k.Get()
defer func() {
utils.BytePool16k.Put(originBuffer)
}()
@@ -127,8 +142,8 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}
// 记录流量
if firstServer != nil {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
if server != nil {
stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
}
}
if err != nil {
@@ -139,11 +154,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
})
// 从客户端读取
clientBuffer := utils.BytePool16k.Get()
var clientBuffer = utils.BytePool16k.Get()
defer func() {
utils.BytePool16k.Put(clientBuffer)
}()
for {
// 是否已达到流量限制
if this.reachedTrafficLimit() {
closer()
return nil
}
n, err := conn.Read(clientBuffer)
if n > 0 {
_, err = originConn.Write(clientBuffer[:n])
@@ -188,3 +209,12 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
err = errors.New("no origin can be used")
return
}
// 检查是否已经达到流量限制
func (this *TCPListener) reachedTrafficLimit() bool {
var server = this.Group.FirstServer()
if server == nil {
return true
}
return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid()
}