IP名单改成同时检查多个IP来源

This commit is contained in:
GoEdgeLab
2021-02-02 19:32:19 +08:00
parent 6ca2a42da9
commit b457277b6e
5 changed files with 101 additions and 33 deletions

View File

@@ -1,6 +1,7 @@
package iplibrary
import (
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"sync"
)
@@ -108,6 +109,30 @@ func (this *IPList) Contains(ip uint64) bool {
return ok
}
// 是否包含一组IP
func (this *IPList) ContainsIPStrings(ipStrings []string) bool {
if len(ipStrings) == 0 {
return false
}
this.locker.RLock()
if this.isAll {
this.locker.RUnlock()
return true
}
for _, ipString := range ipStrings {
if len(ipString) == 0 {
continue
}
_, ok := this.ipMap[utils.IP2Long(ipString)]
if ok {
this.locker.RUnlock()
return true
}
}
this.locker.RUnlock()
return false
}
// 在不加锁的情况下删除某个Item
// 将会被别的方法引用,切记不能加锁
func (this *IPList) deleteItem(itemId int64) {

View File

@@ -737,6 +737,44 @@ func (this *HTTPRequest) requestRemoteAddr() string {
}
}
// 获取请求的客户端地址列表
func (this *HTTPRequest) requestRemoteAddrs() (result []string) {
// X-Forwarded-For
forwardedFor := this.RawReq.Header.Get("X-Forwarded-For")
if len(forwardedFor) > 0 {
commaIndex := strings.Index(forwardedFor, ",")
if commaIndex > 0 {
result = append(result, forwardedFor[:commaIndex])
}
}
// Real-IP
{
realIP, ok := this.RawReq.Header["X-Real-IP"]
if ok && len(realIP) > 0 {
result = append(result, realIP[0])
}
}
// Real-Ip
{
realIP, ok := this.RawReq.Header["X-Real-Ip"]
if ok && len(realIP) > 0 {
result = append(result, realIP[0])
}
}
// Remote-Addr
remoteAddr := this.RawReq.RemoteAddr
host, _, err := net.SplitHostPort(remoteAddr)
if err == nil {
result = append(result, host)
} else {
result = append(result, remoteAddr)
}
return
}
// 请求内容长度
func (this *HTTPRequest) requestLength() int64 {
return this.RawReq.ContentLength

View File

@@ -5,7 +5,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
@@ -46,11 +45,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
}
// 检查IP白名单
remoteAddr := this.requestRemoteAddr()
remoteAddrs := this.requestRemoteAddrs()
inbound := firewallPolicy.Inbound
if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 {
list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId)
if list != nil && list.Contains(utils.IP2Long(remoteAddr)) {
if list != nil && list.ContainsIPStrings(remoteAddrs) {
breakChecking = true
return
}
@@ -59,7 +58,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
// 检查IP黑名单
if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 {
list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId)
if list != nil && list.Contains(utils.IP2Long(remoteAddr)) {
if list != nil && list.ContainsIPStrings(remoteAddrs) {
// TODO 可以配置对封禁的处理方式等
// TODO 需要记录日志信息
this.writer.WriteHeader(http.StatusForbidden)
@@ -77,6 +76,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn {
regionConfig := firewallPolicy.Inbound.Region
if regionConfig.IsNotEmpty() {
for _, remoteAddr := range remoteAddrs {
result, err := iplibrary.SharedLibrary.Lookup(remoteAddr)
if err != nil {
remotelogs.Error("REQUEST", "iplibrary lookup failed: "+err.Error())
@@ -116,6 +116,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
}
}
}
}
// 规则测试
w := sharedWAFManager.FindWAF(firewallPolicy.Id)

View File

@@ -11,13 +11,16 @@ import (
// 将IP转换为整型
// 注意IPv6没有顺序
func IP2Long(ip string) uint64 {
if len(ip) == 0 {
return 0
}
s := net.ParseIP(ip)
if s == nil {
if len(s) == 0 {
return 0
}
if strings.Contains(ip, ":") {
return math.MaxUint32 + xxhash.Sum64String(ip)
return math.MaxUint32 + xxhash.Sum64(s)
}
return uint64(binary.BigEndian.Uint32(s.To4()))
}

View File

@@ -8,4 +8,5 @@ func TestIP2Long(t *testing.T) {
t.Log(IP2Long("0.0.0.0.0"))
t.Log(IP2Long("2001:db8:0:1::101"))
t.Log(IP2Long("2001:db8:0:1::102"))
t.Log(IP2Long("::1"))
}