mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2026-02-19 20:15:37 +08:00
用户端可以添加WAF 黑白名单
This commit is contained in:
@@ -2,6 +2,7 @@ package iplibrary
|
||||
|
||||
import "github.com/TeaOSLab/EdgeNode/internal/utils"
|
||||
|
||||
// IP条目
|
||||
type IPItem struct {
|
||||
Id int64
|
||||
IPFrom uint32
|
||||
@@ -9,6 +10,7 @@ type IPItem struct {
|
||||
ExpiredAt int64
|
||||
}
|
||||
|
||||
// 检查是否包含某个IP
|
||||
func (this *IPItem) Contains(ip uint32) bool {
|
||||
if this.IPTo == 0 {
|
||||
if this.IPFrom != ip {
|
||||
|
||||
@@ -1,45 +1,152 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// IP名单
|
||||
type IPList struct {
|
||||
itemsMap map[int64]*IPItem // id => item
|
||||
itemsMap map[int64]*IPItem // id => item
|
||||
ipMap map[uint32][]int64 // ip => itemIds
|
||||
expireList *expires.List
|
||||
|
||||
locker sync.RWMutex
|
||||
}
|
||||
|
||||
func NewIPList() *IPList {
|
||||
return &IPList{
|
||||
list := &IPList{
|
||||
itemsMap: map[int64]*IPItem{},
|
||||
ipMap: map[uint32][]int64{},
|
||||
}
|
||||
|
||||
expireList := expires.NewList()
|
||||
go func() {
|
||||
expireList.StartGC(func(itemId int64) {
|
||||
list.Delete(itemId)
|
||||
})
|
||||
}()
|
||||
list.expireList = expireList
|
||||
return list
|
||||
}
|
||||
|
||||
func (this *IPList) Add(item *IPItem) {
|
||||
if item == nil || (item.IPFrom == 0 && item.IPTo == 0) {
|
||||
return
|
||||
}
|
||||
|
||||
this.locker.Lock()
|
||||
|
||||
// 是否已经存在
|
||||
_, ok := this.itemsMap[item.Id]
|
||||
if ok {
|
||||
this.deleteItem(item.Id)
|
||||
}
|
||||
|
||||
this.itemsMap[item.Id] = item
|
||||
|
||||
// 展开
|
||||
if item.IPFrom > 0 {
|
||||
if item.IPTo == 0 {
|
||||
this.addIP(item.IPFrom, item.Id)
|
||||
} else {
|
||||
if item.IPFrom > item.IPTo {
|
||||
item.IPTo, item.IPFrom = item.IPFrom, item.IPTo
|
||||
}
|
||||
|
||||
for i := item.IPFrom; i <= item.IPTo; i++ {
|
||||
// 最多不能超过65535,防止整个系统内存爆掉
|
||||
if i >= item.IPFrom+65535 {
|
||||
break
|
||||
}
|
||||
this.addIP(i, item.Id)
|
||||
}
|
||||
}
|
||||
} else if item.IPTo > 0 {
|
||||
this.addIP(item.IPTo, item.Id)
|
||||
}
|
||||
|
||||
if item.ExpiredAt > 0 {
|
||||
this.expireList.Add(item.Id, item.ExpiredAt)
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
func (this *IPList) Delete(itemId int64) {
|
||||
this.locker.Lock()
|
||||
delete(this.itemsMap, itemId)
|
||||
this.locker.Unlock()
|
||||
defer this.locker.Unlock()
|
||||
this.deleteItem(itemId)
|
||||
}
|
||||
|
||||
// 判断是否包含某个IP
|
||||
func (this *IPList) Contains(ip uint32) bool {
|
||||
// TODO 优化查询速度,可能需要把items分成两组,一组是单个的,一组是按照范围的,按照范围的再进行二分法查找
|
||||
this.locker.RLock()
|
||||
for _, item := range this.itemsMap {
|
||||
if item.Contains(ip) {
|
||||
this.locker.RUnlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
_, ok := this.ipMap[ip]
|
||||
this.locker.RUnlock()
|
||||
|
||||
return false
|
||||
return ok
|
||||
}
|
||||
|
||||
// 在不加锁的情况下删除某个Item
|
||||
// 将会被别的方法引用,切记不能加锁
|
||||
func (this *IPList) deleteItem(itemId int64) {
|
||||
item, ok := this.itemsMap[itemId]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
delete(this.itemsMap, itemId)
|
||||
|
||||
// 展开
|
||||
if item.IPFrom > 0 {
|
||||
if item.IPTo == 0 {
|
||||
this.deleteIP(item.IPFrom, item.Id)
|
||||
} else {
|
||||
if item.IPFrom > item.IPTo {
|
||||
item.IPTo, item.IPFrom = item.IPFrom, item.IPTo
|
||||
}
|
||||
|
||||
for i := item.IPFrom; i <= item.IPTo; i++ {
|
||||
// 最多不能超过65535,防止整个系统内存爆掉
|
||||
if i >= item.IPFrom+65535 {
|
||||
break
|
||||
}
|
||||
this.deleteIP(i, item.Id)
|
||||
}
|
||||
}
|
||||
} else if item.IPTo > 0 {
|
||||
this.deleteIP(item.IPTo, item.Id)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加单个IP
|
||||
func (this *IPList) addIP(ip uint32, itemId int64) {
|
||||
itemIds, ok := this.ipMap[ip]
|
||||
if ok {
|
||||
itemIds = append(itemIds, itemId)
|
||||
} else {
|
||||
itemIds = []int64{itemId}
|
||||
}
|
||||
this.ipMap[ip] = itemIds
|
||||
}
|
||||
|
||||
// 删除单个IP
|
||||
func (this *IPList) deleteIP(ip uint32, itemId int64) {
|
||||
itemIds, ok := this.ipMap[ip]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
newItemIds := []int64{}
|
||||
for _, oldItemId := range itemIds {
|
||||
if oldItemId == itemId {
|
||||
continue
|
||||
}
|
||||
newItemIds = append(newItemIds, oldItemId)
|
||||
}
|
||||
if len(newItemIds) > 0 {
|
||||
this.ipMap[ip] = newItemIds
|
||||
} else {
|
||||
delete(this.ipMap, ip)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,12 +1,84 @@
|
||||
package iplibrary
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIPList_Add_Empty(t *testing.T) {
|
||||
ipList := NewIPList()
|
||||
ipList.Add(&IPItem{
|
||||
Id: 1,
|
||||
})
|
||||
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||
logs.PrintAsJSON(ipList.ipMap, t)
|
||||
}
|
||||
|
||||
func TestIPList_Add_One(t *testing.T) {
|
||||
ipList := NewIPList()
|
||||
ipList.Add(&IPItem{
|
||||
Id: 1,
|
||||
IPFrom: IP2Long("192.168.1.1"),
|
||||
})
|
||||
ipList.Add(&IPItem{
|
||||
Id: 2,
|
||||
IPTo: IP2Long("192.168.1.2"),
|
||||
})
|
||||
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||
logs.PrintAsJSON(ipList.ipMap, t)
|
||||
}
|
||||
|
||||
func TestIPList_Update(t *testing.T) {
|
||||
ipList := NewIPList()
|
||||
ipList.Add(&IPItem{
|
||||
Id: 1,
|
||||
IPFrom: IP2Long("192.168.1.1"),
|
||||
})
|
||||
/**ipList.Add(&IPItem{
|
||||
Id: 2,
|
||||
IPFrom: IP2Long("192.168.1.1"),
|
||||
})**/
|
||||
ipList.Add(&IPItem{
|
||||
Id: 1,
|
||||
IPTo: IP2Long("192.168.1.2"),
|
||||
})
|
||||
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||
logs.PrintAsJSON(ipList.ipMap, t)
|
||||
}
|
||||
|
||||
func TestIPList_Add_Range(t *testing.T) {
|
||||
ipList := NewIPList()
|
||||
ipList.Add(&IPItem{
|
||||
Id: 1,
|
||||
IPFrom: IP2Long("192.168.1.1"),
|
||||
IPTo: IP2Long("192.168.2.1"),
|
||||
})
|
||||
ipList.Add(&IPItem{
|
||||
Id: 2,
|
||||
IPTo: IP2Long("192.168.1.2"),
|
||||
})
|
||||
t.Log(len(ipList.ipMap), "ips")
|
||||
logs.PrintAsJSON(ipList.itemsMap, t)
|
||||
logs.PrintAsJSON(ipList.ipMap, t)
|
||||
}
|
||||
|
||||
func TestIPList_Add_Overflow(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
|
||||
ipList := NewIPList()
|
||||
ipList.Add(&IPItem{
|
||||
Id: 1,
|
||||
IPFrom: IP2Long("192.168.1.1"),
|
||||
IPTo: IP2Long("192.169.255.1"),
|
||||
})
|
||||
t.Log(len(ipList.ipMap), "ips")
|
||||
a.IsTrue(len(ipList.ipMap) <= 65535)
|
||||
}
|
||||
|
||||
func TestNewIPList_Memory(t *testing.T) {
|
||||
list := NewIPList()
|
||||
|
||||
@@ -26,27 +98,78 @@ func TestIPList_Contains(t *testing.T) {
|
||||
for i := 0; i < 255; i++ {
|
||||
list.Add(&IPItem{
|
||||
Id: int64(i),
|
||||
IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)),
|
||||
IPTo: 0,
|
||||
IPFrom: IP2Long(strconv.Itoa(i) + ".168.0.1"),
|
||||
IPTo: IP2Long(strconv.Itoa(i) + ".168.255.1"),
|
||||
ExpiredAt: 0,
|
||||
})
|
||||
}
|
||||
t.Log(len(list.ipMap))
|
||||
|
||||
before := time.Now()
|
||||
t.Log(list.Contains(IP2Long("192.168.1.100")))
|
||||
t.Log(list.Contains(IP2Long("192.168.2.100")))
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}
|
||||
|
||||
func TestIPList_Delete(t *testing.T) {
|
||||
list := NewIPList()
|
||||
list.Add(&IPItem{
|
||||
Id: 1,
|
||||
IPFrom: IP2Long("192.168.0.1"),
|
||||
ExpiredAt: 0,
|
||||
})
|
||||
list.Add(&IPItem{
|
||||
Id: 2,
|
||||
IPFrom: IP2Long("192.168.0.1"),
|
||||
ExpiredAt: 0,
|
||||
})
|
||||
t.Log("===BEFORE===")
|
||||
logs.PrintAsJSON(list.itemsMap, t)
|
||||
logs.PrintAsJSON(list.ipMap, t)
|
||||
|
||||
list.Delete(1)
|
||||
|
||||
t.Log("===AFTER===")
|
||||
logs.PrintAsJSON(list.itemsMap, t)
|
||||
logs.PrintAsJSON(list.ipMap, t)
|
||||
}
|
||||
|
||||
func TestGC(t *testing.T) {
|
||||
list := NewIPList()
|
||||
list.Add(&IPItem{
|
||||
Id: 1,
|
||||
IPFrom: IP2Long("192.168.1.100"),
|
||||
IPTo: IP2Long("192.168.1.101"),
|
||||
ExpiredAt: time.Now().Unix() + 1,
|
||||
})
|
||||
list.Add(&IPItem{
|
||||
Id: 2,
|
||||
IPFrom: IP2Long("192.168.1.102"),
|
||||
IPTo: IP2Long("192.168.1.103"),
|
||||
ExpiredAt: 0,
|
||||
})
|
||||
logs.PrintAsJSON(list.itemsMap, t)
|
||||
logs.PrintAsJSON(list.ipMap, t)
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
t.Log("===AFTER GC===")
|
||||
logs.PrintAsJSON(list.itemsMap, t)
|
||||
logs.PrintAsJSON(list.ipMap, t)
|
||||
}
|
||||
|
||||
func BenchmarkIPList_Contains(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
list := NewIPList()
|
||||
for i := 0; i < 10_000; i++ {
|
||||
for i := 192; i < 194; i++ {
|
||||
list.Add(&IPItem{
|
||||
Id: int64(i),
|
||||
IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)),
|
||||
IPTo: 0,
|
||||
Id: int64(1),
|
||||
IPFrom: IP2Long(strconv.Itoa(i) + ".1.0.1"),
|
||||
IPTo: IP2Long(strconv.Itoa(i) + ".2.0.1"),
|
||||
ExpiredAt: time.Now().Unix() + 60,
|
||||
})
|
||||
}
|
||||
b.Log(len(list.ipMap), "ip")
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = list.Contains(IP2Long("192.168.1.100"))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user