mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-09 03:50:27 +08:00
优化IP黑名单检测
This commit is contained in:
@@ -8,8 +8,9 @@ import (
|
|||||||
|
|
||||||
// DropTemporaryTo 使用本地防火墙临时拦截IP数据包
|
// DropTemporaryTo 使用本地防火墙临时拦截IP数据包
|
||||||
func DropTemporaryTo(ip string, expiresAt int64) {
|
func DropTemporaryTo(ip string, expiresAt int64) {
|
||||||
if expiresAt <= 1 {
|
// 如果为0,则表示是长期有效
|
||||||
return
|
if expiresAt <= 0 {
|
||||||
|
expiresAt = time.Now().Unix() + 3600
|
||||||
}
|
}
|
||||||
|
|
||||||
var timeout = expiresAt - time.Now().Unix()
|
var timeout = expiresAt - time.Now().Unix()
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ package nodes
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||||
"net"
|
"net"
|
||||||
@@ -51,7 +50,20 @@ func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerS
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetServerId 设置服务ID
|
// SetServerId 设置服务ID
|
||||||
func (this *BaseClientConn) SetServerId(serverId int64) {
|
func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) {
|
||||||
|
goNext = true
|
||||||
|
|
||||||
|
// 检查服务相关IP黑名单
|
||||||
|
if serverId > 0 && len(this.rawIP) > 0 {
|
||||||
|
// 是否在白名单中
|
||||||
|
ok, _, expiresAt := iplibrary.AllowIP(this.rawIP, serverId)
|
||||||
|
if !ok {
|
||||||
|
_ = this.rawConn.Close()
|
||||||
|
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
this.serverId = serverId
|
this.serverId = serverId
|
||||||
|
|
||||||
// 设置包装前连接
|
// 设置包装前连接
|
||||||
@@ -65,19 +77,7 @@ func (this *BaseClientConn) SetServerId(serverId int64) {
|
|||||||
conn.SetServerId(serverId)
|
conn.SetServerId(serverId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查服务相关IP黑名单
|
return true
|
||||||
if serverId > 0 && len(this.rawIP) > 0 {
|
|
||||||
var list = iplibrary.SharedServerListManager.FindBlackList(serverId, false)
|
|
||||||
if list != nil {
|
|
||||||
expiresAt, ok := list.ContainsExpires(configutils.IPString2Long(this.rawIP))
|
|
||||||
if ok {
|
|
||||||
_ = this.rawConn.Close()
|
|
||||||
if expiresAt > 0 {
|
|
||||||
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerId 读取当前连接绑定的服务ID
|
// ServerId 读取当前连接绑定的服务ID
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ type ClientConnInterface interface {
|
|||||||
ServerId() int64
|
ServerId() int64
|
||||||
|
|
||||||
// SetServerId 设置服务ID
|
// SetServerId 设置服务ID
|
||||||
SetServerId(serverId int64)
|
SetServerId(serverId int64) (goNext bool)
|
||||||
|
|
||||||
// SetUserId 设置所属服务的用户ID
|
// SetUserId 设置所属服务的用户ID
|
||||||
SetUserId(userId int64)
|
SetUserId(userId int64)
|
||||||
|
|||||||
@@ -45,20 +45,17 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
|||||||
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
|
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
|
||||||
isInAllowList = inAllowList
|
isInAllowList = inAllowList
|
||||||
if !canGoNext {
|
if !canGoNext {
|
||||||
if expiresAt > 0 {
|
|
||||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||||
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
var ok = false
|
||||||
|
expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||||
if ok {
|
if ok {
|
||||||
canGoNext = false
|
canGoNext = false
|
||||||
if expiresAt > 0 {
|
|
||||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !canGoNext {
|
if !canGoNext {
|
||||||
tcpConn, ok := conn.(*net.TCPConn)
|
tcpConn, ok := conn.(*net.TCPConn)
|
||||||
|
|||||||
@@ -179,7 +179,10 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
|||||||
if requestConn != nil {
|
if requestConn != nil {
|
||||||
clientConn, ok := requestConn.(ClientConnInterface)
|
clientConn, ok := requestConn.(ClientConnInterface)
|
||||||
if ok {
|
if ok {
|
||||||
clientConn.SetServerId(server.Id)
|
var goNext = clientConn.SetServerId(server.Id)
|
||||||
|
if !goNext {
|
||||||
|
return
|
||||||
|
}
|
||||||
clientConn.SetUserId(server.UserId)
|
clientConn.SetUserId(server.UserId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,7 +75,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
|||||||
// 绑定连接和服务
|
// 绑定连接和服务
|
||||||
clientConn, ok := conn.(ClientConnInterface)
|
clientConn, ok := conn.(ClientConnInterface)
|
||||||
if ok {
|
if ok {
|
||||||
clientConn.SetServerId(server.Id)
|
var goNext = clientConn.SetServerId(server.Id)
|
||||||
|
if !goNext {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
clientConn.SetUserId(server.UserId)
|
clientConn.SetUserId(server.UserId)
|
||||||
} else {
|
} else {
|
||||||
tlsConn, ok := conn.(*tls.Conn)
|
tlsConn, ok := conn.(*tls.Conn)
|
||||||
@@ -84,7 +87,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
|||||||
if internalConn != nil {
|
if internalConn != nil {
|
||||||
clientConn, ok = internalConn.(ClientConnInterface)
|
clientConn, ok = internalConn.(ClientConnInterface)
|
||||||
if ok {
|
if ok {
|
||||||
clientConn.SetServerId(server.Id)
|
var goNext = clientConn.SetServerId(server.Id)
|
||||||
|
if !goNext {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
clientConn.SetUserId(server.UserId)
|
clientConn.SetUserId(server.UserId)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -125,7 +125,7 @@ func (this *IPList) RecordIP(ipType string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 使用本地防火墙
|
// 使用本地防火墙
|
||||||
if useLocalFirewall && expiresAt > 0 {
|
if useLocalFirewall {
|
||||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user