diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 68a4456..eba6a7d 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -67,8 +67,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { // 当前服务的独立设置 if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { - blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false) - if blocked { + blockedRequest, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false) + if blockedRequest { return true } if breakChecking { @@ -78,8 +78,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { // 公用的防火墙设置 if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn { - blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules) - if blocked { + blockedRequest, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules) + if blockedRequest { return true } if breakChecking { @@ -266,8 +266,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir return } - goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType) - if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() { + result, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType) + if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) { + breakChecking = true + } + if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() { this.wafHasRequestBody = true } if err != nil { @@ -277,28 +280,28 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir return } - if ruleSet != nil { + if result.Set != nil { if forceLog { this.forceLog = true } - if ruleSet.HasSpecialActions() { + if result.Set.HasSpecialActions() { this.firewallPolicyId = firewallPolicy.Id - this.firewallRuleGroupId = types.Int64(ruleGroup.Id) - this.firewallRuleSetId = types.Int64(ruleSet.Id) + this.firewallRuleGroupId = types.Int64(result.Group.Id) + this.firewallRuleSetId = types.Int64(result.Set.Id) - if ruleSet.HasAttackActions() { + if result.Set.HasAttackActions() { this.isAttack = true } // 添加统计 - stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions) + stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions) } - this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode) + this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode) } - return !goNext, false + return !result.GoNext, breakChecking } // call response waf @@ -316,23 +319,26 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) { } if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { - blocked = this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false) - if blocked { + blockedRequest, breakChecking := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false) + if blockedRequest { return true } + if breakChecking { + return + } } // 公用的防火墙设置 if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn { - blocked = this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules) - if blocked { + blockedRequest, _ := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules) + if blockedRequest { return true } } return } -func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool) { +func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool, breakChecking bool) { if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass { return } @@ -347,8 +353,11 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi return } - goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer) - if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() { + result, err := w.MatchResponse(this, resp, this.writer) + if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) { + breakChecking = true + } + if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() { this.wafHasRequestBody = true } if err != nil { @@ -358,28 +367,28 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi return } - if ruleSet != nil { + if result.Set != nil { if forceLog { this.forceLog = true } - if ruleSet.HasSpecialActions() { + if result.Set.HasSpecialActions() { this.firewallPolicyId = firewallPolicy.Id - this.firewallRuleGroupId = types.Int64(ruleGroup.Id) - this.firewallRuleSetId = types.Int64(ruleSet.Id) + this.firewallRuleGroupId = types.Int64(result.Group.Id) + this.firewallRuleSetId = types.Int64(result.Set.Id) - if ruleSet.HasAttackActions() { + if result.Set.HasAttackActions() { this.isAttack = true } // 添加统计 - stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions) + stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions) } - this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode) + this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode) } - return !goNext + return !result.GoNext, breakChecking } // WAFRaw 原始请求 diff --git a/internal/waf/action_allow.go b/internal/waf/action_allow.go index 165a440..5cec376 100644 --- a/internal/waf/action_allow.go +++ b/internal/waf/action_allow.go @@ -5,8 +5,18 @@ import ( "net/http" ) +type AllowScope = string + +const ( + AllowScopeGroup AllowScope = "group" + AllowScopeServer AllowScope = "server" + AllowScopeGlobal AllowScope = "global" +) + type AllowAction struct { BaseAction + + Scope AllowScope `yaml:"scope" json:"scope"` } func (this *AllowAction) Init(waf *WAF) error { @@ -25,7 +35,12 @@ func (this *AllowAction) WillChange() bool { return true } -func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { // do nothing - return true, false + return PerformResult{ + ContinueRequest: true, + GoNextGroup: this.Scope == AllowScopeGroup, + IsAllowed: true, + AllowScope: this.Scope, + } } diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index 37fc4b6..89fa1eb 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -61,7 +61,7 @@ func (this *BlockAction) WillChange() bool { return true } -func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { // 加入到黑名单 var timeout = this.Timeout if timeout <= 0 { @@ -93,14 +93,14 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque req, err := http.NewRequest(http.MethodGet, this.URL, nil) if err != nil { logs.Error(err) - return false, false + return PerformResult{} } req.Header.Set("User-Agent", teaconst.GlobalProductName+"/"+teaconst.Version) resp, err := httpClient.Do(req) if err != nil { logs.Error(err) - return false, false + return PerformResult{} } defer func() { _ = resp.Body.Close() @@ -124,11 +124,11 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque data, err := os.ReadFile(path) if err != nil { logs.Error(err) - return false, false + return PerformResult{} } _, _ = writer.Write(data) } - return false, false + return PerformResult{} } if len(this.Body) > 0 { _, _ = writer.Write([]byte(this.Body)) @@ -137,5 +137,5 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque } } - return false, false + return PerformResult{} } diff --git a/internal/waf/action_captcha.go b/internal/waf/action_captcha.go index f05e4ff..a581f8e 100644 --- a/internal/waf/action_captcha.go +++ b/internal/waf/action_captcha.go @@ -123,10 +123,12 @@ func (this *CaptchaAction) WillChange() bool { return true } -func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult { // 是否在白名单中 if SharedIPWhiteList.Contains(wafutils.ComposeIPType(set.Id, req), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) { - return true, false + return PerformResult{ + ContinueRequest: true, + } } var refURL = req.WAFRaw().URL.String() @@ -153,7 +155,9 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req info, err := utils.SimpleEncryptMap(captchaConfig) if err != nil { remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error()) - return true, false + return PerformResult{ + ContinueRequest: true, + } } // 占用一次失败次数 @@ -163,5 +167,5 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect) http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect) - return false, false + return PerformResult{} } diff --git a/internal/waf/action_get_302.go b/internal/waf/action_get_302.go index 5a14aa0..40884ec 100644 --- a/internal/waf/action_get_302.go +++ b/internal/waf/action_get_302.go @@ -41,15 +41,19 @@ func (this *Get302Action) WillChange() bool { return true } -func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { // 仅限于Get if request.WAFRaw().Method != http.MethodGet { - return true, false + return PerformResult{ + ContinueRequest: true, + } } // 是否已经在白名单中 if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { - return true, false + return PerformResult{ + ContinueRequest: true, + } } var m = maps.Map{ @@ -64,7 +68,9 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ info, err := utils.SimpleEncryptMap(m) if err != nil { remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error()) - return true, false + return PerformResult{ + ContinueRequest: true, + } } request.DisableStat() @@ -75,6 +81,6 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ if ok { flusher.Flush() } - - return false, false + + return PerformResult{} } diff --git a/internal/waf/action_go_group.go b/internal/waf/action_go_group.go index 1c61265..a2e28d0 100644 --- a/internal/waf/action_go_group.go +++ b/internal/waf/action_go_group.go @@ -29,20 +29,29 @@ func (this *GoGroupAction) WillChange() bool { return true } -func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { - nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId)) +func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { + var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId)) if nextGroup == nil || !nextGroup.IsOn { - return true, true + return PerformResult{ + ContinueRequest: true, + GoNextSet: true, + } } b, _, nextSet, err := nextGroup.MatchRequest(request) if err != nil { remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error()) - return true, false + return PerformResult{ + ContinueRequest: true, + GoNextSet: true, + } } if !b { - return true, false + return PerformResult{ + ContinueRequest: true, + GoNextSet: true, + } } return nextSet.PerformActions(waf, nextGroup, request, writer) diff --git a/internal/waf/action_go_set.go b/internal/waf/action_go_set.go index 507554a..38e75e6 100644 --- a/internal/waf/action_go_set.go +++ b/internal/waf/action_go_set.go @@ -30,23 +30,35 @@ func (this *GoSetAction) WillChange() bool { return true } -func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { - nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId)) +func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { + var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId)) if nextGroup == nil || !nextGroup.IsOn { - return true, true + return PerformResult{ + ContinueRequest: true, + GoNextSet: true, + } } - nextSet := nextGroup.FindRuleSet(types.Int64(this.SetId)) + var nextSet = nextGroup.FindRuleSet(types.Int64(this.SetId)) if nextSet == nil || !nextSet.IsOn { - return true, true + return PerformResult{ + ContinueRequest: true, + GoNextSet: true, + } } b, _, err := nextSet.MatchRequest(request) if err != nil { remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error()) - return true, false + return PerformResult{ + ContinueRequest: true, + GoNextSet: true, + } } if !b { - return true, false + return PerformResult{ + ContinueRequest: true, + GoNextSet: true, + } } return nextSet.PerformActions(waf, nextGroup, request, writer) } diff --git a/internal/waf/action_interface.go b/internal/waf/action_interface.go index 11d9162..7d8463e 100644 --- a/internal/waf/action_interface.go +++ b/internal/waf/action_interface.go @@ -27,5 +27,5 @@ type ActionInterface interface { WillChange() bool // Perform the action - Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) + Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult } diff --git a/internal/waf/action_js_cookie.go b/internal/waf/action_js_cookie.go index 21102c6..b7a1305 100644 --- a/internal/waf/action_js_cookie.go +++ b/internal/waf/action_js_cookie.go @@ -42,15 +42,19 @@ func (this *JSCookieAction) WillChange() bool { return true } -func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult { // 是否在白名单中 if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) { - return true, false + return PerformResult{ + ContinueRequest: true, + } } nodeConfig, err := nodeconfigs.SharedNodeConfig() if err != nil { - return true, false + return PerformResult{ + ContinueRequest: true, + } } var life = this.Life @@ -69,7 +73,9 @@ func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re var timestamp = pieces[0] var sum = pieces[2] if types.Int64(timestamp) >= time.Now().Unix()-int64(life) && fmt.Sprintf("%x", md5.Sum([]byte(timestamp+"@"+types.String(set.Id)+"@"+nodeConfig.NodeId))) == sum { - return true, false + return PerformResult{ + ContinueRequest: true, + } } } } @@ -103,7 +109,7 @@ window.location.reload(); // 记录失败次数 this.increaseFails(req, waf.Id, group.Id, set.Id) - return false, false + return PerformResult{} } func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64) (goNext bool) { diff --git a/internal/waf/action_log.go b/internal/waf/action_log.go index 6d2f2e9..15f6a03 100644 --- a/internal/waf/action_log.go +++ b/internal/waf/action_log.go @@ -25,6 +25,8 @@ func (this *LogAction) WillChange() bool { return false } -func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { - return true, false +func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { + return PerformResult{ + ContinueRequest: true, + } } diff --git a/internal/waf/action_notify.go b/internal/waf/action_notify.go index b152edc..4a6fc00 100644 --- a/internal/waf/action_notify.go +++ b/internal/waf/action_notify.go @@ -76,7 +76,7 @@ func (this *NotifyAction) WillChange() bool { } // Perform the action -func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { select { case notifyChan <- ¬ifyTask{ ServerId: request.WAFServerId(), @@ -89,5 +89,7 @@ func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ } - return true, false + return PerformResult{ + ContinueRequest: true, + } } diff --git a/internal/waf/action_page.go b/internal/waf/action_page.go index 02fd024..ef92a7a 100644 --- a/internal/waf/action_page.go +++ b/internal/waf/action_page.go @@ -45,9 +45,9 @@ func (this *PageAction) WillChange() bool { } // Perform the action -func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { if writer == nil { - return + return PerformResult{} } request.ProcessResponseHeaders(writer.Header(), this.Status) @@ -73,5 +73,5 @@ func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reques } _, _ = writer.Write([]byte(request.Format(body))) - return false, false + return PerformResult{} } diff --git a/internal/waf/action_post_307.go b/internal/waf/action_post_307.go index ac9fcef..d8e67d0 100644 --- a/internal/waf/action_post_307.go +++ b/internal/waf/action_post_307.go @@ -34,17 +34,21 @@ func (this *Post307Action) WillChange() bool { return true } -func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { var cookieName = "WAF_VALIDATOR_ID" // 仅限于POST if request.WAFRaw().Method != http.MethodPost { - return true, false + return PerformResult{ + ContinueRequest: true, + } } // 是否已经在白名单中 if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { - return true, false + return PerformResult{ + ContinueRequest: true, + } } // 判断是否有Cookie @@ -58,7 +62,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req } var setId = types.String(m.GetInt64("setId")) SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "") - return true, false + return PerformResult{ + ContinueRequest: true, + } } } @@ -74,7 +80,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req info, err := utils.SimpleEncryptMap(m) if err != nil { remotelogs.Error("WAF_POST_307_ACTION", "encode info failed: "+err.Error()) - return true, false + return PerformResult{ + ContinueRequest: true, + } } // 清空请求内容 @@ -101,5 +109,5 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req flusher.Flush() } - return false, false + return PerformResult{} } diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go index 9295bd5..573df36 100644 --- a/internal/waf/action_record_ip.go +++ b/internal/waf/action_record_ip.go @@ -132,7 +132,7 @@ func (this *RecordIPAction) WillChange() bool { return this.Type == "black" } -func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { var ipListId = this.IPListId if ipListId <= 0 { ipListId = firewallconfigs.GlobalListId @@ -143,7 +143,11 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re // 是否在本地白名单中 if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { - return true, false + return PerformResult{ + ContinueRequest: true, + IsAllowed: true, + AllowScope: AllowScopeGlobal, + } } var timeout = this.Timeout @@ -200,5 +204,10 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re } } - return this.Type != "black", false + var isWhite = this.Type != "black" + return PerformResult{ + ContinueRequest: isWhite, + IsAllowed: isWhite, + AllowScope: AllowScopeGlobal, + } } diff --git a/internal/waf/action_redirect.go b/internal/waf/action_redirect.go index 50dab68..ce64064 100644 --- a/internal/waf/action_redirect.go +++ b/internal/waf/action_redirect.go @@ -35,10 +35,10 @@ func (this *RedirectAction) WillChange() bool { } // Perform the action -func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { request.ProcessResponseHeaders(writer.Header(), this.Status) writer.Header().Set("Location", this.URL) writer.WriteHeader(this.Status) - return false, false + return PerformResult{} } diff --git a/internal/waf/action_tag.go b/internal/waf/action_tag.go index 03ea356..ff0b1a5 100644 --- a/internal/waf/action_tag.go +++ b/internal/waf/action_tag.go @@ -27,6 +27,8 @@ func (this *TagAction) WillChange() bool { return false } -func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { - return true, true +func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult { + return PerformResult{ + ContinueRequest: true, + } } diff --git a/internal/waf/results.go b/internal/waf/results.go new file mode 100644 index 0000000..b773104 --- /dev/null +++ b/internal/waf/results.go @@ -0,0 +1,22 @@ +// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package waf + +// PerformResult action performing result +type PerformResult struct { + ContinueRequest bool + GoNextGroup bool + GoNextSet bool + IsAllowed bool + AllowScope AllowScope +} + +// MatchResult request match result +type MatchResult struct { + GoNext bool + HasRequestBody bool + Group *RuleGroup + Set *RuleSet + IsAllowed bool + AllowScope AllowScope +} diff --git a/internal/waf/rule_set.go b/internal/waf/rule_set.go index d4ad993..9472e81 100644 --- a/internal/waf/rule_set.go +++ b/internal/waf/rule_set.go @@ -34,6 +34,9 @@ type RuleSet struct { actionCodes []string actionInstances []ActionInterface + hasAllowActions bool + allowScope string + hasRules bool } @@ -62,6 +65,12 @@ func (this *RuleSet) Init(waf *WAF) error { // action codes var actionCodes = []string{} for _, action := range this.Actions { + if action.Code == ActionAllow { + this.hasAllowActions = true + if action.Options != nil { + this.allowScope = action.Options.GetString("scope") + } + } if !lists.ContainsString(actionCodes, action.Code) { actionCodes = append(actionCodes, action.Code) } @@ -141,19 +150,37 @@ func (this *RuleSet) ActionCodes() []string { return this.actionCodes } -func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { +func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) PerformResult { if len(waf.Mode) != 0 && waf.Mode != firewallconfigs.FirewallModeDefend { - return true, false + return PerformResult{ + ContinueRequest: true, + } } + var isAllowed = this.hasAllowActions + var allowScope = this.allowScope + var continueRequest bool + var goNextGroup bool + var goNextSet bool + // 先执行allow for _, instance := range this.actionInstances { if !instance.WillChange() { continueRequest = req.WAFOnAction(instance) if !continueRequest { - return false, false + return PerformResult{ + IsAllowed: isAllowed, + AllowScope: allowScope, + } + } + var performResult = instance.Perform(waf, group, this, req, writer) + continueRequest = performResult.ContinueRequest + goNextSet = performResult.GoNextSet + if performResult.IsAllowed { + isAllowed = true + allowScope = performResult.AllowScope + goNextGroup = performResult.GoNextGroup } - _, goNextSet = instance.Perform(waf, group, this, req, writer) } } @@ -163,13 +190,36 @@ func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Req if instance.WillChange() { continueRequest = req.WAFOnAction(instance) if !continueRequest { - return false, false + return PerformResult{ + IsAllowed: isAllowed, + AllowScope: allowScope, + } + } + var performResult = instance.Perform(waf, group, this, req, writer) + continueRequest = performResult.ContinueRequest + goNextSet = performResult.GoNextSet + if performResult.IsAllowed { + isAllowed = true + allowScope = performResult.AllowScope + goNextGroup = performResult.GoNextGroup + } + return PerformResult{ + ContinueRequest: performResult.ContinueRequest, + GoNextGroup: goNextGroup, + GoNextSet: performResult.GoNextSet, + IsAllowed: isAllowed, + AllowScope: allowScope, } - return instance.Perform(waf, group, this, req, writer) } } - return true, goNextSet + return PerformResult{ + ContinueRequest: true, + GoNextGroup: goNextGroup, + GoNextSet: goNextSet, + IsAllowed: isAllowed, + AllowScope: allowScope, + } } func (this *RuleSet) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, err error) { diff --git a/internal/waf/rule_set_test.go b/internal/waf/rule_set_test.go index 7799058..0d6ce76 100644 --- a/internal/waf/rule_set_test.go +++ b/internal/waf/rule_set_test.go @@ -6,6 +6,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/cespare/xxhash" "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/maps" "net/http" "regexp" "runtime" @@ -74,6 +75,52 @@ func TestRuleSet_MatchRequest2(t *testing.T) { a.IsTrue(set.MatchRequest(req)) } +func TestRuleSet_MatchRequest_Allow(t *testing.T) { + var a = assert.NewAssertion(t) + + var set = waf.NewRuleSet() + set.Connector = waf.RuleConnectorOr + + set.Rules = []*waf.Rule{ + { + Param: "${requestPath}", + Operator: waf.RuleOperatorMatch, + Value: "hello", + }, + } + + set.Actions = []*waf.ActionConfig{ + { + Code: "allow", + Options: maps.Map{ + "scope": waf.AllowScopeGroup, + }, + }, + } + + var wafInstance = waf.NewWAF() + + err := set.Init(wafInstance) + if err != nil { + t.Fatal(err) + } + + rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) + if err != nil { + t.Fatal(err) + } + var req = requests.NewTestRequest(rawReq) + b, _, err := set.MatchRequest(req) + if err != nil { + t.Fatal(err) + } + a.IsTrue(b) + + var result = set.PerformActions(wafInstance, &waf.RuleGroup{}, req, nil) + a.IsTrue(result.IsAllowed) + t.Log("scope:", result.AllowScope) +} + func BenchmarkRuleSet_MatchRequest(b *testing.B) { runtime.GOMAXPROCS(1) diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go index 89d7d2b..de0dc96 100644 --- a/internal/waf/template_test.go +++ b/internal/waf/template_test.go @@ -52,18 +52,18 @@ func Test_Template2(t *testing.T) { } now := time.Now() - goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } t.Log(time.Since(now).Seconds()*1000, "ms") - if goNext { + if result.GoNext { t.Log("ok") return } - logs.PrintAsJSON(set, t) + logs.PrintAsJSON(result.Set, t) } func BenchmarkTemplate(b *testing.B) { @@ -84,7 +84,7 @@ func BenchmarkTemplate(b *testing.B) { } req.Header.Set("User-Agent", testUserAgent) - _, _, _, _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) } }) } @@ -103,13 +103,13 @@ func testTemplate1010(a *assert.Assertion, t *testing.T, template *waf.WAF) { t.Fatal(err) } req.Header.Set("User-Agent", testUserAgent) - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(result.Code == "1010") + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "1010") } else { t.Log("break at:", id) } @@ -125,13 +125,13 @@ func testTemplate1010(a *assert.Assertion, t *testing.T, template *waf.WAF) { t.Fatal(err) } req.Header.Set("User-Agent", testUserAgent) - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNil(result) - if result != nil { - a.IsTrue(result.Code == "1010") + a.IsNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "1010") } } } @@ -192,13 +192,13 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *waf.WAF) { req.Header.Add("Content-Type", writer.FormDataContentType()) - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(result.Code == "2001") + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "2001") } } @@ -207,13 +207,13 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *waf.WAF) { if err != nil { t.Fatal(err) } - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(result.Code == "3001") + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "3001") } } @@ -222,13 +222,13 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *waf.WAF) { if err != nil { t.Fatal(err) } - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(result.Code == "4001") + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "4001") } } @@ -238,13 +238,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) { if err != nil { t.Fatal(err) } - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(result.Code == "5001") + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "5001") } } @@ -253,13 +253,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) { if err != nil { t.Fatal(err) } - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(result.Code == "5001") + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "5001") } } } @@ -271,13 +271,13 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) { t.Fatal(err) } req.Header.Set("User-Agent", testUserAgent) - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(result.Code == "6001") + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(result.Set.Code == "6001") } } @@ -286,11 +286,11 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) { if err != nil { t.Fatal(err) } - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) + a.IsNotNil(result.Set) } } @@ -325,13 +325,13 @@ func testTemplate7010(a *assert.Assertion, t *testing.T, template *waf.WAF) { t.Fatal(err) } req.Header.Set("User-Agent", testUserAgent) - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(lists.ContainsAny([]string{"7010"}, result.Code)) + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(lists.ContainsAny([]string{"7010"}, result.Set.Code)) } else { t.Log("break:", id) } @@ -423,13 +423,13 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *waf.WAF) { t.Fatal(err) } req.Header.Set("User-Agent", bot) - _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } - a.IsNotNil(result) - if result != nil { - a.IsTrue(lists.ContainsAny([]string{"20001"}, result.Code)) + a.IsNotNil(result.Set) + if result.Set != nil { + a.IsTrue(lists.ContainsAny([]string{"20001"}, result.Set.Code)) } else { t.Log("break:", bot) } diff --git a/internal/waf/waf.go b/internal/waf/waf.go index 6d98f5d..17badca 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -39,7 +39,8 @@ type WAF struct { func NewWAF() *WAF { return &WAF{ - IsOn: true, + IsOn: true, + actionMap: map[int64]ActionInterface{}, } } @@ -243,9 +244,11 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) { this.Outbound = result } -func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) (goNext bool, hasRequestBody bool, resultGroup *RuleGroup, resultSet *RuleSet, err error) { +func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) (result MatchResult, err error) { if !this.hasInboundRules { - return true, hasRequestBody, nil, nil, nil + return MatchResult{ + GoNext: true, + }, nil } // validate captcha @@ -266,51 +269,87 @@ func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, } // match rules + var hasRequestBody bool for _, group := range this.Inbound { if !group.IsOn { continue } - b, hasCheckedRequestBody, set, err := group.MatchRequest(req) + b, hasCheckedRequestBody, set, matchErr := group.MatchRequest(req) if hasCheckedRequestBody { hasRequestBody = true } - if err != nil { - return true, hasRequestBody, nil, nil, err + if matchErr != nil { + return MatchResult{ + GoNext: true, + HasRequestBody: hasRequestBody, + }, matchErr } if b { - continueRequest, goNextSet := set.PerformActions(this, group, req, writer) - if !goNextSet { - return continueRequest, hasRequestBody, group, set, nil + var performResult = set.PerformActions(this, group, req, writer) + if !performResult.GoNextSet { + if performResult.GoNextGroup { + continue + } + return MatchResult{ + GoNext: performResult.ContinueRequest, + HasRequestBody: hasRequestBody, + Group: group, + Set: set, + IsAllowed: performResult.IsAllowed, + AllowScope: performResult.AllowScope, + }, nil } } } - return true, hasRequestBody, nil, nil, nil + return MatchResult{ + GoNext: true, + HasRequestBody: hasRequestBody, + }, nil } -func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, hasRequestBody bool, resultGroup *RuleGroup, resultSet *RuleSet, err error) { +func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (result MatchResult, err error) { if !this.hasOutboundRules { - return true, hasRequestBody, nil, nil, nil + return MatchResult{ + GoNext: true, + }, nil } - resp := requests.NewResponse(rawResp) + var hasRequestBody bool + var resp = requests.NewResponse(rawResp) for _, group := range this.Outbound { if !group.IsOn { continue } - b, hasCheckedRequestBody, set, err := group.MatchResponse(req, resp) + b, hasCheckedRequestBody, set, matchErr := group.MatchResponse(req, resp) if hasCheckedRequestBody { hasRequestBody = true } - if err != nil { - return true, hasRequestBody, nil, nil, err + if matchErr != nil { + return MatchResult{ + GoNext: true, + HasRequestBody: hasRequestBody, + }, matchErr } if b { - continueRequest, goNextSet := set.PerformActions(this, group, req, writer) - if !goNextSet { - return continueRequest, hasRequestBody, group, set, nil + var performResult = set.PerformActions(this, group, req, writer) + if !performResult.GoNextSet { + if performResult.GoNextGroup { + continue + } + return MatchResult{ + GoNext: performResult.ContinueRequest, + HasRequestBody: hasRequestBody, + Group: group, + Set: set, + IsAllowed: performResult.IsAllowed, + AllowScope: performResult.AllowScope, + }, nil } } } - return true, hasRequestBody, nil, nil, nil + return MatchResult{ + GoNext: true, + HasRequestBody: hasRequestBody, + }, nil } // Save to file path diff --git a/internal/waf/waf_test.go b/internal/waf/waf_test.go index 8936257..45af496 100644 --- a/internal/waf/waf_test.go +++ b/internal/waf/waf_test.go @@ -5,6 +5,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/maps" "net/http" "testing" ) @@ -44,7 +45,7 @@ func TestWAF_MatchRequest(t *testing.T) { if err != nil { t.Fatal(err) } - goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } @@ -52,6 +53,160 @@ func TestWAF_MatchRequest(t *testing.T) { t.Log("not match") return } - t.Log("goNext:", goNext, "set:", set.Name) - a.IsFalse(goNext) + t.Log("goNext:", result.GoNext, "set:", set.Name) + a.IsFalse(result.GoNext) +} + +func TestWAF_MatchRequest_Allow(t *testing.T) { + var a = assert.NewAssertion(t) + + var wafInstance = waf.NewWAF() + + { + var set = waf.NewRuleSet() + set.Id = 1 + set.Name = "set1" + set.Connector = waf.RuleConnectorAnd + set.Rules = []*waf.Rule{ + { + Param: "${requestPath}", + Operator: waf.RuleOperatorMatch, + Value: "hello", + }, + } + set.AddAction(waf.ActionAllow, maps.Map{ + "scope": "global", + }) + + var group = waf.NewRuleGroup() + group.Id = 1 + group.AddRuleSet(set) + group.IsInbound = true + + wafInstance.AddRuleGroup(group) + } + + { + var set = waf.NewRuleSet() + set.Id = 2 + set.Name = "set2" + set.Connector = waf.RuleConnectorAnd + set.Rules = []*waf.Rule{ + { + Param: "${requestPath}", + Operator: waf.RuleOperatorMatch, + Value: "he", + }, + } + set.AddAction(waf.ActionAllow, maps.Map{ + "scope": "global", + }) + + var group = waf.NewRuleGroup() + group.Id = 2 + group.AddRuleSet(set) + group.IsInbound = true + + wafInstance.AddRuleGroup(group) + } + + errs := wafInstance.Init() + if len(errs) > 0 { + t.Fatal(errs[0]) + } + + req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) + if err != nil { + t.Fatal(err) + } + result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + if err != nil { + t.Fatal(err) + } + if result.Set == nil { + t.Log("not match") + return + } + t.Log("goNext:", result.GoNext, "set:", result.Set.Name) + a.IsTrue(result.Set.Id == 1) + a.IsTrue(result.GoNext) + a.IsTrue(result.IsAllowed) + a.IsTrue(result.AllowScope == "global") +} + +func TestWAF_MatchRequest_Allow2(t *testing.T) { + var a = assert.NewAssertion(t) + + var wafInstance = waf.NewWAF() + + { + var set = waf.NewRuleSet() + set.Id = 1 + set.Name = "set1" + set.Connector = waf.RuleConnectorAnd + set.Rules = []*waf.Rule{ + { + Param: "${requestPath}", + Operator: waf.RuleOperatorMatch, + Value: "hello", + }, + } + set.AddAction(waf.ActionAllow, maps.Map{ + "scope": "group", + }) + + var group = waf.NewRuleGroup() + group.Id = 1 + group.AddRuleSet(set) + group.IsInbound = true + + wafInstance.AddRuleGroup(group) + } + + { + var set = waf.NewRuleSet() + set.Id = 2 + set.Name = "set2" + set.Connector = waf.RuleConnectorAnd + set.Rules = []*waf.Rule{ + { + Param: "${requestPath}", + Operator: waf.RuleOperatorMatch, + Value: "he", + }, + } + set.AddAction(waf.ActionAllow, maps.Map{ + "scope": "global", + }) + + var group = waf.NewRuleGroup() + group.Id = 2 + group.AddRuleSet(set) + group.IsInbound = true + + wafInstance.AddRuleGroup(group) + } + + errs := wafInstance.Init() + if len(errs) > 0 { + t.Fatal(errs[0]) + } + + req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) + if err != nil { + t.Fatal(err) + } + result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + if err != nil { + t.Fatal(err) + } + if result.Set == nil { + t.Log("not match") + return + } + t.Log("goNext:", result.GoNext, "set:", result.Set.Name) + a.IsTrue(result.Set.Id == 2) + a.IsTrue(result.GoNext) + a.IsTrue(result.IsAllowed) + a.IsTrue(result.AllowScope == "global") }