优化IP黑名单检测

This commit is contained in:
GoEdgeLab
2023-04-05 09:25:33 +08:00
parent b9bfcee79c
commit 9ba3dc8172
7 changed files with 36 additions and 29 deletions

View File

@@ -8,8 +8,9 @@ import (
// DropTemporaryTo 使用本地防火墙临时拦截IP数据包 // DropTemporaryTo 使用本地防火墙临时拦截IP数据包
func DropTemporaryTo(ip string, expiresAt int64) { func DropTemporaryTo(ip string, expiresAt int64) {
if expiresAt <= 1 { // 如果为0则表示是长期有效
return if expiresAt <= 0 {
expiresAt = time.Now().Unix() + 3600
} }
var timeout = expiresAt - time.Now().Unix() var timeout = expiresAt - time.Now().Unix()

View File

@@ -4,7 +4,6 @@ package nodes
import ( import (
"crypto/tls" "crypto/tls"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/firewalls" "github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"net" "net"
@@ -51,7 +50,20 @@ func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerS
} }
// SetServerId 设置服务ID // SetServerId 设置服务ID
func (this *BaseClientConn) SetServerId(serverId int64) { func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) {
goNext = true
// 检查服务相关IP黑名单
if serverId > 0 && len(this.rawIP) > 0 {
// 是否在白名单中
ok, _, expiresAt := iplibrary.AllowIP(this.rawIP, serverId)
if !ok {
_ = this.rawConn.Close()
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
return false
}
}
this.serverId = serverId this.serverId = serverId
// 设置包装前连接 // 设置包装前连接
@@ -65,19 +77,7 @@ func (this *BaseClientConn) SetServerId(serverId int64) {
conn.SetServerId(serverId) conn.SetServerId(serverId)
} }
// 检查服务相关IP黑名单 return true
if serverId > 0 && len(this.rawIP) > 0 {
var list = iplibrary.SharedServerListManager.FindBlackList(serverId, false)
if list != nil {
expiresAt, ok := list.ContainsExpires(configutils.IPString2Long(this.rawIP))
if ok {
_ = this.rawConn.Close()
if expiresAt > 0 {
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
}
}
}
}
} }
// ServerId 读取当前连接绑定的服务ID // ServerId 读取当前连接绑定的服务ID

View File

@@ -16,7 +16,7 @@ type ClientConnInterface interface {
ServerId() int64 ServerId() int64
// SetServerId 设置服务ID // SetServerId 设置服务ID
SetServerId(serverId int64) SetServerId(serverId int64) (goNext bool)
// SetUserId 设置所属服务的用户ID // SetUserId 设置所属服务的用户ID
SetUserId(userId int64) SetUserId(userId int64)

View File

@@ -45,17 +45,14 @@ func (this *ClientListener) Accept() (net.Conn, error) {
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0) canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
isInAllowList = inAllowList isInAllowList = inAllowList
if !canGoNext { if !canGoNext {
if expiresAt > 0 { firewalls.DropTemporaryTo(ip, expiresAt)
firewalls.DropTemporaryTo(ip, expiresAt)
}
} else { } else {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) { if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) var ok = false
expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
if ok { if ok {
canGoNext = false canGoNext = false
if expiresAt > 0 { firewalls.DropTemporaryTo(ip, expiresAt)
firewalls.DropTemporaryTo(ip, expiresAt)
}
} }
} }
} }

View File

@@ -179,7 +179,10 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
if requestConn != nil { if requestConn != nil {
clientConn, ok := requestConn.(ClientConnInterface) clientConn, ok := requestConn.(ClientConnInterface)
if ok { if ok {
clientConn.SetServerId(server.Id) var goNext = clientConn.SetServerId(server.Id)
if !goNext {
return
}
clientConn.SetUserId(server.UserId) clientConn.SetUserId(server.UserId)
} }
} }

View File

@@ -75,7 +75,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 绑定连接和服务 // 绑定连接和服务
clientConn, ok := conn.(ClientConnInterface) clientConn, ok := conn.(ClientConnInterface)
if ok { if ok {
clientConn.SetServerId(server.Id) var goNext = clientConn.SetServerId(server.Id)
if !goNext {
return nil
}
clientConn.SetUserId(server.UserId) clientConn.SetUserId(server.UserId)
} else { } else {
tlsConn, ok := conn.(*tls.Conn) tlsConn, ok := conn.(*tls.Conn)
@@ -84,7 +87,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
if internalConn != nil { if internalConn != nil {
clientConn, ok = internalConn.(ClientConnInterface) clientConn, ok = internalConn.(ClientConnInterface)
if ok { if ok {
clientConn.SetServerId(server.Id) var goNext = clientConn.SetServerId(server.Id)
if !goNext {
return nil
}
clientConn.SetUserId(server.UserId) clientConn.SetUserId(server.UserId)
} }
} }

View File

@@ -125,7 +125,7 @@ func (this *IPList) RecordIP(ipType string,
} }
// 使用本地防火墙 // 使用本地防火墙
if useLocalFirewall && expiresAt > 0 { if useLocalFirewall {
firewalls.DropTemporaryTo(ip, expiresAt) firewalls.DropTemporaryTo(ip, expiresAt)
} }
} }