TCP负载均衡实现流量限制,达到限制后,关闭连接

This commit is contained in:
GoEdgeLab
2022-06-17 21:49:15 +08:00
parent 8c06d648bf
commit 3dc540d8a8
3 changed files with 55 additions and 17 deletions

View File

@@ -0,0 +1,7 @@
// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package nodes
type LingerConn interface {
SetLinger(sec int) error
}

View File

@@ -210,6 +210,7 @@ func (this *HTTPCacheTaskManager) processKey(key *pb.HTTPCacheTaskKey) error {
} }
// TODO 增加失败重试 // TODO 增加失败重试
// TODO 使用并发操作
func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error { func (this *HTTPCacheTaskManager) fetchKey(key *pb.HTTPCacheTaskKey) error {
var fullKey = key.Key var fullKey = key.Key
if !this.protocolReg.MatchString(fullKey) { if !this.protocolReg.MatchString(fullKey) {

View File

@@ -21,7 +21,7 @@ type TCPListener struct {
} }
func (this *TCPListener) Serve() error { func (this *TCPListener) Serve() error {
listener := this.Listener var listener = this.Listener
if this.Group.IsTLS() { if this.Group.IsTLS() {
listener = tls.NewListener(listener, this.buildTLSConfig()) listener = tls.NewListener(listener, this.buildTLSConfig())
} }
@@ -52,14 +52,29 @@ func (this *TCPListener) Reload(group *serverconfigs.ServerAddressGroup) {
} }
func (this *TCPListener) handleConn(conn net.Conn) error { func (this *TCPListener) handleConn(conn net.Conn) error {
firstServer := this.Group.FirstServer() var server = this.Group.FirstServer()
if firstServer == nil { if server == nil {
return errors.New("no server available") return errors.New("no server available")
} }
if firstServer.ReverseProxy == nil { if server.ReverseProxy == nil {
return errors.New("no ReverseProxy configured for the server") return errors.New("no ReverseProxy configured for the server")
} }
// 是否已达到流量限制
if this.reachedTrafficLimit() {
// 关闭连接
tcpConn, ok := conn.(LingerConn)
if ok {
_ = tcpConn.SetLinger(0)
}
_ = conn.Close()
// TODO 使用系统防火墙drop当前端口的数据包一段时间1分钟
// 不能使用阻止IP的方法因为边缘节点只上有可能还有别的代理服务
return nil
}
// 记录域名排行 // 记录域名排行
tlsConn, ok := conn.(*tls.Conn) tlsConn, ok := conn.(*tls.Conn)
var recordStat = false var recordStat = false
@@ -67,17 +82,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
var serverName = tlsConn.ConnectionState().ServerName var serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 { if len(serverName) > 0 {
// 统计 // 统计
stats.SharedTrafficStatManager.Add(firstServer.Id, serverName, 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
recordStat = true recordStat = true
} }
} }
// 统计 // 统计
if !recordStat { if !recordStat {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
} }
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String()) originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String())
if err != nil { if err != nil {
return err return err
} }
@@ -88,17 +103,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
} }
// PROXY Protocol // PROXY Protocol
if firstServer.ReverseProxy != nil && if server.ReverseProxy != nil &&
firstServer.ReverseProxy.ProxyProtocol != nil && server.ReverseProxy.ProxyProtocol != nil &&
firstServer.ReverseProxy.ProxyProtocol.IsOn && server.ReverseProxy.ProxyProtocol.IsOn &&
(firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || firstServer.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) { (server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || server.ReverseProxy.ProxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = conn.RemoteAddr() var remoteAddr = conn.RemoteAddr()
var transportProtocol = proxyproto.TCPv4 var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr.String(), "[") { if strings.Contains(remoteAddr.String(), "[") {
transportProtocol = proxyproto.TCPv6 transportProtocol = proxyproto.TCPv6
} }
header := proxyproto.Header{ var header = proxyproto.Header{
Version: byte(firstServer.ReverseProxy.ProxyProtocol.Version), Version: byte(server.ReverseProxy.ProxyProtocol.Version),
Command: proxyproto.PROXY, Command: proxyproto.PROXY,
TransportProtocol: transportProtocol, TransportProtocol: transportProtocol,
SourceAddr: remoteAddr, SourceAddr: remoteAddr,
@@ -113,7 +128,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 从源站读取 // 从源站读取
goman.New(func() { goman.New(func() {
originBuffer := utils.BytePool16k.Get() var originBuffer = utils.BytePool16k.Get()
defer func() { defer func() {
utils.BytePool16k.Put(originBuffer) utils.BytePool16k.Put(originBuffer)
}() }()
@@ -127,8 +142,8 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
} }
// 记录流量 // 记录流量
if firstServer != nil { if server != nil {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", int64(n), 0, 0, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) stats.SharedTrafficStatManager.Add(server.Id, "", int64(n), 0, 0, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
} }
} }
if err != nil { if err != nil {
@@ -139,11 +154,17 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
}) })
// 从客户端读取 // 从客户端读取
clientBuffer := utils.BytePool16k.Get() var clientBuffer = utils.BytePool16k.Get()
defer func() { defer func() {
utils.BytePool16k.Put(clientBuffer) utils.BytePool16k.Put(clientBuffer)
}() }()
for { for {
// 是否已达到流量限制
if this.reachedTrafficLimit() {
closer()
return nil
}
n, err := conn.Read(clientBuffer) n, err := conn.Read(clientBuffer)
if n > 0 { if n > 0 {
_, err = originConn.Write(clientBuffer[:n]) _, err = originConn.Write(clientBuffer[:n])
@@ -188,3 +209,12 @@ func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfi
err = errors.New("no origin can be used") err = errors.New("no origin can be used")
return return
} }
// 检查是否已经达到流量限制
func (this *TCPListener) reachedTrafficLimit() bool {
var server = this.Group.FirstServer()
if server == nil {
return true
}
return server.TrafficLimitStatus != nil && server.TrafficLimitStatus.IsValid()
}