From c32959a6c896c382c2f611c8547ea9075c259548 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Sat, 21 Nov 2020 20:44:19 +0800 Subject: [PATCH] =?UTF-8?q?[waf]=E6=94=AF=E6=8C=81=E5=8C=85=E5=90=AB?= =?UTF-8?q?=E4=BA=8C=E8=BF=9B=E5=88=B6=E3=80=81=E4=B8=8D=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E4=BA=8C=E8=BF=9B=E5=88=B6=E7=AD=89=E6=93=8D=E4=BD=9C=E7=AC=A6?= =?UTF-8?q?=EF=BC=9B=E6=94=AF=E6=8C=81=E5=AF=B9=E5=8F=82=E6=95=B0=E5=80=BC?= =?UTF-8?q?=E7=BC=96=E8=A7=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/cache/cache.go | 104 ++++++++++++++++++++++++++++++++++ internal/cache/cache_test.go | 103 +++++++++++++++++++++++++++++++++ internal/cache/item.go | 6 ++ internal/cache/option.go | 12 ++++ internal/cache/piece.go | 57 +++++++++++++++++++ internal/cache/piece_test.go | 52 +++++++++++++++++ internal/cache/utils.go | 7 +++ internal/cache/utils_test.go | 13 +++++ internal/nodes/waf_manager.go | 9 +++ internal/waf/param_filter.go | 8 +++ internal/waf/rule.go | 53 ++++++++++++++++- internal/waf/rule_operator.go | 3 + 12 files changed, 425 insertions(+), 2 deletions(-) create mode 100644 internal/cache/cache.go create mode 100644 internal/cache/cache_test.go create mode 100644 internal/cache/item.go create mode 100644 internal/cache/option.go create mode 100644 internal/cache/piece.go create mode 100644 internal/cache/piece_test.go create mode 100644 internal/cache/utils.go create mode 100644 internal/cache/utils_test.go create mode 100644 internal/waf/param_filter.go diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..4ec1fc2 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,104 @@ +package cache + +import ( + "time" +) + +// TTL缓存 +// 最大的缓存时间为30 * 86400 +// Piece数据结构: +// Piece1 | Piece2 | Piece3 | ... +// [ Item1, Item2, ... | ... +// KeyMap列表数据结构 +// { timestamp1 => [key1, key2, ...] }, ... +type Cache struct { + pieces []*Piece + countPieces uint64 + + gcPieceIndex int +} + +func NewCache(opt ...OptionInterface) *Cache { + countPieces := 128 + for _, option := range opt { + if option == nil { + continue + } + switch o := option.(type) { + case *PiecesOption: + if o.Count > 0 { + countPieces = o.Count + } + } + } + + cache := &Cache{ + countPieces: uint64(countPieces), + } + + for i := 0; i < countPieces; i++ { + cache.pieces = append(cache.pieces, NewPiece()) + } + + // start timer + go func() { + ticker := time.NewTicker(1 * time.Second) + for range ticker.C { + cache.GC() + } + }() + + return cache +} + +func (this *Cache) Add(key string, value interface{}, expiredAt int64) { + currentTimestamp := time.Now().Unix() + if expiredAt <= currentTimestamp { + return + } + + maxExpiredAt := currentTimestamp + 30*86400 + if expiredAt > maxExpiredAt { + expiredAt = maxExpiredAt + } + uint64Key := HashKey([]byte(key)) + pieceIndex := uint64Key % this.countPieces + this.pieces[pieceIndex].Add(uint64Key, &Item{ + value: value, + expiredAt: expiredAt, + }) +} + +func (this *Cache) Read(key string) (value *Item) { + uint64Key := HashKey([]byte(key)) + return this.pieces[uint64Key%this.countPieces].Read(uint64Key) +} + +func (this *Cache) readIntKey(key uint64) (value *Item) { + return this.pieces[key%this.countPieces].Read(key) +} + +func (this *Cache) Delete(key string) { + uint64Key := HashKey([]byte(key)) + this.pieces[uint64Key%this.countPieces].Delete(uint64Key) +} + +func (this *Cache) deleteIntKey(key uint64) { + this.pieces[key%this.countPieces].Delete(key) +} + +func (this *Cache) Count() (count int) { + for _, piece := range this.pieces { + count += piece.Count() + } + return +} + +func (this *Cache) GC() { + this.pieces[this.gcPieceIndex].GC() + newIndex := this.gcPieceIndex + 1 + if newIndex >= int(this.countPieces) { + newIndex = 0 + } + this.gcPieceIndex = newIndex +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..0d88d69 --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,103 @@ +package cache + +import ( + "github.com/iwind/TeaGo/rands" + "runtime" + "strconv" + "testing" + "time" +) + +func TestNewCache(t *testing.T) { + cache := NewCache() + cache.Add("a", 1, time.Now().Unix()+3600) + cache.Add("b", 2, time.Now().Unix()+3601) + cache.Add("a", 1, time.Now().Unix()+3602) + cache.Add("d", 1, time.Now().Unix()+1) + + for _, piece := range cache.pieces { + if len(piece.m) > 0 { + for k, item := range piece.m { + t.Log(k, "=>", item.value, item.expiredAt) + } + } + } + t.Log(cache.Read("a")) + time.Sleep(2 * time.Second) + t.Log(cache.Read("d")) +} + +func BenchmarkCache_Add(b *testing.B) { + runtime.GOMAXPROCS(1) + + cache := NewCache() + for i := 0; i < b.N; i++ { + cache.Add(strconv.Itoa(i), i, time.Now().Unix()+int64(i%1024)) + } +} + +func TestCache_Read(t *testing.T) { + runtime.GOMAXPROCS(1) + + var cache = NewCache(PiecesOption{Count: 512}) + + for i := 0; i < 10_000_000; i++ { + cache.Add("HELLO_WORLD_"+strconv.Itoa(i), i, time.Now().Unix()+int64(i%10240)+1) + } + + /**total := 0 + for _, piece := range cache.pieces { + t.Log(len(piece.m), "keys") + total += len(piece.m) + } + t.Log(total, "total keys")**/ + + before := time.Now() + for i := 0; i < 10_240; i++ { + _ = cache.Read("HELLO_WORLD_" + strconv.Itoa(i)) + } + t.Log(time.Since(before).Seconds()*1000, "ms") +} + +func TestCache_GC(t *testing.T) { + var cache = NewCache(&PiecesOption{Count: 5}) + cache.Add("a", 1, time.Now().Unix()+1) + cache.Add("b", 2, time.Now().Unix()+2) + cache.Add("c", 3, time.Now().Unix()+3) + cache.Add("d", 4, time.Now().Unix()+4) + cache.Add("e", 5, time.Now().Unix()+10) + + go func() { + for i := 0; i < 1000; i++ { + cache.Add("f", 1, time.Now().Unix()+1) + time.Sleep(10 * time.Millisecond) + } + }() + + for i := 0; i < 20; i++ { + cache.GC() + t.Log("items:", cache.Count()) + time.Sleep(1 * time.Second) + } + + t.Log("now:", time.Now().Unix()) + for _, p := range cache.pieces { + for k, v := range p.m { + t.Log(k, v.value, v.expiredAt) + } + } +} + +func TestCache_GC2(t *testing.T) { + runtime.GOMAXPROCS(1) + + cache := NewCache() + for i := 0; i < 1_000_000; i++ { + cache.Add(strconv.Itoa(i), i, time.Now().Unix()+int64(rands.Int(0, 100))) + } + + for i := 0; i < 100; i++ { + t.Log(cache.Count(), "items") + time.Sleep(1 * time.Second) + } +} diff --git a/internal/cache/item.go b/internal/cache/item.go new file mode 100644 index 0000000..276da46 --- /dev/null +++ b/internal/cache/item.go @@ -0,0 +1,6 @@ +package cache + +type Item struct { + value interface{} + expiredAt int64 +} diff --git a/internal/cache/option.go b/internal/cache/option.go new file mode 100644 index 0000000..286810d --- /dev/null +++ b/internal/cache/option.go @@ -0,0 +1,12 @@ +package cache + +type OptionInterface interface { +} + +type PiecesOption struct { + Count int +} + +func NewPiecesOption(count int) *PiecesOption { + return &PiecesOption{Count: count} +} diff --git a/internal/cache/piece.go b/internal/cache/piece.go new file mode 100644 index 0000000..6f186e1 --- /dev/null +++ b/internal/cache/piece.go @@ -0,0 +1,57 @@ +package cache + +import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" + "sync" + "time" +) + +type Piece struct { + m map[uint64]*Item + locker sync.RWMutex +} + +func NewPiece() *Piece { + return &Piece{m: map[uint64]*Item{}} +} + +func (this *Piece) Add(key uint64, item *Item) () { + this.locker.Lock() + this.m[key] = item + this.locker.Unlock() +} + +func (this *Piece) Delete(key uint64) { + this.locker.Lock() + delete(this.m, key) + this.locker.Unlock() +} + +func (this *Piece) Read(key uint64) (item *Item) { + this.locker.RLock() + item = this.m[key] + if item != nil && item.expiredAt < utils.UnixTime() { + item = nil + } + this.locker.RUnlock() + + return +} + +func (this *Piece) Count() (count int) { + this.locker.RLock() + count = len(this.m) + this.locker.RUnlock() + return +} + +func (this *Piece) GC() { + this.locker.Lock() + timestamp := time.Now().Unix() + for k, item := range this.m { + if item.expiredAt <= timestamp { + delete(this.m, k) + } + } + this.locker.Unlock() +} diff --git a/internal/cache/piece_test.go b/internal/cache/piece_test.go new file mode 100644 index 0000000..1f4d4e0 --- /dev/null +++ b/internal/cache/piece_test.go @@ -0,0 +1,52 @@ +package cache + +import ( + "github.com/iwind/TeaGo/rands" + "testing" + "time" +) + +func TestPiece_Add(t *testing.T) { + piece := NewPiece() + piece.Add(1, &Item{expiredAt: time.Now().Unix() + 3600}) + piece.Add(2, &Item{}) + piece.Add(3, &Item{}) + piece.Delete(3) + for key, item := range piece.m { + t.Log(key, item.value) + } + t.Log(piece.Read(1)) +} + +func TestPiece_GC(t *testing.T) { + piece := NewPiece() + piece.Add(1, &Item{value: 1, expiredAt: time.Now().Unix() + 1}) + piece.Add(2, &Item{value: 2, expiredAt: time.Now().Unix() + 1}) + piece.Add(3, &Item{value: 3, expiredAt: time.Now().Unix() + 1}) + t.Log("before gc ===") + for key, item := range piece.m { + t.Log(key, item.value) + } + + time.Sleep(1 * time.Second) + piece.GC() + + t.Log("after gc ===") + for key, item := range piece.m { + t.Log(key, item.value) + } +} + +func TestPiece_GC2(t *testing.T) { + piece := NewPiece() + for i := 0; i < 10_000; i++ { + piece.Add(uint64(i), &Item{value: 1, expiredAt: time.Now().Unix() + int64(rands.Int(1, 10))}) + } + + time.Sleep(1 * time.Second) + + before := time.Now() + piece.GC() + t.Log(time.Since(before).Seconds()*1000, "ms") + t.Log(piece.Count()) +} diff --git a/internal/cache/utils.go b/internal/cache/utils.go new file mode 100644 index 0000000..1ea27b0 --- /dev/null +++ b/internal/cache/utils.go @@ -0,0 +1,7 @@ +package cache + +import "github.com/dchest/siphash" + +func HashKey(key []byte) uint64 { + return siphash.Hash(0, 0, key) +} diff --git a/internal/cache/utils_test.go b/internal/cache/utils_test.go new file mode 100644 index 0000000..83b2da5 --- /dev/null +++ b/internal/cache/utils_test.go @@ -0,0 +1,13 @@ +package cache + +import ( + "runtime" + "testing" +) + +func BenchmarkHashKey(b *testing.B) { + runtime.GOMAXPROCS(1) + for i := 0; i < b.N; i++ { + HashKey([]byte("HELLO,WORLDHELLO,WORLDHELLO,WORLDHELLO,WORLDHELLO,WORLDHELLO,WORLD")) + } +} diff --git a/internal/nodes/waf_manager.go b/internal/nodes/waf_manager.go index 49309d2..9423851 100644 --- a/internal/nodes/waf_manager.go +++ b/internal/nodes/waf_manager.go @@ -93,11 +93,20 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( r := &waf.Rule{ Description: rule.Description, Param: rule.Param, + ParamFilters: []*waf.ParamFilter{}, Operator: rule.Operator, Value: rule.Value, IsCaseInsensitive: rule.IsCaseInsensitive, CheckpointOptions: rule.CheckpointOptions, } + + for _, paramFilter := range rule.ParamFilters { + r.ParamFilters = append(r.ParamFilters, &waf.ParamFilter{ + Code: paramFilter.Code, + Options: paramFilter.Options, + }) + } + s.Rules = append(s.Rules, r) } diff --git a/internal/waf/param_filter.go b/internal/waf/param_filter.go new file mode 100644 index 0000000..45e7b7d --- /dev/null +++ b/internal/waf/param_filter.go @@ -0,0 +1,8 @@ +package waf + +import "github.com/iwind/TeaGo/maps" + +type ParamFilter struct { + Code string `yaml:"code" json:"code"` + Options maps.Map `yaml:"options" json:"options"` +} diff --git a/internal/waf/rule.go b/internal/waf/rule.go index a06c769..9633dbe 100644 --- a/internal/waf/rule.go +++ b/internal/waf/rule.go @@ -2,9 +2,12 @@ package waf import ( "bytes" + "encoding/base64" "encoding/binary" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/filterconfigs" + "github.com/TeaOSLab/EdgeNode/internal/logs" "github.com/TeaOSLab/EdgeNode/internal/waf/checkpoints" "github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/utils" @@ -23,7 +26,8 @@ var singleParamRegexp = regexp.MustCompile("^\\${[\\w.-]+}$") // rule type Rule struct { Description string `yaml:"description" json:"description"` - Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName} + Param string `yaml:"param" json:"param"` // such as ${arg.name} or ${args}, can be composite as ${arg.firstName}${arg.lastName} + ParamFilters []*ParamFilter `yaml:"paramFilters" json:"paramFilters"` Operator RuleOperator `yaml:"operator" json:"operator"` // such as contains, gt, ... Value string `yaml:"value" json:"value"` // compared value IsCaseInsensitive bool `yaml:"isCaseInsensitive" json:"isCaseInsensitive"` @@ -122,7 +126,6 @@ func (this *Rule) Init() error { } else { return errors.New("invalid ip range") } - } if singleParamRegexp.MatchString(this.Param) { @@ -187,6 +190,11 @@ func (this *Rule) MatchRequest(req *requests.Request) (b bool, err error) { return false, err } + // execute filters + if len(this.ParamFilters) > 0 { + value = this.execFilter(value) + } + // if is composed checkpoint, we just returns true or false if this.singleCheckpoint.IsComposed() { return types.Bool(value), nil @@ -233,6 +241,12 @@ func (this *Rule) MatchResponse(req *requests.Request, resp *requests.Response) if err != nil { return false, err } + + // execute filters + if len(this.ParamFilters) > 0 { + value = this.execFilter(value) + } + return this.Test(value), nil } @@ -420,6 +434,20 @@ func (this *Rule) Test(value interface{}) bool { } else { return strings.HasSuffix(types.String(value), this.Value) } + case RuleOperatorContainsBinary: + data, _ := base64.StdEncoding.DecodeString(types.String(this.Value)) + if this.IsCaseInsensitive { + return bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data)) + } else { + return bytes.Contains([]byte(types.String(value)), data) + } + case RuleOperatorNotContainsBinary: + data, _ := base64.StdEncoding.DecodeString(types.String(this.Value)) + if this.IsCaseInsensitive { + return !bytes.Contains(bytes.ToUpper([]byte(types.String(value))), bytes.ToUpper(data)) + } else { + return !bytes.Contains([]byte(types.String(value)), data) + } case RuleOperatorHasKey: if types.IsSlice(value) { index := types.Int(this.Value) @@ -594,3 +622,24 @@ func (this *Rule) ipToInt64(ip net.IP) int64 { } return int64(binary.BigEndian.Uint32(ip)) } + +func (this *Rule) execFilter(value interface{}) interface{} { + var goNext bool + var err error + + for _, filter := range this.ParamFilters { + filterInstance := filterconfigs.FindFilter(filter.Code) + if filterInstance == nil { + continue + } + value, goNext, err = filterInstance.Do(value, filter.Options) + if err != nil { + logs.Println("WAF", "filter error: "+err.Error()) + break + } + if !goNext { + break + } + } + return value +} diff --git a/internal/waf/rule_operator.go b/internal/waf/rule_operator.go index d1fc0a6..879f8db 100644 --- a/internal/waf/rule_operator.go +++ b/internal/waf/rule_operator.go @@ -23,6 +23,9 @@ const ( RuleOperatorVersionLt RuleOperator = "version lt" RuleOperatorVersionRange RuleOperator = "version range" + RuleOperatorContainsBinary RuleOperator = "contains binary" // contains binary + RuleOperatorNotContainsBinary RuleOperator = "not contains binary" // not contains binary + // ip RuleOperatorEqIP RuleOperator = "eq ip" RuleOperatorGtIP RuleOperator = "gt ip"