mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-07 18:50:27 +08:00
用户端可以添加WAF 黑白名单
This commit is contained in:
@@ -2,6 +2,7 @@ package iplibrary
|
|||||||
|
|
||||||
import "github.com/TeaOSLab/EdgeNode/internal/utils"
|
import "github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||||
|
|
||||||
|
// IP条目
|
||||||
type IPItem struct {
|
type IPItem struct {
|
||||||
Id int64
|
Id int64
|
||||||
IPFrom uint32
|
IPFrom uint32
|
||||||
@@ -9,6 +10,7 @@ type IPItem struct {
|
|||||||
ExpiredAt int64
|
ExpiredAt int64
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 检查是否包含某个IP
|
||||||
func (this *IPItem) Contains(ip uint32) bool {
|
func (this *IPItem) Contains(ip uint32) bool {
|
||||||
if this.IPTo == 0 {
|
if this.IPTo == 0 {
|
||||||
if this.IPFrom != ip {
|
if this.IPFrom != ip {
|
||||||
|
|||||||
@@ -1,45 +1,152 @@
|
|||||||
package iplibrary
|
package iplibrary
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
|
||||||
"sync"
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IP名单
|
// IP名单
|
||||||
type IPList struct {
|
type IPList struct {
|
||||||
itemsMap map[int64]*IPItem // id => item
|
itemsMap map[int64]*IPItem // id => item
|
||||||
|
ipMap map[uint32][]int64 // ip => itemIds
|
||||||
|
expireList *expires.List
|
||||||
|
|
||||||
locker sync.RWMutex
|
locker sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewIPList() *IPList {
|
func NewIPList() *IPList {
|
||||||
return &IPList{
|
list := &IPList{
|
||||||
itemsMap: map[int64]*IPItem{},
|
itemsMap: map[int64]*IPItem{},
|
||||||
|
ipMap: map[uint32][]int64{},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
expireList := expires.NewList()
|
||||||
|
go func() {
|
||||||
|
expireList.StartGC(func(itemId int64) {
|
||||||
|
list.Delete(itemId)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
list.expireList = expireList
|
||||||
|
return list
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *IPList) Add(item *IPItem) {
|
func (this *IPList) Add(item *IPItem) {
|
||||||
|
if item == nil || (item.IPFrom == 0 && item.IPTo == 0) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
this.locker.Lock()
|
this.locker.Lock()
|
||||||
|
|
||||||
|
// 是否已经存在
|
||||||
|
_, ok := this.itemsMap[item.Id]
|
||||||
|
if ok {
|
||||||
|
this.deleteItem(item.Id)
|
||||||
|
}
|
||||||
|
|
||||||
this.itemsMap[item.Id] = item
|
this.itemsMap[item.Id] = item
|
||||||
|
|
||||||
|
// 展开
|
||||||
|
if item.IPFrom > 0 {
|
||||||
|
if item.IPTo == 0 {
|
||||||
|
this.addIP(item.IPFrom, item.Id)
|
||||||
|
} else {
|
||||||
|
if item.IPFrom > item.IPTo {
|
||||||
|
item.IPTo, item.IPFrom = item.IPFrom, item.IPTo
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := item.IPFrom; i <= item.IPTo; i++ {
|
||||||
|
// 最多不能超过65535,防止整个系统内存爆掉
|
||||||
|
if i >= item.IPFrom+65535 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
this.addIP(i, item.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if item.IPTo > 0 {
|
||||||
|
this.addIP(item.IPTo, item.Id)
|
||||||
|
}
|
||||||
|
|
||||||
|
if item.ExpiredAt > 0 {
|
||||||
|
this.expireList.Add(item.Id, item.ExpiredAt)
|
||||||
|
}
|
||||||
|
|
||||||
this.locker.Unlock()
|
this.locker.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *IPList) Delete(itemId int64) {
|
func (this *IPList) Delete(itemId int64) {
|
||||||
this.locker.Lock()
|
this.locker.Lock()
|
||||||
delete(this.itemsMap, itemId)
|
defer this.locker.Unlock()
|
||||||
this.locker.Unlock()
|
this.deleteItem(itemId)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 判断是否包含某个IP
|
// 判断是否包含某个IP
|
||||||
func (this *IPList) Contains(ip uint32) bool {
|
func (this *IPList) Contains(ip uint32) bool {
|
||||||
// TODO 优化查询速度,可能需要把items分成两组,一组是单个的,一组是按照范围的,按照范围的再进行二分法查找
|
|
||||||
this.locker.RLock()
|
this.locker.RLock()
|
||||||
for _, item := range this.itemsMap {
|
_, ok := this.ipMap[ip]
|
||||||
if item.Contains(ip) {
|
|
||||||
this.locker.RUnlock()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.locker.RUnlock()
|
this.locker.RUnlock()
|
||||||
|
|
||||||
return false
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在不加锁的情况下删除某个Item
|
||||||
|
// 将会被别的方法引用,切记不能加锁
|
||||||
|
func (this *IPList) deleteItem(itemId int64) {
|
||||||
|
item, ok := this.itemsMap[itemId]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(this.itemsMap, itemId)
|
||||||
|
|
||||||
|
// 展开
|
||||||
|
if item.IPFrom > 0 {
|
||||||
|
if item.IPTo == 0 {
|
||||||
|
this.deleteIP(item.IPFrom, item.Id)
|
||||||
|
} else {
|
||||||
|
if item.IPFrom > item.IPTo {
|
||||||
|
item.IPTo, item.IPFrom = item.IPFrom, item.IPTo
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := item.IPFrom; i <= item.IPTo; i++ {
|
||||||
|
// 最多不能超过65535,防止整个系统内存爆掉
|
||||||
|
if i >= item.IPFrom+65535 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
this.deleteIP(i, item.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if item.IPTo > 0 {
|
||||||
|
this.deleteIP(item.IPTo, item.Id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加单个IP
|
||||||
|
func (this *IPList) addIP(ip uint32, itemId int64) {
|
||||||
|
itemIds, ok := this.ipMap[ip]
|
||||||
|
if ok {
|
||||||
|
itemIds = append(itemIds, itemId)
|
||||||
|
} else {
|
||||||
|
itemIds = []int64{itemId}
|
||||||
|
}
|
||||||
|
this.ipMap[ip] = itemIds
|
||||||
|
}
|
||||||
|
|
||||||
|
// 删除单个IP
|
||||||
|
func (this *IPList) deleteIP(ip uint32, itemId int64) {
|
||||||
|
itemIds, ok := this.ipMap[ip]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
newItemIds := []int64{}
|
||||||
|
for _, oldItemId := range itemIds {
|
||||||
|
if oldItemId == itemId {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newItemIds = append(newItemIds, oldItemId)
|
||||||
|
}
|
||||||
|
if len(newItemIds) > 0 {
|
||||||
|
this.ipMap[ip] = newItemIds
|
||||||
|
} else {
|
||||||
|
delete(this.ipMap, ip)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,84 @@
|
|||||||
package iplibrary
|
package iplibrary
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/iwind/TeaGo/assert"
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func TestIPList_Add_Empty(t *testing.T) {
|
||||||
|
ipList := NewIPList()
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
})
|
||||||
|
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(ipList.ipMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPList_Add_One(t *testing.T) {
|
||||||
|
ipList := NewIPList()
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
IPFrom: IP2Long("192.168.1.1"),
|
||||||
|
})
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 2,
|
||||||
|
IPTo: IP2Long("192.168.1.2"),
|
||||||
|
})
|
||||||
|
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(ipList.ipMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPList_Update(t *testing.T) {
|
||||||
|
ipList := NewIPList()
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
IPFrom: IP2Long("192.168.1.1"),
|
||||||
|
})
|
||||||
|
/**ipList.Add(&IPItem{
|
||||||
|
Id: 2,
|
||||||
|
IPFrom: IP2Long("192.168.1.1"),
|
||||||
|
})**/
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
IPTo: IP2Long("192.168.1.2"),
|
||||||
|
})
|
||||||
|
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(ipList.ipMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPList_Add_Range(t *testing.T) {
|
||||||
|
ipList := NewIPList()
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
IPFrom: IP2Long("192.168.1.1"),
|
||||||
|
IPTo: IP2Long("192.168.2.1"),
|
||||||
|
})
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 2,
|
||||||
|
IPTo: IP2Long("192.168.1.2"),
|
||||||
|
})
|
||||||
|
t.Log(len(ipList.ipMap), "ips")
|
||||||
|
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(ipList.ipMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPList_Add_Overflow(t *testing.T) {
|
||||||
|
a := assert.NewAssertion(t)
|
||||||
|
|
||||||
|
ipList := NewIPList()
|
||||||
|
ipList.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
IPFrom: IP2Long("192.168.1.1"),
|
||||||
|
IPTo: IP2Long("192.169.255.1"),
|
||||||
|
})
|
||||||
|
t.Log(len(ipList.ipMap), "ips")
|
||||||
|
a.IsTrue(len(ipList.ipMap) <= 65535)
|
||||||
|
}
|
||||||
|
|
||||||
func TestNewIPList_Memory(t *testing.T) {
|
func TestNewIPList_Memory(t *testing.T) {
|
||||||
list := NewIPList()
|
list := NewIPList()
|
||||||
|
|
||||||
@@ -26,27 +98,78 @@ func TestIPList_Contains(t *testing.T) {
|
|||||||
for i := 0; i < 255; i++ {
|
for i := 0; i < 255; i++ {
|
||||||
list.Add(&IPItem{
|
list.Add(&IPItem{
|
||||||
Id: int64(i),
|
Id: int64(i),
|
||||||
IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)),
|
IPFrom: IP2Long(strconv.Itoa(i) + ".168.0.1"),
|
||||||
IPTo: 0,
|
IPTo: IP2Long(strconv.Itoa(i) + ".168.255.1"),
|
||||||
ExpiredAt: 0,
|
ExpiredAt: 0,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
t.Log(len(list.ipMap))
|
||||||
|
|
||||||
|
before := time.Now()
|
||||||
t.Log(list.Contains(IP2Long("192.168.1.100")))
|
t.Log(list.Contains(IP2Long("192.168.1.100")))
|
||||||
t.Log(list.Contains(IP2Long("192.168.2.100")))
|
t.Log(list.Contains(IP2Long("192.168.2.100")))
|
||||||
|
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIPList_Delete(t *testing.T) {
|
||||||
|
list := NewIPList()
|
||||||
|
list.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
IPFrom: IP2Long("192.168.0.1"),
|
||||||
|
ExpiredAt: 0,
|
||||||
|
})
|
||||||
|
list.Add(&IPItem{
|
||||||
|
Id: 2,
|
||||||
|
IPFrom: IP2Long("192.168.0.1"),
|
||||||
|
ExpiredAt: 0,
|
||||||
|
})
|
||||||
|
t.Log("===BEFORE===")
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(list.ipMap, t)
|
||||||
|
|
||||||
|
list.Delete(1)
|
||||||
|
|
||||||
|
t.Log("===AFTER===")
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(list.ipMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGC(t *testing.T) {
|
||||||
|
list := NewIPList()
|
||||||
|
list.Add(&IPItem{
|
||||||
|
Id: 1,
|
||||||
|
IPFrom: IP2Long("192.168.1.100"),
|
||||||
|
IPTo: IP2Long("192.168.1.101"),
|
||||||
|
ExpiredAt: time.Now().Unix() + 1,
|
||||||
|
})
|
||||||
|
list.Add(&IPItem{
|
||||||
|
Id: 2,
|
||||||
|
IPFrom: IP2Long("192.168.1.102"),
|
||||||
|
IPTo: IP2Long("192.168.1.103"),
|
||||||
|
ExpiredAt: 0,
|
||||||
|
})
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(list.ipMap, t)
|
||||||
|
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
t.Log("===AFTER GC===")
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
logs.PrintAsJSON(list.ipMap, t)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkIPList_Contains(b *testing.B) {
|
func BenchmarkIPList_Contains(b *testing.B) {
|
||||||
runtime.GOMAXPROCS(1)
|
runtime.GOMAXPROCS(1)
|
||||||
|
|
||||||
list := NewIPList()
|
list := NewIPList()
|
||||||
for i := 0; i < 10_000; i++ {
|
for i := 192; i < 194; i++ {
|
||||||
list.Add(&IPItem{
|
list.Add(&IPItem{
|
||||||
Id: int64(i),
|
Id: int64(1),
|
||||||
IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)),
|
IPFrom: IP2Long(strconv.Itoa(i) + ".1.0.1"),
|
||||||
IPTo: 0,
|
IPTo: IP2Long(strconv.Itoa(i) + ".2.0.1"),
|
||||||
ExpiredAt: time.Now().Unix() + 60,
|
ExpiredAt: time.Now().Unix() + 60,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
b.Log(len(list.ipMap), "ip")
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
_ = list.Contains(IP2Long("192.168.1.100"))
|
_ = list.Contains(IP2Long("192.168.1.100"))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ type HTTPRequest struct {
|
|||||||
func (this *HTTPRequest) init() {
|
func (this *HTTPRequest) init() {
|
||||||
this.writer = NewHTTPWriter(this, this.RawWriter)
|
this.writer = NewHTTPWriter(this, this.RawWriter)
|
||||||
this.web = &serverconfigs.HTTPWebConfig{IsOn: true}
|
this.web = &serverconfigs.HTTPWebConfig{IsOn: true}
|
||||||
//this.uri = this.RawReq.URL.RequestURI()
|
// this.uri = this.RawReq.URL.RequestURI()
|
||||||
// 之所以不使用RequestURI(),是不想让URL中的Path被Encode
|
// 之所以不使用RequestURI(),是不想让URL中的Path被Encode
|
||||||
if len(this.RawReq.URL.RawQuery) > 0 {
|
if len(this.RawReq.URL.RawQuery) > 0 {
|
||||||
this.uri = this.RawReq.URL.Path + "?" + this.RawReq.URL.RawQuery
|
this.uri = this.RawReq.URL.Path + "?" + this.RawReq.URL.RawQuery
|
||||||
@@ -82,7 +82,6 @@ func (this *HTTPRequest) init() {
|
|||||||
this.uri = this.RawReq.URL.Path
|
this.uri = this.RawReq.URL.Path
|
||||||
}
|
}
|
||||||
|
|
||||||
this.uri = this.RawReq.URL.Path
|
|
||||||
this.rawURI = this.uri
|
this.rawURI = this.uri
|
||||||
this.varMapping = map[string]string{
|
this.varMapping = map[string]string{
|
||||||
// 缓存相关初始化
|
// 缓存相关初始化
|
||||||
@@ -300,6 +299,9 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo
|
|||||||
// waf
|
// waf
|
||||||
if web.FirewallRef != nil && (web.FirewallRef.IsPrior || isTop) {
|
if web.FirewallRef != nil && (web.FirewallRef.IsPrior || isTop) {
|
||||||
this.web.FirewallRef = web.FirewallRef
|
this.web.FirewallRef = web.FirewallRef
|
||||||
|
if web.FirewallPolicy != nil {
|
||||||
|
this.web.FirewallPolicy = web.FirewallPolicy
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// access log
|
// access log
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package nodes
|
package nodes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||||
@@ -11,8 +12,26 @@ import (
|
|||||||
|
|
||||||
// 调用WAF
|
// 调用WAF
|
||||||
func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||||
firewallPolicy := sharedNodeConfig.HTTPFirewallPolicy
|
// 当前服务的独立设置
|
||||||
|
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
|
||||||
|
blocked = this.checkWAFRequest(this.web.FirewallPolicy)
|
||||||
|
if blocked {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 公用的防火墙设置
|
||||||
|
if sharedNodeConfig.HTTPFirewallPolicy != nil {
|
||||||
|
blocked = this.checkWAFRequest(sharedNodeConfig.HTTPFirewallPolicy)
|
||||||
|
if blocked {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFirewallPolicy) (blocked bool) {
|
||||||
// 检查配置是否为空
|
// 检查配置是否为空
|
||||||
if firewallPolicy == nil || !firewallPolicy.IsOn || firewallPolicy.Inbound == nil || !firewallPolicy.Inbound.IsOn {
|
if firewallPolicy == nil || !firewallPolicy.IsOn || firewallPolicy.Inbound == nil || !firewallPolicy.Inbound.IsOn {
|
||||||
return
|
return
|
||||||
@@ -21,16 +40,16 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
|||||||
// 检查IP白名单
|
// 检查IP白名单
|
||||||
remoteAddr := this.requestRemoteAddr()
|
remoteAddr := this.requestRemoteAddr()
|
||||||
inbound := firewallPolicy.Inbound
|
inbound := firewallPolicy.Inbound
|
||||||
if inbound.WhiteListRef != nil && inbound.WhiteListRef.IsOn && inbound.WhiteListRef.ListId > 0 {
|
if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 {
|
||||||
list := iplibrary.SharedIPListManager.FindList(inbound.WhiteListRef.ListId)
|
list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId)
|
||||||
if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) {
|
if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查IP黑名单
|
// 检查IP黑名单
|
||||||
if inbound.BlackListRef != nil && inbound.BlackListRef.IsOn && inbound.BlackListRef.ListId > 0 {
|
if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 {
|
||||||
list := iplibrary.SharedIPListManager.FindList(inbound.BlackListRef.ListId)
|
list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId)
|
||||||
if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) {
|
if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) {
|
||||||
// TODO 可以配置对封禁的处理方式等
|
// TODO 可以配置对封禁的处理方式等
|
||||||
this.writer.WriteHeader(http.StatusForbidden)
|
this.writer.WriteHeader(http.StatusForbidden)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
// +build windows
|
// +build windows
|
||||||
|
|
||||||
package agent
|
package nodes
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|||||||
114
internal/utils/expires/list.go
Normal file
114
internal/utils/expires/list.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package expires
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ItemMap = map[int64]bool
|
||||||
|
|
||||||
|
type List struct {
|
||||||
|
expireMap map[int64]ItemMap // expires timestamp => map[id]bool
|
||||||
|
itemsMap map[int64]int64 // itemId => timestamp
|
||||||
|
|
||||||
|
locker sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewList() *List {
|
||||||
|
return &List{
|
||||||
|
expireMap: map[int64]ItemMap{},
|
||||||
|
itemsMap: map[int64]int64{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) Add(itemId int64, expiredAt int64) {
|
||||||
|
if expiredAt <= time.Now().Unix() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
this.locker.Lock()
|
||||||
|
defer this.locker.Unlock()
|
||||||
|
|
||||||
|
// 是否已经存在
|
||||||
|
_, ok := this.itemsMap[itemId]
|
||||||
|
if ok {
|
||||||
|
this.removeItem(itemId)
|
||||||
|
}
|
||||||
|
|
||||||
|
expireItemMap, ok := this.expireMap[expiredAt]
|
||||||
|
if ok {
|
||||||
|
expireItemMap[itemId] = true
|
||||||
|
} else {
|
||||||
|
expireItemMap = ItemMap{
|
||||||
|
itemId: true,
|
||||||
|
}
|
||||||
|
this.expireMap[expiredAt] = expireItemMap
|
||||||
|
}
|
||||||
|
|
||||||
|
this.itemsMap[itemId] = expiredAt
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) Remove(itemId int64) {
|
||||||
|
this.locker.Lock()
|
||||||
|
defer this.locker.Unlock()
|
||||||
|
this.removeItem(itemId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) GC(timestamp int64, callback func(itemId int64)) {
|
||||||
|
this.locker.Lock()
|
||||||
|
itemMap := this.gcItems(timestamp)
|
||||||
|
this.locker.Unlock()
|
||||||
|
|
||||||
|
for itemId := range itemMap {
|
||||||
|
callback(itemId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) StartGC(callback func(itemId int64)) {
|
||||||
|
ticker := time.NewTicker(1 * time.Second)
|
||||||
|
lastTimestamp := int64(0)
|
||||||
|
for range ticker.C {
|
||||||
|
timestamp := time.Now().Unix()
|
||||||
|
if lastTimestamp == 0 {
|
||||||
|
lastTimestamp = timestamp - 3600
|
||||||
|
}
|
||||||
|
|
||||||
|
// 防止死循环
|
||||||
|
if lastTimestamp > timestamp {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := lastTimestamp; i <= timestamp; i++ {
|
||||||
|
this.GC(timestamp, callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 这样做是为了防止系统时钟突变
|
||||||
|
lastTimestamp = timestamp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) removeItem(itemId int64) {
|
||||||
|
expiresAt, ok := this.itemsMap[itemId]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(this.itemsMap, itemId)
|
||||||
|
|
||||||
|
expireItemMap, ok := this.expireMap[expiresAt]
|
||||||
|
if ok {
|
||||||
|
delete(expireItemMap, itemId)
|
||||||
|
if len(expireItemMap) == 0 {
|
||||||
|
delete(this.expireMap, expiresAt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *List) gcItems(timestamp int64) ItemMap {
|
||||||
|
expireItemsMap, ok := this.expireMap[timestamp]
|
||||||
|
if ok {
|
||||||
|
for itemId := range expireItemsMap {
|
||||||
|
delete(this.itemsMap, itemId)
|
||||||
|
}
|
||||||
|
delete(this.expireMap, timestamp)
|
||||||
|
}
|
||||||
|
return expireItemsMap
|
||||||
|
}
|
||||||
115
internal/utils/expires/list_test.go
Normal file
115
internal/utils/expires/list_test.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package expires
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/iwind/TeaGo/logs"
|
||||||
|
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||||
|
"math"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestList_Add(t *testing.T) {
|
||||||
|
list := NewList()
|
||||||
|
list.Add(1, time.Now().Unix())
|
||||||
|
t.Log("===BEFORE===")
|
||||||
|
logs.PrintAsJSON(list.expireMap, t)
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
|
||||||
|
list.Add(1, time.Now().Unix()+1)
|
||||||
|
list.Add(2, time.Now().Unix()+1)
|
||||||
|
list.Add(3, time.Now().Unix()+2)
|
||||||
|
t.Log("===AFTER===")
|
||||||
|
logs.PrintAsJSON(list.expireMap, t)
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestList_Add_Overwrite(t *testing.T) {
|
||||||
|
list := NewList()
|
||||||
|
list.Add(1, time.Now().Unix()+1)
|
||||||
|
list.Add(1, time.Now().Unix()+1)
|
||||||
|
list.Add(1, time.Now().Unix()+2)
|
||||||
|
logs.PrintAsJSON(list.expireMap, t)
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestList_Remove(t *testing.T) {
|
||||||
|
list := NewList()
|
||||||
|
list.Add(1, time.Now().Unix()+1)
|
||||||
|
list.Remove(1)
|
||||||
|
logs.PrintAsJSON(list.expireMap, t)
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestList_GC(t *testing.T) {
|
||||||
|
list := NewList()
|
||||||
|
list.Add(1, time.Now().Unix()+1)
|
||||||
|
list.Add(2, time.Now().Unix()+1)
|
||||||
|
list.Add(3, time.Now().Unix()+2)
|
||||||
|
list.GC(time.Now().Unix()+2, func(itemId int64) {
|
||||||
|
t.Log("gc:", itemId)
|
||||||
|
})
|
||||||
|
logs.PrintAsJSON(list.expireMap, t)
|
||||||
|
logs.PrintAsJSON(list.itemsMap, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestList_Start_GC(t *testing.T) {
|
||||||
|
list := NewList()
|
||||||
|
list.Add(1, time.Now().Unix()+1)
|
||||||
|
list.Add(2, time.Now().Unix()+1)
|
||||||
|
list.Add(3, time.Now().Unix()+2)
|
||||||
|
list.Add(4, time.Now().Unix()+5)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
list.StartGC(func(itemId int64) {
|
||||||
|
t.Log("gc:", itemId, timeutil.Format("H:i:s"))
|
||||||
|
time.Sleep(2 * time.Second)
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestList_ManyItems(t *testing.T) {
|
||||||
|
list := NewList()
|
||||||
|
for i := 0; i < 100_000; i++ {
|
||||||
|
list.Add(int64(i), time.Now().Unix()+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
count := 0
|
||||||
|
list.GC(time.Now().Unix()+1, func(itemId int64) {
|
||||||
|
count++
|
||||||
|
})
|
||||||
|
t.Log("gc", count, "items")
|
||||||
|
t.Log(time.Since(now).Seconds()*1000, "ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestList_Map_Performance(t *testing.T) {
|
||||||
|
t.Log("max uint32", math.MaxUint32)
|
||||||
|
|
||||||
|
{
|
||||||
|
m := map[int64]int64{}
|
||||||
|
for i := 0; i < 1_000_000; i++ {
|
||||||
|
m[int64(i)] = time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 100_000; i++ {
|
||||||
|
delete(m, int64(i))
|
||||||
|
}
|
||||||
|
t.Log(time.Since(now).Seconds()*1000, "ms")
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
m := map[uint32]int64{}
|
||||||
|
for i := 0; i < 1_000_000; i++ {
|
||||||
|
m[uint32(i)] = time.Now().Unix()
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for i := 0; i < 100_000; i++ {
|
||||||
|
delete(m, uint32(i))
|
||||||
|
}
|
||||||
|
t.Log(time.Since(now).Seconds()*1000, "ms")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user