IP名单新增IPv6和所有IP两种类型

This commit is contained in:
GoEdgeLab
2021-02-02 15:26:00 +08:00
parent 992d0560b6
commit 9a679c6bc1
12 changed files with 176 additions and 101 deletions

View File

@@ -2,16 +2,39 @@ package iplibrary
import "github.com/TeaOSLab/EdgeNode/internal/utils"
type IPItemType = string
const (
IPItemTypeIPv4 IPItemType = "ipv4" // IPv4
IPItemTypeIPv6 IPItemType = "ipv6" // IPv6
IPItemTypeAll IPItemType = "all" // 所有IP
)
// IP条目
type IPItem struct {
Type string
Id int64
IPFrom uint32
IPTo uint32
IPFrom uint64
IPTo uint64
ExpiredAt int64
}
// 检查是否包含某个IP
func (this *IPItem) Contains(ip uint32) bool {
func (this *IPItem) Contains(ip uint64) bool {
switch this.Type {
case IPItemTypeIPv4:
return this.containsIPv4(ip)
case IPItemTypeIPv6:
return this.containsIPv6(ip)
case IPItemTypeAll:
return this.containsAll(ip)
default:
return this.containsIPv4(ip)
}
}
// 检查是否包含某个IPv4
func (this *IPItem) containsIPv4(ip uint64) bool {
if this.IPTo == 0 {
if this.IPFrom != ip {
return false
@@ -26,3 +49,22 @@ func (this *IPItem) Contains(ip uint32) bool {
}
return true
}
// 检查是否包含某个IPv6
func (this *IPItem) containsIPv6(ip uint64) bool {
if this.IPFrom != ip {
return false
}
if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() {
return false
}
return true
}
// 检查是否包所有IP
func (this *IPItem) containsAll(ip uint64) bool {
if this.ExpiredAt > 0 && this.ExpiredAt < utils.UnixTime() {
return false
}
return true
}

View File

@@ -1,6 +1,7 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/assert"
"testing"
"time"
@@ -11,63 +12,63 @@ func TestIPItem_Contains(t *testing.T) {
{
item := &IPItem{
IPFrom: IP2Long("192.168.1.100"),
IPFrom: utils.IP2Long("192.168.1.100"),
IPTo: 0,
ExpiredAt: 0,
}
a.IsTrue(item.Contains(IP2Long("192.168.1.100")))
a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100")))
}
{
item := &IPItem{
IPFrom: IP2Long("192.168.1.100"),
IPFrom: utils.IP2Long("192.168.1.100"),
IPTo: 0,
ExpiredAt: time.Now().Unix() + 1,
}
a.IsTrue(item.Contains(IP2Long("192.168.1.100")))
a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100")))
}
{
item := &IPItem{
IPFrom: IP2Long("192.168.1.100"),
IPFrom: utils.IP2Long("192.168.1.100"),
IPTo: 0,
ExpiredAt: time.Now().Unix() - 1,
}
a.IsFalse(item.Contains(IP2Long("192.168.1.100")))
a.IsFalse(item.Contains(utils.IP2Long("192.168.1.100")))
}
{
item := &IPItem{
IPFrom: IP2Long("192.168.1.100"),
IPFrom: utils.IP2Long("192.168.1.100"),
IPTo: 0,
ExpiredAt: 0,
}
a.IsFalse(item.Contains(IP2Long("192.168.1.101")))
a.IsFalse(item.Contains(utils.IP2Long("192.168.1.101")))
}
{
item := &IPItem{
IPFrom: IP2Long("192.168.1.1"),
IPTo: IP2Long("192.168.1.101"),
IPFrom: utils.IP2Long("192.168.1.1"),
IPTo: utils.IP2Long("192.168.1.101"),
ExpiredAt: 0,
}
a.IsTrue(item.Contains(IP2Long("192.168.1.100")))
a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100")))
}
{
item := &IPItem{
IPFrom: IP2Long("192.168.1.1"),
IPTo: IP2Long("192.168.1.100"),
IPFrom: utils.IP2Long("192.168.1.1"),
IPTo: utils.IP2Long("192.168.1.100"),
ExpiredAt: 0,
}
a.IsTrue(item.Contains(IP2Long("192.168.1.100")))
a.IsTrue(item.Contains(utils.IP2Long("192.168.1.100")))
}
{
item := &IPItem{
IPFrom: IP2Long("192.168.1.1"),
IPTo: IP2Long("192.168.1.101"),
IPFrom: utils.IP2Long("192.168.1.1"),
IPTo: utils.IP2Long("192.168.1.101"),
ExpiredAt: 0,
}
a.IsTrue(item.Contains(IP2Long("192.168.1.1")))
a.IsTrue(item.Contains(utils.IP2Long("192.168.1.1")))
}
}

View File

@@ -8,16 +8,18 @@ import (
// IP名单
type IPList struct {
itemsMap map[int64]*IPItem // id => item
ipMap map[uint32][]int64 // ip => itemIds
ipMap map[uint64][]int64 // ip => itemIds
expireList *expires.List
isAll bool
locker sync.RWMutex
}
func NewIPList() *IPList {
list := &IPList{
itemsMap: map[int64]*IPItem{},
ipMap: map[uint32][]int64{},
ipMap: map[uint64][]int64{},
}
expireList := expires.NewList()
@@ -31,10 +33,16 @@ func NewIPList() *IPList {
}
func (this *IPList) Add(item *IPItem) {
if item == nil || (item.IPFrom == 0 && item.IPTo == 0) {
if item == nil {
return
}
if item.IPFrom == 0 && item.IPTo == 0 {
if item.Type != "all" {
return
}
}
this.locker.Lock()
// 是否已经存在
@@ -64,6 +72,11 @@ func (this *IPList) Add(item *IPItem) {
}
} else if item.IPTo > 0 {
this.addIP(item.IPTo, item.Id)
} else {
this.addIP(0, item.Id)
// 更新isAll
this.isAll = true
}
if item.ExpiredAt > 0 {
@@ -77,11 +90,18 @@ func (this *IPList) Delete(itemId int64) {
this.locker.Lock()
defer this.locker.Unlock()
this.deleteItem(itemId)
// 更新isAll
this.isAll = len(this.ipMap[0]) > 0
}
// 判断是否包含某个IP
func (this *IPList) Contains(ip uint32) bool {
func (this *IPList) Contains(ip uint64) bool {
this.locker.RLock()
if this.isAll {
this.locker.RUnlock()
return true
}
_, ok := this.ipMap[ip]
this.locker.RUnlock()
@@ -117,11 +137,13 @@ func (this *IPList) deleteItem(itemId int64) {
}
} else if item.IPTo > 0 {
this.deleteIP(item.IPTo, item.Id)
} else {
this.deleteIP(0, item.Id)
}
}
// 添加单个IP
func (this *IPList) addIP(ip uint32, itemId int64) {
func (this *IPList) addIP(ip uint64, itemId int64) {
itemIds, ok := this.ipMap[ip]
if ok {
itemIds = append(itemIds, itemId)
@@ -132,7 +154,7 @@ func (this *IPList) addIP(ip uint32, itemId int64) {
}
// 删除单个IP
func (this *IPList) deleteIP(ip uint32, itemId int64) {
func (this *IPList) deleteIP(ip uint64, itemId int64) {
itemIds, ok := this.ipMap[ip]
if !ok {
return

View File

@@ -1,6 +1,7 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs"
"runtime"
@@ -22,21 +23,30 @@ func TestIPList_Add_One(t *testing.T) {
ipList := NewIPList()
ipList.Add(&IPItem{
Id: 1,
IPFrom: IP2Long("192.168.1.1"),
IPFrom: utils.IP2Long("192.168.1.1"),
})
ipList.Add(&IPItem{
Id: 2,
IPTo: IP2Long("192.168.1.2"),
IPTo: utils.IP2Long("192.168.1.2"),
})
ipList.Add(&IPItem{
Id: 3,
IPFrom: utils.IP2Long("2001:db8:0:1::101"),
})
ipList.Add(&IPItem{
Id: 4,
IPFrom: 0,
Type: "all",
})
logs.PrintAsJSON(ipList.itemsMap, t)
logs.PrintAsJSON(ipList.ipMap, t)
logs.PrintAsJSON(ipList.ipMap, t) // ip => items
}
func TestIPList_Update(t *testing.T) {
ipList := NewIPList()
ipList.Add(&IPItem{
Id: 1,
IPFrom: IP2Long("192.168.1.1"),
IPFrom: utils.IP2Long("192.168.1.1"),
})
/**ipList.Add(&IPItem{
Id: 2,
@@ -44,7 +54,7 @@ func TestIPList_Update(t *testing.T) {
})**/
ipList.Add(&IPItem{
Id: 1,
IPTo: IP2Long("192.168.1.2"),
IPTo: utils.IP2Long("192.168.1.2"),
})
logs.PrintAsJSON(ipList.itemsMap, t)
logs.PrintAsJSON(ipList.ipMap, t)
@@ -54,12 +64,12 @@ func TestIPList_Add_Range(t *testing.T) {
ipList := NewIPList()
ipList.Add(&IPItem{
Id: 1,
IPFrom: IP2Long("192.168.1.1"),
IPTo: IP2Long("192.168.2.1"),
IPFrom: utils.IP2Long("192.168.1.1"),
IPTo: utils.IP2Long("192.168.2.1"),
})
ipList.Add(&IPItem{
Id: 2,
IPTo: IP2Long("192.168.1.2"),
IPTo: utils.IP2Long("192.168.1.2"),
})
t.Log(len(ipList.ipMap), "ips")
logs.PrintAsJSON(ipList.itemsMap, t)
@@ -72,8 +82,8 @@ func TestIPList_Add_Overflow(t *testing.T) {
ipList := NewIPList()
ipList.Add(&IPItem{
Id: 1,
IPFrom: IP2Long("192.168.1.1"),
IPTo: IP2Long("192.169.255.1"),
IPFrom: utils.IP2Long("192.168.1.1"),
IPTo: utils.IP2Long("192.169.255.1"),
})
t.Log(len(ipList.ipMap), "ips")
a.IsTrue(len(ipList.ipMap) <= 65535)
@@ -98,29 +108,54 @@ func TestIPList_Contains(t *testing.T) {
for i := 0; i < 255; i++ {
list.Add(&IPItem{
Id: int64(i),
IPFrom: IP2Long(strconv.Itoa(i) + ".168.0.1"),
IPTo: IP2Long(strconv.Itoa(i) + ".168.255.1"),
IPFrom: utils.IP2Long(strconv.Itoa(i) + ".168.0.1"),
IPTo: utils.IP2Long(strconv.Itoa(i) + ".168.255.1"),
ExpiredAt: 0,
})
}
t.Log(len(list.ipMap))
t.Log(len(list.ipMap), "ip")
before := time.Now()
t.Log(list.Contains(IP2Long("192.168.1.100")))
t.Log(list.Contains(IP2Long("192.168.2.100")))
t.Log(list.Contains(utils.IP2Long("192.168.1.100")))
t.Log(list.Contains(utils.IP2Long("192.168.2.100")))
t.Log(time.Since(before).Seconds()*1000, "ms")
}
func TestIPList_ContainsAll(t *testing.T) {
list := NewIPList()
list.Add(&IPItem{
Id: 1,
Type: "all",
IPFrom: 0,
})
b := list.Contains(utils.IP2Long("192.168.1.1"))
if b {
t.Log(b)
} else {
t.Fatal("'b' should be true")
}
list.Delete(1)
b = list.Contains(utils.IP2Long("192.168.1.1"))
if !b {
t.Log(b)
} else {
t.Fatal("'b' should be false")
}
}
func TestIPList_Delete(t *testing.T) {
list := NewIPList()
list.Add(&IPItem{
Id: 1,
IPFrom: IP2Long("192.168.0.1"),
IPFrom: utils.IP2Long("192.168.0.1"),
ExpiredAt: 0,
})
list.Add(&IPItem{
Id: 2,
IPFrom: IP2Long("192.168.0.1"),
IPFrom: utils.IP2Long("192.168.0.1"),
ExpiredAt: 0,
})
t.Log("===BEFORE===")
@@ -138,14 +173,14 @@ func TestGC(t *testing.T) {
list := NewIPList()
list.Add(&IPItem{
Id: 1,
IPFrom: IP2Long("192.168.1.100"),
IPTo: IP2Long("192.168.1.101"),
IPFrom: utils.IP2Long("192.168.1.100"),
IPTo: utils.IP2Long("192.168.1.101"),
ExpiredAt: time.Now().Unix() + 1,
})
list.Add(&IPItem{
Id: 2,
IPFrom: IP2Long("192.168.1.102"),
IPTo: IP2Long("192.168.1.103"),
IPFrom: utils.IP2Long("192.168.1.102"),
IPTo: utils.IP2Long("192.168.1.103"),
ExpiredAt: 0,
})
logs.PrintAsJSON(list.itemsMap, t)
@@ -164,13 +199,13 @@ func BenchmarkIPList_Contains(b *testing.B) {
for i := 192; i < 194; i++ {
list.Add(&IPItem{
Id: int64(1),
IPFrom: IP2Long(strconv.Itoa(i) + ".1.0.1"),
IPTo: IP2Long(strconv.Itoa(i) + ".2.0.1"),
IPFrom: utils.IP2Long(strconv.Itoa(i) + ".1.0.1"),
IPTo: utils.IP2Long(strconv.Itoa(i) + ".2.0.1"),
ExpiredAt: time.Now().Unix() + 60,
})
}
b.Log(len(list.ipMap), "ip")
for i := 0; i < b.N; i++ {
_ = list.Contains(IP2Long("192.168.1.100"))
_ = list.Contains(utils.IP2Long("192.168.1.100"))
}
}

View File

@@ -1,19 +0,0 @@
package iplibrary
import (
"encoding/binary"
"net"
)
// 将IP转换为整型
func IP2Long(ip string) uint32 {
s := net.ParseIP(ip)
if s == nil {
return 0
}
if len(s) == 16 {
return binary.BigEndian.Uint32(s[12:16])
}
return binary.BigEndian.Uint32(s)
}

View File

@@ -1,21 +0,0 @@
package iplibrary
import (
"runtime"
"testing"
)
func TestIP2Long(t *testing.T) {
t.Log(IP2Long("192.168.1.100"))
t.Log(IP2Long("192.168.1.101"))
t.Log(IP2Long("202.106.0.20"))
t.Log(IP2Long("192.168.1")) // wrong ip, should return 0
}
func BenchmarkIP2Long(b *testing.B) {
runtime.GOMAXPROCS(1)
for i := 0; i < b.N; i++ {
_ = IP2Long("192.168.1.100")
}
}

View File

@@ -3,8 +3,9 @@ package iplibrary
import (
"fmt"
"github.com/TeaOSLab/EdgeNode/internal/errors"
"github.com/iwind/TeaGo/logs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/lionsoul2014/ip2region/binding/golang/ip2region"
"strings"
)
type IP2RegionLibrary struct {
@@ -22,6 +23,11 @@ func (this *IP2RegionLibrary) Load(dbPath string) error {
}
func (this *IP2RegionLibrary) Lookup(ip string) (*Result, error) {
// 暂不支持IPv6
if strings.Contains(ip, ":") {
return nil, nil
}
if this.db == nil {
return nil, errors.New("library has not been loaded")
}
@@ -30,7 +36,7 @@ func (this *IP2RegionLibrary) Lookup(ip string) (*Result, error) {
// 防止panic发生
err := recover()
if err != nil {
logs.Println("[IP2RegionLibrary]panic: " + fmt.Sprintf("%#v", err))
remotelogs.Error("IP2RegionLibrary", "panic: "+fmt.Sprintf("%#v", err))
}
}()

View File

@@ -5,6 +5,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/events"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/rpc"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/iwind/TeaGo/Tea"
"sync"
"time"
@@ -123,8 +124,9 @@ func (this *IPListManager) fetch() (hasNext bool, err error) {
}
list.Add(&IPItem{
Id: item.Id,
IPFrom: IP2Long(item.IpFrom),
IPTo: IP2Long(item.IpTo),
Type: item.Type,
IPFrom: utils.IP2Long(item.IpFrom),
IPTo: utils.IP2Long(item.IpTo),
ExpiredAt: item.ExpiredAt,
})
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
@@ -49,7 +50,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
inbound := firewallPolicy.Inbound
if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 {
list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId)
if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) {
if list != nil && list.Contains(utils.IP2Long(remoteAddr)) {
breakChecking = true
return
}
@@ -58,7 +59,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
// 检查IP黑名单
if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 {
list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId)
if list != nil && list.Contains(iplibrary.IP2Long(remoteAddr)) {
if list != nil && list.Contains(utils.IP2Long(remoteAddr)) {
// TODO 可以配置对封禁的处理方式等
// TODO 需要记录日志信息
this.writer.WriteHeader(http.StatusForbidden)

View File

@@ -2,18 +2,22 @@ package utils
import (
"encoding/binary"
"github.com/cespare/xxhash"
"math"
"net"
"strings"
)
// 将IP转换为整型
func IP2Long(ip string) uint32 {
// 注意IPv6没有顺序
func IP2Long(ip string) uint64 {
s := net.ParseIP(ip)
if s == nil {
return 0
}
if len(s) == 16 {
return binary.BigEndian.Uint32(s[12:16])
if strings.Contains(ip, ":") {
return math.MaxUint32 + xxhash.Sum64String(ip)
}
return binary.BigEndian.Uint32(s)
return uint64(binary.BigEndian.Uint32(s.To4()))
}

View File

@@ -6,4 +6,6 @@ func TestIP2Long(t *testing.T) {
t.Log(IP2Long("0.0.0.0"))
t.Log(IP2Long("1.0.0.0"))
t.Log(IP2Long("0.0.0.0.0"))
t.Log(IP2Long("2001:db8:0:1::101"))
t.Log(IP2Long("2001:db8:0:1::102"))
}

View File

@@ -14,5 +14,5 @@ func VersionToLong(version string) uint32 {
} else if countDots == 0 {
version += ".0.0.0"
}
return IP2Long(version)
return uint32(IP2Long(version))
}