From 4dfa57154799cb592c8a243e3e2ffb5c2fc743fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Thu, 7 Dec 2023 11:42:59 +0800 Subject: [PATCH] =?UTF-8?q?WAF=E6=93=8D=E4=BD=9C=E7=AC=A6=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=8C=85=E5=90=AB=E4=BB=BB=E4=B8=80=E5=8D=95=E8=AF=8D?= =?UTF-8?q?=E3=80=81=E5=8C=85=E5=90=AB=E6=89=80=E6=9C=89=E5=8D=95=E8=AF=8D?= =?UTF-8?q?=E3=80=81=E4=B8=8D=E5=8C=85=E5=90=AB=E4=BB=BB=E4=B8=80=E5=8D=95?= =?UTF-8?q?=E8=AF=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/re/rune_tree.go | 2 +- internal/utils/runes/runes.go | 110 ++++++++++++++ internal/utils/runes/runes_test.go | 97 +++++++++++++ internal/waf/rule.go | 9 +- internal/waf/rule_operator.go | 225 ++++------------------------- internal/waf/template_test.go | 115 ++++++++++++--- 6 files changed, 341 insertions(+), 217 deletions(-) create mode 100644 internal/utils/runes/runes.go create mode 100644 internal/utils/runes/runes_test.go diff --git a/internal/re/rune_tree.go b/internal/re/rune_tree.go index c72104a..262f31e 100644 --- a/internal/re/rune_tree.go +++ b/internal/re/rune_tree.go @@ -4,7 +4,7 @@ package re type RuneMap map[rune]*RuneTree -func (this *RuneMap) Lookup(s string, caseInsensitive bool) bool { +func (this RuneMap) Lookup(s string, caseInsensitive bool) bool { return this.lookup([]rune(s), caseInsensitive, 0) } diff --git a/internal/utils/runes/runes.go b/internal/utils/runes/runes.go new file mode 100644 index 0000000..d82824b --- /dev/null +++ b/internal/utils/runes/runes.go @@ -0,0 +1,110 @@ +// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package runes + +// ContainsAnyWord 检查字符串是否包含任一单词 +func ContainsAnyWord(s string, words []string, isCaseInsensitive bool) bool { + var allRunes = []rune(s) + if len(allRunes) == 0 || len(words) == 0 { + return false + } + + for _, word := range words { + if ContainsWordRunes(allRunes, []rune(word), isCaseInsensitive) { + return true + } + } + return false +} + +// ContainsAllWords 检查字符串是否包含所有单词 +func ContainsAllWords(s string, words []string, isCaseInsensitive bool) bool { + var allRunes = []rune(s) + if len(allRunes) == 0 || len(words) == 0 { + return false + } + + for _, word := range words { + if !ContainsWordRunes(allRunes, []rune(word), isCaseInsensitive) { + return false + } + } + return true +} + +// ContainsWordRunes 检查字符列表是否包含某个单词子字符列表 +func ContainsWordRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) bool { + var l = len(subRunes) + if l == 0 { + return false + } + + var al = len(allRunes) + + for index, r := range allRunes { + if EqualRune(r, subRunes[0], isCaseInsensitive) && (index == 0 || !isChar(allRunes[index-1]) /**boundary check **/) { + var found = true + if l > 1 { + for i := 1; i < l; i++ { + var subIndex = index + i + if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) { + found = false + break + } + } + } + + // check after charset + if found && (al <= index+l || !isChar(allRunes[index+l]) /**boundary check **/) { + return true + } + } + } + + return false +} + +// ContainsSubRunes 检查字符列表是否包含某个子子字符列表 +// 与 ContainsWordRunes 不同,这里不需要检查边界符号 +func ContainsSubRunes(allRunes []rune, subRunes []rune, isCaseInsensitive bool) bool { + var l = len(subRunes) + if l == 0 { + return false + } + + var al = len(allRunes) + + for index, r := range allRunes { + if EqualRune(r, subRunes[0], isCaseInsensitive) { + var found = true + if l > 1 { + for i := 1; i < l; i++ { + var subIndex = index + i + if subIndex > al-1 || !EqualRune(allRunes[subIndex], subRunes[i], isCaseInsensitive) { + found = false + break + } + } + } + + // check after charset + if found { + return true + } + } + } + + return false +} + +// EqualRune 判断两个rune是否相同 +func EqualRune(r1 rune, r2 rune, isCaseInsensitive bool) bool { + const d = 'a' - 'A' + return r1 == r2 || + (isCaseInsensitive && r1 >= 'a' && r1 <= 'z' && r1-r2 == d) || + (isCaseInsensitive && r1 >= 'A' && r1 <= 'Z' && r1-r2 == -d) +} + +func isChar(r rune) bool { + return r >= 'a' && r <= 'z' || r >= 'A' && r <= 'Z' || r >= '0' && r <= '9' +} diff --git a/internal/utils/runes/runes_test.go b/internal/utils/runes/runes_test.go new file mode 100644 index 0000000..da2fd4e --- /dev/null +++ b/internal/utils/runes/runes_test.go @@ -0,0 +1,97 @@ +// Copyright 2023 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package runes_test + +import ( + "github.com/TeaOSLab/EdgeNode/internal/utils/runes" + "github.com/iwind/TeaGo/assert" + "runtime" + "testing" +) + +func TestContainsAllWords(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsTrue(runes.ContainsAllWords("How are you?", []string{"are", "you"}, false)) + a.IsFalse(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, false)) + a.IsTrue(runes.ContainsAllWords("How are you?", []string{"how", "are", "you"}, true)) +} + +func TestContainsAnyWord(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you"}, false)) + a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"are", "you", "ok"}, false)) + a.IsFalse(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, false)) + a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how"}, true)) + a.IsTrue(runes.ContainsAnyWord("How are you?", []string{"how", "ok"}, true)) +} + +func TestContainsWordRunes(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsFalse(runes.ContainsWordRunes([]rune(""), []rune("How"), true)) + a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune(""), true)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("How"), true)) + a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("you"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("are"), false)) + a.IsFalse(runes.ContainsWordRunes([]rune("How are you?"), []rune("re"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you w?"), []rune("w"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("w How are you?"), []rune("w"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are w you?"), []rune("w"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are how you?"), []rune("how"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("how"), true)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("ARE"), true)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("you"), false)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true)) + a.IsTrue(runes.ContainsWordRunes([]rune("How are you?"), []rune("YOU"), true)) + a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU"), true)) + a.IsFalse(runes.ContainsWordRunes([]rune("How are you1?"), []rune("YOU YOU YOU YOU YOU YOU YOU"), true)) +} + +func TestContainsSubRunes(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsFalse(runes.ContainsSubRunes([]rune(""), []rune("How"), true)) + a.IsFalse(runes.ContainsSubRunes([]rune("How are you?"), []rune(""), true)) + a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("YOU"), true)) + a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("ow"), false)) + a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("H"), false)) + a.IsTrue(runes.ContainsSubRunes([]rune("How are you1?"), []rune("How"), false)) + a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("oi"), false)) + a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("g"), false)) + a.IsTrue(runes.ContainsSubRunes([]rune("How are you doing"), []rune("ing"), false)) + a.IsFalse(runes.ContainsSubRunes([]rune("How are you doing"), []rune("int"), false)) +} + +func TestEqualRune(t *testing.T) { + var a = assert.NewAssertion(t) + a.IsTrue(runes.EqualRune('a', 'a', false)) + a.IsTrue(runes.EqualRune('a', 'a', true)) + a.IsFalse(runes.EqualRune('a', 'A', false)) + a.IsTrue(runes.EqualRune('a', 'A', true)) + a.IsFalse(runes.EqualRune('c', 'C', false)) + a.IsTrue(runes.EqualRune('c', 'C', true)) + a.IsTrue(runes.EqualRune('C', 'C', true)) + a.IsTrue(runes.EqualRune('C', 'c', true)) + a.IsTrue(runes.EqualRune('Z', 'z', true)) + a.IsTrue(runes.EqualRune('z', 'Z', true)) + a.IsFalse(runes.EqualRune('z', 'z'+('a'-'A'), true)) +} + +func BenchmarkContainsWordRunes(b *testing.B) { + runtime.GOMAXPROCS(4) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = runes.ContainsWordRunes([]rune("How are you"), []rune("YOU"), true) + } + }) +} + +func BenchmarkContainsSubRunes(b *testing.B) { + runtime.GOMAXPROCS(4) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = runes.ContainsSubRunes([]rune("How are you"), []rune("YOU"), true) + } + }) +} diff --git a/internal/waf/rule.go b/internal/waf/rule.go index 7f582b8..38622a4 100644 --- a/internal/waf/rule.go +++ b/internal/waf/rule.go @@ -9,6 +9,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/filterconfigs" "github.com/TeaOSLab/EdgeNode/internal/re" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" + "github.com/TeaOSLab/EdgeNode/internal/utils/runes" "github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/utils" @@ -77,7 +78,7 @@ func (this *Rule) Init() error { this.floatValue = types.Float64(this.Value) case RuleOperatorNeq: this.floatValue = types.Float64(this.Value) - case RuleOperatorContainsAny, RuleOperatorContainsAll: + case RuleOperatorContainsAny, RuleOperatorContainsAll, RuleOperatorContainsAnyWord, RuleOperatorContainsAllWords, RuleOperatorNotContainsAnyWord: this.stringValues = []string{} if len(this.Value) > 0 { var lines = strings.Split(this.Value, "\n") @@ -546,6 +547,12 @@ func (this *Rule) Test(value any) bool { return true } return false + case RuleOperatorContainsAnyWord: + return runes.ContainsAnyWord(this.stringifyValue(value), this.stringValues, this.IsCaseInsensitive) + case RuleOperatorContainsAllWords: + return runes.ContainsAllWords(this.stringifyValue(value), this.stringValues, this.IsCaseInsensitive) + case RuleOperatorNotContainsAnyWord: + return !runes.ContainsAnyWord(this.stringifyValue(value), this.stringValues, this.IsCaseInsensitive) case RuleOperatorContainsBinary: data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value)) if this.IsCaseInsensitive { diff --git a/internal/waf/rule_operator.go b/internal/waf/rule_operator.go index 36afc30..77c00ff 100644 --- a/internal/waf/rule_operator.go +++ b/internal/waf/rule_operator.go @@ -4,34 +4,38 @@ type RuleOperator = string type RuleCaseInsensitive = string const ( - RuleOperatorGt RuleOperator = "gt" - RuleOperatorGte RuleOperator = "gte" - RuleOperatorLt RuleOperator = "lt" - RuleOperatorLte RuleOperator = "lte" - RuleOperatorEq RuleOperator = "eq" - RuleOperatorNeq RuleOperator = "neq" - RuleOperatorEqString RuleOperator = "eq string" - RuleOperatorNeqString RuleOperator = "neq string" - RuleOperatorMatch RuleOperator = "match" - RuleOperatorNotMatch RuleOperator = "not match" - RuleOperatorWildcardMatch RuleOperator = "wildcard match" - RuleOperatorWildcardNotMatch RuleOperator = "wildcard not match" - RuleOperatorContains RuleOperator = "contains" - RuleOperatorNotContains RuleOperator = "not contains" - RuleOperatorPrefix RuleOperator = "prefix" - RuleOperatorSuffix RuleOperator = "suffix" - RuleOperatorContainsAny RuleOperator = "contains any" - RuleOperatorContainsAll RuleOperator = "contains all" - RuleOperatorInIPList RuleOperator = "in ip list" - RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map - RuleOperatorVersionGt RuleOperator = "version gt" - RuleOperatorVersionLt RuleOperator = "version lt" - RuleOperatorVersionRange RuleOperator = "version range" + RuleOperatorGt RuleOperator = "gt" + RuleOperatorGte RuleOperator = "gte" + RuleOperatorLt RuleOperator = "lt" + RuleOperatorLte RuleOperator = "lte" + RuleOperatorEq RuleOperator = "eq" + RuleOperatorNeq RuleOperator = "neq" + RuleOperatorEqString RuleOperator = "eq string" + RuleOperatorNeqString RuleOperator = "neq string" + RuleOperatorMatch RuleOperator = "match" + RuleOperatorNotMatch RuleOperator = "not match" + RuleOperatorWildcardMatch RuleOperator = "wildcard match" + RuleOperatorWildcardNotMatch RuleOperator = "wildcard not match" + RuleOperatorContains RuleOperator = "contains" + RuleOperatorNotContains RuleOperator = "not contains" + RuleOperatorPrefix RuleOperator = "prefix" + RuleOperatorSuffix RuleOperator = "suffix" + RuleOperatorContainsAny RuleOperator = "contains any" + RuleOperatorContainsAll RuleOperator = "contains all" + RuleOperatorContainsAnyWord RuleOperator = "contains any word" + RuleOperatorContainsAllWords RuleOperator = "contains all word" + RuleOperatorNotContainsAnyWord RuleOperator = "not contains any word" + RuleOperatorInIPList RuleOperator = "in ip list" + RuleOperatorHasKey RuleOperator = "has key" // has key in slice or map + RuleOperatorVersionGt RuleOperator = "version gt" + RuleOperatorVersionLt RuleOperator = "version lt" + RuleOperatorVersionRange RuleOperator = "version range" RuleOperatorContainsBinary RuleOperator = "contains binary" // contains binary RuleOperatorNotContainsBinary RuleOperator = "not contains binary" // not contains binary // ip + RuleOperatorEqIP RuleOperator = "eq ip" RuleOperatorGtIP RuleOperator = "gt ip" RuleOperatorGteIP RuleOperator = "gte ip" @@ -42,10 +46,6 @@ const ( RuleOperatorIPMod10 RuleOperator = "ip mod 10" RuleOperatorIPMod100 RuleOperator = "ip mod 100" RuleOperatorIPMod RuleOperator = "ip mod" - - RuleCaseInsensitiveNone = "none" - RuleCaseInsensitiveYes = "yes" - RuleCaseInsensitiveNo = "no" ) type RuleOperatorDefinition struct { @@ -54,174 +54,3 @@ type RuleOperatorDefinition struct { Description string CaseInsensitive RuleCaseInsensitive // default caseInsensitive setting } - -var AllRuleOperators = []*RuleOperatorDefinition{ - { - Name: "数值大于", - Code: RuleOperatorGt, - Description: "使用数值对比大于", - CaseInsensitive: RuleCaseInsensitiveNone, - }, - { - Name: "数值大于等于", - Code: RuleOperatorGte, - Description: "使用数值对比大于等于", - CaseInsensitive: RuleCaseInsensitiveNone, - }, - { - Name: "数值小于", - Code: RuleOperatorLt, - Description: "使用数值对比小于", - CaseInsensitive: RuleCaseInsensitiveNone, - }, - { - Name: "数值小于等于", - Code: RuleOperatorLte, - Description: "使用数值对比小于等于", - CaseInsensitive: RuleCaseInsensitiveNone, - }, - { - Name: "数值等于", - Code: RuleOperatorEq, - Description: "使用数值对比等于", - CaseInsensitive: RuleCaseInsensitiveNone, - }, - { - Name: "数值不等于", - Code: RuleOperatorNeq, - Description: "使用数值对比不等于", - CaseInsensitive: RuleCaseInsensitiveNone, - }, - { - Name: "字符串等于", - Code: RuleOperatorEqString, - Description: "使用字符串对比等于", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "字符串不等于", - Code: RuleOperatorNeqString, - Description: "使用字符串对比不等于", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "正则匹配", - Code: RuleOperatorMatch, - Description: "使用正则表达式匹配,在头部使用(?i)表示不区分大小写,正则表达式语法 »", - CaseInsensitive: RuleCaseInsensitiveYes, - }, - { - Name: "正则不匹配", - Code: RuleOperatorNotMatch, - Description: "使用正则表达式不匹配,在头部使用(?i)表示不区分大小写,正则表达式语法 »", - CaseInsensitive: RuleCaseInsensitiveYes, - }, - { - Name: "包含字符串", - Code: RuleOperatorContains, - Description: "包含某个字符串", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "不包含字符串", - Code: RuleOperatorNotContains, - Description: "不包含某个字符串", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "包含前缀", - Code: RuleOperatorPrefix, - Description: "包含某个前缀", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "包含后缀", - Code: RuleOperatorSuffix, - Description: "包含某个后缀", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "包含索引", - Code: RuleOperatorHasKey, - Description: "对于一组数据拥有某个键值或者索引", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "版本号大于", - Code: RuleOperatorVersionGt, - Description: "对比版本号大于", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "版本号小于", - Code: RuleOperatorVersionLt, - Description: "对比版本号小于", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "版本号范围", - Code: RuleOperatorVersionRange, - Description: "判断版本号在某个范围内,格式为version1,version2", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP等于", - Code: RuleOperatorEqIP, - Description: "将参数转换为IP进行对比", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP大于", - Code: RuleOperatorGtIP, - Description: "将参数转换为IP进行对比", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP大于等于", - Code: RuleOperatorGteIP, - Description: "将参数转换为IP进行对比", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP小于", - Code: RuleOperatorLtIP, - Description: "将参数转换为IP进行对比", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP小于等于", - Code: RuleOperatorLteIP, - Description: "将参数转换为IP进行对比", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP范围", - Code: RuleOperatorIPRange, - Description: "IP在某个范围之内,范围格式可以是英文逗号分隔的ip1,ip2,或者CIDR格式的ip/bits", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "不在IP范围", - Code: RuleOperatorNotIPRange, - Description: "IP不在某个范围之内,范围格式可以是英文逗号分隔的ip1,ip2,或者CIDR格式的ip/bits", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP取模10", - Code: RuleOperatorIPMod10, - Description: "对IP参数值取模,除数为10,对比值为余数", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP取模100", - Code: RuleOperatorIPMod100, - Description: "对IP参数值取模,除数为100,对比值为余数", - CaseInsensitive: RuleCaseInsensitiveNo, - }, - { - Name: "IP取模", - Code: RuleOperatorIPMod, - Description: "对IP参数值取模,对比值格式为:除数,余数,比如10,1", - CaseInsensitive: RuleCaseInsensitiveNo, - }, -} diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go index b4b1518..ca21b9b 100644 --- a/internal/waf/template_test.go +++ b/internal/waf/template_test.go @@ -16,29 +16,38 @@ import ( ) func Test_Template(t *testing.T) { - a := assert.NewAssertion(t) + var a = assert.NewAssertion(t) - template := Template() - err := template.Init() + var waf = Template() + + for _, group := range waf.Inbound { + group.IsOn = true + + for _, set := range group.RuleSets { + set.IsOn = true + } + } + + err := waf.Init() if err != nil { t.Fatal(err) } - testTemplate1001(a, t, template) - testTemplate1002(a, t, template) - testTemplate1003(a, t, template) - testTemplate2001(a, t, template) - testTemplate3001(a, t, template) - testTemplate4001(a, t, template) - testTemplate5001(a, t, template) - testTemplate6001(a, t, template) - testTemplate7001(a, t, template) - testTemplate20001(a, t, template) + testTemplate1001(a, t, waf) + testTemplate1002(a, t, waf) + testTemplate1003(a, t, waf) + testTemplate2001(a, t, waf) + testTemplate3001(a, t, waf) + testTemplate4001(a, t, waf) + testTemplate5001(a, t, waf) + testTemplate6001(a, t, waf) + testTemplate7001(a, t, waf) + testTemplate20001(a, t, waf) } func Test_Template2(t *testing.T) { reader := bytes.NewReader([]byte(strings.Repeat("HELLO", 1024))) - req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123", reader) + req, err := http.NewRequest(http.MethodPost, "https://example.com/index.php?id=123", reader) if err != nil { t.Fatal(err) } @@ -65,15 +74,25 @@ func Test_Template2(t *testing.T) { } func BenchmarkTemplate(b *testing.B) { - waf := Template() + var waf = Template() + + for _, group := range waf.Inbound { + group.IsOn = true + + for _, set := range group.RuleSets { + set.IsOn = true + } + } + err := waf.Init() if err != nil { b.Fatal(err) } + b.ResetTimer() + for i := 0; i < b.N; i++ { - reader := bytes.NewReader([]byte(strings.Repeat("Hello", 1024))) - req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=123", reader) + req, err := http.NewRequest(http.MethodGet, "https://example.com/index.php?id=123", nil) if err != nil { b.Fatal(err) } @@ -312,6 +331,68 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) { } } +func TestTemplateSQLInjection(t *testing.T) { + var template = Template() + errs := template.Init() + if len(errs) > 0 { + t.Fatal(errs) + return + } + var group = template.FindRuleGroupWithCode("sqlInjection") + if group == nil { + t.Fatal("group not found") + return + } + // + //for _, set := range group.RuleSets { + // for _, rule := range set.Rules { + // t.Logf("%#v", rule.singleCheckpoint) + // } + //} + + req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234", nil) + if err != nil { + t.Fatal(err) + } + _, _, result, err := group.MatchRequest(requests.NewTestRequest(req)) + if err != nil { + t.Fatal(err) + } + if result != nil { + t.Log(result) + } +} + +func BenchmarkTemplateSQLInjection(b *testing.B) { + var template = Template() + errs := template.Init() + if len(errs) > 0 { + b.Fatal(errs) + return + } + var group = template.FindRuleGroupWithCode("sqlInjection") + if group == nil { + b.Fatal("group not found") + return + } + + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234", nil) + if err != nil { + b.Fatal(err) + } + _, _, result, err := group.MatchRequest(requests.NewTestRequest(req)) + if err != nil { + b.Fatal(err) + } + _ = result + } + }) +} + func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) { // enable bot rule set for _, g := range template.Inbound {