diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go index 15786ae..6ecd99a 100644 --- a/internal/iplibrary/ip_list.go +++ b/internal/iplibrary/ip_list.go @@ -221,15 +221,15 @@ func (this *IPList) addItem(item *IPItem, lock bool, sortable bool) { this.itemsMap[item.Id] = item // 展开 - if !IsZero(item.IPFrom) { + if item.Type == IPItemTypeAll { + this.allItemsMap[item.Id] = item + } else if !IsZero(item.IPFrom) { if !IsZero(item.IPTo) { this.sortedRangeItems = append(this.sortedRangeItems, item) shouldSort = true } else { this.ipMap[ToHex(item.IPFrom)] = item } - } else { - this.allItemsMap[item.Id] = item } if item.ExpiredAt > 0 { @@ -310,6 +310,12 @@ func (this *IPList) deleteItem(itemId uint64) { // 从buffer中删除 delete(this.bufferItemsMap, itemId) + // 从all items中删除 + _, ok := this.allItemsMap[itemId] + if ok { + delete(this.allItemsMap, itemId) + } + // 检查是否存在 oldItem, existsOld := this.itemsMap[itemId] if !existsOld { @@ -327,13 +333,6 @@ func (this *IPList) deleteItem(itemId uint64) { delete(this.itemsMap, itemId) - // 是否为All Item - _, ok := this.allItemsMap[itemId] - if ok { - delete(this.allItemsMap, itemId) - return - } - // 删除排序中的Item if !IsZero(oldItem.IPTo) { var index = -1 diff --git a/internal/iplibrary/ip_list_test.go b/internal/iplibrary/ip_list_test.go index 5475e28..af39d58 100644 --- a/internal/iplibrary/ip_list_test.go +++ b/internal/iplibrary/ip_list_test.go @@ -245,19 +245,37 @@ func TestIPList_Contains_Many(t *testing.T) { func TestIPList_ContainsAll(t *testing.T) { var a = assert.NewAssertion(t) - var list = iplibrary.NewIPList() - list.Add(&iplibrary.IPItem{ - Id: 1, - Type: "all", - IPFrom: nil, - }) - var b = list.Contains(iputils.ToBytes("192.168.1.1")) - a.IsTrue(b) + { + var list = iplibrary.NewIPList() + list.Add(&iplibrary.IPItem{ + Id: 1, + Type: "all", + IPFrom: nil, + }) + var b = list.Contains(iputils.ToBytes("192.168.1.1")) + a.IsTrue(b) - list.Delete(1) + list.Delete(1) - b = list.Contains(iputils.ToBytes("192.168.1.1")) - a.IsFalse(b) + b = list.Contains(iputils.ToBytes("192.168.1.1")) + a.IsFalse(b) + } + + { + var list = iplibrary.NewIPList() + list.Add(&iplibrary.IPItem{ + Id: 1, + Type: "all", + IPFrom: iputils.ToBytes("0.0.0.0"), + }) + var b = list.Contains(iputils.ToBytes("192.168.1.1")) + a.IsTrue(b) + + list.Delete(1) + + b = list.Contains(iputils.ToBytes("192.168.1.1")) + a.IsFalse(b) + } } func TestIPList_ContainsIPStrings(t *testing.T) {