diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index ca4274c..a89df76 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -89,7 +89,9 @@ type HTTPRequest struct { firewallRuleSetId int64 firewallRuleId int64 firewallActions []string - tags []string + wafHasRequestBody bool + + tags []string logAttrs map[string]string diff --git a/internal/nodes/http_request_log.go b/internal/nodes/http_request_log.go index c6d74ac..a2f71b5 100644 --- a/internal/nodes/http_request_log.go +++ b/internal/nodes/http_request_log.go @@ -149,8 +149,7 @@ func (this *HTTPRequest) log() { } // 请求Body - // TODO 考虑在被攻击时记录攻击的requestBody(如果requestBody匹配规则的话),但要考虑请求尺寸、数据库容量,避免因为日志而导致服务不稳定 - if ref != nil && ref.ContainsField(serverconfigs.HTTPAccessLogFieldRequestBody) { + if (ref != nil && ref.ContainsField(serverconfigs.HTTPAccessLogFieldRequestBody)) || this.wafHasRequestBody { accessLog.RequestBody = this.requestBodyData if len(accessLog.RequestBody) > AccessLogMaxRequestBodySize { diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 1e4855c..4f126ee 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -53,11 +53,19 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { return true } - var forceLog = this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn && this.ReqServer.HTTPFirewallPolicy.Log != nil && this.ReqServer.HTTPFirewallPolicy.Log.IsOn + var forceLog = false + var forceLogRequestBody = false + if this.ReqServer.HTTPFirewallPolicy != nil && + this.ReqServer.HTTPFirewallPolicy.IsOn && + this.ReqServer.HTTPFirewallPolicy.Log != nil && + this.ReqServer.HTTPFirewallPolicy.Log.IsOn { + forceLog = true + forceLogRequestBody = this.ReqServer.HTTPFirewallPolicy.Log.RequestBody + } // 当前服务的独立设置 if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { - blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog) + blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody) if blocked { return true } @@ -68,7 +76,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { // 公用的防火墙设置 if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn { - blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog) + blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody) if blocked { return true } @@ -80,7 +88,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { return } -func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, forceLog bool) (blocked bool, breakChecking bool) { +func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, forceLog bool, logRequestBody bool) (blocked bool, breakChecking bool) { // 检查配置是否为空 if firewallPolicy == nil || !firewallPolicy.IsOn || firewallPolicy.Inbound == nil || !firewallPolicy.Inbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass { return @@ -199,7 +207,10 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir return } - goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer) + goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer) + if forceLog && logRequestBody && hasRequestBody { + this.wafHasRequestBody = true + } if err != nil { if !this.canIgnore(err) { remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error()) @@ -238,9 +249,15 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) { } // 当前服务的独立设置 - var forceLog = this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn && this.ReqServer.HTTPFirewallPolicy.Log != nil && this.ReqServer.HTTPFirewallPolicy.Log.IsOn + var forceLog = false + var forceLogRequestBody = false + if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn && this.ReqServer.HTTPFirewallPolicy.Log != nil && this.ReqServer.HTTPFirewallPolicy.Log.IsOn { + forceLog = true + forceLogRequestBody = this.ReqServer.HTTPFirewallPolicy.Log.RequestBody + } + if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { - blocked := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog) + blocked := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody) if blocked { return true } @@ -248,7 +265,7 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) { // 公用的防火墙设置 if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn { - blocked := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog) + blocked := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody) if blocked { return true } @@ -256,7 +273,7 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) { return } -func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool) (blocked bool) { +func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool) (blocked bool) { if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass { return } @@ -266,7 +283,10 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi return } - goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer) + goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer) + if forceLog && logRequestBody && hasRequestBody { + this.wafHasRequestBody = true + } if err != nil { if !this.canIgnore(err) { remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error()) diff --git a/internal/waf/action_go_group.go b/internal/waf/action_go_group.go index fe8afc3..252fd35 100644 --- a/internal/waf/action_go_group.go +++ b/internal/waf/action_go_group.go @@ -35,7 +35,7 @@ func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req return true } - b, nextSet, err := nextGroup.MatchRequest(request) + b, _, nextSet, err := nextGroup.MatchRequest(request) if err != nil { logs.Error(err) return true diff --git a/internal/waf/action_go_set.go b/internal/waf/action_go_set.go index ea998e9..76cf84a 100644 --- a/internal/waf/action_go_set.go +++ b/internal/waf/action_go_set.go @@ -40,7 +40,7 @@ func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque return true } - b, err := nextSet.MatchRequest(request) + b, _, err := nextSet.MatchRequest(request) if err != nil { logs.Error(err) return true diff --git a/internal/waf/checkpoints/cc.go b/internal/waf/checkpoints/cc.go index f1b9722..bdef428 100644 --- a/internal/waf/checkpoints/cc.go +++ b/internal/waf/checkpoints/cc.go @@ -30,7 +30,7 @@ func (this *CCCheckpoint) Start() { this.cache = ttlcache.NewCache() } -func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *CCCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = 0 if this.cache == nil { @@ -120,7 +120,7 @@ func (this *CCCheckpoint) RequestValue(req requests.Request, param string, optio return } -func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *CCCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/cc2.go b/internal/waf/checkpoints/cc2.go index ead9c64..f247d06 100644 --- a/internal/waf/checkpoints/cc2.go +++ b/internal/waf/checkpoints/cc2.go @@ -32,7 +32,7 @@ type CC2Checkpoint struct { Checkpoint } -func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { var keys = options.GetSlice("keys") var keyValues = []string{} for _, key := range keys { @@ -71,6 +71,6 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti return } -func (this *CC2Checkpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *CC2Checkpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return } diff --git a/internal/waf/checkpoints/checkpoint_interface.go b/internal/waf/checkpoints/checkpoint_interface.go index 532ae62..535e9e5 100644 --- a/internal/waf/checkpoints/checkpoint_interface.go +++ b/internal/waf/checkpoints/checkpoint_interface.go @@ -17,10 +17,10 @@ type CheckpointInterface interface { IsComposed() bool // RequestValue get request value - RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) + RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) // ResponseValue get response value - ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) + ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) // ParamOptions param option list ParamOptions() *ParamOptions diff --git a/internal/waf/checkpoints/request_all.go b/internal/waf/checkpoints/request_all.go index b347568..213ab0b 100644 --- a/internal/waf/checkpoints/request_all.go +++ b/internal/waf/checkpoints/request_all.go @@ -11,7 +11,7 @@ type RequestAllCheckpoint struct { Checkpoint } -func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { valueBytes := []byte{} if len(req.WAFRaw().RequestURI) > 0 { valueBytes = append(valueBytes, req.WAFRaw().RequestURI...) @@ -28,10 +28,11 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin valueBytes = append(valueBytes, ' ') var bodyData = req.WAFGetCacheBody() + hasRequestBody = true if len(bodyData) == 0 { data, err := req.WAFReadBody(utils.MaxBodySize) // read body if err != nil { - return "", err, nil + return "", hasRequestBody, err, nil } bodyData = data @@ -46,7 +47,7 @@ func (this *RequestAllCheckpoint) RequestValue(req requests.Request, param strin return } -func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestAllCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = "" if this.IsRequest() { return this.RequestValue(req, param, options) diff --git a/internal/waf/checkpoints/request_all_test.go b/internal/waf/checkpoints/request_all_test.go index 5ee81d5..6f7ebec 100644 --- a/internal/waf/checkpoints/request_all_test.go +++ b/internal/waf/checkpoints/request_all_test.go @@ -18,7 +18,7 @@ func TestRequestAllCheckpoint_RequestValue(t *testing.T) { } checkpoint := new(RequestAllCheckpoint) - v, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) + v, _, sysErr, userErr := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) if sysErr != nil { t.Fatal(sysErr) } @@ -42,7 +42,7 @@ func TestRequestAllCheckpoint_RequestValue_Max(t *testing.T) { } checkpoint := new(RequestBodyCheckpoint) - value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) + value, _, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) if err != nil { t.Fatal(err) } @@ -65,6 +65,6 @@ func BenchmarkRequestAllCheckpoint_RequestValue(b *testing.B) { checkpoint := new(RequestAllCheckpoint) for i := 0; i < b.N; i++ { - _, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) + _, _, _, _ = checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) } } diff --git a/internal/waf/checkpoints/request_arg.go b/internal/waf/checkpoints/request_arg.go index a9c51a5..4026be9 100644 --- a/internal/waf/checkpoints/request_arg.go +++ b/internal/waf/checkpoints/request_arg.go @@ -9,11 +9,11 @@ type RequestArgCheckpoint struct { Checkpoint } -func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - return req.WAFRaw().URL.Query().Get(param), nil, nil +func (this *RequestArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { + return req.WAFRaw().URL.Query().Get(param), hasRequestBody, nil, nil } -func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_args.go b/internal/waf/checkpoints/request_args.go index a83dc3f..237c443 100644 --- a/internal/waf/checkpoints/request_args.go +++ b/internal/waf/checkpoints/request_args.go @@ -9,12 +9,12 @@ type RequestArgsCheckpoint struct { Checkpoint } -func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestArgsCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().URL.RawQuery return } -func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestArgsCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_body.go b/internal/waf/checkpoints/request_body.go index 90884d6..cdc1ac2 100644 --- a/internal/waf/checkpoints/request_body.go +++ b/internal/waf/checkpoints/request_body.go @@ -11,7 +11,7 @@ type RequestBodyCheckpoint struct { Checkpoint } -func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.RequestBodyIsEmpty(req) { value = "" return @@ -23,10 +23,11 @@ func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param stri } var bodyData = req.WAFGetCacheBody() + hasRequestBody = true if len(bodyData) == 0 { data, err := req.WAFReadBody(utils.MaxBodySize) // read body if err != nil { - return "", err, nil + return "", hasRequestBody, err, nil } bodyData = data @@ -34,10 +35,10 @@ func (this *RequestBodyCheckpoint) RequestValue(req requests.Request, param stri req.WAFRestoreBody(data) } - return bodyData, nil, nil + return bodyData, hasRequestBody, nil, nil } -func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_body_test.go b/internal/waf/checkpoints/request_body_test.go index 8bdb0d2..63e00e5 100644 --- a/internal/waf/checkpoints/request_body_test.go +++ b/internal/waf/checkpoints/request_body_test.go @@ -34,7 +34,7 @@ func TestRequestBodyCheckpoint_RequestValue_Max(t *testing.T) { } checkpoint := new(RequestBodyCheckpoint) - value, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) + value, _, err, _ := checkpoint.RequestValue(requests.NewTestRequest(req), "", nil) if err != nil { t.Fatal(err) } diff --git a/internal/waf/checkpoints/request_content_type.go b/internal/waf/checkpoints/request_content_type.go index 6ff04cd..3a5fa7f 100644 --- a/internal/waf/checkpoints/request_content_type.go +++ b/internal/waf/checkpoints/request_content_type.go @@ -9,12 +9,12 @@ type RequestContentTypeCheckpoint struct { Checkpoint } -func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestContentTypeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().Header.Get("Content-Type") return } -func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestContentTypeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_cookie.go b/internal/waf/checkpoints/request_cookie.go index 33fd968..5986e61 100644 --- a/internal/waf/checkpoints/request_cookie.go +++ b/internal/waf/checkpoints/request_cookie.go @@ -9,7 +9,7 @@ type RequestCookieCheckpoint struct { Checkpoint } -func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { cookie, err := req.WAFRaw().Cookie(param) if err != nil { value = "" @@ -20,7 +20,7 @@ func (this *RequestCookieCheckpoint) RequestValue(req requests.Request, param st return } -func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestCookieCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_cookies.go b/internal/waf/checkpoints/request_cookies.go index a9f1035..d499b6a 100644 --- a/internal/waf/checkpoints/request_cookies.go +++ b/internal/waf/checkpoints/request_cookies.go @@ -11,7 +11,7 @@ type RequestCookiesCheckpoint struct { Checkpoint } -func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { var cookies = []string{} for _, cookie := range req.WAFRaw().Cookies() { cookies = append(cookies, url.QueryEscape(cookie.Name)+"="+url.QueryEscape(cookie.Value)) @@ -20,7 +20,7 @@ func (this *RequestCookiesCheckpoint) RequestValue(req requests.Request, param s return } -func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestCookiesCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_form_arg.go b/internal/waf/checkpoints/request_form_arg.go index 8cc3a37..4d0a3da 100644 --- a/internal/waf/checkpoints/request_form_arg.go +++ b/internal/waf/checkpoints/request_form_arg.go @@ -12,7 +12,9 @@ type RequestFormArgCheckpoint struct { Checkpoint } -func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { + hasRequestBody = true + if this.RequestBodyIsEmpty(req) { value = "" return @@ -27,7 +29,7 @@ func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param s if len(bodyData) == 0 { data, err := req.WAFReadBody(utils.MaxBodySize) // read body if err != nil { - return "", err, nil + return "", hasRequestBody, err, nil } bodyData = data @@ -37,10 +39,10 @@ func (this *RequestFormArgCheckpoint) RequestValue(req requests.Request, param s // TODO improve performance values, _ := url.ParseQuery(string(bodyData)) - return values.Get(param), nil, nil + return values.Get(param), hasRequestBody, nil, nil } -func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestFormArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_general_header_length.go b/internal/waf/checkpoints/request_general_header_length.go index 50e8251..55de1dc 100644 --- a/internal/waf/checkpoints/request_general_header_length.go +++ b/internal/waf/checkpoints/request_general_header_length.go @@ -14,15 +14,15 @@ func (this *RequestGeneralHeaderLengthCheckpoint) IsComposed() bool { return true } -func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = false - headers := options.GetSlice("headers") + var headers = options.GetSlice("headers") if len(headers) == 0 { return } - length := options.GetInt("length") + var length = options.GetInt("length") for _, header := range headers { v := req.WAFRaw().Header.Get(types.String(header)) @@ -35,6 +35,6 @@ func (this *RequestGeneralHeaderLengthCheckpoint) RequestValue(req requests.Requ return } -func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return } diff --git a/internal/waf/checkpoints/request_geo_city_name.go b/internal/waf/checkpoints/request_geo_city_name.go index e79ab87..bb38b8f 100644 --- a/internal/waf/checkpoints/request_geo_city_name.go +++ b/internal/waf/checkpoints/request_geo_city_name.go @@ -15,11 +15,11 @@ func (this *RequestGeoCityNameCheckpoint) IsComposed() bool { return false } -func (this *RequestGeoCityNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeoCityNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.Format("${geo.city.name}") return } -func (this *RequestGeoCityNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeoCityNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_geo_country_name.go b/internal/waf/checkpoints/request_geo_country_name.go index 438310f..b11c317 100644 --- a/internal/waf/checkpoints/request_geo_country_name.go +++ b/internal/waf/checkpoints/request_geo_country_name.go @@ -15,11 +15,11 @@ func (this *RequestGeoCountryNameCheckpoint) IsComposed() bool { return false } -func (this *RequestGeoCountryNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeoCountryNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.Format("${geo.country.name}") return } -func (this *RequestGeoCountryNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeoCountryNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_geo_province_name.go b/internal/waf/checkpoints/request_geo_province_name.go index 37d7b88..460c1d8 100644 --- a/internal/waf/checkpoints/request_geo_province_name.go +++ b/internal/waf/checkpoints/request_geo_province_name.go @@ -15,11 +15,11 @@ func (this *RequestGeoProvinceNameCheckpoint) IsComposed() bool { return false } -func (this *RequestGeoProvinceNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeoProvinceNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.Format("${geo.province.name}") return } -func (this *RequestGeoProvinceNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestGeoProvinceNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_header.go b/internal/waf/checkpoints/request_header.go index 8b206d0..e73e9ce 100644 --- a/internal/waf/checkpoints/request_header.go +++ b/internal/waf/checkpoints/request_header.go @@ -10,7 +10,7 @@ type RequestHeaderCheckpoint struct { Checkpoint } -func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { v, found := req.WAFRaw().Header[param] if !found { value = "" @@ -20,7 +20,7 @@ func (this *RequestHeaderCheckpoint) RequestValue(req requests.Request, param st return } -func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_headers.go b/internal/waf/checkpoints/request_headers.go index 0fdb225..cfee677 100644 --- a/internal/waf/checkpoints/request_headers.go +++ b/internal/waf/checkpoints/request_headers.go @@ -11,7 +11,7 @@ type RequestHeadersCheckpoint struct { Checkpoint } -func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { var headers = []string{} for k, v := range req.WAFRaw().Header { for _, subV := range v { @@ -23,7 +23,7 @@ func (this *RequestHeadersCheckpoint) RequestValue(req requests.Request, param s return } -func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHeadersCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_host.go b/internal/waf/checkpoints/request_host.go index 105f4a7..c45c091 100644 --- a/internal/waf/checkpoints/request_host.go +++ b/internal/waf/checkpoints/request_host.go @@ -9,12 +9,12 @@ type RequestHostCheckpoint struct { Checkpoint } -func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHostCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().Host return } -func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestHostCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_isp_name.go b/internal/waf/checkpoints/request_isp_name.go index 68c6b7d..2f9f214 100644 --- a/internal/waf/checkpoints/request_isp_name.go +++ b/internal/waf/checkpoints/request_isp_name.go @@ -15,11 +15,11 @@ func (this *RequestISPNameCheckpoint) IsComposed() bool { return false } -func (this *RequestISPNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestISPNameCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.Format("${isp.name}") return } -func (this *RequestISPNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestISPNameCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_json_arg.go b/internal/waf/checkpoints/request_json_arg.go index 578ab18..374a122 100644 --- a/internal/waf/checkpoints/request_json_arg.go +++ b/internal/waf/checkpoints/request_json_arg.go @@ -14,12 +14,13 @@ type RequestJSONArgCheckpoint struct { Checkpoint } -func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { var bodyData = req.WAFGetCacheBody() + hasRequestBody = true if len(bodyData) == 0 { data, err := req.WAFReadBody(wafutils.MaxBodySize) // read body if err != nil { - return "", err, nil + return "", hasRequestBody, err, nil } bodyData = data @@ -31,17 +32,17 @@ func (this *RequestJSONArgCheckpoint) RequestValue(req requests.Request, param s var m interface{} = nil err := json.Unmarshal(bodyData, &m) if err != nil || m == nil { - return "", nil, err + return "", hasRequestBody, nil, err } value = utils.Get(m, strings.Split(param, ".")) if value != nil { - return value, nil, err + return value, hasRequestBody, nil, err } - return "", nil, nil + return "", hasRequestBody, nil, nil } -func (this *RequestJSONArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestJSONArgCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_length.go b/internal/waf/checkpoints/request_length.go index e26a18b..e1f5b4a 100644 --- a/internal/waf/checkpoints/request_length.go +++ b/internal/waf/checkpoints/request_length.go @@ -9,12 +9,12 @@ type RequestLengthCheckpoint struct { Checkpoint } -func (this *RequestLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().ContentLength return } -func (this *RequestLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_method.go b/internal/waf/checkpoints/request_method.go index 3b85fc0..6887b00 100644 --- a/internal/waf/checkpoints/request_method.go +++ b/internal/waf/checkpoints/request_method.go @@ -9,12 +9,12 @@ type RequestMethodCheckpoint struct { Checkpoint } -func (this *RequestMethodCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestMethodCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().Method return } -func (this *RequestMethodCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestMethodCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_path.go b/internal/waf/checkpoints/request_path.go index 5e757bb..8468b07 100644 --- a/internal/waf/checkpoints/request_path.go +++ b/internal/waf/checkpoints/request_path.go @@ -9,11 +9,11 @@ type RequestPathCheckpoint struct { Checkpoint } -func (this *RequestPathCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - return req.WAFRaw().URL.Path, nil, nil +func (this *RequestPathCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { + return req.WAFRaw().URL.Path, false, nil, nil } -func (this *RequestPathCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestPathCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_proto.go b/internal/waf/checkpoints/request_proto.go index 235b2db..b2e94f4 100644 --- a/internal/waf/checkpoints/request_proto.go +++ b/internal/waf/checkpoints/request_proto.go @@ -9,12 +9,12 @@ type RequestProtoCheckpoint struct { Checkpoint } -func (this *RequestProtoCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestProtoCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().Proto return } -func (this *RequestProtoCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestProtoCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_raw_remote_addr.go b/internal/waf/checkpoints/request_raw_remote_addr.go index 7886c44..e6ee90a 100644 --- a/internal/waf/checkpoints/request_raw_remote_addr.go +++ b/internal/waf/checkpoints/request_raw_remote_addr.go @@ -10,7 +10,7 @@ type RequestRawRemoteAddrCheckpoint struct { Checkpoint } -func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { host, _, err := net.SplitHostPort(req.WAFRaw().RemoteAddr) if err == nil { value = host @@ -20,7 +20,7 @@ func (this *RequestRawRemoteAddrCheckpoint) RequestValue(req requests.Request, p return } -func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRawRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_referer.go b/internal/waf/checkpoints/request_referer.go index 775c084..a4da2b7 100644 --- a/internal/waf/checkpoints/request_referer.go +++ b/internal/waf/checkpoints/request_referer.go @@ -9,12 +9,12 @@ type RequestRefererCheckpoint struct { Checkpoint } -func (this *RequestRefererCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRefererCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().Referer() return } -func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRefererCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_referer_block.go b/internal/waf/checkpoints/request_referer_block.go index b03f1b9..732d68a 100644 --- a/internal/waf/checkpoints/request_referer_block.go +++ b/internal/waf/checkpoints/request_referer_block.go @@ -17,7 +17,7 @@ type RequestRefererBlockCheckpoint struct { // RequestValue 计算checkpoint值 // 选项:allowEmpty, allowSameDomain, allowDomains -func (this *RequestRefererBlockCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRefererBlockCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { var referer = req.WAFRaw().Referer() if len(referer) == 0 { @@ -61,6 +61,6 @@ func (this *RequestRefererBlockCheckpoint) RequestValue(req requests.Request, pa return } -func (this *RequestRefererBlockCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRefererBlockCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return } diff --git a/internal/waf/checkpoints/request_remote_addr.go b/internal/waf/checkpoints/request_remote_addr.go index dc26a10..b88e6ef 100644 --- a/internal/waf/checkpoints/request_remote_addr.go +++ b/internal/waf/checkpoints/request_remote_addr.go @@ -9,12 +9,12 @@ type RequestRemoteAddrCheckpoint struct { Checkpoint } -func (this *RequestRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemoteAddrCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRemoteIP() return } -func (this *RequestRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemoteAddrCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_remote_port.go b/internal/waf/checkpoints/request_remote_port.go index f5aa158..9279fb9 100644 --- a/internal/waf/checkpoints/request_remote_port.go +++ b/internal/waf/checkpoints/request_remote_port.go @@ -11,7 +11,7 @@ type RequestRemotePortCheckpoint struct { Checkpoint } -func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { _, port, err := net.SplitHostPort(req.WAFRaw().RemoteAddr) if err == nil { value = types.Int(port) @@ -21,7 +21,7 @@ func (this *RequestRemotePortCheckpoint) RequestValue(req requests.Request, para return } -func (this *RequestRemotePortCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemotePortCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_remote_user.go b/internal/waf/checkpoints/request_remote_user.go index a2d1e20..705c385 100644 --- a/internal/waf/checkpoints/request_remote_user.go +++ b/internal/waf/checkpoints/request_remote_user.go @@ -9,7 +9,7 @@ type RequestRemoteUserCheckpoint struct { Checkpoint } -func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { username, _, ok := req.WAFRaw().BasicAuth() if !ok { value = "" @@ -19,7 +19,7 @@ func (this *RequestRemoteUserCheckpoint) RequestValue(req requests.Request, para return } -func (this *RequestRemoteUserCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestRemoteUserCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_scheme.go b/internal/waf/checkpoints/request_scheme.go index c01fb62..65b116b 100644 --- a/internal/waf/checkpoints/request_scheme.go +++ b/internal/waf/checkpoints/request_scheme.go @@ -9,12 +9,12 @@ type RequestSchemeCheckpoint struct { Checkpoint } -func (this *RequestSchemeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestSchemeCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.Format("${scheme}") return } -func (this *RequestSchemeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestSchemeCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_upload.go b/internal/waf/checkpoints/request_upload.go index 474bb81..6df97d3 100644 --- a/internal/waf/checkpoints/request_upload.go +++ b/internal/waf/checkpoints/request_upload.go @@ -17,7 +17,7 @@ type RequestUploadCheckpoint struct { Checkpoint } -func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.RequestBodyIsEmpty(req) { value = "" return @@ -36,6 +36,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param st return } + hasRequestBody = true if req.WAFRaw().MultipartForm == nil { var bodyData = req.WAFGetCacheBody() if len(bodyData) == 0 { @@ -121,7 +122,7 @@ func (this *RequestUploadCheckpoint) RequestValue(req requests.Request, param st return } -func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestUploadCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_uri.go b/internal/waf/checkpoints/request_uri.go index bfe72fd..49417ff 100644 --- a/internal/waf/checkpoints/request_uri.go +++ b/internal/waf/checkpoints/request_uri.go @@ -9,7 +9,7 @@ type RequestURICheckpoint struct { Checkpoint } -func (this *RequestURICheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestURICheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if len(req.WAFRaw().RequestURI) > 0 { value = req.WAFRaw().RequestURI } else if req.WAFRaw().URL != nil { @@ -18,7 +18,7 @@ func (this *RequestURICheckpoint) RequestValue(req requests.Request, param strin return } -func (this *RequestURICheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestURICheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_url.go b/internal/waf/checkpoints/request_url.go index 88c8727..dfc5287 100644 --- a/internal/waf/checkpoints/request_url.go +++ b/internal/waf/checkpoints/request_url.go @@ -9,11 +9,11 @@ type RequestURLCheckpoint struct { Checkpoint } -func (this *RequestURLCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { - return req.Format("${requestURL}"), nil, nil +func (this *RequestURLCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { + return req.Format("${requestURL}"), hasRequestBody, nil, nil } -func (this *RequestURLCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestURLCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/request_user_agent.go b/internal/waf/checkpoints/request_user_agent.go index a9c1bec..ee894d3 100644 --- a/internal/waf/checkpoints/request_user_agent.go +++ b/internal/waf/checkpoints/request_user_agent.go @@ -9,12 +9,12 @@ type RequestUserAgentCheckpoint struct { Checkpoint } -func (this *RequestUserAgentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestUserAgentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = req.WAFRaw().UserAgent() return } -func (this *RequestUserAgentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *RequestUserAgentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/checkpoints/response_body.go b/internal/waf/checkpoints/response_body.go index beffed2..02e4c35 100644 --- a/internal/waf/checkpoints/response_body.go +++ b/internal/waf/checkpoints/response_body.go @@ -7,7 +7,7 @@ import ( "io/ioutil" ) -// ${responseBody} +// ResponseBodyCheckpoint ${responseBody} type ResponseBodyCheckpoint struct { Checkpoint } @@ -16,12 +16,12 @@ func (this *ResponseBodyCheckpoint) IsRequest() bool { return false } -func (this *ResponseBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBodyCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = "" return } -func (this *ResponseBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBodyCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if resp.ContentLength == 0 { value = "" return diff --git a/internal/waf/checkpoints/response_bytes_sent.go b/internal/waf/checkpoints/response_bytes_sent.go index 75a719a..5e017a9 100644 --- a/internal/waf/checkpoints/response_bytes_sent.go +++ b/internal/waf/checkpoints/response_bytes_sent.go @@ -5,7 +5,7 @@ import ( "github.com/iwind/TeaGo/maps" ) -// ${bytesSent} +// ResponseBytesSentCheckpoint ${bytesSent} type ResponseBytesSentCheckpoint struct { Checkpoint } @@ -14,12 +14,12 @@ func (this *ResponseBytesSentCheckpoint) IsRequest() bool { return false } -func (this *ResponseBytesSentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBytesSentCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = 0 return } -func (this *ResponseBytesSentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseBytesSentCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = 0 if resp != nil { value = resp.ContentLength diff --git a/internal/waf/checkpoints/response_general_header_length.go b/internal/waf/checkpoints/response_general_header_length.go index 00404fd..11376bb 100644 --- a/internal/waf/checkpoints/response_general_header_length.go +++ b/internal/waf/checkpoints/response_general_header_length.go @@ -18,12 +18,12 @@ func (this *ResponseGeneralHeaderLengthCheckpoint) IsComposed() bool { return true } -func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseGeneralHeaderLengthCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return } -func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseGeneralHeaderLengthCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = false headers := options.GetSlice("headers") diff --git a/internal/waf/checkpoints/response_header.go b/internal/waf/checkpoints/response_header.go index 839e657..cd321df 100644 --- a/internal/waf/checkpoints/response_header.go +++ b/internal/waf/checkpoints/response_header.go @@ -5,7 +5,7 @@ import ( "github.com/iwind/TeaGo/maps" ) -// ${responseHeader.arg} +// ResponseHeaderCheckpoint ${responseHeader.arg} type ResponseHeaderCheckpoint struct { Checkpoint } @@ -14,12 +14,12 @@ func (this *ResponseHeaderCheckpoint) IsRequest() bool { return false } -func (this *ResponseHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseHeaderCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = "" return } -func (this *ResponseHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseHeaderCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if resp != nil && resp.Header != nil { value = resp.Header.Get(param) } else { diff --git a/internal/waf/checkpoints/response_status.go b/internal/waf/checkpoints/response_status.go index eb9a9bd..b17ca02 100644 --- a/internal/waf/checkpoints/response_status.go +++ b/internal/waf/checkpoints/response_status.go @@ -14,12 +14,12 @@ func (this *ResponseStatusCheckpoint) IsRequest() bool { return false } -func (this *ResponseStatusCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseStatusCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { value = 0 return } -func (this *ResponseStatusCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *ResponseStatusCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if resp != nil { value = resp.StatusCode } diff --git a/internal/waf/checkpoints/sample_request.go b/internal/waf/checkpoints/sample_request.go index 1aa1197..33357c5 100644 --- a/internal/waf/checkpoints/sample_request.go +++ b/internal/waf/checkpoints/sample_request.go @@ -10,11 +10,11 @@ type SampleRequestCheckpoint struct { Checkpoint } -func (this *SampleRequestCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *SampleRequestCheckpoint) RequestValue(req requests.Request, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { return } -func (this *SampleRequestCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, sysErr error, userErr error) { +func (this *SampleRequestCheckpoint) ResponseValue(req requests.Request, resp *requests.Response, param string, options maps.Map) (value interface{}, hasRequestBody bool, sysErr error, userErr error) { if this.IsRequest() { return this.RequestValue(req, param, options) } diff --git a/internal/waf/rule.go b/internal/waf/rule.go index 06beebe..c21c8e0 100644 --- a/internal/waf/rule.go +++ b/internal/waf/rule.go @@ -184,11 +184,14 @@ func (this *Rule) Init() error { return err } -func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) { +func (this *Rule) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, err error) { if this.singleCheckpoint != nil { - value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions) + value, hasCheckedRequestBody, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions) + if hasCheckedRequestBody { + hasRequestBody = true + } if err != nil { - return false, err + return false, hasRequestBody, err } // execute filters @@ -198,10 +201,10 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) { // if is composed checkpoint, we just returns true or false if this.singleCheckpoint.IsComposed() { - return types.Bool(value), nil + return types.Bool(value), hasRequestBody, nil } - return this.Test(value), nil + return this.Test(value), hasRequestBody, nil } value := configutils.ParseVariables(this.Param, func(varName string) (value string) { @@ -213,14 +216,20 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) { } if len(pieces) == 1 { - value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions) + value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, "", this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { err = err1 } return types.String(value1) } - value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions) + value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { err = err1 } @@ -228,19 +237,22 @@ func (this *Rule) MatchRequest(req requests.Request) (b bool, err error) { }) if err != nil { - return false, err + return false, hasRequestBody, err } - return this.Test(value), nil + return this.Test(value), hasRequestBody, nil } -func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) { +func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) (b bool, hasRequestBody bool, err error) { if this.singleCheckpoint != nil { // if is request param if this.singleCheckpoint.IsRequest() { - value, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions) + value, hasCheckRequestBody, err, _ := this.singleCheckpoint.RequestValue(req, this.singleParam, this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err != nil { - return false, err + return false, hasRequestBody, err } // execute filters @@ -248,21 +260,24 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( value = this.execFilter(value) } - return this.Test(value), nil + return this.Test(value), hasRequestBody, nil } // response param - value, err, _ := this.singleCheckpoint.ResponseValue(req, resp, this.singleParam, this.CheckpointOptions) + value, hasCheckRequestBody, err, _ := this.singleCheckpoint.ResponseValue(req, resp, this.singleParam, this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err != nil { - return false, err + return false, hasRequestBody, err } // if is composed checkpoint, we just returns true or false if this.singleCheckpoint.IsComposed() { - return types.Bool(value), nil + return types.Bool(value), hasRequestBody, nil } - return this.Test(value), nil + return this.Test(value), hasRequestBody, nil } value := configutils.ParseVariables(this.Param, func(varName string) (value string) { @@ -275,13 +290,19 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( if len(pieces) == 1 { if point.IsRequest() { - value1, err1, _ := point.RequestValue(req, "", this.CheckpointOptions) + value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, "", this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { err = err1 } return types.String(value1) } else { - value1, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions) + value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, "", this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { err = err1 } @@ -290,13 +311,19 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( } if point.IsRequest() { - value1, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions) + value1, hasCheckRequestBody, err1, _ := point.RequestValue(req, pieces[1], this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { err = err1 } return types.String(value1) } else { - value1, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions) + value1, hasCheckRequestBody, err1, _ := point.ResponseValue(req, resp, pieces[1], this.CheckpointOptions) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { err = err1 } @@ -305,10 +332,10 @@ func (this *Rule) MatchResponse(req requests.Request, resp *requests.Response) ( }) if err != nil { - return false, err + return false, hasRequestBody, err } - return this.Test(value), nil + return this.Test(value), hasRequestBody, nil } func (this *Rule) Test(value interface{}) bool { diff --git a/internal/waf/rule_group.go b/internal/waf/rule_group.go index 5d54ace..a0e8926 100644 --- a/internal/waf/rule_group.go +++ b/internal/waf/rule_group.go @@ -75,7 +75,7 @@ func (this *RuleGroup) RemoveRuleSet(id int64) { this.RuleSets = result } -func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, set *RuleSet, err error) { +func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, set *RuleSet, err error) { if !this.hasRuleSets { return } @@ -83,18 +83,18 @@ func (this *RuleGroup) MatchRequest(req requests.Request) (b bool, set *RuleSet, if !set.IsOn { continue } - b, err = set.MatchRequest(req) + b, hasRequestBody, err = set.MatchRequest(req) if err != nil { - return false, nil, err + return false, hasRequestBody, nil, err } if b { - return true, set, nil + return true, hasRequestBody, set, nil } } return } -func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, set *RuleSet, err error) { +func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Response) (b bool, hasRequestBody bool, set *RuleSet, err error) { if !this.hasRuleSets { return } @@ -102,12 +102,12 @@ func (this *RuleGroup) MatchResponse(req requests.Request, resp *requests.Respon if !set.IsOn { continue } - b, err = set.MatchResponse(req, resp) + b, hasRequestBody, err = set.MatchResponse(req, resp) if err != nil { - return false, nil, err + return false, hasRequestBody, nil, err } if b { - return true, set, nil + return true, hasRequestBody, set, nil } } return diff --git a/internal/waf/rule_set.go b/internal/waf/rule_set.go index 817710b..4946251 100644 --- a/internal/waf/rule_set.go +++ b/internal/waf/rule_set.go @@ -167,89 +167,105 @@ func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Req return true } -func (this *RuleSet) MatchRequest(req requests.Request) (b bool, err error) { +func (this *RuleSet) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, err error) { // 是否忽略局域网IP if this.IgnoreLocal && utils.IsLocalIP(req.WAFRemoteIP()) { - return false, nil + return false, hasRequestBody, nil } if !this.hasRules { - return false, nil + return false, hasRequestBody, nil } switch this.Connector { case RuleConnectorAnd: for _, rule := range this.Rules { - b1, err1 := rule.MatchRequest(req) + b1, hasCheckRequestBody, err1 := rule.MatchRequest(req) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { - return false, err1 + return false, hasRequestBody, err1 } if !b1 { - return false, nil + return false, hasRequestBody, nil } } - return true, nil + return true, hasRequestBody, nil case RuleConnectorOr: for _, rule := range this.Rules { - b1, err1 := rule.MatchRequest(req) + b1, hasCheckRequestBody, err1 := rule.MatchRequest(req) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { - return false, err1 + return false, hasRequestBody, err1 } if b1 { - return true, nil + return true, hasRequestBody, nil } } default: // same as And for _, rule := range this.Rules { - b1, err1 := rule.MatchRequest(req) + b1, hasCheckRequestBody, err1 := rule.MatchRequest(req) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { - return false, err1 + return false, hasRequestBody, err1 } if !b1 { - return false, nil + return false, hasRequestBody, nil } } - return true, nil + return true, hasRequestBody, nil } return } -func (this *RuleSet) MatchResponse(req requests.Request, resp *requests.Response) (b bool, err error) { +func (this *RuleSet) MatchResponse(req requests.Request, resp *requests.Response) (b bool, hasRequestBody bool, err error) { if !this.hasRules { - return false, nil + return false, hasRequestBody, nil } switch this.Connector { case RuleConnectorAnd: for _, rule := range this.Rules { - b1, err1 := rule.MatchResponse(req, resp) + b1, hasCheckRequestBody, err1 := rule.MatchResponse(req, resp) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { - return false, err1 + return false, hasRequestBody, err1 } if !b1 { - return false, nil + return false, hasRequestBody, nil } } - return true, nil + return true, hasRequestBody, nil case RuleConnectorOr: for _, rule := range this.Rules { - b1, err1 := rule.MatchResponse(req, resp) + // 对于OR连接符,只需要判断最先匹配的一条规则中的hasRequestBody即可 + b1, hasCheckRequestBody, err1 := rule.MatchResponse(req, resp) if err1 != nil { - return false, err1 + return false, hasCheckRequestBody, err1 } if b1 { - return true, nil + return true, hasCheckRequestBody, nil } } default: // same as And for _, rule := range this.Rules { - b1, err1 := rule.MatchResponse(req, resp) + b1, hasCheckRequestBody, err1 := rule.MatchResponse(req, resp) + if hasCheckRequestBody { + hasRequestBody = true + } if err1 != nil { - return false, err1 + return false, hasRequestBody, err1 } if !b1 { - return false, nil + return false, hasRequestBody, nil } } - return true, nil + return true, hasRequestBody, nil } return } diff --git a/internal/waf/rule_set_test.go b/internal/waf/rule_set_test.go index 5317643..023bd02 100644 --- a/internal/waf/rule_set_test.go +++ b/internal/waf/rule_set_test.go @@ -113,7 +113,7 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) { } req := requests.NewTestRequest(rawReq) for i := 0; i < b.N; i++ { - _, _ = set.MatchRequest(req) + _, _, _ = set.MatchRequest(req) } } @@ -143,7 +143,7 @@ func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) { } req := requests.NewTestRequest(rawReq) for i := 0; i < b.N; i++ { - _, _ = set.MatchRequest(req) + _, _, _ = set.MatchRequest(req) } } diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go index 4fddcda..cb458de 100644 --- a/internal/waf/template_test.go +++ b/internal/waf/template_test.go @@ -49,7 +49,7 @@ func Test_Template2(t *testing.T) { } now := time.Now() - goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil) + goNext, _, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -77,7 +77,7 @@ func BenchmarkTemplate(b *testing.B) { b.Fatal(err) } - _, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil) } } @@ -86,7 +86,7 @@ func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -101,7 +101,7 @@ func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -116,7 +116,7 @@ func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -182,7 +182,7 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) { req.Header.Add("Content-Type", writer.FormDataContentType()) - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -197,7 +197,7 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -212,7 +212,7 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -228,7 +228,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -243,7 +243,7 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -260,7 +260,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -275,7 +275,7 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -298,7 +298,7 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) { if err != nil { t.Fatal(err) } - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } @@ -335,7 +335,7 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) { t.Fatal(err) } req.Header.Set("User-Agent", bot) - _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) + _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) } diff --git a/internal/waf/waf.go b/internal/waf/waf.go index 9db27c1..0936598 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -241,9 +241,9 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) { this.Outbound = result } -func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) { +func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter) (goNext bool, hasRequestBody bool, group *RuleGroup, set *RuleSet, err error) { if !this.hasInboundRules { - return true, nil, nil, nil + return true, hasRequestBody, nil, nil, nil } // validate captcha @@ -264,37 +264,43 @@ func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter) if !group.IsOn { continue } - b, set, err := group.MatchRequest(req) + b, hasCheckedRequestBody, set, err := group.MatchRequest(req) + if hasCheckedRequestBody { + hasRequestBody = true + } if err != nil { - return true, nil, nil, err + return true, hasRequestBody, nil, nil, err } if b { - goNext := set.PerformActions(this, group, req, writer) - return goNext, group, set, nil + goNext = set.PerformActions(this, group, req, writer) + return goNext, hasRequestBody, group, set, nil } } - return true, nil, nil, nil + return true, hasRequestBody, nil, nil, nil } -func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, group *RuleGroup, set *RuleSet, err error) { +func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, hasRequestBody bool, group *RuleGroup, set *RuleSet, err error) { if !this.hasOutboundRules { - return true, nil, nil, nil + return true, hasRequestBody, nil, nil, nil } resp := requests.NewResponse(rawResp) for _, group := range this.Outbound { if !group.IsOn { continue } - b, set, err := group.MatchResponse(req, resp) + b, hasCheckedRequestBody, set, err := group.MatchResponse(req, resp) + if hasCheckedRequestBody { + hasRequestBody = true + } if err != nil { - return true, nil, nil, err + return true, hasRequestBody, nil, nil, err } if b { - goNext := set.PerformActions(this, group, req, writer) - return goNext, group, set, nil + goNext = set.PerformActions(this, group, req, writer) + return goNext, hasRequestBody, group, set, nil } } - return true, nil, nil, nil + return true, hasRequestBody, nil, nil, nil } // Save save to file path diff --git a/internal/waf/waf_test.go b/internal/waf/waf_test.go index 19780fc..2eb3f67 100644 --- a/internal/waf/waf_test.go +++ b/internal/waf/waf_test.go @@ -42,7 +42,7 @@ func TestWAF_MatchRequest(t *testing.T) { if err != nil { t.Fatal(err) } - goNext, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil) + goNext, _, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil) if err != nil { t.Fatal(err) }