diff --git a/internal/iplibrary/ip_item.go b/internal/iplibrary/ip_item.go index 257fd00..8bb0c74 100644 --- a/internal/iplibrary/ip_item.go +++ b/internal/iplibrary/ip_item.go @@ -2,16 +2,39 @@ package iplibrary import "github.com/TeaOSLab/EdgeNode/internal/utils" +type IPItemType = string + +const ( + IPItemTypeIPv4 IPItemType = "ipv4" // IPv4 + IPItemTypeIPv6 IPItemType = "ipv6" // IPv6 + IPItemTypeAll IPItemType = "all" // 所有IP +) + // IP条目 type IPItem struct { + Type string Id int64 - IPFrom uint32 - IPTo uint32 + IPFrom uint64 + IPTo uint64 ExpiredAt int64 } // 检查是否包含某个IP -func (this *IPItem) Contains(ip uint32) bool { +func (this *IPItem) Contains(ip uint64) bool { + switch this.Type { + case IPItemTypeIPv4: + return this.containsIPv4(ip) + case IPItemTypeIPv6: + return this.containsIPv6(ip) + case IPItemTypeAll: + return this.containsAll(ip) + default: + return this.containsIPv4(ip) + } +} + +// 检查是否包含某个IPv4 +func (this *IPItem) containsIPv4(ip uint64) bool { if this.IPTo == 0 { if this.IPFrom != ip { return false @@ -26,3 +49,22 @@ func (this *IPItem) Contains(ip uint32) bool { } return true } + +// 检查是否包含某个IPv6 +func (this *IPItem) containsIPv6(ip uint64) bool { + if this.IPFrom != ip { + return false + } + if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() { + return false + } + return true +} + +// 检查是否包所有IP +func (this *IPItem) containsAll(ip uint64) bool { + if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() { + return false + } + return true +} diff --git a/internal/iplibrary/ip_item_test.go b/internal/iplibrary/ip_item_test.go index 632cb0a..3597a9c 100644 --- a/internal/iplibrary/ip_item_test.go +++ b/internal/iplibrary/ip_item_test.go @@ -1,6 +1,7 @@ package iplibrary import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/iwind/TeaGo/assert" "testing" "time" @@ -11,63 +12,63 @@ func TestIPItem_Contains(t *testing.T) { { item := &IPItem{ - IPFrom: IP2Long("192.168.1.100"), + IPFrom: utils.IP2Long("192.168.1.100"), IPTo: 0, ExpiredAt: 0, } - a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100"))) } { item := &IPItem{ - IPFrom: IP2Long("192.168.1.100"), + IPFrom: utils.IP2Long("192.168.1.100"), IPTo: 0, ExpiredAt: time.Now().Unix() + 1, } - a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100"))) } { item := &IPItem{ - IPFrom: IP2Long("192.168.1.100"), + IPFrom: utils.IP2Long("192.168.1.100"), IPTo: 0, ExpiredAt: time.Now().Unix() - 1, } - a.IsFalse(item.Contains(IP2Long("192.168.1.100"))) + a.IsFalse(item.Contains(utils.IP2Long("192.168.1.100"))) } { item := &IPItem{ - IPFrom: IP2Long("192.168.1.100"), + IPFrom: utils.IP2Long("192.168.1.100"), IPTo: 0, ExpiredAt: 0, } - a.IsFalse(item.Contains(IP2Long("192.168.1.101"))) + a.IsFalse(item.Contains(utils.IP2Long("192.168.1.101"))) } { item := &IPItem{ - IPFrom: IP2Long("192.168.1.1"), - IPTo: IP2Long("192.168.1.101"), + IPFrom: utils.IP2Long("192.168.1.1"), + IPTo: utils.IP2Long("192.168.1.101"), ExpiredAt: 0, } - a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100"))) } { item := &IPItem{ - IPFrom: IP2Long("192.168.1.1"), - IPTo: IP2Long("192.168.1.100"), + IPFrom: utils.IP2Long("192.168.1.1"), + IPTo: utils.IP2Long("192.168.1.100"), ExpiredAt: 0, } - a.IsTrue(item.Contains(IP2Long("192.168.1.100"))) + a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100"))) } { item := &IPItem{ - IPFrom: IP2Long("192.168.1.1"), - IPTo: IP2Long("192.168.1.101"), + IPFrom: utils.IP2Long("192.168.1.1"), + IPTo: utils.IP2Long("192.168.1.101"), ExpiredAt: 0, } - a.IsTrue(item.Contains(IP2Long("192.168.1.1"))) + a.IsTrue(item.Contains(utils.IP2Long("192.168.1.1"))) } } diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go index b402190..3446f4c 100644 --- a/internal/iplibrary/ip_list.go +++ b/internal/iplibrary/ip_list.go @@ -8,16 +8,18 @@ import ( // IP名单 type IPList struct { itemsMap map[int64]*IPItem // id => item - ipMap map[uint32][]int64 // ip => itemIds + ipMap map[uint64][]int64 // ip => itemIds expireList *expires.List + isAll bool + locker sync.RWMutex } func NewIPList() *IPList { list := &IPList{ itemsMap: map[int64]*IPItem{}, - ipMap: map[uint32][]int64{}, + ipMap: map[uint64][]int64{}, } expireList := expires.NewList() @@ -31,10 +33,16 @@ func NewIPList() *IPList { } func (this *IPList) Add(item *IPItem) { - if item == nil || (item.IPFrom == 0 && item.IPTo == 0) { + if item == nil { return } + if item.IPFrom == 0 && item.IPTo == 0 { + if item.Type != "all" { + return + } + } + this.locker.Lock() // 是否已经存在 @@ -64,6 +72,11 @@ func (this *IPList) Add(item *IPItem) { } } else if item.IPTo > 0 { this.addIP(item.IPTo, item.Id) + } else { + this.addIP(0, item.Id) + + // 更新isAll + this.isAll = true } if item.ExpiredAt > 0 { @@ -77,11 +90,18 @@ func (this *IPList) Delete(itemId int64) { this.locker.Lock() defer this.locker.Unlock() this.deleteItem(itemId) + + // 更新isAll + this.isAll = len(this.ipMap[0]) > 0 } // 判断是否包含某个IP -func (this *IPList) Contains(ip uint32) bool { +func (this *IPList) Contains(ip uint64) bool { this.locker.RLock() + if this.isAll { + this.locker.RUnlock() + return true + } _, ok := this.ipMap[ip] this.locker.RUnlock() @@ -117,11 +137,13 @@ func (this *IPList) deleteItem(itemId int64) { } } else if item.IPTo > 0 { this.deleteIP(item.IPTo, item.Id) + } else { + this.deleteIP(0, item.Id) } } // 添加单个IP -func (this *IPList) addIP(ip uint32, itemId int64) { +func (this *IPList) addIP(ip uint64, itemId int64) { itemIds, ok := this.ipMap[ip] if ok { itemIds = append(itemIds, itemId) @@ -132,7 +154,7 @@ func (this *IPList) addIP(ip uint32, itemId int64) { } // 删除单个IP -func (this *IPList) deleteIP(ip uint32, itemId int64) { +func (this *IPList) deleteIP(ip uint64, itemId int64) { itemIds, ok := this.ipMap[ip] if !ok { return diff --git a/internal/iplibrary/ip_list_test.go b/internal/iplibrary/ip_list_test.go index d0d67e8..82a3dcc 100644 --- a/internal/iplibrary/ip_list_test.go +++ b/internal/iplibrary/ip_list_test.go @@ -1,6 +1,7 @@ package iplibrary import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/logs" "runtime" @@ -22,21 +23,30 @@ func TestIPList_Add_One(t *testing.T) { ipList := NewIPList() ipList.Add(&IPItem{ Id: 1, - IPFrom: IP2Long("192.168.1.1"), + IPFrom: utils.IP2Long("192.168.1.1"), }) ipList.Add(&IPItem{ Id: 2, - IPTo: IP2Long("192.168.1.2"), + IPTo: utils.IP2Long("192.168.1.2"), + }) + ipList.Add(&IPItem{ + Id: 3, + IPFrom: utils.IP2Long("2001:db8:0:1::101"), + }) + ipList.Add(&IPItem{ + Id: 4, + IPFrom: 0, + Type: "all", }) logs.PrintAsJSON(ipList.itemsMap, t) - logs.PrintAsJSON(ipList.ipMap, t) + logs.PrintAsJSON(ipList.ipMap, t) // ip => items } func TestIPList_Update(t *testing.T) { ipList := NewIPList() ipList.Add(&IPItem{ Id: 1, - IPFrom: IP2Long("192.168.1.1"), + IPFrom: utils.IP2Long("192.168.1.1"), }) /**ipList.Add(&IPItem{ Id: 2, @@ -44,7 +54,7 @@ func TestIPList_Update(t *testing.T) { })**/ ipList.Add(&IPItem{ Id: 1, - IPTo: IP2Long("192.168.1.2"), + IPTo: utils.IP2Long("192.168.1.2"), }) logs.PrintAsJSON(ipList.itemsMap, t) logs.PrintAsJSON(ipList.ipMap, t) @@ -54,12 +64,12 @@ func TestIPList_Add_Range(t *testing.T) { ipList := NewIPList() ipList.Add(&IPItem{ Id: 1, - IPFrom: IP2Long("192.168.1.1"), - IPTo: IP2Long("192.168.2.1"), + IPFrom: utils.IP2Long("192.168.1.1"), + IPTo: utils.IP2Long("192.168.2.1"), }) ipList.Add(&IPItem{ Id: 2, - IPTo: IP2Long("192.168.1.2"), + IPTo: utils.IP2Long("192.168.1.2"), }) t.Log(len(ipList.ipMap), "ips") logs.PrintAsJSON(ipList.itemsMap, t) @@ -72,8 +82,8 @@ func TestIPList_Add_Overflow(t *testing.T) { ipList := NewIPList() ipList.Add(&IPItem{ Id: 1, - IPFrom: IP2Long("192.168.1.1"), - IPTo: IP2Long("192.169.255.1"), + IPFrom: utils.IP2Long("192.168.1.1"), + IPTo: utils.IP2Long("192.169.255.1"), }) t.Log(len(ipList.ipMap), "ips") a.IsTrue(len(ipList.ipMap) <= 65535) @@ -98,29 +108,54 @@ func TestIPList_Contains(t *testing.T) { for i := 0; i < 255; i++ { list.Add(&IPItem{ Id: int64(i), - IPFrom: IP2Long(strconv.Itoa(i) + ".168.0.1"), - IPTo: IP2Long(strconv.Itoa(i) + ".168.255.1"), + IPFrom: utils.IP2Long(strconv.Itoa(i) + ".168.0.1"), + IPTo: utils.IP2Long(strconv.Itoa(i) + ".168.255.1"), ExpiredAt: 0, }) } - t.Log(len(list.ipMap)) + t.Log(len(list.ipMap), "ip") before := time.Now() - t.Log(list.Contains(IP2Long("192.168.1.100"))) - t.Log(list.Contains(IP2Long("192.168.2.100"))) + t.Log(list.Contains(utils.IP2Long("192.168.1.100"))) + t.Log(list.Contains(utils.IP2Long("192.168.2.100"))) t.Log(time.Since(before).Seconds()*1000, "ms") } +func TestIPList_ContainsAll(t *testing.T) { + list := NewIPList() + list.Add(&IPItem{ + Id: 1, + Type: "all", + IPFrom: 0, + }) + b := list.Contains(utils.IP2Long("192.168.1.1")) + if b { + t.Log(b) + } else { + t.Fatal("'b' should be true") + } + + list.Delete(1) + + b = list.Contains(utils.IP2Long("192.168.1.1")) + if !b { + t.Log(b) + } else { + t.Fatal("'b' should be false") + } + +} + func TestIPList_Delete(t *testing.T) { list := NewIPList() list.Add(&IPItem{ Id: 1, - IPFrom: IP2Long("192.168.0.1"), + IPFrom: utils.IP2Long("192.168.0.1"), ExpiredAt: 0, }) list.Add(&IPItem{ Id: 2, - IPFrom: IP2Long("192.168.0.1"), + IPFrom: utils.IP2Long("192.168.0.1"), ExpiredAt: 0, }) t.Log("===BEFORE===") @@ -138,14 +173,14 @@ func TestGC(t *testing.T) { list := NewIPList() list.Add(&IPItem{ Id: 1, - IPFrom: IP2Long("192.168.1.100"), - IPTo: IP2Long("192.168.1.101"), + IPFrom: utils.IP2Long("192.168.1.100"), + IPTo: utils.IP2Long("192.168.1.101"), ExpiredAt: time.Now().Unix() + 1, }) list.Add(&IPItem{ Id: 2, - IPFrom: IP2Long("192.168.1.102"), - IPTo: IP2Long("192.168.1.103"), + IPFrom: utils.IP2Long("192.168.1.102"), + IPTo: utils.IP2Long("192.168.1.103"), ExpiredAt: 0, }) logs.PrintAsJSON(list.itemsMap, t) @@ -164,13 +199,13 @@ func BenchmarkIPList_Contains(b *testing.B) { for i := 192; i < 194; i++ { list.Add(&IPItem{ Id: int64(1), - IPFrom: IP2Long(strconv.Itoa(i) + ".1.0.1"), - IPTo: IP2Long(strconv.Itoa(i) + ".2.0.1"), + IPFrom: utils.IP2Long(strconv.Itoa(i) + ".1.0.1"), + IPTo: utils.IP2Long(strconv.Itoa(i) + ".2.0.1"), ExpiredAt: time.Now().Unix() + 60, }) } b.Log(len(list.ipMap), "ip") for i := 0; i < b.N; i++ { - _ = list.Contains(IP2Long("192.168.1.100")) + _ = list.Contains(utils.IP2Long("192.168.1.100")) } } diff --git a/internal/iplibrary/ip_utils.go b/internal/iplibrary/ip_utils.go deleted file mode 100644 index 16c7dd6..0000000 --- a/internal/iplibrary/ip_utils.go +++ /dev/null @@ -1,19 +0,0 @@ -package iplibrary - -import ( - "encoding/binary" - "net" -) - -// 将IP转换为整型 -func IP2Long(ip string) uint32 { - s := net.ParseIP(ip) - if s == nil { - return 0 - } - - if len(s) == 16 { - return binary.BigEndian.Uint32(s[12:16]) - } - return binary.BigEndian.Uint32(s) -} diff --git a/internal/iplibrary/ip_utils_test.go b/internal/iplibrary/ip_utils_test.go deleted file mode 100644 index da7007b..0000000 --- a/internal/iplibrary/ip_utils_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package iplibrary - -import ( - "runtime" - "testing" -) - -func TestIP2Long(t *testing.T) { - t.Log(IP2Long("192.168.1.100")) - t.Log(IP2Long("192.168.1.101")) - t.Log(IP2Long("202.106.0.20")) - t.Log(IP2Long("192.168.1")) // wrong ip, should return 0 -} - -func BenchmarkIP2Long(b *testing.B) { - runtime.GOMAXPROCS(1) - - for i := 0; i < b.N; i++ { - _ = IP2Long("192.168.1.100") - } -} diff --git a/internal/iplibrary/library_ip2region.go b/internal/iplibrary/library_ip2region.go index 2f20f98..cba9715 100644 --- a/internal/iplibrary/library_ip2region.go +++ b/internal/iplibrary/library_ip2region.go @@ -3,8 +3,9 @@ package iplibrary import ( "fmt" "github.com/TeaOSLab/EdgeNode/internal/errors" - "github.com/iwind/TeaGo/logs" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/lionsoul2014/ip2region/binding/golang/ip2region" + "strings" ) type IP2RegionLibrary struct { @@ -22,6 +23,11 @@ func (this *IP2RegionLibrary) Load(dbPath string) error { } func (this *IP2RegionLibrary) Lookup(ip string) (*Result, error) { + // 暂不支持IPv6 + if strings.Contains(ip, ":") { + return nil, nil + } + if this.db == nil { return nil, errors.New("library has not been loaded") } @@ -30,7 +36,7 @@ func (this *IP2RegionLibrary) Lookup(ip string) (*Result, error) { // 防止panic发生 err := recover() if err != nil { - logs.Println("[IP2RegionLibrary]panic: " + fmt.Sprintf("%#v", err)) + remotelogs.Error("IP2RegionLibrary", "panic: "+fmt.Sprintf("%#v", err)) } }() diff --git a/internal/iplibrary/manager_ip_list.go b/internal/iplibrary/manager_ip_list.go index 017cf97..a89dc3d 100644 --- a/internal/iplibrary/manager_ip_list.go +++ b/internal/iplibrary/manager_ip_list.go @@ -5,6 +5,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/events" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/rpc" + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/iwind/TeaGo/Tea" "sync" "time" @@ -123,8 +124,9 @@ func (this *IPListManager) fetch() (hasNext bool, err error) { } list.Add(&IPItem{ Id: item.Id, - IPFrom: IP2Long(item.IpFrom), - IPTo: IP2Long(item.IpTo), + Type: item.Type, + IPFrom: utils.IP2Long(item.IpFrom), + IPTo: utils.IP2Long(item.IpTo), ExpiredAt: item.ExpiredAt, }) } diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 5009a31..9c92db6 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -5,6 +5,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/stats" + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" @@ -49,7 +50,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir inbound := firewallPolicy.Inbound if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 { list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId) - if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) { + if list != nil && list.Contains(utils.IP2Long(remoteAddr)) { breakChecking = true return } @@ -58,7 +59,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir // 检查IP黑名单 if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 { list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId) - if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) { + if list != nil && list.Contains(utils.IP2Long(remoteAddr)) { // TODO 可以配置对封禁的处理方式等 // TODO 需要记录日志信息 this.writer.WriteHeader(http.StatusForbidden) diff --git a/internal/utils/ip.go b/internal/utils/ip.go index 5810520..4ef80a4 100644 --- a/internal/utils/ip.go +++ b/internal/utils/ip.go @@ -2,18 +2,22 @@ package utils import ( "encoding/binary" + "github.com/cespare/xxhash" + "math" "net" + "strings" ) // 将IP转换为整型 -func IP2Long(ip string) uint32 { +// 注意IPv6没有顺序 +func IP2Long(ip string) uint64 { s := net.ParseIP(ip) if s == nil { return 0 } - if len(s) == 16 { - return binary.BigEndian.Uint32(s[12:16]) + if strings.Contains(ip, ":") { + return math.MaxUint32 + xxhash.Sum64String(ip) } - return binary.BigEndian.Uint32(s) + return uint64(binary.BigEndian.Uint32(s.To4())) } diff --git a/internal/utils/ip_test.go b/internal/utils/ip_test.go index 132b4aa..baca495 100644 --- a/internal/utils/ip_test.go +++ b/internal/utils/ip_test.go @@ -6,4 +6,6 @@ func TestIP2Long(t *testing.T) { t.Log(IP2Long("0.0.0.0")) t.Log(IP2Long("1.0.0.0")) t.Log(IP2Long("0.0.0.0.0")) + t.Log(IP2Long("2001:db8:0:1::101")) + t.Log(IP2Long("2001:db8:0:1::102")) } diff --git a/internal/utils/version.go b/internal/utils/version.go index 8940a0c..5742711 100644 --- a/internal/utils/version.go +++ b/internal/utils/version.go @@ -14,5 +14,5 @@ func VersionToLong(version string) uint32 { } else if countDots == 0 { version += ".0.0.0" } - return IP2Long(version) + return uint32(IP2Long(version)) }