IP名单改成同时检查多个IP来源

This commit is contained in:
GoEdgeLab
2021-02-02 19:32:19 +08:00
parent 6ca2a42da9
commit b457277b6e
5 changed files with 101 additions and 33 deletions

View File

@@ -1,6 +1,7 @@
package iplibrary package iplibrary
import ( import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"sync" "sync"
) )
@@ -108,6 +109,30 @@ func (this *IPList) Contains(ip uint64) bool {
return ok return ok
} }
// 是否包含一组IP
func (this *IPList) ContainsIPStrings(ipStrings []string) bool {
if len(ipStrings) == 0 {
return false
}
this.locker.RLock()
if this.isAll {
this.locker.RUnlock()
return true
}
for _, ipString := range ipStrings {
if len(ipString) == 0 {
continue
}
_, ok := this.ipMap[utils.IP2Long(ipString)]
if ok {
this.locker.RUnlock()
return true
}
}
this.locker.RUnlock()
return false
}
// 在不加锁的情况下删除某个Item // 在不加锁的情况下删除某个Item
// 将会被别的方法引用,切记不能加锁 // 将会被别的方法引用,切记不能加锁
func (this *IPList) deleteItem(itemId int64) { func (this *IPList) deleteItem(itemId int64) {

View File

@@ -737,6 +737,44 @@ func (this *HTTPRequest) requestRemoteAddr() string {
} }
} }
// 获取请求的客户端地址列表
func (this *HTTPRequest) requestRemoteAddrs() (result []string) {
// X-Forwarded-For
forwardedFor := this.RawReq.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
commaIndex := strings.Index(forwardedFor, ",")
if commaIndex > 0 {
result = append(result, forwardedFor[:commaIndex])
}
}
// Real-IP
{
realIP, ok := this.RawReq.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
result = append(result, realIP[0])
}
}
// Real-Ip
{
realIP, ok := this.RawReq.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
result = append(result, realIP[0])
}
}
// Remote-Addr
remoteAddr := this.RawReq.RemoteAddr
host, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
result = append(result, host)
} else {
result = append(result, remoteAddr)
}
return
}
// 请求内容长度 // 请求内容长度
func (this *HTTPRequest) requestLength() int64 { func (this *HTTPRequest) requestLength() int64 {
return this.RawReq.ContentLength return this.RawReq.ContentLength

View File

@@ -5,7 +5,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
@@ -46,11 +45,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
} }
// 检查IP白名单 // 检查IP白名单
remoteAddr := this.requestRemoteAddr() remoteAddrs := this.requestRemoteAddrs()
inbound := firewallPolicy.Inbound inbound := firewallPolicy.Inbound
if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 { if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 {
list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId) list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId)
if list != nil && list.Contains(utils.IP2Long(remoteAddr)) { if list != nil && list.ContainsIPStrings(remoteAddrs) {
breakChecking = true breakChecking = true
return return
} }
@@ -59,7 +58,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
// 检查IP黑名单 // 检查IP黑名单
if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 { if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 {
list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId) list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId)
if list != nil && list.Contains(utils.IP2Long(remoteAddr)) { if list != nil && list.ContainsIPStrings(remoteAddrs) {
// TODO 可以配置对封禁的处理方式等 // TODO 可以配置对封禁的处理方式等
// TODO 需要记录日志信息 // TODO 需要记录日志信息
this.writer.WriteHeader(http.StatusForbidden) this.writer.WriteHeader(http.StatusForbidden)
@@ -77,39 +76,41 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn { if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
regionConfig := firewallPolicy.Inbound.Region regionConfig := firewallPolicy.Inbound.Region
if regionConfig.IsNotEmpty() { if regionConfig.IsNotEmpty() {
result, err := iplibrary.SharedLibrary.Lookup(remoteAddr) for _, remoteAddr := range remoteAddrs {
if err != nil { result, err := iplibrary.SharedLibrary.Lookup(remoteAddr)
remotelogs.Error("REQUEST", "iplibrary lookup failed: "+err.Error()) if err != nil {
} else if result != nil { remotelogs.Error("REQUEST", "iplibrary lookup failed: "+err.Error())
// 检查国家级别封禁 } else if result != nil {
if len(regionConfig.DenyCountryIds) > 0 && len(result.Country) > 0 { // 检查国家级别封禁
countryId := iplibrary.SharedCountryManager.Lookup(result.Country) if len(regionConfig.DenyCountryIds) > 0 && len(result.Country) > 0 {
if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) { countryId := iplibrary.SharedCountryManager.Lookup(result.Country)
// TODO 可以配置对封禁的处理方式等 if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) {
// TODO 需要记录日志信息 // TODO 可以配置对封禁的处理方式等
this.writer.WriteHeader(http.StatusForbidden) // TODO 需要记录日志信息
this.writer.Close() this.writer.WriteHeader(http.StatusForbidden)
this.writer.Close()
// 停止日志 // 停止日志
this.disableLog = true this.disableLog = true
return true, false return true, false
}
} }
}
// 检查省份封禁 // 检查省份封禁
if len(regionConfig.DenyProvinceIds) > 0 && len(result.Province) > 0 { if len(regionConfig.DenyProvinceIds) > 0 && len(result.Province) > 0 {
provinceId := iplibrary.SharedProvinceManager.Lookup(result.Province) provinceId := iplibrary.SharedProvinceManager.Lookup(result.Province)
if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) { if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) {
// TODO 可以配置对封禁的处理方式等 // TODO 可以配置对封禁的处理方式等
// TODO 需要记录日志信息 // TODO 需要记录日志信息
this.writer.WriteHeader(http.StatusForbidden) this.writer.WriteHeader(http.StatusForbidden)
this.writer.Close() this.writer.Close()
// 停止日志 // 停止日志
this.disableLog = true this.disableLog = true
return true, false return true, false
}
} }
} }
} }

View File

@@ -11,13 +11,16 @@ import (
// 将IP转换为整型 // 将IP转换为整型
// 注意IPv6没有顺序 // 注意IPv6没有顺序
func IP2Long(ip string) uint64 { func IP2Long(ip string) uint64 {
if len(ip) == 0 {
return 0
}
s := net.ParseIP(ip) s := net.ParseIP(ip)
if s == nil { if len(s) == 0 {
return 0 return 0
} }
if strings.Contains(ip, ":") { if strings.Contains(ip, ":") {
return math.MaxUint32 + xxhash.Sum64String(ip) return math.MaxUint32 + xxhash.Sum64(s)
} }
return uint64(binary.BigEndian.Uint32(s.To4())) return uint64(binary.BigEndian.Uint32(s.To4()))
} }

View File

@@ -8,4 +8,5 @@ func TestIP2Long(t *testing.T) {
t.Log(IP2Long("0.0.0.0.0")) t.Log(IP2Long("0.0.0.0.0"))
t.Log(IP2Long("2001:db8:0:1::101")) t.Log(IP2Long("2001:db8:0:1::101"))
t.Log(IP2Long("2001:db8:0:1::102")) t.Log(IP2Long("2001:db8:0:1::102"))
t.Log(IP2Long("::1"))
} }