diff --git a/internal/waf/action_utils_test.go b/internal/waf/action_utils_test.go index 735fe32..f1aba4d 100644 --- a/internal/waf/action_utils_test.go +++ b/internal/waf/action_utils_test.go @@ -1,6 +1,7 @@ -package waf +package waf_test import ( + "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/maps" @@ -11,22 +12,22 @@ import ( func TestFindActionInstance(t *testing.T) { a := assert.NewAssertion(t) - t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil)) - t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil)) - t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil)) - t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil)) - t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) - t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) - t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"})) + t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil)) + t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil)) + t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil)) + t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil)) + t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil)) + t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil)) + t.Logf("ActionGoSet: %#v", waf.FindActionInstance(waf.ActionGoSet, maps.Map{"groupId": "a", "setId": "b"})) - a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil)) + a.IsTrue(waf.FindActionInstance(waf.ActionGoSet, nil) != waf.FindActionInstance(waf.ActionGoSet, nil)) } func TestFindActionInstance_Options(t *testing.T) { //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{})) //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{})) //logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t) - logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{ + logs.PrintAsJSON(waf.FindActionInstance(waf.ActionBlock, maps.Map{ "timeout": 3600, }), t) } @@ -34,6 +35,6 @@ func TestFindActionInstance_Options(t *testing.T) { func BenchmarkFindActionInstance(b *testing.B) { runtime.GOMAXPROCS(1) for i := 0; i < b.N; i++ { - FindActionInstance(ActionGoSet, nil) + waf.FindActionInstance(waf.ActionGoSet, nil) } } diff --git a/internal/waf/injectionutils/utils_sqli.go b/internal/waf/injectionutils/utils_sqli.go index 49e87d0..7a44a5d 100644 --- a/internal/waf/injectionutils/utils_sqli.go +++ b/internal/waf/injectionutils/utils_sqli.go @@ -10,11 +10,43 @@ package injectionutils */ import "C" import ( + "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" + "github.com/TeaOSLab/EdgeNode/internal/waf/utils" + "github.com/cespare/xxhash/v2" "net/url" + "strconv" "strings" "unsafe" ) +// DetectSQLInjectionCache detect sql injection in string with cache +func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool { + var l = len(input) + + if l == 0 { + return false + } + + if cacheLife <= 0 || l < 128 || l > utils.MaxCacheDataSize { + return DetectSQLInjection(input) + } + + var hash = xxhash.Sum64String(input) + var key = "WAF@SQLI@" + strconv.FormatUint(hash, 10) + var item = utils.SharedCache.Read(key) + if item != nil { + return item.Value == 1 + } + + var result = DetectSQLInjection(input) + if result { + utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife) + } else { + utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife) + } + return result +} + // DetectSQLInjection detect sql injection in string func DetectSQLInjection(input string) bool { if len(input) == 0 { @@ -26,7 +58,7 @@ func DetectSQLInjection(input string) bool { } // 兼容 /PATH?URI - if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 4096 { + if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 { var argsIndex = strings.Index(input, "?") if argsIndex > 0 { var args = input[argsIndex+1:] diff --git a/internal/waf/injectionutils/utils_sqli_test.go b/internal/waf/injectionutils/utils_sqli_test.go index 10af15a..6e22b20 100644 --- a/internal/waf/injectionutils/utils_sqli_test.go +++ b/internal/waf/injectionutils/utils_sqli_test.go @@ -4,8 +4,12 @@ package injectionutils_test import ( "github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils" + "github.com/TeaOSLab/EdgeNode/internal/waf/utils" "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/rands" + "github.com/iwind/TeaGo/types" "runtime" + "strings" "testing" ) @@ -23,6 +27,7 @@ func TestDetectSQLInjection(t *testing.T) { a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1")) a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1")) a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/sql/injection?id=123%20or%201=1")) + a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/' or 1=1")) } func BenchmarkDetectSQLInjection(b *testing.B) { @@ -45,6 +50,71 @@ func BenchmarkDetectSQLInjection_URL(b *testing.B) { }) } +func BenchmarkDetectSQLInjection_Normal_Small(b *testing.B) { + runtime.GOMAXPROCS(4) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = injectionutils.DetectSQLInjection("a/sql/injection?id=1234") + } + }) +} + +func BenchmarkDetectSQLInjection_URL_Normal_Small(b *testing.B) { + runtime.GOMAXPROCS(4) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = injectionutils.DetectSQLInjection("/sql/injection?id=" + types.String(rands.Int64()%10000)) + } + }) +} + +func BenchmarkDetectSQLInjection_URL_Normal_Middle(b *testing.B) { + runtime.GOMAXPROCS(4) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = injectionutils.DetectSQLInjection("/search?q=libinjection+fingerprint&newwindow=1&sca_esv=589290862&sxsrf=AMwHvKnxuLoejn2XlNniffC12E_xc35M7Q%3A1702090118361&ei=htvzzebfFZfo1e8PvLGggAk&ved=0ahUKEwjTsYmnq4GDAxUWdPOHHbwkCJAQ4ddDCBA&uact=5&oq=libinjection+fingerprint&gs_lp=Egxnd3Mtd2l6LXNlcnAiGIxpYmluamVjdGlvbmBmaW5nKXJwcmludTIEEAAYHjIGVAAYCBgeSiEaUPkRWKFZcAJ4AZABAJgBHgGgAfoEqgwDMC40uAEGyAEA-AEBwgIKEAFYTxjWMuiwA-IDBBgAVteIBgGQBgI&sclient=gws-wiz-serp#ip=1") + } + }) +} + +func BenchmarkDetectSQLInjection_URL_Normal_Small_Cache(b *testing.B) { + runtime.GOMAXPROCS(4) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = injectionutils.DetectSQLInjectionCache("/sql/injection?id="+types.String(rands.Int64()%10000), utils.CacheMiddleLife) + } + }) +} + +func BenchmarkDetectSQLInjection_Normal_Large(b *testing.B) { + runtime.GOMAXPROCS(4) + + var s = strings.Repeat("A", 512) + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = injectionutils.DetectSQLInjection("a/sql/injection?id=" + types.String(rands.Int64()%10000) + "&s=" + s) + } + }) +} + +func BenchmarkDetectSQLInjection_Normal_Large_Cache(b *testing.B) { + runtime.GOMAXPROCS(4) + + var s = strings.Repeat("A", 512) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = injectionutils.DetectSQLInjectionCache("a/sql/injection?id="+types.String(rands.Int64()%10000)+"&s="+s, utils.CacheMiddleLife) + } + }) +} + func BenchmarkDetectSQLInjection_URL_Unescape(b *testing.B) { runtime.GOMAXPROCS(4) diff --git a/internal/waf/injectionutils/utils_xss.go b/internal/waf/injectionutils/utils_xss.go index a4778c1..0808129 100644 --- a/internal/waf/injectionutils/utils_xss.go +++ b/internal/waf/injectionutils/utils_xss.go @@ -10,11 +10,42 @@ package injectionutils */ import "C" import ( + "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" + "github.com/TeaOSLab/EdgeNode/internal/waf/utils" + "github.com/cespare/xxhash/v2" "net/url" + "strconv" "strings" "unsafe" ) +func DetectXSSCache(input string, cacheLife utils.CacheLife) bool { + var l = len(input) + + if l == 0 { + return false + } + + if cacheLife <= 0 || l < 512 || l > utils.MaxCacheDataSize { + return DetectXSS(input) + } + + var hash = xxhash.Sum64String(input) + var key = "WAF@XSS@" + strconv.FormatUint(hash, 10) + var item = utils.SharedCache.Read(key) + if item != nil { + return item.Value == 1 + } + + var result = DetectXSS(input) + if result { + utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife) + } else { + utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife) + } + return result +} + // DetectXSS detect XSS in string func DetectXSS(input string) bool { if len(input) == 0 { @@ -26,7 +57,7 @@ func DetectXSS(input string) bool { } // 兼容 /PATH?URI - if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 4096 { + if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 { var argsIndex = strings.Index(input, "?") if argsIndex > 0 { var args = input[argsIndex+1:] diff --git a/internal/waf/injectionutils/utils_xss_test.go b/internal/waf/injectionutils/utils_xss_test.go index c8b6642..7abc1c7 100644 --- a/internal/waf/injectionutils/utils_xss_test.go +++ b/internal/waf/injectionutils/utils_xss_test.go @@ -4,6 +4,7 @@ package injectionutils_test import ( "github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils" + "github.com/TeaOSLab/EdgeNode/internal/waf/utils" "github.com/iwind/TeaGo/assert" "runtime" "testing" @@ -25,7 +26,10 @@ func TestDetectXSS(t *testing.T) { } func BenchmarkDetectXSS_MISS(b *testing.B) { - b.Log(injectionutils.DetectXSS("RequestId: 1234567890")) + var result = injectionutils.DetectXSS("RequestId: 1234567890") + if result { + b.Fatal("'result' should not be 'true'") + } runtime.GOMAXPROCS(4) @@ -36,8 +40,26 @@ func BenchmarkDetectXSS_MISS(b *testing.B) { }) } +func BenchmarkDetectXSS_MISS_Cache(b *testing.B) { + var result = injectionutils.DetectXSS("RequestId: 1234567890") + if result { + b.Fatal("'result' should not be 'true'") + } + + runtime.GOMAXPROCS(4) + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = injectionutils.DetectXSSCache("RequestId: 1234567890", utils.CacheMiddleLife) + } + }) +} + func BenchmarkDetectXSS_HIT(b *testing.B) { - b.Log(injectionutils.DetectXSS("RequestId: 1234567890")) + var result = injectionutils.DetectXSS("RequestId: 1234567890") + if !result { + b.Fatal("'result' should not be 'false'") + } runtime.GOMAXPROCS(4) diff --git a/internal/waf/rule.go b/internal/waf/rule.go index 13174bc..e33df9e 100644 --- a/internal/waf/rule.go +++ b/internal/waf/rule.go @@ -22,6 +22,7 @@ import ( "net" "reflect" "regexp" + "sort" "strings" ) @@ -56,8 +57,8 @@ type Rule struct { floatValue float64 - reg *re.Regexp - regCacheLife utils.CacheLife + reg *re.Regexp + cacheLife utils.CacheLife } func NewRule() *Rule { @@ -93,6 +94,9 @@ func (this *Rule) Init() error { } } } + if this.Operator == RuleOperatorContainsAnyWord || this.Operator == RuleOperatorContainsAllWords || this.Operator == RuleOperatorNotContainsAnyWord { + sort.Strings(this.stringValues) + } } case RuleOperatorMatch: var v = this.Value @@ -166,7 +170,7 @@ func (this *Rule) Init() error { this.singleCheckpoint = checkpoint this.Priority = checkpoint.Priority() - this.regCacheLife = checkpoint.CacheLife() + this.cacheLife = checkpoint.CacheLife() } else { var checkpoint = checkpoints.FindCheckpoint(prefix) if checkpoint == nil { @@ -176,7 +180,7 @@ func (this *Rule) Init() error { this.singleCheckpoint = checkpoint this.Priority = checkpoint.Priority() - this.regCacheLife = checkpoint.CacheLife() + this.cacheLife = checkpoint.CacheLife() } return nil @@ -195,8 +199,8 @@ func (this *Rule) Init() error { this.multipleCheckpoints[prefix] = checkpoint this.Priority = checkpoint.Priority() - if this.regCacheLife <= 0 || checkpoint.CacheLife() < this.regCacheLife { - this.regCacheLife = checkpoint.CacheLife() + if this.cacheLife <= 0 || checkpoint.CacheLife() < this.cacheLife { + this.cacheLife = checkpoint.CacheLife() } } } else { @@ -208,7 +212,7 @@ func (this *Rule) Init() error { this.multipleCheckpoints[prefix] = checkpoint this.Priority = checkpoint.Priority() - this.regCacheLife = checkpoint.CacheLife() + this.cacheLife = checkpoint.CacheLife() } } @@ -408,7 +412,7 @@ func (this *Rule) Test(value any) bool { stringList, ok := value.([]string) if ok { for _, s := range stringList { - if utils.MatchStringCache(this.reg, s, this.regCacheLife) { + if utils.MatchStringCache(this.reg, s, this.cacheLife) { return true } } @@ -419,7 +423,7 @@ func (this *Rule) Test(value any) bool { byteSlices, ok := value.([][]byte) if ok { for _, byteSlice := range byteSlices { - if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) { + if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) { return true } } @@ -429,11 +433,11 @@ func (this *Rule) Test(value any) bool { // bytes byteSlice, ok := value.([]byte) if ok { - return utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) + return utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) } // string - return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife) + return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife) case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch: if value == nil { value = "" @@ -441,7 +445,7 @@ func (this *Rule) Test(value any) bool { stringList, ok := value.([]string) if ok { for _, s := range stringList { - if utils.MatchStringCache(this.reg, s, this.regCacheLife) { + if utils.MatchStringCache(this.reg, s, this.cacheLife) { return false } } @@ -452,7 +456,7 @@ func (this *Rule) Test(value any) bool { byteSlices, ok := value.([][]byte) if ok { for _, byteSlice := range byteSlices { - if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) { + if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) { return false } } @@ -462,10 +466,10 @@ func (this *Rule) Test(value any) bool { // bytes byteSlice, ok := value.([]byte) if ok { - return !utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) + return !utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) } - return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife) + return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife) case RuleOperatorContains: if types.IsSlice(value) { _, isBytes := value.([]byte) @@ -575,20 +579,20 @@ func (this *Rule) Test(value any) bool { switch xValue := value.(type) { case []string: for _, v := range xValue { - if injectionutils.DetectSQLInjection(v) { + if injectionutils.DetectSQLInjectionCache(v, this.cacheLife) { return true } } return false case [][]byte: for _, v := range xValue { - if injectionutils.DetectSQLInjection(string(v)) { + if injectionutils.DetectSQLInjectionCache(string(v), this.cacheLife) { return true } } return false default: - return injectionutils.DetectSQLInjection(this.stringifyValue(value)) + return injectionutils.DetectSQLInjectionCache(this.stringifyValue(value), this.cacheLife) } case RuleOperatorContainsXSS: if value == nil { @@ -597,20 +601,20 @@ func (this *Rule) Test(value any) bool { switch xValue := value.(type) { case []string: for _, v := range xValue { - if injectionutils.DetectXSS(v) { + if injectionutils.DetectXSSCache(v, this.cacheLife) { return true } } return false case [][]byte: for _, v := range xValue { - if injectionutils.DetectXSS(string(v)) { + if injectionutils.DetectXSSCache(string(v), this.cacheLife) { return true } } return false default: - return injectionutils.DetectXSS(this.stringifyValue(value)) + return injectionutils.DetectXSSCache(this.stringifyValue(value), this.cacheLife) } case RuleOperatorContainsBinary: data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value)) diff --git a/internal/waf/rule_set_test.go b/internal/waf/rule_set_test.go index 023bd02..7799058 100644 --- a/internal/waf/rule_set_test.go +++ b/internal/waf/rule_set_test.go @@ -1,7 +1,8 @@ -package waf +package waf_test import ( "bytes" + "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/cespare/xxhash" "github.com/iwind/TeaGo/assert" @@ -12,18 +13,18 @@ import ( ) func TestRuleSet_MatchRequest(t *testing.T) { - set := NewRuleSet() - set.Connector = RuleConnectorAnd + var set = waf.NewRuleSet() + set.Connector = waf.RuleConnectorAnd - set.Rules = []*Rule{ + set.Rules = []*waf.Rule{ { Param: "${arg.name}", - Operator: RuleOperatorEqString, + Operator: waf.RuleOperatorEqString, Value: "lu", }, { Param: "${arg.age}", - Operator: RuleOperatorEq, + Operator: waf.RuleOperatorEq, Value: "20", }, } @@ -42,20 +43,20 @@ func TestRuleSet_MatchRequest(t *testing.T) { } func TestRuleSet_MatchRequest2(t *testing.T) { - a := assert.NewAssertion(t) + var a = assert.NewAssertion(t) - set := NewRuleSet() - set.Connector = RuleConnectorOr + var set = waf.NewRuleSet() + set.Connector = waf.RuleConnectorOr - set.Rules = []*Rule{ + set.Rules = []*waf.Rule{ { Param: "${arg.name}", - Operator: RuleOperatorEqString, + Operator: waf.RuleOperatorEqString, Value: "lu", }, { Param: "${arg.age}", - Operator: RuleOperatorEq, + Operator: waf.RuleOperatorEq, Value: "21", }, } @@ -76,28 +77,28 @@ func TestRuleSet_MatchRequest2(t *testing.T) { func BenchmarkRuleSet_MatchRequest(b *testing.B) { runtime.GOMAXPROCS(1) - set := NewRuleSet() - set.Connector = RuleConnectorOr + var set = waf.NewRuleSet() + set.Connector = waf.RuleConnectorOr - set.Rules = []*Rule{ + set.Rules = []*waf.Rule{ { Param: "${requestAll}", - Operator: RuleOperatorMatch, + Operator: waf.RuleOperatorMatch, Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`, }, { Param: "${requestAll}", - Operator: RuleOperatorMatch, + Operator: waf.RuleOperatorMatch, Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, }, { Param: "${arg.name}", - Operator: RuleOperatorEqString, + Operator: waf.RuleOperatorEqString, Value: "lu", }, { Param: "${arg.age}", - Operator: RuleOperatorEq, + Operator: waf.RuleOperatorEq, Value: "21", }, } @@ -120,13 +121,13 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) { func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) { runtime.GOMAXPROCS(1) - set := NewRuleSet() - set.Connector = RuleConnectorOr + var set = waf.NewRuleSet() + set.Connector = waf.RuleConnectorOr - set.Rules = []*Rule{ + set.Rules = []*waf.Rule{ { Param: "${requestBody}", - Operator: RuleOperatorMatch, + Operator: waf.RuleOperatorMatch, Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, IsCaseInsensitive: false, }, diff --git a/internal/waf/template.go b/internal/waf/template.go index 1ebdaff..a2b2aca 100644 --- a/internal/waf/template.go +++ b/internal/waf/template.go @@ -1,434 +1,40 @@ package waf -func Template() *WAF { - waf := NewWAF() - waf.Id = 0 - waf.IsOn = true +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeNode/internal/waf/utils" +) - // xss - { - group := NewRuleGroup() +func Template() (*WAF, error) { + var config = firewallconfigs.HTTPFirewallTemplate() + if config.Inbound != nil { + config.Inbound.IsOn = true + } + + for _, group := range config.AllRuleGroups() { + if group.Code == "cc" || group.Code == "cc2" { + continue + } group.IsOn = true - group.IsInbound = true - group.Name = "XSS" - group.Code = "xss" - group.Description = "防跨站脚本攻击(Cross Site Scripting)" - { - set := NewRuleSet() + for _, set := range group.Sets { set.IsOn = true - set.Name = "Javascript事件" - set.Code = "1001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestURI}", - Operator: RuleOperatorMatch, - Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`, // TODO more keywords here - IsCaseInsensitive: true, - }) - group.AddRuleSet(set) } - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "Javascript函数" - set.Code = "1002" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestURI}", - Operator: RuleOperatorMatch, - Value: `(alert|eval|prompt|confirm)\s*\(`, // TODO more keywords here - IsCaseInsensitive: true, - }) - group.AddRuleSet(set) - } - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "HTML标签" - set.Code = "1003" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestURI}", - Operator: RuleOperatorMatch, - Value: `<(script|iframe|link)`, // TODO more keywords here - IsCaseInsensitive: true, - }) - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) } - // upload - { - group := NewRuleGroup() - group.IsOn = true - group.IsInbound = true - group.Name = "文件上传" - group.Code = "upload" - group.Description = "防止上传可执行脚本文件到服务器" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "上传文件扩展名" - set.Code = "2001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestUpload.ext}", - Operator: RuleOperatorMatch, - Value: `\.(php|jsp|aspx|asp|exe|asa|rb|py)\b`, // TODO more keywords here - IsCaseInsensitive: true, - }) - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) + instance, err := SharedWAFManager.ConvertWAF(config) + if err != nil { + return nil, err } - // web shell - { - group := NewRuleGroup() - group.IsOn = true - group.IsInbound = true - group.Name = "Web Shell" - group.Code = "webShell" - group.Description = "防止远程执行服务器命令" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "Web Shell" - set.Code = "3001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, // TODO more keywords here - IsCaseInsensitive: true, - }) - group.AddRuleSet(set) + for _, group := range instance.Inbound { + for _, set := range group.RuleSets { + for _, rule := range set.Rules { + rule.cacheLife = utils.CacheDisabled // for performance test + _ = rule + } } - - waf.AddRuleGroup(group) } - // command injection - { - group := NewRuleGroup() - group.IsOn = true - group.IsInbound = true - group.Name = "命令注入" - group.Code = "commandInjection" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "命令注入" - set.Code = "4001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestURI}", - Operator: RuleOperatorMatch, - Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here - IsCaseInsensitive: false, - }) - set.AddRule(&Rule{ - Param: "${requestBody}", - Operator: RuleOperatorMatch, - Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here - IsCaseInsensitive: false, - }) - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) - } - - // path traversal - { - group := NewRuleGroup() - group.IsOn = true - group.IsInbound = true - group.Name = "路径穿越" - group.Code = "pathTraversal" - group.Description = "防止读取网站目录之外的其他系统文件" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "路径穿越" - set.Code = "5001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestURI}", - Operator: RuleOperatorMatch, - Value: `((\.+)(/+)){2,}`, // TODO more keywords here - IsCaseInsensitive: false, - }) - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) - } - - // special dirs - { - group := NewRuleGroup() - group.IsOn = true - group.IsInbound = true - group.Name = "特殊目录" - group.Code = "denyDirs" - group.Description = "防止通过Web访问到一些特殊目录" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "特殊目录" - set.Code = "6001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${requestPath}", - Operator: RuleOperatorMatch, - Value: `/\.(git|svn|htaccess|idea)\b`, // TODO more keywords here - IsCaseInsensitive: true, - }) - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) - } - - // sql injection - { - group := NewRuleGroup() - group.IsOn = true - group.IsInbound = true - group.Name = "SQL注入" - group.Code = "sqlInjection" - group.Description = "防止SQL注入漏洞" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "Union SQL Injection" - set.Code = "7001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `union[\s/\*]+select`, - IsCaseInsensitive: true, - }) - - group.AddRuleSet(set) - } - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "SQL注释" - set.Code = "7002" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `/\*(!|\x00)`, - IsCaseInsensitive: true, - }) - - group.AddRuleSet(set) - } - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "SQL条件" - set.Code = "7003" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `\s(and|or|rlike)\s+(if|updatexml)\s*\(`, - IsCaseInsensitive: true, - }) - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `\s+(and|or|rlike)\s+(select|case)\s+`, - IsCaseInsensitive: true, - }) - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `\s+(and|or|procedure)\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+(\s|$|--|#)`, - IsCaseInsensitive: true, - }) - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `\(\s*case\s+when\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+\s+then\s+`, - IsCaseInsensitive: true, - }) - - group.AddRuleSet(set) - } - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "SQL函数" - set.Code = "7004" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `(updatexml|extractvalue|ascii|ord|char|chr|count|concat|rand|floor|substr|length|len|user|database|benchmark|analyse)\s*\(`, - IsCaseInsensitive: true, - }) - - group.AddRuleSet(set) - } - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "SQL附加语句" - set.Code = "7005" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - - set.AddRule(&Rule{ - Param: "${requestAll}", - Operator: RuleOperatorMatch, - Value: `;\s*(declare|use|drop|create|exec|delete|update|insert)\s`, - IsCaseInsensitive: true, - }) - - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) - } - - // bot - { - group := NewRuleGroup() - group.IsOn = false - group.IsInbound = true - group.Name = "网络爬虫" - group.Code = "bot" - group.Description = "禁止一些网络爬虫" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "常见网络爬虫" - set.Code = "20001" - set.Connector = RuleConnectorOr - set.AddAction(ActionBlock, nil) - - set.AddRule(&Rule{ - Param: "${userAgent}", - Operator: RuleOperatorMatch, - Value: `Googlebot|AdsBot|bingbot|BingPreview|facebookexternalhit|Slurp|Sogou|proximic|Baiduspider|yandex|twitterbot|spider|python`, - IsCaseInsensitive: true, - }) - - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) - } - - // cc - { - group := NewRuleGroup() - group.IsOn = false - group.IsInbound = true - group.Name = "CC攻击" - group.Description = "Challenge Collapsar,防止短时间大量请求涌入,请谨慎开启和设置" - group.Code = "cc2" - - { - set := NewRuleSet() - set.IsOn = true - set.Name = "CC请求数" - set.Description = "限制单IP在一定时间内的请求数" - set.Code = "8001" - set.Connector = RuleConnectorAnd - set.AddAction(ActionBlock, nil) - set.AddRule(&Rule{ - Param: "${cc2}", - Operator: RuleOperatorGt, - Value: "1000", - CheckpointOptions: map[string]interface{}{ - "period": "60", - "threshold": 1000, - "keys": []string{"${remoteAddr}", "${requestPath}"}, - }, - IsCaseInsensitive: false, - }) - set.AddRule(&Rule{ - Param: "${remoteAddr}", - Operator: RuleOperatorNotIPRange, - Value: `127.0.0.1/8`, - IsCaseInsensitive: false, - }) - set.AddRule(&Rule{ - Param: "${remoteAddr}", - Operator: RuleOperatorNotIPRange, - Value: `192.168.0.1/16`, - IsCaseInsensitive: false, - }) - set.AddRule(&Rule{ - Param: "${remoteAddr}", - Operator: RuleOperatorNotIPRange, - Value: `10.0.0.1/8`, - IsCaseInsensitive: false, - }) - set.AddRule(&Rule{ - Param: "${remoteAddr}", - Operator: RuleOperatorNotIPRange, - Value: `172.16.0.1/12`, - IsCaseInsensitive: false, - }) - - group.AddRuleSet(set) - } - - waf.AddRuleGroup(group) - } - - // custom - { - group := NewRuleGroup() - group.IsOn = true - group.IsInbound = true - group.Name = "自定义规则分组" - group.Description = "我的自定义规则分组,可以将自定义的规则放在这个分组下" - group.Code = "custom" - waf.AddRuleGroup(group) - } - - return waf + return instance, nil } diff --git a/internal/waf/template_test.go b/internal/waf/template_test.go index ca21b9b..7112ea6 100644 --- a/internal/waf/template_test.go +++ b/internal/waf/template_test.go @@ -1,12 +1,15 @@ -package waf +package waf_test import ( "bytes" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/types" + "math/rand" "mime/multipart" "net/http" "net/url" @@ -15,34 +18,26 @@ import ( "time" ) +const testUserAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0" + func Test_Template(t *testing.T) { var a = assert.NewAssertion(t) - var waf = Template() - - for _, group := range waf.Inbound { - group.IsOn = true - - for _, set := range group.RuleSets { - set.IsOn = true - } - } - - err := waf.Init() + wafInstance, err := waf.Template() if err != nil { t.Fatal(err) } - testTemplate1001(a, t, waf) - testTemplate1002(a, t, waf) - testTemplate1003(a, t, waf) - testTemplate2001(a, t, waf) - testTemplate3001(a, t, waf) - testTemplate4001(a, t, waf) - testTemplate5001(a, t, waf) - testTemplate6001(a, t, waf) - testTemplate7001(a, t, waf) - testTemplate20001(a, t, waf) + testTemplate1001(a, t, wafInstance) + testTemplate1002(a, t, wafInstance) + testTemplate1003(a, t, wafInstance) + testTemplate2001(a, t, wafInstance) + testTemplate3001(a, t, wafInstance) + testTemplate4001(a, t, wafInstance) + testTemplate5001(a, t, wafInstance) + testTemplate6001(a, t, wafInstance) + testTemplate7001(a, t, wafInstance) + testTemplate20001(a, t, wafInstance) } func Test_Template2(t *testing.T) { @@ -52,14 +47,13 @@ func Test_Template2(t *testing.T) { t.Fatal(err) } - waf := Template() - var errs = waf.Init() - if len(errs) > 0 { - t.Fatal(errs[0]) + wafInstance, err := waf.Template() + if err != nil { + t.Fatal(err) } now := time.Now() - goNext, _, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) } @@ -74,17 +68,7 @@ func Test_Template2(t *testing.T) { } func BenchmarkTemplate(b *testing.B) { - var waf = Template() - - for _, group := range waf.Inbound { - group.IsOn = true - - for _, set := range group.RuleSets { - set.IsOn = true - } - } - - err := waf.Init() + wafInstance, err := waf.Template() if err != nil { b.Fatal(err) } @@ -96,16 +80,18 @@ func BenchmarkTemplate(b *testing.B) { if err != nil { b.Fatal(err) } + req.Header.Set("User-Agent", testUserAgent) - _, _, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) + _, _, _, _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) } } -func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) { +func testTemplate1001(a *assert.Assertion, t *testing.T, template *waf.WAF) { req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=onmousedown%3D123", nil) if err != nil { t.Fatal(err) } + req.Header.Set("User-Agent", testUserAgent) _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) if err != nil { t.Fatal(err) @@ -116,7 +102,7 @@ func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) { } } -func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) { +func testTemplate1002(a *assert.Assertion, t *testing.T, template *waf.WAF) { req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=eval%28", nil) if err != nil { t.Fatal(err) @@ -131,7 +117,7 @@ func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) { } } -func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) { +func testTemplate1003(a *assert.Assertion, t *testing.T, template *waf.WAF) { req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=