diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index ae7c6f9..e2fb0b0 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -11,6 +11,7 @@ import ( "github.com/iwind/TeaGo/types" "io" "io/ioutil" + "net" "net/http" ) @@ -305,3 +306,17 @@ func (this *HTTPRequest) WAFRestoreBody(data []byte) { func (this *HTTPRequest) WAFServerId() int64 { return this.Server.Id } + +// WAFClose 关闭连接 +func (this *HTTPRequest) WAFClose() { + requestConn := this.RawReq.Context().Value(HTTPConnContextKey) + if requestConn == nil { + return + } + conn, ok := requestConn.(net.Conn) + if ok { + _ = conn.Close() + return + } + return +} diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index 5421172..e9ebdba 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -66,16 +66,7 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque if writer != nil { // close the connection - defer func() { - hijack, ok := writer.(http.Hijacker) - if ok { - conn, _, _ := hijack.Hijack() - if conn != nil { - _ = conn.Close() - return - } - } - }() + defer request.WAFClose() // output response if this.StatusCode > 0 { @@ -128,5 +119,6 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque _, _ = writer.Write([]byte("The request is blocked by " + teaconst.ProductName)) } } + return false } diff --git a/internal/waf/requests/request.go b/internal/waf/requests/request.go index 5fb63fc..9c4b698 100644 --- a/internal/waf/requests/request.go +++ b/internal/waf/requests/request.go @@ -26,6 +26,9 @@ type Request interface { // WAFServerId 服务ID WAFServerId() int64 + // WAFClose 关闭当前请求所在的连接 + WAFClose() + // Format 格式化变量 Format(string) string } diff --git a/internal/waf/requests/test_request.go b/internal/waf/requests/test_request.go index 2a9ca48..114f462 100644 --- a/internal/waf/requests/test_request.go +++ b/internal/waf/requests/test_request.go @@ -66,6 +66,10 @@ func (this *TestRequest) WAFServerId() int64 { return 0 } +// WAFClose 关闭当前请求所在的连接 +func (this *TestRequest) WAFClose() { +} + func (this *TestRequest) Format(s string) string { return s }