diff --git a/internal/firewalls/utils.go b/internal/firewalls/utils.go new file mode 100644 index 0000000..ea38919 --- /dev/null +++ b/internal/firewalls/utils.go @@ -0,0 +1,29 @@ +// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package firewalls + +import ( + "time" +) + +// DropTemporaryTo 使用本地防火墙临时拦截IP数据包 +func DropTemporaryTo(ip string, expiresAt int64) { + if expiresAt <= 1 { + return + } + + var timeout = expiresAt - time.Now().Unix() + if timeout < 1 { + return + } + if timeout > 3600 { + timeout = 3600 + } + + // 使用本地防火墙延长封禁 + var fw = Firewall() + if fw != nil && !fw.IsMock() { + // 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险 + _ = fw.DropSourceIP(ip, int(timeout), true) + } +} diff --git a/internal/iplibrary/ip_list.go b/internal/iplibrary/ip_list.go index 5081dc4..cd4cc5c 100644 --- a/internal/iplibrary/ip_list.go +++ b/internal/iplibrary/ip_list.go @@ -72,6 +72,25 @@ func (this *IPList) Contains(ip uint64) bool { return item != nil } +// ContainsExpires 判断是否包含某个IP +func (this *IPList) ContainsExpires(ip uint64) (expiresAt int64, ok bool) { + this.locker.RLock() + if len(this.allItemsMap) > 0 { + this.locker.RUnlock() + return 0, true + } + + var item = this.lookupIP(ip) + + this.locker.RUnlock() + + if item == nil { + return + } + + return item.ExpiredAt, true +} + // ContainsIPStrings 是否包含一组IP中的任意一个,并返回匹配的第一个Item func (this *IPList) ContainsIPStrings(ipStrings []string) (item *IPItem, found bool) { if len(ipStrings) == 0 { @@ -155,7 +174,7 @@ func (this *IPList) addItem(item *IPItem, sortable bool) { this.locker.Unlock() } -// 对列表进行排序 +// 对列表进行排序 func (this *IPList) sortItems() { sort.Slice(this.sortedItems, func(i, j int) bool { var item1 = this.sortedItems[i] diff --git a/internal/iplibrary/list_utils.go b/internal/iplibrary/list_utils.go index af1b589..e70a964 100644 --- a/internal/iplibrary/list_utils.go +++ b/internal/iplibrary/list_utils.go @@ -10,50 +10,54 @@ import ( // AllowIP 检查IP是否被允许访问 // 如果一个IP不在任何名单中,则允许访问 -func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) { +func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expiresAt int64) { if !Tea.IsTesting() { // 如果在测试环境,我们不加入一些白名单,以便于可以在本地和局域网正常测试 // 放行lo if ip == "127.0.0.1" || ip == "::1" { - return true, true + return true, true, 0 } // check node nodeConfig, err := nodeconfigs.SharedNodeConfig() if err == nil && nodeConfig.IPIsAutoAllowed(ip) { - return true, true + return true, true, 0 } } var ipLong = utils.IP2Long(ip) if ipLong == 0 { - return false, false + return false, false, 0 } // check white lists if GlobalWhiteIPList.Contains(ipLong) { - return true, true + return true, true, 0 } if serverId > 0 { var list = SharedServerListManager.FindWhiteList(serverId, false) if list != nil && list.Contains(ipLong) { - return true, true + return true, true, 0 } } // check black lists - if GlobalBlackIPList.Contains(ipLong) { - return false, false + expiresAt, ok := GlobalBlackIPList.ContainsExpires(ipLong) + if ok { + return false, false, expiresAt } if serverId > 0 { var list = SharedServerListManager.FindBlackList(serverId, false) - if list != nil && list.Contains(ipLong) { - return false, false + if list != nil { + expiresAt, ok = list.ContainsExpires(ipLong) + if ok { + return false, false, expiresAt + } } } - return true, false + return true, false, 0 } // IsInWhiteList 检查IP是否在白名单中 @@ -73,7 +77,7 @@ func AllowIPStrings(ipStrings []string, serverId int64) bool { return true } for _, ip := range ipStrings { - isAllowed, _ := AllowIP(ip, serverId) + isAllowed, _, _ := AllowIP(ip, serverId) if !isAllowed { return false } diff --git a/internal/nodes/client_conn_base.go b/internal/nodes/client_conn_base.go index fe54891..e7bcb45 100644 --- a/internal/nodes/client_conn_base.go +++ b/internal/nodes/client_conn_base.go @@ -4,6 +4,9 @@ package nodes import ( "crypto/tls" + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "github.com/TeaOSLab/EdgeNode/internal/firewalls" + "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "net" ) @@ -61,6 +64,20 @@ func (this *BaseClientConn) SetServerId(serverId int64) { case *ClientConn: conn.SetServerId(serverId) } + + // 检查服务相关IP黑名单 + if serverId > 0 && len(this.rawIP) > 0 { + var list = iplibrary.SharedServerListManager.FindBlackList(serverId, false) + if list != nil { + expiresAt, ok := list.ContainsExpires(configutils.IPString2Long(this.rawIP)) + if ok { + _ = this.rawConn.Close() + if expiresAt > 0 { + firewalls.DropTemporaryTo(this.rawIP, expiresAt) + } + } + } + } } // ServerId 读取当前连接绑定的服务ID diff --git a/internal/nodes/client_listener.go b/internal/nodes/client_listener.go index de3c4e5..18c6b29 100644 --- a/internal/nodes/client_listener.go +++ b/internal/nodes/client_listener.go @@ -8,7 +8,6 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/waf" "net" - "time" ) // ClientListener 客户端网络监听 @@ -43,24 +42,19 @@ func (this *ClientListener) Accept() (net.Conn, error) { ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) var isInAllowList = false if err == nil { - canGoNext, inAllowList := iplibrary.AllowIP(ip, 0) + canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0) isInAllowList = inAllowList - if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) { - expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) - if ok { - var timeout = expiresAt - time.Now().Unix() - if timeout > 0 { + if !canGoNext { + if expiresAt > 0 { + firewalls.DropTemporaryTo(ip, expiresAt) + } + } else { + if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) { + expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) + if ok { canGoNext = false - - if timeout > 3600 { - timeout = 3600 - } - - // 使用本地防火墙延长封禁 - var fw = firewalls.Firewall() - if fw != nil && !fw.IsMock() { - // 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险 - _ = fw.DropSourceIP(ip, int(timeout), true) + if expiresAt > 0 { + firewalls.DropTemporaryTo(ip, expiresAt) } } } diff --git a/internal/nodes/http_request_limit.go b/internal/nodes/http_request_limit.go index a6e50e4..f78ad13 100644 --- a/internal/nodes/http_request_limit.go +++ b/internal/nodes/http_request_limit.go @@ -9,7 +9,7 @@ import ( func (this *HTTPRequest) doRequestLimit() (shouldStop bool) { // 是否在全局名单中 - _, isInAllowedList := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id) + _, isInAllowedList, _ := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id) if isInAllowedList { return false } diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index 47d0e56..76cd9b1 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -35,7 +35,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) { } // 是否在全局名单中 - canGoNext, isInAllowedList := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id) + canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id) if !canGoNext { this.disableLog = true this.Close() diff --git a/internal/waf/action_js_cookie.go b/internal/waf/action_js_cookie.go index 32d6124..caf94f3 100644 --- a/internal/waf/action_js_cookie.go +++ b/internal/waf/action_js_cookie.go @@ -119,13 +119,7 @@ func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, var countFails = ttlcache.SharedCache.IncreaseInt64(key, 1, time.Now().Unix()+300, true) if int(countFails) >= maxFails { - var useLocalFirewall = false - - if this.Scope == firewallconfigs.FirewallScopeGlobal { - useLocalFirewall = true - } - - SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次") + SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次") return false } diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go index 6f3fd73..4fbfabd 100644 --- a/internal/waf/action_record_ip.go +++ b/internal/waf/action_record_ip.go @@ -30,7 +30,7 @@ type recordIPTask struct { sourceHTTPFirewallRuleSetId int64 } -var recordIPTaskChan = make(chan *recordIPTask, 1024) +var recordIPTaskChan = make(chan *recordIPTask, 2048) func init() { if !teaconst.IsMain { @@ -45,32 +45,56 @@ func init() { return } - for task := range recordIPTaskChan { - ipType := "ipv4" - if strings.Contains(task.ip, ":") { - ipType = "ipv6" - } - var reason = task.reason - if len(reason) == 0 { - reason = "触发WAF规则自动加入" - } - _, err = rpcClient.IPItemRPC.CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{ - IpListId: task.listId, - IpFrom: task.ip, - IpTo: "", - ExpiredAt: task.expiresAt, - Reason: reason, - Type: ipType, - EventLevel: task.level, - ServerId: task.serverId, - SourceNodeId: teaconst.NodeId, - SourceServerId: task.sourceServerId, - SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId, - SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId, - SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId, - }) - if err != nil { - remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error()) + const maxItems = 512 // 每次上传的最大IP数 + + for { + var pbItems = []*pb.CreateIPItemsRequest_IPItem{} + + func() { + for { + select { + case task := <-recordIPTaskChan: + var ipType = "ipv4" + if strings.Contains(task.ip, ":") { + ipType = "ipv6" + } + var reason = task.reason + if len(reason) == 0 { + reason = "触发WAF规则自动加入" + } + + pbItems = append(pbItems, &pb.CreateIPItemsRequest_IPItem{ + IpListId: task.listId, + IpFrom: task.ip, + IpTo: "", + ExpiredAt: task.expiresAt, + Reason: reason, + Type: ipType, + EventLevel: task.level, + ServerId: task.serverId, + SourceNodeId: teaconst.NodeId, + SourceServerId: task.sourceServerId, + SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId, + SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId, + SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId, + }) + + if len(pbItems) >= maxItems { + return + } + default: + return + } + } + }() + + if len(pbItems) > 0 { + _, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems}) + if err != nil { + remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error()) + } + } else { + time.Sleep(1 * time.Second) } } }) diff --git a/internal/waf/captcha_counter.go b/internal/waf/captcha_counter.go index f62d80e..7591797 100644 --- a/internal/waf/captcha_counter.go +++ b/internal/waf/captcha_counter.go @@ -29,13 +29,7 @@ func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, pol } var countFails = ttlcache.SharedCache.IncreaseInt64(CaptchaCacheKey(req, pageCode), 1, time.Now().Unix()+300, true) if int(countFails) >= maxFails { - var useLocalFirewall = false - - if actionConfig.FailBlockScopeAll { - useLocalFirewall = true - } - - SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次") + SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次") return false } } diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index e3a69c6..ac04788 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -10,7 +10,6 @@ import ( "github.com/iwind/TeaGo/types" "sync" "sync/atomic" - "time" ) var SharedIPWhiteList = NewIPList(IPListTypeAllow) @@ -95,6 +94,12 @@ func (this *IPList) RecordIP(ipType string, this.Add(ipType, scope, serverId, ip, expiresAt) if this.listType == IPListTypeDeny { + // 作用域 + var scopeServerId int64 + if scope == firewallconfigs.FirewallScopeService { + scopeServerId = serverId + } + // 加入队列等待上传 select { case recordIPTaskChan <- &recordIPTask{ @@ -102,7 +107,7 @@ func (this *IPList) RecordIP(ipType string, listId: firewallconfigs.GlobalListId, expiresAt: expiresAt, level: firewallconfigs.DefaultEventLevel, - serverId: serverId, + serverId: scopeServerId, sourceServerId: serverId, sourceHTTPFirewallPolicyId: policyId, sourceHTTPFirewallRuleGroupId: groupId, @@ -114,15 +119,8 @@ func (this *IPList) RecordIP(ipType string, } // 使用本地防火墙 - if useLocalFirewall { - var seconds = expiresAt - time.Now().Unix() - if seconds > 0 { - // 最大3600,防止误封时间过长 - if seconds > 3600 { - seconds = 3600 - } - _ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true) - } + if useLocalFirewall && expiresAt > 0 { + firewalls.DropTemporaryTo(ip, expiresAt) } // 关闭此IP相关连接