diff --git a/internal/firewalls/utils.go b/internal/firewalls/utils.go index ea38919..b3368f8 100644 --- a/internal/firewalls/utils.go +++ b/internal/firewalls/utils.go @@ -8,8 +8,9 @@ import ( // DropTemporaryTo 使用本地防火墙临时拦截IP数据包 func DropTemporaryTo(ip string, expiresAt int64) { - if expiresAt <= 1 { - return + // 如果为0,则表示是长期有效 + if expiresAt <= 0 { + expiresAt = time.Now().Unix() + 3600 } var timeout = expiresAt - time.Now().Unix() diff --git a/internal/nodes/client_conn_base.go b/internal/nodes/client_conn_base.go index e7bcb45..35d770d 100644 --- a/internal/nodes/client_conn_base.go +++ b/internal/nodes/client_conn_base.go @@ -4,7 +4,6 @@ package nodes import ( "crypto/tls" - "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeNode/internal/firewalls" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "net" @@ -51,7 +50,20 @@ func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerS } // SetServerId 设置服务ID -func (this *BaseClientConn) SetServerId(serverId int64) { +func (this *BaseClientConn) SetServerId(serverId int64) (goNext bool) { + goNext = true + + // 检查服务相关IP黑名单 + if serverId > 0 && len(this.rawIP) > 0 { + // 是否在白名单中 + ok, _, expiresAt := iplibrary.AllowIP(this.rawIP, serverId) + if !ok { + _ = this.rawConn.Close() + firewalls.DropTemporaryTo(this.rawIP, expiresAt) + return false + } + } + this.serverId = serverId // 设置包装前连接 @@ -65,19 +77,7 @@ func (this *BaseClientConn) SetServerId(serverId int64) { conn.SetServerId(serverId) } - // 检查服务相关IP黑名单 - if serverId > 0 && len(this.rawIP) > 0 { - var list = iplibrary.SharedServerListManager.FindBlackList(serverId, false) - if list != nil { - expiresAt, ok := list.ContainsExpires(configutils.IPString2Long(this.rawIP)) - if ok { - _ = this.rawConn.Close() - if expiresAt > 0 { - firewalls.DropTemporaryTo(this.rawIP, expiresAt) - } - } - } - } + return true } // ServerId 读取当前连接绑定的服务ID diff --git a/internal/nodes/client_conn_interface.go b/internal/nodes/client_conn_interface.go index 75c9926..738602e 100644 --- a/internal/nodes/client_conn_interface.go +++ b/internal/nodes/client_conn_interface.go @@ -16,7 +16,7 @@ type ClientConnInterface interface { ServerId() int64 // SetServerId 设置服务ID - SetServerId(serverId int64) + SetServerId(serverId int64) (goNext bool) // SetUserId 设置所属服务的用户ID SetUserId(userId int64) diff --git a/internal/nodes/client_listener.go b/internal/nodes/client_listener.go index 18c6b29..a76ba63 100644 --- a/internal/nodes/client_listener.go +++ b/internal/nodes/client_listener.go @@ -45,17 +45,14 @@ func (this *ClientListener) Accept() (net.Conn, error) { canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0) isInAllowList = inAllowList if !canGoNext { - if expiresAt > 0 { - firewalls.DropTemporaryTo(ip, expiresAt) - } + firewalls.DropTemporaryTo(ip, expiresAt) } else { if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) { - expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) + var ok = false + expiresAt, ok = waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) if ok { canGoNext = false - if expiresAt > 0 { - firewalls.DropTemporaryTo(ip, expiresAt) - } + firewalls.DropTemporaryTo(ip, expiresAt) } } } diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index 071032e..861e2ac 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -179,7 +179,10 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http. if requestConn != nil { clientConn, ok := requestConn.(ClientConnInterface) if ok { - clientConn.SetServerId(server.Id) + var goNext = clientConn.SetServerId(server.Id) + if !goNext { + return + } clientConn.SetUserId(server.UserId) } } diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index a41ed74..306a14a 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -75,7 +75,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error { // 绑定连接和服务 clientConn, ok := conn.(ClientConnInterface) if ok { - clientConn.SetServerId(server.Id) + var goNext = clientConn.SetServerId(server.Id) + if !goNext { + return nil + } clientConn.SetUserId(server.UserId) } else { tlsConn, ok := conn.(*tls.Conn) @@ -84,7 +87,10 @@ func (this *TCPListener) handleConn(conn net.Conn) error { if internalConn != nil { clientConn, ok = internalConn.(ClientConnInterface) if ok { - clientConn.SetServerId(server.Id) + var goNext = clientConn.SetServerId(server.Id) + if !goNext { + return nil + } clientConn.SetUserId(server.UserId) } } diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 067dab8..d1acaa6 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -125,7 +125,7 @@ func (this *IPList) RecordIP(ipType string, } // 使用本地防火墙 - if useLocalFirewall && expiresAt > 0 { + if useLocalFirewall { firewalls.DropTemporaryTo(ip, expiresAt) } }