diff --git a/internal/iplibrary/ip_item.go b/internal/iplibrary/ip_item.go index 1b7c6bd..257fd00 100644 --- a/internal/iplibrary/ip_item.go +++ b/internal/iplibrary/ip_item.go @@ -2,6 +2,7 @@ package iplibrary import "github.com/TeaOSLab/EdgeNode/internal/utils" +// IP条目 type IPItem struct { Id int64 IPFrom uint32 @@ -9,6 +10,7 @@ type IPItem struct { ExpiredAt int64 } +// 检查是否包含某个IP func (this *IPItem) Contains(ip uint32) bool { if this.IPTo == 0 { if this.IPFrom != ip { diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go index 8daeef3..b402190 100644 --- a/internal/iplibrary/ip_list.go +++ b/internal/iplibrary/ip_list.go @@ -1,45 +1,152 @@ package iplibrary import ( + "github.com/TeaOSLab/EdgeNode/internal/utils/expires" "sync" ) // IP名单 type IPList struct { - itemsMap map[int64]*IPItem // id => item + itemsMap map[int64]*IPItem // id => item + ipMap map[uint32][]int64 // ip => itemIds + expireList *expires.List locker sync.RWMutex } func NewIPList() *IPList { - return &IPList{ + list := &IPList{ itemsMap: map[int64]*IPItem{}, + ipMap: map[uint32][]int64{}, } + + expireList := expires.NewList() + go func() { + expireList.StartGC(func(itemId int64) { + list.Delete(itemId) + }) + }() + list.expireList = expireList + return list } func (this *IPList) Add(item *IPItem) { + if item == nil || (item.IPFrom == 0 && item.IPTo == 0) { + return + } + this.locker.Lock() + + // 是否已经存在 + _, ok := this.itemsMap[item.Id] + if ok { + this.deleteItem(item.Id) + } + this.itemsMap[item.Id] = item + + // 展开 + if item.IPFrom > 0 { + if item.IPTo == 0 { + this.addIP(item.IPFrom, item.Id) + } else { + if item.IPFrom > item.IPTo { + item.IPTo, item.IPFrom = item.IPFrom, item.IPTo + } + + for i := item.IPFrom; i <= item.IPTo; i++ { + // 最多不能超过65535,防止整个系统内存爆掉 + if i >= item.IPFrom+65535 { + break + } + this.addIP(i, item.Id) + } + } + } else if item.IPTo > 0 { + this.addIP(item.IPTo, item.Id) + } + + if item.ExpiredAt > 0 { + this.expireList.Add(item.Id, item.ExpiredAt) + } + this.locker.Unlock() } func (this *IPList) Delete(itemId int64) { this.locker.Lock() - delete(this.itemsMap, itemId) - this.locker.Unlock() + defer this.locker.Unlock() + this.deleteItem(itemId) } // 判断是否包含某个IP func (this *IPList) Contains(ip uint32) bool { - // TODO 优化查询速度,可能需要把items分成两组,一组是单个的,一组是按照范围的,按照范围的再进行二分法查找 this.locker.RLock() - for _, item := range this.itemsMap { - if item.Contains(ip) { - this.locker.RUnlock() - return true - } - } + _, ok := this.ipMap[ip] this.locker.RUnlock() - return false + return ok +} + +// 在不加锁的情况下删除某个Item +// 将会被别的方法引用,切记不能加锁 +func (this *IPList) deleteItem(itemId int64) { + item, ok := this.itemsMap[itemId] + if !ok { + return + } + + delete(this.itemsMap, itemId) + + // 展开 + if item.IPFrom > 0 { + if item.IPTo == 0 { + this.deleteIP(item.IPFrom, item.Id) + } else { + if item.IPFrom > item.IPTo { + item.IPTo, item.IPFrom = item.IPFrom, item.IPTo + } + + for i := item.IPFrom; i <= item.IPTo; i++ { + // 最多不能超过65535,防止整个系统内存爆掉 + if i >= item.IPFrom+65535 { + break + } + this.deleteIP(i, item.Id) + } + } + } else if item.IPTo > 0 { + this.deleteIP(item.IPTo, item.Id) + } +} + +// 添加单个IP +func (this *IPList) addIP(ip uint32, itemId int64) { + itemIds, ok := this.ipMap[ip] + if ok { + itemIds = append(itemIds, itemId) + } else { + itemIds = []int64{itemId} + } + this.ipMap[ip] = itemIds +} + +// 删除单个IP +func (this *IPList) deleteIP(ip uint32, itemId int64) { + itemIds, ok := this.ipMap[ip] + if !ok { + return + } + newItemIds := []int64{} + for _, oldItemId := range itemIds { + if oldItemId == itemId { + continue + } + newItemIds = append(newItemIds, oldItemId) + } + if len(newItemIds) > 0 { + this.ipMap[ip] = newItemIds + } else { + delete(this.ipMap, ip) + } } diff --git a/internal/iplibrary/ip_list_test.go b/internal/iplibrary/ip_list_test.go index f5a7315..d0d67e8 100644 --- a/internal/iplibrary/ip_list_test.go +++ b/internal/iplibrary/ip_list_test.go @@ -1,12 +1,84 @@ package iplibrary import ( + "github.com/iwind/TeaGo/assert" + "github.com/iwind/TeaGo/logs" "runtime" "strconv" "testing" "time" ) +func TestIPList_Add_Empty(t *testing.T) { + ipList := NewIPList() + ipList.Add(&IPItem{ + Id: 1, + }) + logs.PrintAsJSON(ipList.itemsMap, t) + logs.PrintAsJSON(ipList.ipMap, t) +} + +func TestIPList_Add_One(t *testing.T) { + ipList := NewIPList() + ipList.Add(&IPItem{ + Id: 1, + IPFrom: IP2Long("192.168.1.1"), + }) + ipList.Add(&IPItem{ + Id: 2, + IPTo: IP2Long("192.168.1.2"), + }) + logs.PrintAsJSON(ipList.itemsMap, t) + logs.PrintAsJSON(ipList.ipMap, t) +} + +func TestIPList_Update(t *testing.T) { + ipList := NewIPList() + ipList.Add(&IPItem{ + Id: 1, + IPFrom: IP2Long("192.168.1.1"), + }) + /**ipList.Add(&IPItem{ + Id: 2, + IPFrom: IP2Long("192.168.1.1"), + })**/ + ipList.Add(&IPItem{ + Id: 1, + IPTo: IP2Long("192.168.1.2"), + }) + logs.PrintAsJSON(ipList.itemsMap, t) + logs.PrintAsJSON(ipList.ipMap, t) +} + +func TestIPList_Add_Range(t *testing.T) { + ipList := NewIPList() + ipList.Add(&IPItem{ + Id: 1, + IPFrom: IP2Long("192.168.1.1"), + IPTo: IP2Long("192.168.2.1"), + }) + ipList.Add(&IPItem{ + Id: 2, + IPTo: IP2Long("192.168.1.2"), + }) + t.Log(len(ipList.ipMap), "ips") + logs.PrintAsJSON(ipList.itemsMap, t) + logs.PrintAsJSON(ipList.ipMap, t) +} + +func TestIPList_Add_Overflow(t *testing.T) { + a := assert.NewAssertion(t) + + ipList := NewIPList() + ipList.Add(&IPItem{ + Id: 1, + IPFrom: IP2Long("192.168.1.1"), + IPTo: IP2Long("192.169.255.1"), + }) + t.Log(len(ipList.ipMap), "ips") + a.IsTrue(len(ipList.ipMap) <= 65535) +} + func TestNewIPList_Memory(t *testing.T) { list := NewIPList() @@ -26,27 +98,78 @@ func TestIPList_Contains(t *testing.T) { for i := 0; i < 255; i++ { list.Add(&IPItem{ Id: int64(i), - IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)), - IPTo: 0, + IPFrom: IP2Long(strconv.Itoa(i) + ".168.0.1"), + IPTo: IP2Long(strconv.Itoa(i) + ".168.255.1"), ExpiredAt: 0, }) } + t.Log(len(list.ipMap)) + + before := time.Now() t.Log(list.Contains(IP2Long("192.168.1.100"))) t.Log(list.Contains(IP2Long("192.168.2.100"))) + t.Log(time.Since(before).Seconds()*1000, "ms") +} + +func TestIPList_Delete(t *testing.T) { + list := NewIPList() + list.Add(&IPItem{ + Id: 1, + IPFrom: IP2Long("192.168.0.1"), + ExpiredAt: 0, + }) + list.Add(&IPItem{ + Id: 2, + IPFrom: IP2Long("192.168.0.1"), + ExpiredAt: 0, + }) + t.Log("===BEFORE===") + logs.PrintAsJSON(list.itemsMap, t) + logs.PrintAsJSON(list.ipMap, t) + + list.Delete(1) + + t.Log("===AFTER===") + logs.PrintAsJSON(list.itemsMap, t) + logs.PrintAsJSON(list.ipMap, t) +} + +func TestGC(t *testing.T) { + list := NewIPList() + list.Add(&IPItem{ + Id: 1, + IPFrom: IP2Long("192.168.1.100"), + IPTo: IP2Long("192.168.1.101"), + ExpiredAt: time.Now().Unix() + 1, + }) + list.Add(&IPItem{ + Id: 2, + IPFrom: IP2Long("192.168.1.102"), + IPTo: IP2Long("192.168.1.103"), + ExpiredAt: 0, + }) + logs.PrintAsJSON(list.itemsMap, t) + logs.PrintAsJSON(list.ipMap, t) + + time.Sleep(2 * time.Second) + t.Log("===AFTER GC===") + logs.PrintAsJSON(list.itemsMap, t) + logs.PrintAsJSON(list.ipMap, t) } func BenchmarkIPList_Contains(b *testing.B) { runtime.GOMAXPROCS(1) list := NewIPList() - for i := 0; i < 10_000; i++ { + for i := 192; i < 194; i++ { list.Add(&IPItem{ - Id: int64(i), - IPFrom: IP2Long("192.168.1." + strconv.Itoa(i)), - IPTo: 0, + Id: int64(1), + IPFrom: IP2Long(strconv.Itoa(i) + ".1.0.1"), + IPTo: IP2Long(strconv.Itoa(i) + ".2.0.1"), ExpiredAt: time.Now().Unix() + 60, }) } + b.Log(len(list.ipMap), "ip") for i := 0; i < b.N; i++ { _ = list.Contains(IP2Long("192.168.1.100")) } diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 7b5a69b..55550b7 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -74,7 +74,7 @@ type HTTPRequest struct { func (this *HTTPRequest) init() { this.writer = NewHTTPWriter(this, this.RawWriter) this.web = &serverconfigs.HTTPWebConfig{IsOn: true} - //this.uri = this.RawReq.URL.RequestURI() + // this.uri = this.RawReq.URL.RequestURI() // 之所以不使用RequestURI(),是不想让URL中的Path被Encode if len(this.RawReq.URL.RawQuery) > 0 { this.uri = this.RawReq.URL.Path + "?" + this.RawReq.URL.RawQuery @@ -82,7 +82,6 @@ func (this *HTTPRequest) init() { this.uri = this.RawReq.URL.Path } - this.uri = this.RawReq.URL.Path this.rawURI = this.uri this.varMapping = map[string]string{ // 缓存相关初始化 @@ -300,6 +299,9 @@ func (this *HTTPRequest) configureWeb(web *serverconfigs.HTTPWebConfig, isTop bo // waf if web.FirewallRef != nil && (web.FirewallRef.IsPrior || isTop) { this.web.FirewallRef = web.FirewallRef + if web.FirewallPolicy != nil { + this.web.FirewallPolicy = web.FirewallPolicy + } } // access log diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 9318621..34114ac 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -1,6 +1,7 @@ package nodes import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/waf" @@ -11,8 +12,26 @@ import ( // 调用WAF func (this *HTTPRequest) doWAFRequest() (blocked bool) { - firewallPolicy := sharedNodeConfig.HTTPFirewallPolicy + // 当前服务的独立设置 + if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { + blocked = this.checkWAFRequest(this.web.FirewallPolicy) + if blocked { + return + } + } + // 公用的防火墙设置 + if sharedNodeConfig.HTTPFirewallPolicy != nil { + blocked = this.checkWAFRequest(sharedNodeConfig.HTTPFirewallPolicy) + if blocked { + return + } + } + + return +} + +func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFirewallPolicy) (blocked bool) { // 检查配置是否为空 if firewallPolicy == nil || !firewallPolicy.IsOn || firewallPolicy.Inbound == nil || !firewallPolicy.Inbound.IsOn { return @@ -21,16 +40,16 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { // 检查IP白名单 remoteAddr := this.requestRemoteAddr() inbound := firewallPolicy.Inbound - if inbound.WhiteListRef != nil && inbound.WhiteListRef.IsOn && inbound.WhiteListRef.ListId > 0 { - list := iplibrary.SharedIPListManager.FindList(inbound.WhiteListRef.ListId) + if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 { + list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId) if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) { return } } // 检查IP黑名单 - if inbound.BlackListRef != nil && inbound.BlackListRef.IsOn && inbound.BlackListRef.ListId > 0 { - list := iplibrary.SharedIPListManager.FindList(inbound.BlackListRef.ListId) + if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 { + list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId) if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) { // TODO 可以配置对封禁的处理方式等 this.writer.WriteHeader(http.StatusForbidden) diff --git a/internal/nodes/node_status_executor_windows.go b/internal/nodes/node_status_executor_windows.go index c49a912..d53016f 100644 --- a/internal/nodes/node_status_executor_windows.go +++ b/internal/nodes/node_status_executor_windows.go @@ -1,6 +1,6 @@ // +build windows -package agent +package nodes import ( "context" diff --git a/internal/utils/expires/list.go b/internal/utils/expires/list.go new file mode 100644 index 0000000..2e83fe0 --- /dev/null +++ b/internal/utils/expires/list.go @@ -0,0 +1,114 @@ +package expires + +import ( + "sync" + "time" +) + +type ItemMap = map[int64]bool + +type List struct { + expireMap map[int64]ItemMap // expires timestamp => map[id]bool + itemsMap map[int64]int64 // itemId => timestamp + + locker sync.Mutex +} + +func NewList() *List { + return &List{ + expireMap: map[int64]ItemMap{}, + itemsMap: map[int64]int64{}, + } +} + +func (this *List) Add(itemId int64, expiredAt int64) { + if expiredAt <= time.Now().Unix() { + return + } + this.locker.Lock() + defer this.locker.Unlock() + + // 是否已经存在 + _, ok := this.itemsMap[itemId] + if ok { + this.removeItem(itemId) + } + + expireItemMap, ok := this.expireMap[expiredAt] + if ok { + expireItemMap[itemId] = true + } else { + expireItemMap = ItemMap{ + itemId: true, + } + this.expireMap[expiredAt] = expireItemMap + } + + this.itemsMap[itemId] = expiredAt +} + +func (this *List) Remove(itemId int64) { + this.locker.Lock() + defer this.locker.Unlock() + this.removeItem(itemId) +} + +func (this *List) GC(timestamp int64, callback func(itemId int64)) { + this.locker.Lock() + itemMap := this.gcItems(timestamp) + this.locker.Unlock() + + for itemId := range itemMap { + callback(itemId) + } +} + +func (this *List) StartGC(callback func(itemId int64)) { + ticker := time.NewTicker(1 * time.Second) + lastTimestamp := int64(0) + for range ticker.C { + timestamp := time.Now().Unix() + if lastTimestamp == 0 { + lastTimestamp = timestamp - 3600 + } + + // 防止死循环 + if lastTimestamp > timestamp { + continue + } + + for i := lastTimestamp; i <= timestamp; i++ { + this.GC(timestamp, callback) + } + + // 这样做是为了防止系统时钟突变 + lastTimestamp = timestamp + } +} + +func (this *List) removeItem(itemId int64) { + expiresAt, ok := this.itemsMap[itemId] + if !ok { + return + } + delete(this.itemsMap, itemId) + + expireItemMap, ok := this.expireMap[expiresAt] + if ok { + delete(expireItemMap, itemId) + if len(expireItemMap) == 0 { + delete(this.expireMap, expiresAt) + } + } +} + +func (this *List) gcItems(timestamp int64) ItemMap { + expireItemsMap, ok := this.expireMap[timestamp] + if ok { + for itemId := range expireItemsMap { + delete(this.itemsMap, itemId) + } + delete(this.expireMap, timestamp) + } + return expireItemsMap +} diff --git a/internal/utils/expires/list_test.go b/internal/utils/expires/list_test.go new file mode 100644 index 0000000..c4b06d3 --- /dev/null +++ b/internal/utils/expires/list_test.go @@ -0,0 +1,115 @@ +package expires + +import ( + "github.com/iwind/TeaGo/logs" + timeutil "github.com/iwind/TeaGo/utils/time" + "math" + "testing" + "time" +) + +func TestList_Add(t *testing.T) { + list := NewList() + list.Add(1, time.Now().Unix()) + t.Log("===BEFORE===") + logs.PrintAsJSON(list.expireMap, t) + logs.PrintAsJSON(list.itemsMap, t) + + list.Add(1, time.Now().Unix()+1) + list.Add(2, time.Now().Unix()+1) + list.Add(3, time.Now().Unix()+2) + t.Log("===AFTER===") + logs.PrintAsJSON(list.expireMap, t) + logs.PrintAsJSON(list.itemsMap, t) +} + +func TestList_Add_Overwrite(t *testing.T) { + list := NewList() + list.Add(1, time.Now().Unix()+1) + list.Add(1, time.Now().Unix()+1) + list.Add(1, time.Now().Unix()+2) + logs.PrintAsJSON(list.expireMap, t) + logs.PrintAsJSON(list.itemsMap, t) +} + +func TestList_Remove(t *testing.T) { + list := NewList() + list.Add(1, time.Now().Unix()+1) + list.Remove(1) + logs.PrintAsJSON(list.expireMap, t) + logs.PrintAsJSON(list.itemsMap, t) +} + +func TestList_GC(t *testing.T) { + list := NewList() + list.Add(1, time.Now().Unix()+1) + list.Add(2, time.Now().Unix()+1) + list.Add(3, time.Now().Unix()+2) + list.GC(time.Now().Unix()+2, func(itemId int64) { + t.Log("gc:", itemId) + }) + logs.PrintAsJSON(list.expireMap, t) + logs.PrintAsJSON(list.itemsMap, t) +} + +func TestList_Start_GC(t *testing.T) { + list := NewList() + list.Add(1, time.Now().Unix()+1) + list.Add(2, time.Now().Unix()+1) + list.Add(3, time.Now().Unix()+2) + list.Add(4, time.Now().Unix()+5) + + go func() { + list.StartGC(func(itemId int64) { + t.Log("gc:", itemId, timeutil.Format("H:i:s")) + time.Sleep(2 * time.Second) + }) + }() + + time.Sleep(10 * time.Second) +} + +func TestList_ManyItems(t *testing.T) { + list := NewList() + for i := 0; i < 100_000; i++ { + list.Add(int64(i), time.Now().Unix()+1) + } + + now := time.Now() + count := 0 + list.GC(time.Now().Unix()+1, func(itemId int64) { + count++ + }) + t.Log("gc", count, "items") + t.Log(time.Since(now).Seconds()*1000, "ms") +} + +func TestList_Map_Performance(t *testing.T) { + t.Log("max uint32", math.MaxUint32) + + { + m := map[int64]int64{} + for i := 0; i < 1_000_000; i++ { + m[int64(i)] = time.Now().Unix() + } + + now := time.Now() + for i := 0; i < 100_000; i++ { + delete(m, int64(i)) + } + t.Log(time.Since(now).Seconds()*1000, "ms") + } + + { + m := map[uint32]int64{} + for i := 0; i < 1_000_000; i++ { + m[uint32(i)] = time.Now().Unix() + } + + now := time.Now() + for i := 0; i < 100_000; i++ { + delete(m, uint32(i)) + } + t.Log(time.Since(now).Seconds()*1000, "ms") + } +}