使用泛型优化计数器内存

This commit is contained in:
刘祥超
2023-11-15 15:57:41 +08:00
parent 768384dcf0
commit 59f27215d3
6 changed files with 59 additions and 50 deletions

View File

@@ -305,7 +305,7 @@ func (this *ClientConn) increaseSYNFlood(synFloodConfig *firewallconfigs.SYNFloo
// 非TLS设置为两倍防止误封
minAttempts = 2 * minAttempts
}
if result >= types.Uint64(minAttempts) {
if result >= types.Uint32(minAttempts) {
var timeout = synFloodConfig.TimeoutSeconds
if timeout <= 0 {
timeout = 600

View File

@@ -13,12 +13,16 @@ import (
const maxItemsPerGroup = 50_000
var SharedCounter = NewCounter().WithGC()
var SharedCounter = NewCounter[uint32]().WithGC()
type Counter struct {
type SupportedUIntType interface {
uint32 | uint64
}
type Counter[T SupportedUIntType] struct {
countMaps uint64
locker *syncutils.RWMutex
itemMaps []map[uint64]*Item
itemMaps []map[uint64]*Item[T]
gcTicker *time.Ticker
gcIndex int
@@ -26,18 +30,18 @@ type Counter struct {
}
// NewCounter create new counter
func NewCounter() *Counter {
func NewCounter[T SupportedUIntType]() *Counter[T] {
var count = utils.SystemMemoryGB() * 8
if count < 8 {
count = 8
}
var itemMaps = []map[uint64]*Item{}
var itemMaps = []map[uint64]*Item[T]{}
for i := 0; i < count; i++ {
itemMaps = append(itemMaps, map[uint64]*Item{})
itemMaps = append(itemMaps, map[uint64]*Item[T]{})
}
var counter = &Counter{
var counter = &Counter[T]{
countMaps: uint64(count),
locker: syncutils.NewRWMutex(count),
itemMaps: itemMaps,
@@ -47,7 +51,7 @@ func NewCounter() *Counter {
}
// WithGC start the counter with gc automatically
func (this *Counter) WithGC() *Counter {
func (this *Counter[T]) WithGC() *Counter[T] {
if this.gcTicker != nil {
return this
}
@@ -62,7 +66,7 @@ func (this *Counter) WithGC() *Counter {
}
// Increase key
func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 {
func (this *Counter[T]) Increase(key uint64, lifeSeconds int) T {
var index = int(key % this.countMaps)
this.locker.RLock(index)
var item = this.itemMaps[index][key]
@@ -70,7 +74,7 @@ func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 {
if item == nil {
// no need to care about duplication
// always insert new item even when itemMap is full
item = NewItem(lifeSeconds)
item = NewItem[T](lifeSeconds)
this.locker.Lock(index)
this.itemMaps[index][key] = item
this.locker.Unlock(index)
@@ -83,12 +87,12 @@ func (this *Counter) Increase(key uint64, lifeSeconds int) uint64 {
}
// IncreaseKey increase string key
func (this *Counter) IncreaseKey(key string, lifeSeconds int) uint64 {
func (this *Counter[T]) IncreaseKey(key string, lifeSeconds int) T {
return this.Increase(this.hash(key), lifeSeconds)
}
// Get value of key
func (this *Counter) Get(key uint64) uint64 {
func (this *Counter[T]) Get(key uint64) T {
var index = int(key % this.countMaps)
this.locker.RLock(index)
defer this.locker.RUnlock(index)
@@ -100,12 +104,12 @@ func (this *Counter) Get(key uint64) uint64 {
}
// GetKey get value of string key
func (this *Counter) GetKey(key string) uint64 {
func (this *Counter[T]) GetKey(key string) T {
return this.Get(this.hash(key))
}
// Reset key
func (this *Counter) Reset(key uint64) {
func (this *Counter[T]) Reset(key uint64) {
var index = int(key % this.countMaps)
this.locker.RLock(index)
var item = this.itemMaps[index][key]
@@ -119,12 +123,12 @@ func (this *Counter) Reset(key uint64) {
}
// ResetKey string key
func (this *Counter) ResetKey(key string) {
func (this *Counter[T]) ResetKey(key string) {
this.Reset(this.hash(key))
}
// TotalItems get items count
func (this *Counter) TotalItems() int {
func (this *Counter[T]) TotalItems() int {
var total = 0
for i := 0; i < int(this.countMaps); i++ {
@@ -137,7 +141,7 @@ func (this *Counter) TotalItems() int {
}
// GC garbage expired items
func (this *Counter) GC() {
func (this *Counter[T]) GC() {
this.gcLocker.Lock()
var gcIndex = this.gcIndex
@@ -186,11 +190,11 @@ func (this *Counter) GC() {
}
}
func (this *Counter) CountMaps() int {
func (this *Counter[T]) CountMaps() int {
return int(this.countMaps)
}
// calculate hash of the key
func (this *Counter) hash(key string) uint64 {
func (this *Counter[T]) hash(key string) uint64 {
return xxhash.Sum64String(key)
}

View File

@@ -19,7 +19,7 @@ import (
func TestCounter_Increase(t *testing.T) {
var a = assert.NewAssertion(t)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
a.IsTrue(counter.Increase(1, 10) == 1)
a.IsTrue(counter.Increase(1, 10) == 2)
a.IsTrue(counter.Increase(2, 10) == 1)
@@ -32,7 +32,7 @@ func TestCounter_Increase(t *testing.T) {
func TestCounter_IncreaseKey(t *testing.T) {
var a = assert.NewAssertion(t)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
a.IsTrue(counter.IncreaseKey("1", 10) == 1)
a.IsTrue(counter.IncreaseKey("1", 10) == 2)
a.IsTrue(counter.IncreaseKey("2", 10) == 1)
@@ -47,7 +47,7 @@ func TestCounter_GC(t *testing.T) {
return
}
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
counter.Increase(1, 20)
time.Sleep(1 * time.Second)
counter.Increase(1, 20)
@@ -61,7 +61,7 @@ func TestCounter_GC2(t *testing.T) {
return
}
var counter = counters.NewCounter().WithGC()
var counter = counters.NewCounter[uint32]().WithGC()
for i := 0; i < 1e5; i++ {
counter.Increase(uint64(i), rands.Int(10, 300))
}
@@ -79,7 +79,7 @@ func TestCounterMemory(t *testing.T) {
var stat = &runtime.MemStats{}
runtime.ReadMemStats(stat)
var counter = counters.NewCounter().WithGC()
var counter = counters.NewCounter[uint32]()
for i := 0; i < 1_000_000; i++ {
counter.Increase(uint64(i), rands.Int(10, 300))
}
@@ -98,7 +98,7 @@ func TestCounterMemory(t *testing.T) {
func BenchmarkCounter_Increase(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
b.ResetTimer()
var i uint64
@@ -114,7 +114,7 @@ func BenchmarkCounter_Increase(b *testing.B) {
func BenchmarkCounter_IncreaseKey(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
go func() {
var ticker = time.NewTicker(100 * time.Millisecond)
@@ -138,7 +138,7 @@ func BenchmarkCounter_IncreaseKey(b *testing.B) {
func BenchmarkCounter_IncreaseKey2(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
go func() {
var ticker = time.NewTicker(1 * time.Millisecond)
@@ -162,7 +162,7 @@ func BenchmarkCounter_IncreaseKey2(b *testing.B) {
func BenchmarkCounter_GC(b *testing.B) {
runtime.GOMAXPROCS(4)
var counter = counters.NewCounter()
var counter = counters.NewCounter[uint32]()
for i := uint64(0); i < 1e5; i++ {
counter.IncreaseKey(types.String(i), 20)

View File

@@ -6,16 +6,16 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
)
type Item struct {
lifeSeconds int64
spanSeconds int64
spans []uint64
const spanMaxValue = 10_000_000
type Item[T SupportedUIntType] struct {
spans []T
lastUpdateTime int64
lifeSeconds int64
spanSeconds int64
}
func NewItem(lifeSeconds int) *Item {
func NewItem[T SupportedUIntType](lifeSeconds int) *Item[T] {
if lifeSeconds <= 0 {
lifeSeconds = 60
}
@@ -25,21 +25,23 @@ func NewItem(lifeSeconds int) *Item {
}
var countSpans = lifeSeconds/spanSeconds + 1 /** prevent index out of bounds **/
return &Item{
return &Item[T]{
lifeSeconds: int64(lifeSeconds),
spanSeconds: int64(spanSeconds),
spans: make([]uint64, countSpans),
spans: make([]T, countSpans),
lastUpdateTime: fasttime.Now().Unix(),
}
}
func (this *Item) Increase() (result uint64) {
func (this *Item[T]) Increase() (result T) {
var currentTime = fasttime.Now().Unix()
var currentSpanIndex = this.calculateSpanIndex(currentTime)
// return quickly
if this.lastUpdateTime == currentTime {
if this.spans[currentSpanIndex] < spanMaxValue {
this.spans[currentSpanIndex]++
}
for _, count := range this.spans {
result += count
}
@@ -69,7 +71,9 @@ func (this *Item) Increase() (result uint64) {
}
}
if this.spans[currentSpanIndex] < spanMaxValue {
this.spans[currentSpanIndex]++
}
this.lastUpdateTime = currentTime
for _, count := range this.spans {
@@ -79,7 +83,7 @@ func (this *Item) Increase() (result uint64) {
return
}
func (this *Item) Sum() (result uint64) {
func (this *Item[T]) Sum() (result T) {
if this.lastUpdateTime == 0 {
return 0
}
@@ -104,16 +108,16 @@ func (this *Item) Sum() (result uint64) {
return result
}
func (this *Item) Reset() {
func (this *Item[T]) Reset() {
for index := range this.spans {
this.spans[index] = 0
}
}
func (this *Item) IsExpired(currentTime int64) bool {
func (this *Item[T]) IsExpired(currentTime int64) bool {
return this.lastUpdateTime < currentTime-this.lifeSeconds-this.spanSeconds
}
func (this *Item) calculateSpanIndex(timestamp int64) int {
func (this *Item[T]) calculateSpanIndex(timestamp int64) int {
return int(timestamp % this.lifeSeconds / this.spanSeconds)
}

View File

@@ -17,7 +17,7 @@ func TestItem_Increase(t *testing.T) {
return
}
var item = counters.NewItem(10)
var item = counters.NewItem[uint32](10)
t.Log(item.Increase(), item.Sum())
time.Sleep(1 * time.Second)
t.Log(item.Increase(), item.Sum())
@@ -41,7 +41,7 @@ func TestItem_Increase2(t *testing.T) {
var a = assert.NewAssertion(t)
var item = counters.NewItem(20)
var item = counters.NewItem[uint32](20)
for i := 0; i < 100; i++ {
t.Log(item.Increase(), item.Sum(), timeutil.Format("H:i:s"))
time.Sleep(2 * time.Second)
@@ -58,7 +58,7 @@ func TestItem_IsExpired(t *testing.T) {
var currentTime = time.Now().Unix()
var item = counters.NewItem(10)
var item = counters.NewItem[uint32](10)
t.Log(item.IsExpired(currentTime))
time.Sleep(10 * time.Second)
t.Log(item.IsExpired(currentTime))
@@ -73,7 +73,7 @@ func BenchmarkItem_Increase(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
var item = counters.NewItem(60)
var item = counters.NewItem[uint32](60)
item.Increase()
item.Sum()
}

View File

@@ -76,7 +76,8 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti
}
var ccKey = "WAF-CC-" + types.String(ruleId) + "-" + strings.Join(keyValues, "@")
value = counters.SharedCounter.IncreaseKey(ccKey, period)
var ccValue = counters.SharedCounter.IncreaseKey(ccKey, period)
value = ccValue
// 基于指纹统计
var enableFingerprint = true
@@ -96,7 +97,7 @@ func (this *CC2Checkpoint) RequestValue(req requests.Request, param string, opti
}
var fpCCKey = "WAF-CC-" + types.String(ruleId) + "-" + strings.Join(fpKeyValues, "@")
var fpValue = counters.SharedCounter.IncreaseKey(fpCCKey, period)
if fpValue > value.(uint64) {
if fpValue > ccValue {
value = fpValue
}
}