From c9eb577c0614278b2e846718d367283a27fcf451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Sat, 6 Apr 2024 10:07:39 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E5=A5=BD=E5=9C=B0=E6=94=AF=E6=8C=81IP?= =?UTF-8?q?v6/=E4=BC=98=E5=8C=96IP=E5=90=8D=E5=8D=95=E5=86=85=E5=AD=98?= =?UTF-8?q?=E7=94=A8=E9=87=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/firewalls/firewall_nftables.go | 16 +-- internal/iplibrary/ip_item.go | 38 +++--- internal/iplibrary/ip_item_test.go | 81 +++++++------ internal/iplibrary/ip_list.go | 73 +++++------ internal/iplibrary/ip_list_test.go | 134 ++++++++++----------- internal/iplibrary/list_utils.go | 64 ++++++++-- internal/iplibrary/list_utils_test.go | 2 +- internal/iplibrary/manager_ip_list.go | 19 +-- internal/iplibrary/manager_ip_list_test.go | 26 ++-- internal/iplibrary/server_list_manager.go | 4 + internal/utils/ip.go | 20 --- internal/utils/ip_test.go | 9 -- internal/utils/version.go | 9 +- 13 files changed, 254 insertions(+), 241 deletions(-) diff --git a/internal/firewalls/firewall_nftables.go b/internal/firewalls/firewall_nftables.go index 938d16d..b4c9fc0 100644 --- a/internal/firewalls/firewall_nftables.go +++ b/internal/firewalls/firewall_nftables.go @@ -6,7 +6,7 @@ package firewalls import ( "errors" "fmt" - "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "github.com/TeaOSLab/EdgeCommon/pkg/iputils" "github.com/TeaOSLab/EdgeNode/internal/conns" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" "github.com/TeaOSLab/EdgeNode/internal/events" @@ -386,12 +386,12 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async // 再次尝试关闭连接 defer conns.SharedMap.CloseIPConns(ip) - var ipLong = configutils.IPString2Long(ip) if strings.Contains(ip, ":") { // ipv6 if len(this.denyIPv6Sets) == 0 { return errors.New("ipv6 ip set not found") } - return this.denyIPv6Sets[ipLong%uint64(len(this.denyIPv6Sets))].AddElement(data.To16(), &nftables.ElementOptions{ + var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets)) + return this.denyIPv6Sets[setIndex].AddElement(data.To16(), &nftables.ElementOptions{ Timeout: time.Duration(timeoutSeconds) * time.Second, }, false) } @@ -400,7 +400,8 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async if len(this.denyIPv4Sets) == 0 { return errors.New("ipv4 ip set not found") } - return this.denyIPv4Sets[ipLong%uint64(len(this.denyIPv4Sets))].AddElement(data.To4(), &nftables.ElementOptions{ + var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets)) + return this.denyIPv4Sets[setIndex].AddElement(data.To4(), &nftables.ElementOptions{ Timeout: time.Duration(timeoutSeconds) * time.Second, }, false) } @@ -412,10 +413,10 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error { return errors.New("invalid ip '" + ip + "'") } - var ipLong = configutils.IPString2Long(ip) if strings.Contains(ip, ":") { // ipv6 + var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv6Sets)) if len(this.denyIPv6Sets) > 0 { - err := this.denyIPv6Sets[ipLong%uint64(len(this.denyIPv6Sets))].DeleteElement(data.To16()) + err := this.denyIPv6Sets[setIndex].DeleteElement(data.To16()) if err != nil { return err } @@ -433,7 +434,8 @@ func (this *NFTablesFirewall) RemoveSourceIP(ip string) error { // ipv4 if len(this.denyIPv4Sets) > 0 { - err := this.denyIPv4Sets[ipLong%uint64(len(this.denyIPv4Sets))].DeleteElement(data.To4()) + var setIndex = iputils.ParseIP(ip).Mod(len(this.denyIPv4Sets)) + err := this.denyIPv4Sets[setIndex].DeleteElement(data.To4()) if err != nil { return err } diff --git a/internal/iplibrary/ip_item.go b/internal/iplibrary/ip_item.go index 556b04a..a7488df 100644 --- a/internal/iplibrary/ip_item.go +++ b/internal/iplibrary/ip_item.go @@ -14,36 +14,37 @@ const ( // IPItem IP条目 type IPItem struct { - Type string `json:"type"` - Id uint64 `json:"id"` - IPFrom uint64 `json:"ipFrom"` - IPTo uint64 `json:"ipTo"` + Type string `json:"type"` + Id uint64 `json:"id"` + IPFrom []byte `json:"ipFrom"` + IPTo []byte `json:"ipTo"` + ExpiredAt int64 `json:"expiredAt"` EventLevel string `json:"eventLevel"` } // Contains 检查是否包含某个IP -func (this *IPItem) Contains(ip uint64) bool { +func (this *IPItem) Contains(ipBytes []byte) bool { switch this.Type { case IPItemTypeIPv4: - return this.containsIPv4(ip) + return this.containsIP(ipBytes) case IPItemTypeIPv6: - return this.containsIPv6(ip) + return this.containsIP(ipBytes) case IPItemTypeAll: return this.containsAll() default: - return this.containsIPv4(ip) + return this.containsIP(ipBytes) } } -// 检查是否包含某个IPv4 -func (this *IPItem) containsIPv4(ip uint64) bool { - if this.IPTo == 0 { - if this.IPFrom != ip { +// 检查是否包含某个 +func (this *IPItem) containsIP(ipBytes []byte) bool { + if IsZero(this.IPTo) { + if CompareBytes(this.IPFrom, ipBytes) != 0 { return false } } else { - if this.IPFrom > ip || this.IPTo < ip { + if CompareBytes(this.IPFrom, ipBytes) > 0 || CompareBytes(this.IPTo, ipBytes) < 0 { return false } } @@ -53,17 +54,6 @@ func (this *IPItem) containsIPv4(ip uint64) bool { return true } -// 检查是否包含某个IPv6 -func (this *IPItem) containsIPv6(ip uint64) bool { - if this.IPFrom != ip { - return false - } - if this.ExpiredAt > 0 && this.ExpiredAt < fasttime.Now().Unix() { - return false - } - return true -} - // 检查是否包所有IP func (this *IPItem) containsAll() bool { if this.ExpiredAt > 0 && this.ExpiredAt < fasttime.Now().Unix() { diff --git a/internal/iplibrary/ip_item_test.go b/internal/iplibrary/ip_item_test.go index 9d5c777..3904602 100644 --- a/internal/iplibrary/ip_item_test.go +++ b/internal/iplibrary/ip_item_test.go @@ -1,7 +1,7 @@ -package iplibrary +package iplibrary_test import ( - "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/utils/testutils" "github.com/iwind/TeaGo/assert" "math/rand" @@ -12,89 +12,92 @@ import ( ) func TestIPItem_Contains(t *testing.T) { - a := assert.NewAssertion(t) + var a = assert.NewAssertion(t) { - item := &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.100"), - IPTo: 0, + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.100"), + IPTo: nil, ExpiredAt: 0, } - a.IsTrue(item.Contains(utils.IP2LongHash("192.168.1.100"))) + a.IsTrue(item.Contains(iplibrary.IPBytes("192.168.1.100"))) } { - item := &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.100"), - IPTo: 0, + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.100"), + IPTo: nil, ExpiredAt: time.Now().Unix() + 1, } - a.IsTrue(item.Contains(utils.IP2LongHash("192.168.1.100"))) + a.IsTrue(item.Contains(iplibrary.IPBytes("192.168.1.100"))) } { - item := &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.100"), - IPTo: 0, + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.100"), + IPTo: nil, ExpiredAt: time.Now().Unix() - 1, } - a.IsFalse(item.Contains(utils.IP2LongHash("192.168.1.100"))) + a.IsFalse(item.Contains(iplibrary.IPBytes("192.168.1.100"))) } { - item := &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.100"), - IPTo: 0, + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.100"), + IPTo: nil, ExpiredAt: 0, } - a.IsFalse(item.Contains(utils.IP2LongHash("192.168.1.101"))) + a.IsFalse(item.Contains(iplibrary.IPBytes("192.168.1.101"))) } { - item := &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: utils.IP2LongHash("192.168.1.101"), + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.1"), + IPTo: iplibrary.IPBytes("192.168.1.101"), ExpiredAt: 0, } - a.IsTrue(item.Contains(utils.IP2LongHash("192.168.1.100"))) + a.IsTrue(item.Contains(iplibrary.IPBytes("192.168.1.100"))) } { - item := &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: utils.IP2LongHash("192.168.1.100"), + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.1"), + IPTo: iplibrary.IPBytes("192.168.1.100"), ExpiredAt: 0, } - a.IsTrue(item.Contains(utils.IP2LongHash("192.168.1.100"))) + a.IsTrue(item.Contains(iplibrary.IPBytes("192.168.1.100"))) } { - item := &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: utils.IP2LongHash("192.168.1.101"), + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.1"), + IPTo: iplibrary.IPBytes("192.168.1.101"), ExpiredAt: 0, } - a.IsTrue(item.Contains(utils.IP2LongHash("192.168.1.1"))) + a.IsTrue(item.Contains(iplibrary.IPBytes("192.168.1.1"))) } } func TestIPItem_Memory(t *testing.T) { var isSingleTest = testutils.IsSingleTesting() - var list = NewIPList() + var list = iplibrary.NewIPList() var count = 100 if isSingleTest { count = 2_000_000 } for i := 0; i < count; i++ { - list.Add(&IPItem{ + list.Add(&iplibrary.IPItem{ Type: "ip", Id: uint64(i), - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: 0, + IPFrom: iplibrary.IPBytes("192.168.1.1"), + IPTo: nil, ExpiredAt: time.Now().Unix(), EventLevel: "", }) } + + runtime.GC() + t.Log("waiting") if isSingleTest { time.Sleep(10 * time.Second) @@ -104,9 +107,9 @@ func TestIPItem_Memory(t *testing.T) { func BenchmarkIPItem_Contains(b *testing.B) { runtime.GOMAXPROCS(1) - var item = &IPItem{ - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: utils.IP2LongHash("192.168.1.101"), + var item = &iplibrary.IPItem{ + IPFrom: iplibrary.IPBytes("192.168.1.1"), + IPTo: iplibrary.IPBytes("192.168.1.101"), ExpiredAt: 0, } @@ -114,7 +117,7 @@ func BenchmarkIPItem_Contains(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - var ip = utils.IP2LongHash("192.168.1." + strconv.Itoa(rand.Int()%255)) + var ip = iplibrary.IPBytes("192.168.1." + strconv.Itoa(rand.Int()%255)) item.Contains(ip) } }) diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go index f9185cf..6ae29b6 100644 --- a/internal/iplibrary/ip_list.go +++ b/internal/iplibrary/ip_list.go @@ -1,7 +1,6 @@ package iplibrary import ( - "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" "sort" @@ -12,7 +11,6 @@ var GlobalBlackIPList = NewIPList() var GlobalWhiteIPList = NewIPList() // IPList IP名单 -// TODO 考虑将ipv6单独放入buckets // TODO 对ipMap进行分区 type IPList struct { isDeleted bool @@ -20,8 +18,8 @@ type IPList struct { itemsMap map[uint64]*IPItem // id => item sortedRangeItems []*IPItem - ipMap map[uint64]*IPItem // ipFrom => *IPItem - bufferItemsMap map[uint64]*IPItem // id => *IPItem + ipMap map[string]*IPItem // ipFrom => IPItem + bufferItemsMap map[uint64]*IPItem // id => IPItem allItemsMap map[uint64]*IPItem // id => item @@ -35,7 +33,7 @@ func NewIPList() *IPList { itemsMap: map[uint64]*IPItem{}, bufferItemsMap: map[uint64]*IPItem{}, allItemsMap: map[uint64]*IPItem{}, - ipMap: map[uint64]*IPItem{}, + ipMap: map[string]*IPItem{}, } var expireList = expires.NewList() @@ -59,7 +57,7 @@ func (this *IPList) AddDelay(item *IPItem) { return } - if item.IPTo > 0 { + if !IsZero(item.IPTo) { this.mu.Lock() this.bufferItemsMap[item.Id] = item this.mu.Unlock() @@ -81,7 +79,7 @@ func (this *IPList) Delete(itemId uint64) { } // Contains 判断是否包含某个IP -func (this *IPList) Contains(ip uint64) bool { +func (this *IPList) Contains(ipBytes []byte) bool { if this.isDeleted { return false } @@ -93,13 +91,12 @@ func (this *IPList) Contains(ip uint64) bool { return true } - var item = this.lookupIP(ip) - + var item = this.lookupIP(ipBytes) return item != nil } // ContainsExpires 判断是否包含某个IP -func (this *IPList) ContainsExpires(ip uint64) (expiresAt int64, ok bool) { +func (this *IPList) ContainsExpires(ipBytes []byte) (expiresAt int64, ok bool) { if this.isDeleted { return } @@ -111,7 +108,7 @@ func (this *IPList) ContainsExpires(ip uint64) (expiresAt int64, ok bool) { return 0, true } - var item = this.lookupIP(ip) + var item = this.lookupIP(ipBytes) if item == nil { return @@ -148,7 +145,7 @@ func (this *IPList) ContainsIPStrings(ipStrings []string) (item *IPItem, found b if len(ipString) == 0 { continue } - item = this.lookupIP(utils.IP2LongHash(ipString)) + item = this.lookupIP(IPBytes(ipString)) if item != nil { found = true return @@ -165,7 +162,7 @@ func (this *IPList) SortedRangeItems() []*IPItem { return this.sortedRangeItems } -func (this *IPList) IPMap() map[uint64]*IPItem { +func (this *IPList) IPMap() map[string]*IPItem { return this.ipMap } @@ -177,6 +174,10 @@ func (this *IPList) AllItemsMap() map[uint64]*IPItem { return this.allItemsMap } +func (this *IPList) BufferItemsMap() map[uint64]*IPItem { + return this.bufferItemsMap +} + func (this *IPList) addItem(item *IPItem, lock bool, sortable bool) { if item == nil { return @@ -188,20 +189,20 @@ func (this *IPList) addItem(item *IPItem, lock bool, sortable bool) { var shouldSort bool - if item.IPFrom == item.IPTo { - item.IPTo = 0 + if CompareBytes(item.IPFrom, item.IPTo) == 0 { + item.IPTo = nil } - if item.IPFrom == 0 && item.IPTo == 0 { + if IsZero(item.IPFrom) && IsZero(item.IPTo) { if item.Type != IPItemTypeAll { return } - } else if item.IPTo > 0 { - if item.IPFrom > item.IPTo { + } else if !IsZero(item.IPTo) { + if CompareBytes(item.IPFrom, item.IPTo) > 0 { item.IPFrom, item.IPTo = item.IPTo, item.IPFrom - } else if item.IPFrom == 0 { + } else if IsZero(item.IPFrom) { item.IPFrom = item.IPTo - item.IPTo = 0 + item.IPTo = nil } } @@ -219,12 +220,12 @@ func (this *IPList) addItem(item *IPItem, lock bool, sortable bool) { this.itemsMap[item.Id] = item // 展开 - if item.IPFrom > 0 { - if item.IPTo > 0 { + if !IsZero(item.IPFrom) { + if !IsZero(item.IPTo) { this.sortedRangeItems = append(this.sortedRangeItems, item) shouldSort = true } else { - this.ipMap[item.IPFrom] = item + this.ipMap[ToHex(item.IPFrom)] = item } } else { this.allItemsMap[item.Id] = item @@ -253,18 +254,18 @@ func (this *IPList) sortRangeItems(force bool) { sort.Slice(this.sortedRangeItems, func(i, j int) bool { var item1 = this.sortedRangeItems[i] var item2 = this.sortedRangeItems[j] - if item1.IPFrom == item2.IPFrom { - return item1.IPTo < item2.IPTo + if CompareBytes(item1.IPFrom, item2.IPFrom) == 0 { + return CompareBytes(item1.IPTo, item2.IPTo) < 0 } - return item1.IPFrom < item2.IPFrom + return CompareBytes(item1.IPFrom, item2.IPFrom) < 0 }) } } // 不加锁的情况下查找Item -func (this *IPList) lookupIP(ip uint64) *IPItem { +func (this *IPList) lookupIP(ipBytes []byte) *IPItem { { - item, ok := this.ipMap[ip] + item, ok := this.ipMap[ToHex(ipBytes)] if ok { return item } @@ -278,12 +279,13 @@ func (this *IPList) lookupIP(ip uint64) *IPItem { var resultIndex = -1 sort.Search(count, func(i int) bool { var item = this.sortedRangeItems[i] - if item.IPFrom < ip { - if item.IPTo >= ip { + var cmp = CompareBytes(item.IPFrom, ipBytes) + if cmp < 0 { + if CompareBytes(item.IPTo, ipBytes) >= 0 { resultIndex = i } return false - } else if item.IPFrom == ip { + } else if cmp == 0 { resultIndex = i return false } @@ -310,10 +312,11 @@ func (this *IPList) deleteItem(itemId uint64) { } // 从ipMap中删除 - if oldItem.IPTo == 0 { - ipItem, ok := this.ipMap[oldItem.IPFrom] + if IsZero(oldItem.IPTo) { + var ipHex = ToHex(oldItem.IPFrom) + ipItem, ok := this.ipMap[ipHex] if ok && ipItem.Id == itemId { - delete(this.ipMap, oldItem.IPFrom) + delete(this.ipMap, ipHex) } } @@ -327,7 +330,7 @@ func (this *IPList) deleteItem(itemId uint64) { } // 删除排序中的Item - if oldItem.IPTo > 0 { + if !IsZero(oldItem.IPTo) { var index = -1 for itemIndex, item := range this.sortedRangeItems { if item.Id == itemId { diff --git a/internal/iplibrary/ip_list_test.go b/internal/iplibrary/ip_list_test.go index f49c0d3..4486168 100644 --- a/internal/iplibrary/ip_list_test.go +++ b/internal/iplibrary/ip_list_test.go @@ -3,7 +3,6 @@ package iplibrary_test import ( "fmt" "github.com/TeaOSLab/EdgeNode/internal/iplibrary" - "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" "github.com/TeaOSLab/EdgeNode/internal/utils/testutils" "github.com/iwind/TeaGo/assert" @@ -19,7 +18,7 @@ import ( ) func TestIPList_Add_Empty(t *testing.T) { - ipList := iplibrary.NewIPList() + var ipList = iplibrary.NewIPList() ipList.Add(&iplibrary.IPItem{ Id: 1, }) @@ -29,31 +28,33 @@ func TestIPList_Add_Empty(t *testing.T) { } func TestIPList_Add_One(t *testing.T) { + var a = assert.NewAssertion(t) + var ipList = iplibrary.NewIPList() ipList.Add(&iplibrary.IPItem{ Id: 1, - IPFrom: utils.IP2LongHash("192.168.1.1"), + IPFrom: iplibrary.IPBytes("192.168.1.1"), }) ipList.Add(&iplibrary.IPItem{ Id: 2, - IPTo: utils.IP2LongHash("192.168.1.2"), + IPTo: iplibrary.IPBytes("192.168.1.2"), }) ipList.Add(&iplibrary.IPItem{ Id: 3, - IPFrom: utils.IP2LongHash("192.168.0.2"), + IPFrom: iplibrary.IPBytes("192.168.0.2"), }) ipList.Add(&iplibrary.IPItem{ Id: 4, - IPFrom: utils.IP2LongHash("192.168.0.2"), - IPTo: utils.IP2LongHash("192.168.0.1"), + IPFrom: iplibrary.IPBytes("192.168.0.2"), + IPTo: iplibrary.IPBytes("192.168.0.1"), }) ipList.Add(&iplibrary.IPItem{ Id: 5, - IPFrom: utils.IP2LongHash("2001:db8:0:1::101"), + IPFrom: iplibrary.IPBytes("2001:db8:0:1::101"), }) ipList.Add(&iplibrary.IPItem{ Id: 6, - IPFrom: 0, + IPFrom: nil, Type: "all", }) t.Log("===items===") @@ -63,6 +64,7 @@ func TestIPList_Add_One(t *testing.T) { logs.PrintAsJSON(ipList.SortedRangeItems(), t) t.Log("===all items===") + a.IsTrue(len(ipList.AllItemsMap()) == 1) logs.PrintAsJSON(ipList.AllItemsMap(), t) // ip => items t.Log("===ip items===") @@ -73,7 +75,7 @@ func TestIPList_Update(t *testing.T) { var ipList = iplibrary.NewIPList() ipList.Add(&iplibrary.IPItem{ Id: 1, - IPFrom: utils.IP2LongHash("192.168.1.1"), + IPFrom: iplibrary.IPBytes("192.168.1.1"), }) t.Log("===before===") @@ -83,12 +85,12 @@ func TestIPList_Update(t *testing.T) { /**ipList.Add(&iplibrary.IPItem{ Id: 2, - IPFrom: utils.IP2LongHash("192.168.1.1"), + IPFrom: iplibrary.IPBytes("192.168.1.1"), })**/ ipList.Add(&iplibrary.IPItem{ Id: 1, //IPFrom: 123, - IPTo: utils.IP2LongHash("192.168.1.2"), + IPTo: iplibrary.IPBytes("192.168.1.2"), }) t.Log("===after===") @@ -102,11 +104,11 @@ func TestIPList_Update_AllItems(t *testing.T) { ipList.Add(&iplibrary.IPItem{ Id: 1, Type: iplibrary.IPItemTypeAll, - IPFrom: 0, + IPFrom: nil, }) ipList.Add(&iplibrary.IPItem{ Id: 1, - IPTo: 0, + IPTo: nil, }) t.Log("===items map===") logs.PrintAsJSON(ipList.ItemsMap(), t) @@ -122,17 +124,17 @@ func TestIPList_Add_Range(t *testing.T) { var ipList = iplibrary.NewIPList() ipList.Add(&iplibrary.IPItem{ Id: 1, - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: utils.IP2LongHash("192.168.2.1"), + IPFrom: iplibrary.IPBytes("192.168.1.1"), + IPTo: iplibrary.IPBytes("192.168.2.1"), }) ipList.Add(&iplibrary.IPItem{ Id: 2, - IPTo: utils.IP2LongHash("192.168.1.2"), + IPTo: iplibrary.IPBytes("192.168.1.2"), }) ipList.Add(&iplibrary.IPItem{ Id: 3, - IPFrom: utils.IP2LongHash("192.168.0.1"), - IPTo: utils.IP2LongHash("192.168.0.2"), + IPFrom: iplibrary.IPBytes("192.168.0.1"), + IPTo: iplibrary.IPBytes("192.168.0.2"), }) a.IsTrue(len(ipList.SortedRangeItems()) == 2) @@ -149,19 +151,6 @@ func TestIPList_Add_Range(t *testing.T) { logs.PrintAsJSON(ipList.IPMap(), t) } -func TestIPList_Add_Overflow(t *testing.T) { - var a = assert.NewAssertion(t) - - var ipList = iplibrary.NewIPList() - ipList.Add(&iplibrary.IPItem{ - Id: 1, - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: utils.IP2LongHash("192.169.255.1"), - }) - t.Log(len(ipList.ItemsMap()), "ips") - a.IsTrue(len(ipList.ItemsMap()) <= 65535) -} - func TestNewIPList_Memory(t *testing.T) { var list = iplibrary.NewIPList() @@ -174,11 +163,12 @@ func TestNewIPList_Memory(t *testing.T) { for i := 0; i < count; i++ { list.AddDelay(&iplibrary.IPItem{ Id: uint64(i), - IPFrom: 1, - IPTo: 2, + IPFrom: iplibrary.IPBytes(testutils.RandIP()), + IPTo: iplibrary.IPBytes(testutils.RandIP()), ExpiredAt: time.Now().Unix(), }) } + list.Sort() runtime.GC() @@ -194,25 +184,25 @@ func TestIPList_Contains(t *testing.T) { for i := 0; i < 255; i++ { list.Add(&iplibrary.IPItem{ Id: uint64(i), - IPFrom: utils.IP2LongHash(strconv.Itoa(i) + ".168.0.1"), - IPTo: utils.IP2LongHash(strconv.Itoa(i) + ".168.255.1"), + IPFrom: iplibrary.IPBytes(strconv.Itoa(i) + ".168.0.1"), + IPTo: iplibrary.IPBytes(strconv.Itoa(i) + ".168.255.1"), ExpiredAt: 0, }) } for i := 0; i < 255; i++ { list.Add(&iplibrary.IPItem{ Id: uint64(1000 + i), - IPFrom: utils.IP2LongHash("192.167.2." + strconv.Itoa(i)), + IPFrom: iplibrary.IPBytes("192.167.2." + strconv.Itoa(i)), }) } t.Log(len(list.ItemsMap()), "ip") var before = time.Now() - a.IsTrue(list.Contains(utils.IP2LongHash("192.168.1.100"))) - a.IsTrue(list.Contains(utils.IP2LongHash("192.168.2.100"))) - a.IsFalse(list.Contains(utils.IP2LongHash("192.169.3.100"))) - a.IsFalse(list.Contains(utils.IP2LongHash("192.167.3.100"))) - a.IsTrue(list.Contains(utils.IP2LongHash("192.167.2.100"))) + a.IsTrue(list.Contains(iplibrary.IPBytes("192.168.1.100"))) + a.IsTrue(list.Contains(iplibrary.IPBytes("192.168.2.100"))) + a.IsFalse(list.Contains(iplibrary.IPBytes("192.169.3.100"))) + a.IsFalse(list.Contains(iplibrary.IPBytes("192.167.3.100"))) + a.IsTrue(list.Contains(iplibrary.IPBytes("192.167.2.100"))) t.Log(time.Since(before).Seconds()*1000, "ms") } @@ -221,20 +211,19 @@ func TestIPList_Contains_Many(t *testing.T) { for i := 0; i < 1_000_000; i++ { list.AddDelay(&iplibrary.IPItem{ Id: uint64(i), - IPFrom: utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))), - IPTo: utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))), + IPFrom: iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))), + IPTo: iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255))), ExpiredAt: 0, }) } - list.Sort() - var before = time.Now() + list.Sort() t.Log("sort cost:", time.Since(before).Seconds()*1000, "ms") t.Log(len(list.ItemsMap()), "ip") before = time.Now() - _ = list.Contains(utils.IP2LongHash("192.168.1.100")) + _ = list.Contains(iplibrary.IPBytes("192.168.1.100")) t.Log("contains cost:", time.Since(before).Seconds()*1000, "ms") } @@ -245,14 +234,14 @@ func TestIPList_ContainsAll(t *testing.T) { list.Add(&iplibrary.IPItem{ Id: 1, Type: "all", - IPFrom: 0, + IPFrom: nil, }) - var b = list.Contains(utils.IP2LongHash("192.168.1.1")) + var b = list.Contains(iplibrary.IPBytes("192.168.1.1")) a.IsTrue(b) list.Delete(1) - b = list.Contains(utils.IP2LongHash("192.168.1.1")) + b = list.Contains(iplibrary.IPBytes("192.168.1.1")) a.IsFalse(b) } @@ -263,8 +252,8 @@ func TestIPList_ContainsIPStrings(t *testing.T) { for i := 0; i < 255; i++ { list.Add(&iplibrary.IPItem{ Id: uint64(i), - IPFrom: utils.IP2LongHash(strconv.Itoa(i) + ".168.0.1"), - IPTo: utils.IP2LongHash(strconv.Itoa(i) + ".168.255.1"), + IPFrom: iplibrary.IPBytes(strconv.Itoa(i) + ".168.0.1"), + IPTo: iplibrary.IPBytes(strconv.Itoa(i) + ".168.255.1"), ExpiredAt: 0, }) } @@ -286,18 +275,18 @@ func TestIPList_Delete(t *testing.T) { var list = iplibrary.NewIPList() list.Add(&iplibrary.IPItem{ Id: 1, - IPFrom: utils.IP2LongHash("192.168.0.1"), + IPFrom: iplibrary.IPBytes("192.168.0.1"), ExpiredAt: 0, }) list.Add(&iplibrary.IPItem{ Id: 2, - IPFrom: utils.IP2LongHash("192.168.0.1"), + IPFrom: iplibrary.IPBytes("192.168.0.1"), ExpiredAt: 0, }) list.Add(&iplibrary.IPItem{ Id: 3, - IPFrom: utils.IP2LongHash("192.168.1.1"), - IPTo: utils.IP2LongHash("192.168.2.1"), + IPFrom: iplibrary.IPBytes("192.168.1.1"), + IPTo: iplibrary.IPBytes("192.168.2.1"), ExpiredAt: 0, }) t.Log("===before===") @@ -349,14 +338,14 @@ func TestIPList_GC(t *testing.T) { var list = iplibrary.NewIPList() list.Add(&iplibrary.IPItem{ Id: 1, - IPFrom: utils.IP2LongHash("192.168.1.100"), - IPTo: utils.IP2LongHash("192.168.1.101"), + IPFrom: iplibrary.IPBytes("192.168.1.100"), + IPTo: iplibrary.IPBytes("192.168.1.101"), ExpiredAt: time.Now().Unix() + 1, }) list.Add(&iplibrary.IPItem{ Id: 2, - IPFrom: utils.IP2LongHash("192.168.1.102"), - IPTo: utils.IP2LongHash("192.168.1.103"), + IPFrom: iplibrary.IPBytes("192.168.1.102"), + IPTo: iplibrary.IPBytes("192.168.1.103"), ExpiredAt: 0, }) logs.PrintAsJSON(list.ItemsMap(), t) @@ -372,7 +361,7 @@ func TestIPList_GC(t *testing.T) { a.IsTrue(len(list.SortedRangeItems()) == 1) } -func TestTooManyLists(t *testing.T) { +func TestManyLists(t *testing.T) { debug.SetMaxThreads(20) var lists = []*iplibrary.IPList{} @@ -397,8 +386,8 @@ func BenchmarkIPList_Add(b *testing.B) { for i := 1; i < 200_000; i++ { list.AddDelay(&iplibrary.IPItem{ Id: uint64(i), - IPFrom: utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), - IPTo: utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), + IPFrom: iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), + IPTo: iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), ExpiredAt: time.Now().Unix() + 60, }) } @@ -414,8 +403,8 @@ func BenchmarkIPList_Add(b *testing.B) { list.Add(&iplibrary.IPItem{ Type: "", Id: uint64(i % 1_000_000), - IPFrom: utils.IP2LongHash(ip), - IPTo: 0, + IPFrom: iplibrary.IPBytes(ip), + IPTo: nil, ExpiredAt: fasttime.Now().Unix() + 3600, EventLevel: "", }) @@ -429,11 +418,11 @@ func BenchmarkIPList_Contains(b *testing.B) { for i := 1; i < 1_000_000; i++ { var item = &iplibrary.IPItem{ Id: uint64(i), - IPFrom: utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), + IPFrom: iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), ExpiredAt: time.Now().Unix() + 60, } - if i%1000 == 0 { - item.IPTo = utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1") + if i%100 == 0 { + item.IPTo = iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1") } list.Add(item) } @@ -443,8 +432,7 @@ func BenchmarkIPList_Contains(b *testing.B) { b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { - var ip = fmt.Sprintf("%d.%d.%d.%d", rand.Int()%255, rand.Int()%255, rand.Int()%255, rand.Int()%255) - _ = list.Contains(utils.IP2LongHash(ip)) + _ = list.Contains(iplibrary.IPBytes(testutils.RandIP())) } }) } @@ -454,15 +442,15 @@ func BenchmarkIPList_Sort(b *testing.B) { for i := 0; i < 1_000_000; i++ { var item = &iplibrary.IPItem{ Id: uint64(i), - IPFrom: utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), + IPFrom: iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1"), ExpiredAt: time.Now().Unix() + 60, } if i%100 == 0 { - item.IPTo = utils.IP2LongHash(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1") + item.IPTo = iplibrary.IPBytes(strconv.Itoa(rands.Int(0, 255)) + "." + strconv.Itoa(rands.Int(0, 255)) + ".0.1") } - list.Add(item) + list.AddDelay(item) } b.ResetTimer() diff --git a/internal/iplibrary/list_utils.go b/internal/iplibrary/list_utils.go index 25f66c0..e389d8f 100644 --- a/internal/iplibrary/list_utils.go +++ b/internal/iplibrary/list_utils.go @@ -3,9 +3,11 @@ package iplibrary import ( + "bytes" + "encoding/hex" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" - "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/iwind/TeaGo/Tea" + "net" ) // AllowIP 检查IP是否被允许访问 @@ -24,25 +26,25 @@ func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expir } } - var ipLong = utils.IP2LongHash(ip) - if ipLong == 0 { + var ipBytes = IPBytes(ip) + if IsZero(ipBytes) { return false, false, 0 } // check white lists - if GlobalWhiteIPList.Contains(ipLong) { + if GlobalWhiteIPList.Contains(ipBytes) { return true, true, 0 } if serverId > 0 { var list = SharedServerListManager.FindWhiteList(serverId, false) - if list != nil && list.Contains(ipLong) { + if list != nil && list.Contains(ipBytes) { return true, true, 0 } } // check black lists - expiresAt, ok := GlobalBlackIPList.ContainsExpires(ipLong) + expiresAt, ok := GlobalBlackIPList.ContainsExpires(ipBytes) if ok { return false, false, expiresAt } @@ -50,7 +52,7 @@ func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expir if serverId > 0 { var list = SharedServerListManager.FindBlackList(serverId, false) if list != nil { - expiresAt, ok = list.ContainsExpires(ipLong) + expiresAt, ok = list.ContainsExpires(ipBytes) if ok { return false, false, expiresAt } @@ -62,13 +64,13 @@ func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expir // IsInWhiteList 检查IP是否在白名单中 func IsInWhiteList(ip string) bool { - var ipLong = utils.IP2LongHash(ip) - if ipLong == 0 { + var ipBytes = IPBytes(ip) + if IsZero(ipBytes) { return false } // check white lists - return GlobalWhiteIPList.Contains(ipLong) + return GlobalWhiteIPList.Contains(ipBytes) } // AllowIPStrings 检查一组IP是否被允许访问 @@ -84,3 +86,45 @@ func AllowIPStrings(ipStrings []string, serverId int64) bool { } return true } + +func IsZero(ipBytes []byte) bool { + return len(ipBytes) == 0 +} + +func CompareBytes(b1 []byte, b2 []byte) int { + var l1 = len(b1) + var l2 = len(b2) + if l1 < l2 { + return -1 + } + if l1 > l2 { + return 1 + } + return bytes.Compare(b1, b2) +} + +func IPBytes(ip string) []byte { + if len(ip) == 0 { + return nil + } + + var i = net.ParseIP(ip) + if i == nil { + return nil + } + + var i4 = i.To4() + if i4 != nil { + return i4 + } + + return i.To16() +} + +func ToHex(b []byte) string { + if len(b) == 0 { + return "" + } + + return hex.EncodeToString(b) +} diff --git a/internal/iplibrary/list_utils_test.go b/internal/iplibrary/list_utils_test.go index e21af44..0e91ba2 100644 --- a/internal/iplibrary/list_utils_test.go +++ b/internal/iplibrary/list_utils_test.go @@ -14,7 +14,7 @@ func TestIPIsAllowed(t *testing.T) { } var manager = NewIPListManager() - manager.init() + manager.Init() var before = time.Now() defer func() { diff --git a/internal/iplibrary/manager_ip_list.go b/internal/iplibrary/manager_ip_list.go index c6dffb9..922962c 100644 --- a/internal/iplibrary/manager_ip_list.go +++ b/internal/iplibrary/manager_ip_list.go @@ -9,7 +9,6 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" "github.com/TeaOSLab/EdgeNode/internal/trackers" - "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/zero" "github.com/iwind/TeaGo/Tea" @@ -68,10 +67,10 @@ func NewIPListManager() *IPListManager { } func (this *IPListManager) Start() { - this.init() + this.Init() // 第一次读取 - err := this.loop() + err := this.Loop() if err != nil { remotelogs.ErrorObject("IP_LIST_MANAGER", err) } @@ -86,7 +85,7 @@ func (this *IPListManager) Start() { case <-this.ticker.C: case <-IPListUpdateNotify: } - err = this.loop() + err = this.Loop() if err != nil { countErrors++ @@ -111,7 +110,7 @@ func (this *IPListManager) Stop() { } } -func (this *IPListManager) init() { +func (this *IPListManager) Init() { // 从数据库中当中读取数据 // 检查sqlite文件是否存在,以便决定使用sqlite还是kv var sqlitePath = Tea.Root + "/data/ip_list.db" @@ -164,7 +163,7 @@ func (this *IPListManager) init() { } } -func (this *IPListManager) loop() error { +func (this *IPListManager) Loop() error { // 是否同步IP名单 nodeConfig, _ := nodeconfigs.SharedNodeConfig() if nodeConfig != nil && !nodeConfig.EnableIPLists { @@ -245,6 +244,10 @@ func (this *IPListManager) DeleteExpiredItems() { } } +func (this *IPListManager) ListMap() map[int64]*IPList { + return this.listMap +} + // 处理IP条目 func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) { var changedLists = map[*IPList]zero.Zero{} @@ -301,8 +304,8 @@ func (this *IPListManager) processItems(items []*pb.IPItem, fromRemote bool) { list.AddDelay(&IPItem{ Id: uint64(item.Id), Type: item.Type, - IPFrom: utils.IP2LongHash(item.IpFrom), - IPTo: utils.IP2LongHash(item.IpTo), + IPFrom: IPBytes(item.IpFrom), + IPTo: IPBytes(item.IpTo), ExpiredAt: item.ExpiredAt, EventLevel: item.EventLevel, }) diff --git a/internal/iplibrary/manager_ip_list_test.go b/internal/iplibrary/manager_ip_list_test.go index b596825..5469ef0 100644 --- a/internal/iplibrary/manager_ip_list_test.go +++ b/internal/iplibrary/manager_ip_list_test.go @@ -1,7 +1,7 @@ -package iplibrary +package iplibrary_test import ( - "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/utils/testutils" "github.com/iwind/TeaGo/logs" "testing" @@ -13,11 +13,11 @@ func TestIPListManager_init(t *testing.T) { return } - var manager = NewIPListManager() - manager.init() - t.Log(manager.listMap) - t.Log(SharedServerListManager.blackMap) - logs.PrintAsJSON(GlobalBlackIPList.SortedRangeItems(), t) + var manager = iplibrary.NewIPListManager() + manager.Init() + t.Log(manager.ListMap()) + t.Log(iplibrary.SharedServerListManager.BlackMap()) + logs.PrintAsJSON(iplibrary.GlobalBlackIPList.SortedRangeItems(), t) } func TestIPListManager_check(t *testing.T) { @@ -25,15 +25,15 @@ func TestIPListManager_check(t *testing.T) { return } - var manager = NewIPListManager() - manager.init() + var manager = iplibrary.NewIPListManager() + manager.Init() var before = time.Now() defer func() { t.Log(time.Since(before).Seconds()*1000, "ms") }() - t.Log(SharedServerListManager.FindBlackList(23, true).Contains(utils.IP2LongHash("127.0.0.2"))) - t.Log(GlobalBlackIPList.Contains(utils.IP2LongHash("127.0.0.6"))) + t.Log(iplibrary.SharedServerListManager.FindBlackList(23, true).Contains(iplibrary.IPBytes("127.0.0.2"))) + t.Log(iplibrary.GlobalBlackIPList.Contains(iplibrary.IPBytes("127.0.0.6"))) } func TestIPListManager_loop(t *testing.T) { @@ -41,9 +41,9 @@ func TestIPListManager_loop(t *testing.T) { return } - var manager = NewIPListManager() + var manager = iplibrary.NewIPListManager() manager.Start() - err := manager.loop() + err := manager.Loop() if err != nil { t.Fatal(err) } diff --git a/internal/iplibrary/server_list_manager.go b/internal/iplibrary/server_list_manager.go index 46719ff..67731d6 100644 --- a/internal/iplibrary/server_list_manager.go +++ b/internal/iplibrary/server_list_manager.go @@ -59,3 +59,7 @@ func (this *ServerListManager) FindBlackList(serverId int64, autoCreate bool) *I return nil } + +func (this *ServerListManager) BlackMap() map[int64]*IPList { + return this.blackMap +} diff --git a/internal/utils/ip.go b/internal/utils/ip.go index fb332d6..13f3146 100644 --- a/internal/utils/ip.go +++ b/internal/utils/ip.go @@ -1,30 +1,10 @@ package utils import ( - "encoding/binary" - "github.com/cespare/xxhash" - "math" "net" "strings" ) -// IP2LongHash 非标地将IP转换为整型 -// 注意IPv6没有顺序 -func IP2LongHash(ip string) uint64 { - if len(ip) == 0 { - return 0 - } - s := net.ParseIP(ip) - if len(s) == 0 { - return 0 - } - - if strings.Contains(ip, ":") { - return math.MaxUint32 + xxhash.Sum64(s) - } - return uint64(binary.BigEndian.Uint32(s.To4())) -} - // IsLocalIP 判断是否为本地IP func IsLocalIP(ipString string) bool { var ip = net.ParseIP(ipString) diff --git a/internal/utils/ip_test.go b/internal/utils/ip_test.go index e056be5..536cc85 100644 --- a/internal/utils/ip_test.go +++ b/internal/utils/ip_test.go @@ -6,15 +6,6 @@ import ( "testing" ) -func TestIP2Long(t *testing.T) { - t.Log(utils.IP2LongHash("0.0.0.0")) - t.Log(utils.IP2LongHash("1.0.0.0")) - t.Log(utils.IP2LongHash("0.0.0.0.0")) - t.Log(utils.IP2LongHash("2001:db8:0:1::101")) - t.Log(utils.IP2LongHash("2001:db8:0:1::102")) - t.Log(utils.IP2LongHash("::1")) -} - func TestIsLocalIP(t *testing.T) { var a = assert.NewAssertion(t) a.IsFalse(utils.IsLocalIP("a")) diff --git a/internal/utils/version.go b/internal/utils/version.go index a0aff4c..aaafa08 100644 --- a/internal/utils/version.go +++ b/internal/utils/version.go @@ -1,7 +1,8 @@ package utils import ( - "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "encoding/binary" + "net" "strings" ) @@ -15,5 +16,9 @@ func VersionToLong(version string) uint32 { } else if countDots == 0 { version += ".0.0.0" } - return uint32(configutils.IPString2Long(version)) + var ip = net.ParseIP(version) + if ip == nil || ip.To4() == nil { + return 0 + } + return binary.BigEndian.Uint32(ip.To4()) }