diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go index 3446f4c..adfd2f2 100644 --- a/internal/iplibrary/ip_list.go +++ b/internal/iplibrary/ip_list.go @@ -1,6 +1,7 @@ package iplibrary import ( + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils/expires" "sync" ) @@ -108,6 +109,30 @@ func (this *IPList) Contains(ip uint64) bool { return ok } +// 是否包含一组IP +func (this *IPList) ContainsIPStrings(ipStrings []string) bool { + if len(ipStrings) == 0 { + return false + } + this.locker.RLock() + if this.isAll { + this.locker.RUnlock() + return true + } + for _, ipString := range ipStrings { + if len(ipString) == 0 { + continue + } + _, ok := this.ipMap[utils.IP2Long(ipString)] + if ok { + this.locker.RUnlock() + return true + } + } + this.locker.RUnlock() + return false +} + // 在不加锁的情况下删除某个Item // 将会被别的方法引用,切记不能加锁 func (this *IPList) deleteItem(itemId int64) { diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index daaf985..2d57d5d 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -737,6 +737,44 @@ func (this *HTTPRequest) requestRemoteAddr() string { } } +// 获取请求的客户端地址列表 +func (this *HTTPRequest) requestRemoteAddrs() (result []string) { + // X-Forwarded-For + forwardedFor := this.RawReq.Header.Get("X-Forwarded-For") + if len(forwardedFor) > 0 { + commaIndex := strings.Index(forwardedFor, ",") + if commaIndex > 0 { + result = append(result, forwardedFor[:commaIndex]) + } + } + + // Real-IP + { + realIP, ok := this.RawReq.Header["X-Real-IP"] + if ok && len(realIP) > 0 { + result = append(result, realIP[0]) + } + } + + // Real-Ip + { + realIP, ok := this.RawReq.Header["X-Real-Ip"] + if ok && len(realIP) > 0 { + result = append(result, realIP[0]) + } + } + + // Remote-Addr + remoteAddr := this.RawReq.RemoteAddr + host, _, err := net.SplitHostPort(remoteAddr) + if err == nil { + result = append(result, host) + } else { + result = append(result, remoteAddr) + } + return +} + // 请求内容长度 func (this *HTTPRequest) requestLength() int64 { return this.RawReq.ContentLength diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 9c92db6..9d8f68a 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -5,7 +5,6 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/stats" - "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" @@ -46,11 +45,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir } // 检查IP白名单 - remoteAddr := this.requestRemoteAddr() + remoteAddrs := this.requestRemoteAddrs() inbound := firewallPolicy.Inbound if inbound.AllowListRef != nil && inbound.AllowListRef.IsOn && inbound.AllowListRef.ListId > 0 { list := iplibrary.SharedIPListManager.FindList(inbound.AllowListRef.ListId) - if list != nil && list.Contains(utils.IP2Long(remoteAddr)) { + if list != nil && list.ContainsIPStrings(remoteAddrs) { breakChecking = true return } @@ -59,7 +58,7 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir // 检查IP黑名单 if inbound.DenyListRef != nil && inbound.DenyListRef.IsOn && inbound.DenyListRef.ListId > 0 { list := iplibrary.SharedIPListManager.FindList(inbound.DenyListRef.ListId) - if list != nil && list.Contains(utils.IP2Long(remoteAddr)) { + if list != nil && list.ContainsIPStrings(remoteAddrs) { // TODO 可以配置对封禁的处理方式等 // TODO 需要记录日志信息 this.writer.WriteHeader(http.StatusForbidden) @@ -77,39 +76,41 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir if firewallPolicy.Inbound.Region != nil && firewallPolicy.Inbound.Region.IsOn { regionConfig := firewallPolicy.Inbound.Region if regionConfig.IsNotEmpty() { - result, err := iplibrary.SharedLibrary.Lookup(remoteAddr) - if err != nil { - remotelogs.Error("REQUEST", "iplibrary lookup failed: "+err.Error()) - } else if result != nil { - // 检查国家级别封禁 - if len(regionConfig.DenyCountryIds) > 0 && len(result.Country) > 0 { - countryId := iplibrary.SharedCountryManager.Lookup(result.Country) - if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) { - // TODO 可以配置对封禁的处理方式等 - // TODO 需要记录日志信息 - this.writer.WriteHeader(http.StatusForbidden) - this.writer.Close() + for _, remoteAddr := range remoteAddrs { + result, err := iplibrary.SharedLibrary.Lookup(remoteAddr) + if err != nil { + remotelogs.Error("REQUEST", "iplibrary lookup failed: "+err.Error()) + } else if result != nil { + // 检查国家级别封禁 + if len(regionConfig.DenyCountryIds) > 0 && len(result.Country) > 0 { + countryId := iplibrary.SharedCountryManager.Lookup(result.Country) + if countryId > 0 && lists.ContainsInt64(regionConfig.DenyCountryIds, countryId) { + // TODO 可以配置对封禁的处理方式等 + // TODO 需要记录日志信息 + this.writer.WriteHeader(http.StatusForbidden) + this.writer.Close() - // 停止日志 - this.disableLog = true + // 停止日志 + this.disableLog = true - return true, false + return true, false + } } - } - // 检查省份封禁 - if len(regionConfig.DenyProvinceIds) > 0 && len(result.Province) > 0 { - provinceId := iplibrary.SharedProvinceManager.Lookup(result.Province) - if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) { - // TODO 可以配置对封禁的处理方式等 - // TODO 需要记录日志信息 - this.writer.WriteHeader(http.StatusForbidden) - this.writer.Close() + // 检查省份封禁 + if len(regionConfig.DenyProvinceIds) > 0 && len(result.Province) > 0 { + provinceId := iplibrary.SharedProvinceManager.Lookup(result.Province) + if provinceId > 0 && lists.ContainsInt64(regionConfig.DenyProvinceIds, provinceId) { + // TODO 可以配置对封禁的处理方式等 + // TODO 需要记录日志信息 + this.writer.WriteHeader(http.StatusForbidden) + this.writer.Close() - // 停止日志 - this.disableLog = true + // 停止日志 + this.disableLog = true - return true, false + return true, false + } } } } diff --git a/internal/utils/ip.go b/internal/utils/ip.go index 4ef80a4..eb85eb6 100644 --- a/internal/utils/ip.go +++ b/internal/utils/ip.go @@ -11,13 +11,16 @@ import ( // 将IP转换为整型 // 注意IPv6没有顺序 func IP2Long(ip string) uint64 { + if len(ip) == 0 { + return 0 + } s := net.ParseIP(ip) - if s == nil { + if len(s) == 0 { return 0 } if strings.Contains(ip, ":") { - return math.MaxUint32 + xxhash.Sum64String(ip) + return math.MaxUint32 + xxhash.Sum64(s) } return uint64(binary.BigEndian.Uint32(s.To4())) } diff --git a/internal/utils/ip_test.go b/internal/utils/ip_test.go index baca495..a074da8 100644 --- a/internal/utils/ip_test.go +++ b/internal/utils/ip_test.go @@ -8,4 +8,5 @@ func TestIP2Long(t *testing.T) { t.Log(IP2Long("0.0.0.0.0")) t.Log(IP2Long("2001:db8:0:1::101")) t.Log(IP2Long("2001:db8:0:1::102")) + t.Log(IP2Long("::1")) }