diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index 41b7b60..b11e459 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -305,7 +305,7 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo // 非TLS,设置为两倍,防止误封 minAttempts = 2 * minAttempts } - if result >= types.Uint64(minAttempts) { + if result >= types.Uint32(minAttempts) { var timeout = synFloodConfig.TimeoutSeconds if timeout <= 0 { timeout = 600 diff --git a/internal/utils/counters/counter.go b/internal/utils/counters/counter.go index a711f2e..f8207c1 100644 --- a/internal/utils/counters/counter.go +++ b/internal/utils/counters/counter.go @@ -13,12 +13,16 @@ import ( const maxItemsPerGroup = 50_000 -var SharedCounter = NewCounter().WithGC() +var SharedCounter = NewCounter[uint32]().WithGC() -type Counter struct { +type SupportedUIntType interface { + uint32 | uint64 +} + +type Counter[T SupportedUIntType] struct { countMaps uint64 locker *syncutils.RWMutex - itemMaps []map[uint64]*Item + itemMaps []map[uint64]*Item[T] gcTicker *time.Ticker gcIndex int @@ -26,18 +30,18 @@ type Counter struct { } // NewCounter create new counter -func NewCounter() *Counter { +func NewCounter[T SupportedUIntType]() *Counter[T] { var count = utils.SystemMemoryGB() * 8 if count < 8 { count = 8 } - var itemMaps = []map[uint64]*Item{} + var itemMaps = []map[uint64]*Item[T]{} for i := 0; i < count; i++ { - itemMaps = append(itemMaps, map[uint64]*Item{}) + itemMaps = append(itemMaps, map[uint64]*Item[T]{}) } - var counter = &Counter{ + var counter = &Counter[T]{ countMaps: uint64(count), locker: syncutils.NewRWMutex(count), itemMaps: itemMaps, @@ -47,7 +51,7 @@ func NewCounter() *Counter { } // WithGC start the counter with gc automatically -func (this *Counter) WithGC() *Counter { +func (this *Counter[T]) WithGC() *Counter[T] { if this.gcTicker != nil { return this } @@ -62,7 +66,7 @@ func (this *Counter) WithGC() *Counter { } // Increase key -func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 { +func (this *Counter[T]) Increase(key uint64, lifeSeconds int) T { var index = int(key % this.countMaps) this.locker.RLock(index) var item = this.itemMaps[index][key] @@ -70,7 +74,7 @@ func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 { if item == nil { // no need to care about duplication // always insert new item even when itemMap is full - item = NewItem(lifeSeconds) + item = NewItem[T](lifeSeconds) this.locker.Lock(index) this.itemMaps[index][key] = item this.locker.Unlock(index) @@ -83,12 +87,12 @@ func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 { } // IncreaseKey increase string key -func (this *Counter) IncreaseKey(key string, lifeSeconds int) uint64 { +func (this *Counter[T]) IncreaseKey(key string, lifeSeconds int) T { return this.Increase(this.hash(key), lifeSeconds) } // Get value of key -func (this *Counter) Get(key uint64) uint64 { +func (this *Counter[T]) Get(key uint64) T { var index = int(key % this.countMaps) this.locker.RLock(index) defer this.locker.RUnlock(index) @@ -100,12 +104,12 @@ func (this *Counter) Get(key uint64) uint64 { } // GetKey get value of string key -func (this *Counter) GetKey(key string) uint64 { +func (this *Counter[T]) GetKey(key string) T { return this.Get(this.hash(key)) } // Reset key -func (this *Counter) Reset(key uint64) { +func (this *Counter[T]) Reset(key uint64) { var index = int(key % this.countMaps) this.locker.RLock(index) var item = this.itemMaps[index][key] @@ -119,12 +123,12 @@ func (this *Counter) Reset(key uint64) { } // ResetKey string key -func (this *Counter) ResetKey(key string) { +func (this *Counter[T]) ResetKey(key string) { this.Reset(this.hash(key)) } // TotalItems get items count -func (this *Counter) TotalItems() int { +func (this *Counter[T]) TotalItems() int { var total = 0 for i := 0; i < int(this.countMaps); i++ { @@ -137,7 +141,7 @@ func (this *Counter) TotalItems() int { } // GC garbage expired items -func (this *Counter) GC() { +func (this *Counter[T]) GC() { this.gcLocker.Lock() var gcIndex = this.gcIndex @@ -186,11 +190,11 @@ func (this *Counter) GC() { } } -func (this *Counter) CountMaps() int { +func (this *Counter[T]) CountMaps() int { return int(this.countMaps) } // calculate hash of the key -func (this *Counter) hash(key string) uint64 { +func (this *Counter[T]) hash(key string) uint64 { return xxhash.Sum64String(key) } diff --git a/internal/utils/counters/counter_test.go b/internal/utils/counters/counter_test.go index d3cce97..8083711 100644 --- a/internal/utils/counters/counter_test.go +++ b/internal/utils/counters/counter_test.go @@ -19,7 +19,7 @@ import ( func TestCounter_Increase(t *testing.T) { var a = assert.NewAssertion(t) - var counter = counters.NewCounter() + var counter = counters.NewCounter[uint32]() a.IsTrue(counter.Increase(1, 10) == 1) a.IsTrue(counter.Increase(1, 10) == 2) a.IsTrue(counter.Increase(2, 10) == 1) @@ -32,7 +32,7 @@ func TestCounter_Increase(t *testing.T) { func TestCounter_IncreaseKey(t *testing.T) { var a = assert.NewAssertion(t) - var counter = counters.NewCounter() + var counter = counters.NewCounter[uint32]() a.IsTrue(counter.IncreaseKey("1", 10) == 1) a.IsTrue(counter.IncreaseKey("1", 10) == 2) a.IsTrue(counter.IncreaseKey("2", 10) == 1) @@ -47,7 +47,7 @@ func TestCounter_GC(t *testing.T) { return } - var counter = counters.NewCounter() + var counter = counters.NewCounter[uint32]() counter.Increase(1, 20) time.Sleep(1 * time.Second) counter.Increase(1, 20) @@ -61,7 +61,7 @@ func TestCounter_GC2(t *testing.T) { return } - var counter = counters.NewCounter().WithGC() + var counter = counters.NewCounter[uint32]().WithGC() for i := 0; i < 1e5; i++ { counter.Increase(uint64(i), rands.Int(10, 300)) } @@ -79,7 +79,7 @@ func TestCounterMemory(t *testing.T) { var stat = &runtime.MemStats{} runtime.ReadMemStats(stat) - var counter = counters.NewCounter().WithGC() + var counter = counters.NewCounter[uint32]() for i := 0; i < 1_000_000; i++ { counter.Increase(uint64(i), rands.Int(10, 300)) } @@ -98,7 +98,7 @@ func TestCounterMemory(t *testing.T) { func BenchmarkCounter_Increase(b *testing.B) { runtime.GOMAXPROCS(4) - var counter = counters.NewCounter() + var counter = counters.NewCounter[uint32]() b.ResetTimer() var i uint64 @@ -114,7 +114,7 @@ func BenchmarkCounter_Increase(b *testing.B) { func BenchmarkCounter_IncreaseKey(b *testing.B) { runtime.GOMAXPROCS(4) - var counter = counters.NewCounter() + var counter = counters.NewCounter[uint32]() go func() { var ticker = time.NewTicker(100 * time.Millisecond) @@ -138,7 +138,7 @@ func BenchmarkCounter_IncreaseKey(b *testing.B) { func BenchmarkCounter_IncreaseKey2(b *testing.B) { runtime.GOMAXPROCS(4) - var counter = counters.NewCounter() + var counter = counters.NewCounter[uint32]() go func() { var ticker = time.NewTicker(1 * time.Millisecond) @@ -162,7 +162,7 @@ func BenchmarkCounter_IncreaseKey2(b *testing.B) { func BenchmarkCounter_GC(b *testing.B) { runtime.GOMAXPROCS(4) - var counter = counters.NewCounter() + var counter = counters.NewCounter[uint32]() for i := uint64(0); i < 1e5; i++ { counter.IncreaseKey(types.String(i), 20) diff --git a/internal/utils/counters/item.go b/internal/utils/counters/item.go index c19578b..9dd4d42 100644 --- a/internal/utils/counters/item.go +++ b/internal/utils/counters/item.go @@ -6,16 +6,16 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" ) -type Item struct { - lifeSeconds int64 - - spanSeconds int64 - spans []uint64 +const spanMaxValue = 10_000_000 +type Item[T SupportedUIntType] struct { + spans []T lastUpdateTime int64 + lifeSeconds int64 + spanSeconds int64 } -func NewItem(lifeSeconds int) *Item { +func NewItem[T SupportedUIntType](lifeSeconds int) *Item[T] { if lifeSeconds <= 0 { lifeSeconds = 60 } @@ -25,21 +25,23 @@ func NewItem(lifeSeconds int) *Item { } var countSpans = lifeSeconds/spanSeconds + 1 /** prevent index out of bounds **/ - return &Item{ + return &Item[T]{ lifeSeconds: int64(lifeSeconds), spanSeconds: int64(spanSeconds), - spans: make([]uint64, countSpans), + spans: make([]T, countSpans), lastUpdateTime: fasttime.Now().Unix(), } } -func (this *Item) Increase() (result uint64) { +func (this *Item[T]) Increase() (result T) { var currentTime = fasttime.Now().Unix() var currentSpanIndex = this.calculateSpanIndex(currentTime) // return quickly if this.lastUpdateTime == currentTime { - this.spans[currentSpanIndex]++ + if this.spans[currentSpanIndex] < spanMaxValue { + this.spans[currentSpanIndex]++ + } for _, count := range this.spans { result += count } @@ -69,7 +71,9 @@ func (this *Item) Increase() (result uint64) { } } - this.spans[currentSpanIndex]++ + if this.spans[currentSpanIndex] < spanMaxValue { + this.spans[currentSpanIndex]++ + } this.lastUpdateTime = currentTime for _, count := range this.spans { @@ -79,7 +83,7 @@ func (this *Item) Increase() (result uint64) { return } -func (this *Item) Sum() (result uint64) { +func (this *Item[T]) Sum() (result T) { if this.lastUpdateTime == 0 { return 0 } @@ -104,16 +108,16 @@ func (this *Item) Sum() (result uint64) { return result } -func (this *Item) Reset() { +func (this *Item[T]) Reset() { for index := range this.spans { this.spans[index] = 0 } } -func (this *Item) IsExpired(currentTime int64) bool { +func (this *Item[T]) IsExpired(currentTime int64) bool { return this.lastUpdateTime < currentTime-this.lifeSeconds-this.spanSeconds } -func (this *Item) calculateSpanIndex(timestamp int64) int { +func (this *Item[T]) calculateSpanIndex(timestamp int64) int { return int(timestamp % this.lifeSeconds / this.spanSeconds) } diff --git a/internal/utils/counters/item_test.go b/internal/utils/counters/item_test.go index 2a0c1c9..098ed1f 100644 --- a/internal/utils/counters/item_test.go +++ b/internal/utils/counters/item_test.go @@ -17,7 +17,7 @@ func TestItem_Increase(t *testing.T) { return } - var item = counters.NewItem(10) + var item = counters.NewItem[uint32](10) t.Log(item.Increase(), item.Sum()) time.Sleep(1 * time.Second) t.Log(item.Increase(), item.Sum()) @@ -41,7 +41,7 @@ func TestItem_Increase2(t *testing.T) { var a = assert.NewAssertion(t) - var item = counters.NewItem(20) + var item = counters.NewItem[uint32](20) for i := 0; i < 100; i++ { t.Log(item.Increase(), item.Sum(), timeutil.Format("H:i:s")) time.Sleep(2 * time.Second) @@ -58,7 +58,7 @@ func TestItem_IsExpired(t *testing.T) { var currentTime = time.Now().Unix() - var item = counters.NewItem(10) + var item = counters.NewItem[uint32](10) t.Log(item.IsExpired(currentTime)) time.Sleep(10 * time.Second) t.Log(item.IsExpired(currentTime)) @@ -73,7 +73,7 @@ func BenchmarkItem_Increase(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - var item = counters.NewItem(60) + var item = counters.NewItem[uint32](60) item.Increase() item.Sum() } diff --git a/internal/waf/checkpoints/cc2.go b/internal/waf/checkpoints/cc2.go index 915ad8b..03857b3 100644 --- a/internal/waf/checkpoints/cc2.go +++ b/internal/waf/checkpoints/cc2.go @@ -76,7 +76,8 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti } var ccKey = "WAF-CC-" + types.String(ruleId) + "-" + strings.Join(keyValues, "@") - value = counters.SharedCounter.IncreaseKey(ccKey, period) + var ccValue = counters.SharedCounter.IncreaseKey(ccKey, period) + value = ccValue // 基于指纹统计 var enableFingerprint = true @@ -96,7 +97,7 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti } var fpCCKey = "WAF-CC-" + types.String(ruleId) + "-" + strings.Join(fpKeyValues, "@") var fpValue = counters.SharedCounter.IncreaseKey(fpCCKey, period) - if fpValue > value.(uint64) { + if fpValue > ccValue { value = fpValue } }