WAF SQL注入和XSS检测增加缓存/优化部分WAF相关测试用例

This commit is contained in:
GoEdgeLab
2023-12-09 11:46:50 +08:00
parent 86ab242f68
commit 5a7247b8be
12 changed files with 325 additions and 567 deletions

View File

@@ -1,6 +1,7 @@
package waf package waf_test
import ( import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/maps"
@@ -11,22 +12,22 @@ import (
func TestFindActionInstance(t *testing.T) { func TestFindActionInstance(t *testing.T) {
a := assert.NewAssertion(t) a := assert.NewAssertion(t)
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil)) t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil))
t.Logf("ActionBlock: %p", FindActionInstance(ActionBlock, nil)) t.Logf("ActionBlock: %p", waf.FindActionInstance(waf.ActionBlock, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil)) t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil))
t.Logf("ActionGoGroup: %p", FindActionInstance(ActionGoGroup, nil)) t.Logf("ActionGoGroup: %p", waf.FindActionInstance(waf.ActionGoGroup, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil))
t.Logf("ActionGoSet: %p", FindActionInstance(ActionGoSet, nil)) t.Logf("ActionGoSet: %p", waf.FindActionInstance(waf.ActionGoSet, nil))
t.Logf("ActionGoSet: %#v", FindActionInstance(ActionGoSet, maps.Map{"groupId": "a", "setId": "b"})) t.Logf("ActionGoSet: %#v", waf.FindActionInstance(waf.ActionGoSet, maps.Map{"groupId": "a", "setId": "b"}))
a.IsTrue(FindActionInstance(ActionGoSet, nil) != FindActionInstance(ActionGoSet, nil)) a.IsTrue(waf.FindActionInstance(waf.ActionGoSet, nil) != waf.FindActionInstance(waf.ActionGoSet, nil))
} }
func TestFindActionInstance_Options(t *testing.T) { func TestFindActionInstance_Options(t *testing.T) {
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{})) //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{})) //t.Logf("%p", FindActionInstance(ActionBlock, maps.Map{}))
//logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t) //logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{}), t)
logs.PrintAsJSON(FindActionInstance(ActionBlock, maps.Map{ logs.PrintAsJSON(waf.FindActionInstance(waf.ActionBlock, maps.Map{
"timeout": 3600, "timeout": 3600,
}), t) }), t)
} }
@@ -34,6 +35,6 @@ func TestFindActionInstance_Options(t *testing.T) {
func BenchmarkFindActionInstance(b *testing.B) { func BenchmarkFindActionInstance(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
FindActionInstance(ActionGoSet, nil) waf.FindActionInstance(waf.ActionGoSet, nil)
} }
} }

View File

@@ -10,11 +10,43 @@ package injectionutils
*/ */
import "C" import "C"
import ( import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/cespare/xxhash/v2"
"net/url" "net/url"
"strconv"
"strings" "strings"
"unsafe" "unsafe"
) )
// DetectSQLInjectionCache detect sql injection in string with cache
func DetectSQLInjectionCache(input string, cacheLife utils.CacheLife) bool {
var l = len(input)
if l == 0 {
return false
}
if cacheLife <= 0 || l < 128 || l > utils.MaxCacheDataSize {
return DetectSQLInjection(input)
}
var hash = xxhash.Sum64String(input)
var key = "WAF@SQLI@" + strconv.FormatUint(hash, 10)
var item = utils.SharedCache.Read(key)
if item != nil {
return item.Value == 1
}
var result = DetectSQLInjection(input)
if result {
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
return result
}
// DetectSQLInjection detect sql injection in string // DetectSQLInjection detect sql injection in string
func DetectSQLInjection(input string) bool { func DetectSQLInjection(input string) bool {
if len(input) == 0 { if len(input) == 0 {
@@ -26,7 +58,7 @@ func DetectSQLInjection(input string) bool {
} }
// 兼容 /PATH?URI // 兼容 /PATH?URI
if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 4096 { if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 {
var argsIndex = strings.Index(input, "?") var argsIndex = strings.Index(input, "?")
if argsIndex > 0 { if argsIndex > 0 {
var args = input[argsIndex+1:] var args = input[argsIndex+1:]

View File

@@ -4,8 +4,12 @@ package injectionutils_test
import ( import (
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils" "github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/rands"
"github.com/iwind/TeaGo/types"
"runtime" "runtime"
"strings"
"testing" "testing"
) )
@@ -23,6 +27,7 @@ func TestDetectSQLInjection(t *testing.T) {
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1")) a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123 or 1=1"))
a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1")) a.IsTrue(injectionutils.DetectSQLInjection("/sql/injection?id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/sql/injection?id=123%20or%201=1")) a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/sql/injection?id=123%20or%201=1"))
a.IsTrue(injectionutils.DetectSQLInjection("https://example.com/' or 1=1"))
} }
func BenchmarkDetectSQLInjection(b *testing.B) { func BenchmarkDetectSQLInjection(b *testing.B) {
@@ -45,6 +50,71 @@ func BenchmarkDetectSQLInjection_URL(b *testing.B) {
}) })
} }
func BenchmarkDetectSQLInjection_Normal_Small(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=1234")
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Small(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/sql/injection?id=" + types.String(rands.Int64()%10000))
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Middle(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("/search?q=libinjection+fingerprint&newwindow=1&sca_esv=589290862&sxsrf=AMwHvKnxuLoejn2XlNniffC12E_xc35M7Q%3A1702090118361&ei=htvzzebfFZfo1e8PvLGggAk&ved=0ahUKEwjTsYmnq4GDAxUWdPOHHbwkCJAQ4ddDCBA&uact=5&oq=libinjection+fingerprint&gs_lp=Egxnd3Mtd2l6LXNlcnAiGIxpYmluamVjdGlvbmBmaW5nKXJwcmludTIEEAAYHjIGVAAYCBgeSiEaUPkRWKFZcAJ4AZABAJgBHgGgAfoEqgwDMC40uAEGyAEA-AEBwgIKEAFYTxjWMuiwA-IDBBgAVteIBgGQBgI&sclient=gws-wiz-serp#ip=1")
}
})
}
func BenchmarkDetectSQLInjection_URL_Normal_Small_Cache(b *testing.B) {
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjectionCache("/sql/injection?id="+types.String(rands.Int64()%10000), utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectSQLInjection_Normal_Large(b *testing.B) {
runtime.GOMAXPROCS(4)
var s = strings.Repeat("A", 512)
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjection("a/sql/injection?id=" + types.String(rands.Int64()%10000) + "&s=" + s)
}
})
}
func BenchmarkDetectSQLInjection_Normal_Large_Cache(b *testing.B) {
runtime.GOMAXPROCS(4)
var s = strings.Repeat("A", 512)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectSQLInjectionCache("a/sql/injection?id="+types.String(rands.Int64()%10000)+"&s="+s, utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectSQLInjection_URL_Unescape(b *testing.B) { func BenchmarkDetectSQLInjection_URL_Unescape(b *testing.B) {
runtime.GOMAXPROCS(4) runtime.GOMAXPROCS(4)

View File

@@ -10,11 +10,42 @@ package injectionutils
*/ */
import "C" import "C"
import ( import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/cespare/xxhash/v2"
"net/url" "net/url"
"strconv"
"strings" "strings"
"unsafe" "unsafe"
) )
func DetectXSSCache(input string, cacheLife utils.CacheLife) bool {
var l = len(input)
if l == 0 {
return false
}
if cacheLife <= 0 || l < 512 || l > utils.MaxCacheDataSize {
return DetectXSS(input)
}
var hash = xxhash.Sum64String(input)
var key = "WAF@XSS@" + strconv.FormatUint(hash, 10)
var item = utils.SharedCache.Read(key)
if item != nil {
return item.Value == 1
}
var result = DetectXSS(input)
if result {
utils.SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else {
utils.SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
}
return result
}
// DetectXSS detect XSS in string // DetectXSS detect XSS in string
func DetectXSS(input string) bool { func DetectXSS(input string) bool {
if len(input) == 0 { if len(input) == 0 {
@@ -26,7 +57,7 @@ func DetectXSS(input string) bool {
} }
// 兼容 /PATH?URI // 兼容 /PATH?URI
if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 4096 { if (input[0] == '/' || strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")) && len(input) < 1024 {
var argsIndex = strings.Index(input, "?") var argsIndex = strings.Index(input, "?")
if argsIndex > 0 { if argsIndex > 0 {
var args = input[argsIndex+1:] var args = input[argsIndex+1:]

View File

@@ -4,6 +4,7 @@ package injectionutils_test
import ( import (
"github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils" "github.com/TeaOSLab/EdgeNode/internal/waf/injectionutils"
"github.com/TeaOSLab/EdgeNode/internal/waf/utils"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"runtime" "runtime"
"testing" "testing"
@@ -25,7 +26,10 @@ func TestDetectXSS(t *testing.T) {
} }
func BenchmarkDetectXSS_MISS(b *testing.B) { func BenchmarkDetectXSS_MISS(b *testing.B) {
b.Log(injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")) var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
if result {
b.Fatal("'result' should not be 'true'")
}
runtime.GOMAXPROCS(4) runtime.GOMAXPROCS(4)
@@ -36,8 +40,26 @@ func BenchmarkDetectXSS_MISS(b *testing.B) {
}) })
} }
func BenchmarkDetectXSS_MISS_Cache(b *testing.B) {
var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span></body></html>")
if result {
b.Fatal("'result' should not be 'true'")
}
runtime.GOMAXPROCS(4)
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_ = injectionutils.DetectXSSCache("<html><body><span>RequestId: 1234567890</span></body></html>", utils.CacheMiddleLife)
}
})
}
func BenchmarkDetectXSS_HIT(b *testing.B) { func BenchmarkDetectXSS_HIT(b *testing.B) {
b.Log(injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")) var result = injectionutils.DetectXSS("<html><body><span>RequestId: 1234567890</span><script src=\"\"></script></body></html>")
if !result {
b.Fatal("'result' should not be 'false'")
}
runtime.GOMAXPROCS(4) runtime.GOMAXPROCS(4)

View File

@@ -22,6 +22,7 @@ import (
"net" "net"
"reflect" "reflect"
"regexp" "regexp"
"sort"
"strings" "strings"
) )
@@ -57,7 +58,7 @@ type Rule struct {
floatValue float64 floatValue float64
reg *re.Regexp reg *re.Regexp
regCacheLife utils.CacheLife cacheLife utils.CacheLife
} }
func NewRule() *Rule { func NewRule() *Rule {
@@ -93,6 +94,9 @@ func (this *Rule) Init() error {
} }
} }
} }
if this.Operator == RuleOperatorContainsAnyWord || this.Operator == RuleOperatorContainsAllWords || this.Operator == RuleOperatorNotContainsAnyWord {
sort.Strings(this.stringValues)
}
} }
case RuleOperatorMatch: case RuleOperatorMatch:
var v = this.Value var v = this.Value
@@ -166,7 +170,7 @@ func (this *Rule) Init() error {
this.singleCheckpoint = checkpoint this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority() this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife() this.cacheLife = checkpoint.CacheLife()
} else { } else {
var checkpoint = checkpoints.FindCheckpoint(prefix) var checkpoint = checkpoints.FindCheckpoint(prefix)
if checkpoint == nil { if checkpoint == nil {
@@ -176,7 +180,7 @@ func (this *Rule) Init() error {
this.singleCheckpoint = checkpoint this.singleCheckpoint = checkpoint
this.Priority = checkpoint.Priority() this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife() this.cacheLife = checkpoint.CacheLife()
} }
return nil return nil
@@ -195,8 +199,8 @@ func (this *Rule) Init() error {
this.multipleCheckpoints[prefix] = checkpoint this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority() this.Priority = checkpoint.Priority()
if this.regCacheLife <= 0 || checkpoint.CacheLife() < this.regCacheLife { if this.cacheLife <= 0 || checkpoint.CacheLife() < this.cacheLife {
this.regCacheLife = checkpoint.CacheLife() this.cacheLife = checkpoint.CacheLife()
} }
} }
} else { } else {
@@ -208,7 +212,7 @@ func (this *Rule) Init() error {
this.multipleCheckpoints[prefix] = checkpoint this.multipleCheckpoints[prefix] = checkpoint
this.Priority = checkpoint.Priority() this.Priority = checkpoint.Priority()
this.regCacheLife = checkpoint.CacheLife() this.cacheLife = checkpoint.CacheLife()
} }
} }
@@ -408,7 +412,7 @@ func (this *Rule) Test(value any) bool {
stringList, ok := value.([]string) stringList, ok := value.([]string)
if ok { if ok {
for _, s := range stringList { for _, s := range stringList {
if utils.MatchStringCache(this.reg, s, this.regCacheLife) { if utils.MatchStringCache(this.reg, s, this.cacheLife) {
return true return true
} }
} }
@@ -419,7 +423,7 @@ func (this *Rule) Test(value any) bool {
byteSlices, ok := value.([][]byte) byteSlices, ok := value.([][]byte)
if ok { if ok {
for _, byteSlice := range byteSlices { for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) { if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) {
return true return true
} }
} }
@@ -429,11 +433,11 @@ func (this *Rule) Test(value any) bool {
// bytes // bytes
byteSlice, ok := value.([]byte) byteSlice, ok := value.([]byte)
if ok { if ok {
return utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) return utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife)
} }
// string // string
return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife) return utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife)
case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch: case RuleOperatorNotMatch, RuleOperatorWildcardNotMatch:
if value == nil { if value == nil {
value = "" value = ""
@@ -441,7 +445,7 @@ func (this *Rule) Test(value any) bool {
stringList, ok := value.([]string) stringList, ok := value.([]string)
if ok { if ok {
for _, s := range stringList { for _, s := range stringList {
if utils.MatchStringCache(this.reg, s, this.regCacheLife) { if utils.MatchStringCache(this.reg, s, this.cacheLife) {
return false return false
} }
} }
@@ -452,7 +456,7 @@ func (this *Rule) Test(value any) bool {
byteSlices, ok := value.([][]byte) byteSlices, ok := value.([][]byte)
if ok { if ok {
for _, byteSlice := range byteSlices { for _, byteSlice := range byteSlices {
if utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) { if utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife) {
return false return false
} }
} }
@@ -462,10 +466,10 @@ func (this *Rule) Test(value any) bool {
// bytes // bytes
byteSlice, ok := value.([]byte) byteSlice, ok := value.([]byte)
if ok { if ok {
return !utils.MatchBytesCache(this.reg, byteSlice, this.regCacheLife) return !utils.MatchBytesCache(this.reg, byteSlice, this.cacheLife)
} }
return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.regCacheLife) return !utils.MatchStringCache(this.reg, this.stringifyValue(value), this.cacheLife)
case RuleOperatorContains: case RuleOperatorContains:
if types.IsSlice(value) { if types.IsSlice(value) {
_, isBytes := value.([]byte) _, isBytes := value.([]byte)
@@ -575,20 +579,20 @@ func (this *Rule) Test(value any) bool {
switch xValue := value.(type) { switch xValue := value.(type) {
case []string: case []string:
for _, v := range xValue { for _, v := range xValue {
if injectionutils.DetectSQLInjection(v) { if injectionutils.DetectSQLInjectionCache(v, this.cacheLife) {
return true return true
} }
} }
return false return false
case [][]byte: case [][]byte:
for _, v := range xValue { for _, v := range xValue {
if injectionutils.DetectSQLInjection(string(v)) { if injectionutils.DetectSQLInjectionCache(string(v), this.cacheLife) {
return true return true
} }
} }
return false return false
default: default:
return injectionutils.DetectSQLInjection(this.stringifyValue(value)) return injectionutils.DetectSQLInjectionCache(this.stringifyValue(value), this.cacheLife)
} }
case RuleOperatorContainsXSS: case RuleOperatorContainsXSS:
if value == nil { if value == nil {
@@ -597,20 +601,20 @@ func (this *Rule) Test(value any) bool {
switch xValue := value.(type) { switch xValue := value.(type) {
case []string: case []string:
for _, v := range xValue { for _, v := range xValue {
if injectionutils.DetectXSS(v) { if injectionutils.DetectXSSCache(v, this.cacheLife) {
return true return true
} }
} }
return false return false
case [][]byte: case [][]byte:
for _, v := range xValue { for _, v := range xValue {
if injectionutils.DetectXSS(string(v)) { if injectionutils.DetectXSSCache(string(v), this.cacheLife) {
return true return true
} }
} }
return false return false
default: default:
return injectionutils.DetectXSS(this.stringifyValue(value)) return injectionutils.DetectXSSCache(this.stringifyValue(value), this.cacheLife)
} }
case RuleOperatorContainsBinary: case RuleOperatorContainsBinary:
data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value)) data, _ := base64.StdEncoding.DecodeString(this.stringifyValue(this.Value))

View File

@@ -1,7 +1,8 @@
package waf package waf_test
import ( import (
"bytes" "bytes"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/cespare/xxhash" "github.com/cespare/xxhash"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
@@ -12,18 +13,18 @@ import (
) )
func TestRuleSet_MatchRequest(t *testing.T) { func TestRuleSet_MatchRequest(t *testing.T) {
set := NewRuleSet() var set = waf.NewRuleSet()
set.Connector = RuleConnectorAnd set.Connector = waf.RuleConnectorAnd
set.Rules = []*Rule{ set.Rules = []*waf.Rule{
{ {
Param: "${arg.name}", Param: "${arg.name}",
Operator: RuleOperatorEqString, Operator: waf.RuleOperatorEqString,
Value: "lu", Value: "lu",
}, },
{ {
Param: "${arg.age}", Param: "${arg.age}",
Operator: RuleOperatorEq, Operator: waf.RuleOperatorEq,
Value: "20", Value: "20",
}, },
} }
@@ -42,20 +43,20 @@ func TestRuleSet_MatchRequest(t *testing.T) {
} }
func TestRuleSet_MatchRequest2(t *testing.T) { func TestRuleSet_MatchRequest2(t *testing.T) {
a := assert.NewAssertion(t) var a = assert.NewAssertion(t)
set := NewRuleSet() var set = waf.NewRuleSet()
set.Connector = RuleConnectorOr set.Connector = waf.RuleConnectorOr
set.Rules = []*Rule{ set.Rules = []*waf.Rule{
{ {
Param: "${arg.name}", Param: "${arg.name}",
Operator: RuleOperatorEqString, Operator: waf.RuleOperatorEqString,
Value: "lu", Value: "lu",
}, },
{ {
Param: "${arg.age}", Param: "${arg.age}",
Operator: RuleOperatorEq, Operator: waf.RuleOperatorEq,
Value: "21", Value: "21",
}, },
} }
@@ -76,28 +77,28 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
func BenchmarkRuleSet_MatchRequest(b *testing.B) { func BenchmarkRuleSet_MatchRequest(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)
set := NewRuleSet() var set = waf.NewRuleSet()
set.Connector = RuleConnectorOr set.Connector = waf.RuleConnectorOr
set.Rules = []*Rule{ set.Rules = []*waf.Rule{
{ {
Param: "${requestAll}", Param: "${requestAll}",
Operator: RuleOperatorMatch, Operator: waf.RuleOperatorMatch,
Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`, Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`,
}, },
{ {
Param: "${requestAll}", Param: "${requestAll}",
Operator: RuleOperatorMatch, Operator: waf.RuleOperatorMatch,
Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`,
}, },
{ {
Param: "${arg.name}", Param: "${arg.name}",
Operator: RuleOperatorEqString, Operator: waf.RuleOperatorEqString,
Value: "lu", Value: "lu",
}, },
{ {
Param: "${arg.age}", Param: "${arg.age}",
Operator: RuleOperatorEq, Operator: waf.RuleOperatorEq,
Value: "21", Value: "21",
}, },
} }
@@ -120,13 +121,13 @@ func BenchmarkRuleSet_MatchRequest(b *testing.B) {
func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) { func BenchmarkRuleSet_MatchRequest_Regexp(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)
set := NewRuleSet() var set = waf.NewRuleSet()
set.Connector = RuleConnectorOr set.Connector = waf.RuleConnectorOr
set.Rules = []*Rule{ set.Rules = []*waf.Rule{
{ {
Param: "${requestBody}", Param: "${requestBody}",
Operator: RuleOperatorMatch, Operator: waf.RuleOperatorMatch,
Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`,
IsCaseInsensitive: false, IsCaseInsensitive: false,
}, },

View File

@@ -1,434 +1,40 @@
package waf package waf
func Template() *WAF { import (
waf := NewWAF() "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
waf.Id = 0 "github.com/TeaOSLab/EdgeNode/internal/waf/utils"
waf.IsOn = true )
// xss func Template() (*WAF, error) {
{ var config = firewallconfigs.HTTPFirewallTemplate()
group := NewRuleGroup() if config.Inbound != nil {
config.Inbound.IsOn = true
}
for _, group := range config.AllRuleGroups() {
if group.Code == "cc" || group.Code == "cc2" {
continue
}
group.IsOn = true group.IsOn = true
group.IsInbound = true
group.Name = "XSS"
group.Code = "xss"
group.Description = "防跨站脚本攻击Cross Site Scripting"
{ for _, set := range group.Sets {
set := NewRuleSet()
set.IsOn = true set.IsOn = true
set.Name = "Javascript事件" }
set.Code = "1001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `(onmouseover|onmousemove|onmousedown|onmouseup|onerror|onload|onclick|ondblclick|onkeydown|onkeyup|onkeypress)\s*=`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
} }
{ instance, err := SharedWAFManager.ConvertWAF(config)
set := NewRuleSet() if err != nil {
set.IsOn = true return nil, err
set.Name = "Javascript函数"
set.Code = "1002"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `(alert|eval|prompt|confirm)\s*\(`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
} }
{ for _, group := range instance.Inbound {
set := NewRuleSet() for _, set := range group.RuleSets {
set.IsOn = true for _, rule := range set.Rules {
set.Name = "HTML标签" rule.cacheLife = utils.CacheDisabled // for performance test
set.Code = "1003" _ = rule
set.Connector = RuleConnectorOr }
set.AddAction(ActionBlock, nil) }
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `<(script|iframe|link)`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
} }
waf.AddRuleGroup(group) return instance, nil
}
// upload
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "文件上传"
group.Code = "upload"
group.Description = "防止上传可执行脚本文件到服务器"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "上传文件扩展名"
set.Code = "2001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestUpload.ext}",
Operator: RuleOperatorMatch,
Value: `\.(php|jsp|aspx|asp|exe|asa|rb|py)\b`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// web shell
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "Web Shell"
group.Code = "webShell"
group.Description = "防止远程执行服务器命令"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "Web Shell"
set.Code = "3001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\b(eval|system|exec|execute|passthru|shell_exec|phpinfo)\s*\(`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// command injection
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "命令注入"
group.Code = "commandInjection"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "命令注入"
set.Code = "4001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${requestBody}",
Operator: RuleOperatorMatch,
Value: `\b(pwd|ls|ll|whoami|id|net\s+user)\s*$`, // TODO more keywords here
IsCaseInsensitive: false,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// path traversal
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "路径穿越"
group.Code = "pathTraversal"
group.Description = "防止读取网站目录之外的其他系统文件"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "路径穿越"
set.Code = "5001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestURI}",
Operator: RuleOperatorMatch,
Value: `((\.+)(/+)){2,}`, // TODO more keywords here
IsCaseInsensitive: false,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// special dirs
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "特殊目录"
group.Code = "denyDirs"
group.Description = "防止通过Web访问到一些特殊目录"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "特殊目录"
set.Code = "6001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestPath}",
Operator: RuleOperatorMatch,
Value: `/\.(git|svn|htaccess|idea)\b`, // TODO more keywords here
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// sql injection
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "SQL注入"
group.Code = "sqlInjection"
group.Description = "防止SQL注入漏洞"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "Union SQL Injection"
set.Code = "7001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `union[\s/\*]+select`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL注释"
set.Code = "7002"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `/\*(!|\x00)`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL条件"
set.Code = "7003"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\s(and|or|rlike)\s+(if|updatexml)\s*\(`,
IsCaseInsensitive: true,
})
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\s+(and|or|rlike)\s+(select|case)\s+`,
IsCaseInsensitive: true,
})
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\s+(and|or|procedure)\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+(\s|$|--|#)`,
IsCaseInsensitive: true,
})
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `\(\s*case\s+when\s+[\w\p{L}]+\s*=\s*[\w\p{L}]+\s+then\s+`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL函数"
set.Code = "7004"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `(updatexml|extractvalue|ascii|ord|char|chr|count|concat|rand|floor|substr|length|len|user|database|benchmark|analyse)\s*\(`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
{
set := NewRuleSet()
set.IsOn = true
set.Name = "SQL附加语句"
set.Code = "7005"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${requestAll}",
Operator: RuleOperatorMatch,
Value: `;\s*(declare|use|drop|create|exec|delete|update|insert)\s`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// bot
{
group := NewRuleGroup()
group.IsOn = false
group.IsInbound = true
group.Name = "网络爬虫"
group.Code = "bot"
group.Description = "禁止一些网络爬虫"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "常见网络爬虫"
set.Code = "20001"
set.Connector = RuleConnectorOr
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${userAgent}",
Operator: RuleOperatorMatch,
Value: `Googlebot|AdsBot|bingbot|BingPreview|facebookexternalhit|Slurp|Sogou|proximic|Baiduspider|yandex|twitterbot|spider|python`,
IsCaseInsensitive: true,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// cc
{
group := NewRuleGroup()
group.IsOn = false
group.IsInbound = true
group.Name = "CC攻击"
group.Description = "Challenge Collapsar防止短时间大量请求涌入请谨慎开启和设置"
group.Code = "cc2"
{
set := NewRuleSet()
set.IsOn = true
set.Name = "CC请求数"
set.Description = "限制单IP在一定时间内的请求数"
set.Code = "8001"
set.Connector = RuleConnectorAnd
set.AddAction(ActionBlock, nil)
set.AddRule(&Rule{
Param: "${cc2}",
Operator: RuleOperatorGt,
Value: "1000",
CheckpointOptions: map[string]interface{}{
"period": "60",
"threshold": 1000,
"keys": []string{"${remoteAddr}", "${requestPath}"},
},
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `127.0.0.1/8`,
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `192.168.0.1/16`,
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `10.0.0.1/8`,
IsCaseInsensitive: false,
})
set.AddRule(&Rule{
Param: "${remoteAddr}",
Operator: RuleOperatorNotIPRange,
Value: `172.16.0.1/12`,
IsCaseInsensitive: false,
})
group.AddRuleSet(set)
}
waf.AddRuleGroup(group)
}
// custom
{
group := NewRuleGroup()
group.IsOn = true
group.IsInbound = true
group.Name = "自定义规则分组"
group.Description = "我的自定义规则分组,可以将自定义的规则放在这个分组下"
group.Code = "custom"
waf.AddRuleGroup(group)
}
return waf
} }

View File

@@ -1,12 +1,15 @@
package waf package waf_test
import ( import (
"bytes" "bytes"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
"github.com/iwind/TeaGo/types"
"math/rand"
"mime/multipart" "mime/multipart"
"net/http" "net/http"
"net/url" "net/url"
@@ -15,34 +18,26 @@ import (
"time" "time"
) )
const testUserAgent = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_0_0) AppleWebKit/500.00 (KHTML, like Gecko) Chrome/100.0.0.0"
func Test_Template(t *testing.T) { func Test_Template(t *testing.T) {
var a = assert.NewAssertion(t) var a = assert.NewAssertion(t)
var waf = Template() wafInstance, err := waf.Template()
for _, group := range waf.Inbound {
group.IsOn = true
for _, set := range group.RuleSets {
set.IsOn = true
}
}
err := waf.Init()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
testTemplate1001(a, t, waf) testTemplate1001(a, t, wafInstance)
testTemplate1002(a, t, waf) testTemplate1002(a, t, wafInstance)
testTemplate1003(a, t, waf) testTemplate1003(a, t, wafInstance)
testTemplate2001(a, t, waf) testTemplate2001(a, t, wafInstance)
testTemplate3001(a, t, waf) testTemplate3001(a, t, wafInstance)
testTemplate4001(a, t, waf) testTemplate4001(a, t, wafInstance)
testTemplate5001(a, t, waf) testTemplate5001(a, t, wafInstance)
testTemplate6001(a, t, waf) testTemplate6001(a, t, wafInstance)
testTemplate7001(a, t, waf) testTemplate7001(a, t, wafInstance)
testTemplate20001(a, t, waf) testTemplate20001(a, t, wafInstance)
} }
func Test_Template2(t *testing.T) { func Test_Template2(t *testing.T) {
@@ -52,14 +47,13 @@ func Test_Template2(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
waf := Template() wafInstance, err := waf.Template()
var errs = waf.Init() if err != nil {
if len(errs) > 0 { t.Fatal(err)
t.Fatal(errs[0])
} }
now := time.Now() now := time.Now()
goNext, _, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -74,17 +68,7 @@ func Test_Template2(t *testing.T) {
} }
func BenchmarkTemplate(b *testing.B) { func BenchmarkTemplate(b *testing.B) {
var waf = Template() wafInstance, err := waf.Template()
for _, group := range waf.Inbound {
group.IsOn = true
for _, set := range group.RuleSets {
set.IsOn = true
}
}
err := waf.Init()
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -96,16 +80,18 @@ func BenchmarkTemplate(b *testing.B) {
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent)
_, _, _, _, _ = waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) _, _, _, _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
} }
} }
func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate1001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=onmousedown%3D123", nil) req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=onmousedown%3D123", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -116,7 +102,7 @@ func testTemplate1001(a *assert.Assertion, t *testing.T, template *WAF) {
} }
} }
func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate1002(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=eval%28", nil) req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=eval%28", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -131,7 +117,7 @@ func testTemplate1002(a *assert.Assertion, t *testing.T, template *WAF) {
} }
} }
func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate1003(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=<script src=\"123.js\">", nil) req, err := http.NewRequest(http.MethodGet, "http://example.com/index.php?id=<script src=\"123.js\">", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -146,7 +132,7 @@ func testTemplate1003(a *assert.Assertion, t *testing.T, template *WAF) {
} }
} }
func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate2001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
body := bytes.NewBuffer([]byte{}) body := bytes.NewBuffer([]byte{})
writer := multipart.NewWriter(body) writer := multipart.NewWriter(body)
@@ -212,7 +198,7 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *WAF) {
} }
} }
func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate3001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req, err := http.NewRequest(http.MethodPost, "http://example.com/index.php?exec1+(", bytes.NewReader([]byte("exec('rm -rf /hello');"))) req, err := http.NewRequest(http.MethodPost, "http://example.com/index.php?exec1+(", bytes.NewReader([]byte("exec('rm -rf /hello');")))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -227,7 +213,7 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *WAF) {
} }
} }
func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate4001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req, err := http.NewRequest(http.MethodPost, "http://example.com/index.php?whoami", nil) req, err := http.NewRequest(http.MethodPost, "http://example.com/index.php?whoami", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -242,7 +228,7 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *WAF) {
} }
} }
func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
{ {
req, err := http.NewRequest(http.MethodPost, "http://example.com/.././..", nil) req, err := http.NewRequest(http.MethodPost, "http://example.com/.././..", nil)
if err != nil { if err != nil {
@@ -274,12 +260,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *WAF) {
} }
} }
func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
{ {
req, err := http.NewRequest(http.MethodPost, "http://example.com/.svn/123.txt", nil) req, err := http.NewRequest(http.MethodPost, "http://example.com/.svn/123.txt", nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -299,11 +286,11 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *WAF) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNil(result) a.IsNotNil(result)
} }
} }
func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate7001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
for _, id := range []string{ for _, id := range []string{
"union select", "union select",
" and if(", " and if(",
@@ -311,13 +298,14 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) {
" and select ", " and select ",
" and id=123 ", " and id=123 ",
"(case when a=1 then ", "(case when a=1 then ",
"updatexml (", " and updatexml (",
"; delete from table", "; delete from table",
} { } {
req, err := http.NewRequest(http.MethodPost, "http://example.com/?id="+url.QueryEscape(id), nil) req, err := http.NewRequest(http.MethodPost, "http://example.com/?id="+url.QueryEscape(id), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) _, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -332,11 +320,9 @@ func testTemplate7001(a *assert.Assertion, t *testing.T, template *WAF) {
} }
func TestTemplateSQLInjection(t *testing.T) { func TestTemplateSQLInjection(t *testing.T) {
var template = Template() template, err := waf.Template()
errs := template.Init() if err != nil {
if len(errs) > 0 { t.Fatal(err)
t.Fatal(errs)
return
} }
var group = template.FindRuleGroupWithCode("sqlInjection") var group = template.FindRuleGroupWithCode("sqlInjection")
if group == nil { if group == nil {
@@ -354,6 +340,7 @@ func TestTemplateSQLInjection(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent)
_, _, result, err := group.MatchRequest(requests.NewTestRequest(req)) _, _, result, err := group.MatchRequest(requests.NewTestRequest(req))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -364,11 +351,9 @@ func TestTemplateSQLInjection(t *testing.T) {
} }
func BenchmarkTemplateSQLInjection(b *testing.B) { func BenchmarkTemplateSQLInjection(b *testing.B) {
var template = Template() template, err := waf.Template()
errs := template.Init() if err != nil {
if len(errs) > 0 { b.Fatal(err)
b.Fatal(errs)
return
} }
var group = template.FindRuleGroupWithCode("sqlInjection") var group = template.FindRuleGroupWithCode("sqlInjection")
if group == nil { if group == nil {
@@ -380,10 +365,12 @@ func BenchmarkTemplateSQLInjection(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234", nil) req, err := http.NewRequest(http.MethodPost, "https://example.com/?id=1234" + types.String(rand.Int()%10000), nil)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent)
_, _, result, err := group.MatchRequest(requests.NewTestRequest(req)) _, _, result, err := group.MatchRequest(requests.NewTestRequest(req))
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@@ -393,7 +380,7 @@ func BenchmarkTemplateSQLInjection(b *testing.B) {
}) })
} }
func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) { func testTemplate20001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
// enable bot rule set // enable bot rule set
for _, g := range template.Inbound { for _, g := range template.Inbound {
if g.Code == "bot" { if g.Code == "bot" {
@@ -404,7 +391,7 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *WAF) {
for _, bot := range []string{ for _, bot := range []string{
"Googlebot", "Googlebot",
"AdsBot", "AdsBot-Google",
"bingbot", "bingbot",
"BingPreview", "BingPreview",
"facebookexternalhit", "facebookexternalhit",

View File

@@ -7,13 +7,13 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/cachehits" "github.com/TeaOSLab/EdgeNode/internal/utils/cachehits"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/cespare/xxhash" "github.com/cespare/xxhash/v2"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
stringutil "github.com/iwind/TeaGo/utils/string" stringutil "github.com/iwind/TeaGo/utils/string"
"strconv" "strconv"
) )
var cache = ttlcache.NewCache[int8]() var SharedCache = ttlcache.NewCache[int8]()
var cacheHits *cachehits.Stat var cacheHits *cachehits.Stat
func init() { func init() {
@@ -24,7 +24,7 @@ func init() {
} }
const ( const (
maxCacheDataSize = 1024 MaxCacheDataSize = 1024
) )
type CacheLife = int64 type CacheLife = int64
@@ -45,22 +45,22 @@ func MatchStringCache(regex *re.Regexp, s string, cacheLife CacheLife) bool {
var regIdString = regex.IdString() var regIdString = regex.IdString()
// 如果长度超过一定数量,大概率是不能重用的 // 如果长度超过一定数量,大概率是不能重用的
if cacheLife <= 0 || len(s) > maxCacheDataSize || !cacheHits.IsGood(regIdString) { if cacheLife <= 0 || len(s) > MaxCacheDataSize || !cacheHits.IsGood(regIdString) {
return regex.MatchString(s) return regex.MatchString(s)
} }
var hash = xxhash.Sum64String(s) var hash = xxhash.Sum64String(s)
var key = regIdString + "@" + strconv.FormatUint(hash, 10) var key = regIdString + "@" + strconv.FormatUint(hash, 10)
var item = cache.Read(key) var item = SharedCache.Read(key)
if item != nil { if item != nil {
cacheHits.IncreaseHit(regIdString) cacheHits.IncreaseHit(regIdString)
return item.Value == 1 return item.Value == 1
} }
var b = regex.MatchString(s) var b = regex.MatchString(s)
if b { if b {
cache.Write(key, 1, fasttime.Now().Unix()+cacheLife) SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else { } else {
cache.Write(key, 0, fasttime.Now().Unix()+cacheLife) SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
} }
cacheHits.IncreaseCached(regIdString) cacheHits.IncreaseCached(regIdString)
return b return b
@@ -75,22 +75,22 @@ func MatchBytesCache(regex *re.Regexp, byteSlice []byte, cacheLife CacheLife) bo
var regIdString = regex.IdString() var regIdString = regex.IdString()
// 如果长度超过一定数量,大概率是不能重用的 // 如果长度超过一定数量,大概率是不能重用的
if cacheLife <= 0 || len(byteSlice) > maxCacheDataSize || !cacheHits.IsGood(regIdString) { if cacheLife <= 0 || len(byteSlice) > MaxCacheDataSize || !cacheHits.IsGood(regIdString) {
return regex.Match(byteSlice) return regex.Match(byteSlice)
} }
var hash = xxhash.Sum64(byteSlice) var hash = xxhash.Sum64(byteSlice)
var key = regIdString + "@" + strconv.FormatUint(hash, 10) var key = regIdString + "@" + strconv.FormatUint(hash, 10)
var item = cache.Read(key) var item = SharedCache.Read(key)
if item != nil { if item != nil {
cacheHits.IncreaseHit(regIdString) cacheHits.IncreaseHit(regIdString)
return item.Value == 1 return item.Value == 1
} }
var b = regex.Match(byteSlice) var b = regex.Match(byteSlice)
if b { if b {
cache.Write(key, 1, fasttime.Now().Unix()+cacheLife) SharedCache.Write(key, 1, fasttime.Now().Unix()+cacheLife)
} else { } else {
cache.Write(key, 0, fasttime.Now().Unix()+cacheLife) SharedCache.Write(key, 0, fasttime.Now().Unix()+cacheLife)
} }
cacheHits.IncreaseCached(regIdString) cacheHits.IncreaseCached(regIdString)
return b return b

View File

@@ -402,7 +402,7 @@ func (this *WAF) Stop() {
} }
// MergeTemplate merge with template // MergeTemplate merge with template
func (this *WAF) MergeTemplate() (changedItems []string) { func (this *WAF) MergeTemplate() (changedItems []string, err error) {
changedItems = []string{} changedItems = []string{}
// compare versions // compare versions
@@ -411,7 +411,10 @@ func (this *WAF) MergeTemplate() (changedItems []string) {
} }
this.CreatedVersion = teaconst.Version this.CreatedVersion = teaconst.Version
template := Template() template, err := Template()
if err != nil {
return nil, err
}
groups := []*RuleGroup{} groups := []*RuleGroup{}
groups = append(groups, template.Inbound...) groups = append(groups, template.Inbound...)
groups = append(groups, template.Outbound...) groups = append(groups, template.Outbound...)

View File

@@ -1,7 +1,8 @@
package waf package waf_test
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"net/http" "net/http"
@@ -9,32 +10,32 @@ import (
) )
func TestWAF_MatchRequest(t *testing.T) { func TestWAF_MatchRequest(t *testing.T) {
a := assert.NewAssertion(t) var a = assert.NewAssertion(t)
set := NewRuleSet() var set = waf.NewRuleSet()
set.Name = "Name_Age" set.Name = "Name_Age"
set.Connector = RuleConnectorAnd set.Connector = waf.RuleConnectorAnd
set.Rules = []*Rule{ set.Rules = []*waf.Rule{
{ {
Param: "${arg.name}", Param: "${arg.name}",
Operator: RuleOperatorEqString, Operator: waf.RuleOperatorEqString,
Value: "lu", Value: "lu",
}, },
{ {
Param: "${arg.age}", Param: "${arg.age}",
Operator: RuleOperatorEq, Operator: waf.RuleOperatorEq,
Value: "20", Value: "20",
}, },
} }
set.AddAction(ActionBlock, nil) set.AddAction(waf.ActionBlock, nil)
group := NewRuleGroup() var group = waf.NewRuleGroup()
group.AddRuleSet(set) group.AddRuleSet(set)
group.IsInbound = true group.IsInbound = true
waf := NewWAF() var wafInstance = waf.NewWAF()
waf.AddRuleGroup(group) wafInstance.AddRuleGroup(group)
errs := waf.Init() errs := wafInstance.Init()
if len(errs) > 0 { if len(errs) > 0 {
t.Fatal(errs[0]) t.Fatal(errs[0])
} }
@@ -43,7 +44,7 @@ func TestWAF_MatchRequest(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
goNext, _, _, set, err := waf.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }