diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 94bc172..9204e2e 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -190,14 +190,6 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir return } - w.OnAction(func(action waf.ActionInterface) (goNext bool) { - switch action.Code() { - case waf.ActionTag: - this.tags = action.(*waf.TagAction).Tags - } - return true - }) - goNext, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer) if err != nil { remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error()) @@ -254,14 +246,6 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi return } - w.OnAction(func(action waf.ActionInterface) (goNext bool) { - switch action.Code() { - case waf.ActionTag: - this.tags = action.(*waf.TagAction).Tags - } - return true - }) - goNext, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer) if err != nil { remotelogs.Error("HTTP_REQUEST_WAF", this.rawURI+": "+err.Error()) @@ -344,3 +328,20 @@ func (this *HTTPRequest) WAFClose() { } return } + +func (this *HTTPRequest) WAFOnAction(action interface{}) (goNext bool) { + if action == nil { + return true + } + + instance, ok := action.(waf.ActionInterface) + if !ok { + return true + } + + switch instance.Code() { + case waf.ActionTag: + this.tags = append(this.tags, action.(*waf.TagAction).Tags...) + } + return true +} diff --git a/internal/waf/ip_list_test.go b/internal/waf/ip_list_test.go index 0877aa4..01e570a 100644 --- a/internal/waf/ip_list_test.go +++ b/internal/waf/ip_list_test.go @@ -13,7 +13,7 @@ import ( ) func TestNewIPList(t *testing.T) { - list := NewIPList() + list := NewIPList(IPListTypeDeny) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) @@ -34,7 +34,7 @@ func TestNewIPList(t *testing.T) { func TestIPList_Contains(t *testing.T) { a := assert.NewAssertion(t) - list := NewIPList() + list := NewIPList(IPListTypeDeny) for i := 0; i < 1_0000; i++ { list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) @@ -47,7 +47,7 @@ func TestIPList_Contains(t *testing.T) { func BenchmarkIPList_Add(b *testing.B) { runtime.GOMAXPROCS(1) - list := NewIPList() + list := NewIPList(IPListTypeDeny) for i := 0; i < b.N; i++ { list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } @@ -57,7 +57,7 @@ func BenchmarkIPList_Add(b *testing.B) { func BenchmarkIPList_Has(b *testing.B) { runtime.GOMAXPROCS(1) - list := NewIPList() + list := NewIPList(IPListTypeDeny) for i := 0; i < 1_0000; i++ { list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) diff --git a/internal/waf/requests/request.go b/internal/waf/requests/request.go index 9c4b698..f7aac5a 100644 --- a/internal/waf/requests/request.go +++ b/internal/waf/requests/request.go @@ -29,6 +29,9 @@ type Request interface { // WAFClose 关闭当前请求所在的连接 WAFClose() + // WAFOnAction 动作回调 + WAFOnAction(action interface{}) (goNext bool) + // Format 格式化变量 Format(string) string } diff --git a/internal/waf/requests/test_request.go b/internal/waf/requests/test_request.go index 114f462..8b67b09 100644 --- a/internal/waf/requests/test_request.go +++ b/internal/waf/requests/test_request.go @@ -73,3 +73,7 @@ func (this *TestRequest) WAFClose() { func (this *TestRequest) Format(s string) string { return s } + +func (this *TestRequest) WAFOnAction(action interface{}) bool { + return true +} diff --git a/internal/waf/rule_set.go b/internal/waf/rule_set.go index 14227a4..405f10a 100644 --- a/internal/waf/rule_set.go +++ b/internal/waf/rule_set.go @@ -139,11 +139,9 @@ func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Req // 先执行allow for _, instance := range this.actionInstances { if !instance.WillChange() { - if waf.onActionCallback != nil { - goNext := waf.onActionCallback(instance) - if !goNext { - return false - } + goNext := req.WAFOnAction(instance) + if !goNext { + return false } instance.Perform(waf, group, this, req, writer) } @@ -153,11 +151,9 @@ func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Req for _, instance := range this.actionInstances { // 只执行第一个可能改变请求的动作,其余的都会被忽略 if instance.WillChange() { - if waf.onActionCallback != nil { - goNext := waf.onActionCallback(instance) - if !goNext { - return false - } + goNext := req.WAFOnAction(instance) + if !goNext { + return false } return instance.Perform(waf, group, this, req, writer) } diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go index e10d691..0ac5a3c 100644 --- a/internal/waf/template_test.go +++ b/internal/waf/template_test.go @@ -23,10 +23,6 @@ func Test_Template(t *testing.T) { t.Fatal(err) } - template.OnAction(func(action ActionInterface) (goNext bool) { - return action.Code() != ActionBlock - }) - testTemplate1001(a, t, template) testTemplate1002(a, t, template) testTemplate1003(a, t, template) diff --git a/internal/waf/waf.go b/internal/waf/waf.go index 3752f24..cfbb96d 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -27,7 +27,6 @@ type WAF struct { hasInboundRules bool hasOutboundRules bool - onActionCallback func(action ActionInterface) (goNext bool) checkpointsMap map[string]checkpoints.CheckpointInterface // prefix => checkpoint } @@ -347,10 +346,6 @@ func (this *WAF) CountOutboundRuleSets() int { return count } -func (this *WAF) OnAction(onActionCallback func(action ActionInterface) (goNext bool)) { - this.onActionCallback = onActionCallback -} - func (this *WAF) FindCheckpointInstance(prefix string) checkpoints.CheckpointInterface { instance, ok := this.checkpointsMap[prefix] if ok { diff --git a/internal/waf/waf_test.go b/internal/waf/waf_test.go index 5395eb8..6d6b316 100644 --- a/internal/waf/waf_test.go +++ b/internal/waf/waf_test.go @@ -38,10 +38,6 @@ func TestWAF_MatchRequest(t *testing.T) { t.Fatal(err) } - waf.OnAction(func(action ActionInterface) (goNext bool) { - return action.Code() != ActionBlock - }) - req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil) if err != nil { t.Fatal(err)