From a1aa2b9224f8c6f688dc708409d9262c36bdfb86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Wed, 29 Sep 2021 09:19:45 +0800 Subject: [PATCH] =?UTF-8?q?Block=E5=8A=A8=E4=BD=9C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E6=97=B6=E9=97=B460=E7=A7=92?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/http_request_waf.go | 9 +++++++++ internal/nodes/listener_http.go | 10 ++++++++++ internal/nodes/traffic_conn.go | 8 +++++++- internal/nodes/traffic_listener.go | 2 +- internal/waf/action_block.go | 8 +++++--- internal/waf/action_record_ip.go | 2 +- internal/waf/ip_list.go | 2 +- 7 files changed, 34 insertions(+), 7 deletions(-) diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index bab4d2c..ae7c6f9 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -16,6 +16,15 @@ import ( // 调用WAF func (this *HTTPRequest) doWAFRequest() (blocked bool) { + // 当前连接是否已关闭 + var conn = this.RawReq.Context().Value(HTTPConnContextKey) + if conn != nil { + trafficConn, ok := conn.(*TrafficConn) + if ok && trafficConn.IsClosed() { + return true + } + } + // 当前服务的独立设置 if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy) diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index c5744c4..ba05424 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -1,6 +1,7 @@ package nodes import ( + "context" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "golang.org/x/net/http2" @@ -18,6 +19,12 @@ var httpErrorLogger = log.New(io.Discard, "", 0) var metricNewConnMap = map[string]bool{} // remoteAddr => bool var metricNewConnMapLocker = &sync.Mutex{} +type contextKey struct { + key string +} + +var HTTPConnContextKey = &contextKey{key: "http-conn"} + type HTTPListener struct { BaseListener @@ -65,6 +72,9 @@ func (this *HTTPListener) Serve() error { metricNewConnMapLocker.Unlock() } }, + ConnContext: func(ctx context.Context, c net.Conn) context.Context { + return context.WithValue(ctx, HTTPConnContextKey, c) + }, } this.httpServer.SetKeepAlivesEnabled(true) diff --git a/internal/nodes/traffic_conn.go b/internal/nodes/traffic_conn.go index 9ced9dd..97b151e 100644 --- a/internal/nodes/traffic_conn.go +++ b/internal/nodes/traffic_conn.go @@ -44,7 +44,8 @@ func init() { // TrafficConn 用于统计流量的连接 type TrafficConn struct { - rawConn net.Conn + rawConn net.Conn + isClosed bool } func NewTrafficConn(conn net.Conn) net.Conn { @@ -68,6 +69,7 @@ func (this *TrafficConn) Write(b []byte) (n int, err error) { } func (this *TrafficConn) Close() error { + this.isClosed = true return this.rawConn.Close() } @@ -90,3 +92,7 @@ func (this *TrafficConn) SetReadDeadline(t time.Time) error { func (this *TrafficConn) SetWriteDeadline(t time.Time) error { return this.rawConn.SetWriteDeadline(t) } + +func (this *TrafficConn) IsClosed() bool { + return this.isClosed +} diff --git a/internal/nodes/traffic_listener.go b/internal/nodes/traffic_listener.go index dd0ffd0..c0be765 100644 --- a/internal/nodes/traffic_listener.go +++ b/internal/nodes/traffic_listener.go @@ -24,7 +24,7 @@ func (this *TrafficListener) Accept() (net.Conn, error) { // 是否在WAF名单中 ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err == nil { - if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackLIst.Contains(waf.IPTypeAll, ip) { + if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackList.Contains(waf.IPTypeAll, ip) { defer func() { _ = conn.Close() }() diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index 4b91fe4..5421172 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -57,10 +57,12 @@ func (this *BlockAction) WillChange() bool { } func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { - if this.Timeout > 0 { - // 加入到黑名单 - SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(this.Timeout)) + // 加入到黑名单 + var timeout = this.Timeout + if timeout <= 0 { + timeout = 60 // 默认封锁60秒 } + SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(timeout)) if writer != nil { // close the connection diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go index 8a34906..353b6f9 100644 --- a/internal/waf/action_record_ip.go +++ b/internal/waf/action_record_ip.go @@ -92,7 +92,7 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re if this.Type == "black" { _ = this.CloseConn(writer) - SharedIPBlackLIst.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt) + SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt) } else { // 加入本地白名单 timeout := this.Timeout diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 5c53624..5ade6ef 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -9,7 +9,7 @@ import ( ) var SharedIPWhiteList = NewIPList() -var SharedIPBlackLIst = NewIPList() +var SharedIPBlackList = NewIPList() const IPTypeAll = "*"