diff --git a/internal/firewalls/ddos_protection.go b/internal/firewalls/ddos_protection.go index c80ded7..9054b62 100644 --- a/internal/firewalls/ddos_protection.go +++ b/internal/firewalls/ddos_protection.go @@ -546,7 +546,7 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error { _, ok := oldMap[ip] if !ok { // 不存在则添加 - err = set.AddIPElement(ip, nil) + err = set.AddIPElement(ip, nil, false) if err != nil { return errors.New("add ip '" + ip + "' failed: " + err.Error()) } diff --git a/internal/firewalls/firewall_nftables.go b/internal/firewalls/firewall_nftables.go index 1e107e5..feeddbe 100644 --- a/internal/firewalls/firewall_nftables.go +++ b/internal/firewalls/firewall_nftables.go @@ -335,14 +335,14 @@ func (this *NFTablesFirewall) AllowSourceIP(ip string) error { if this.allowIPv6Set == nil { return errors.New("ipv6 ip set is nil") } - return this.allowIPv6Set.AddElement(data.To16(), nil) + return this.allowIPv6Set.AddElement(data.To16(), nil, false) } // ipv4 if this.allowIPv4Set == nil { return errors.New("ipv4 ip set is nil") } - return this.allowIPv4Set.AddElement(data.To4(), nil) + return this.allowIPv4Set.AddElement(data.To4(), nil, false) } // RejectSourceIP 拒绝某个源IP连接 @@ -388,7 +388,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async } return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{ Timeout: time.Duration(timeoutSeconds) * time.Second, - }) + }, false) } // ipv4 @@ -397,7 +397,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async } return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{ Timeout: time.Duration(timeoutSeconds) * time.Second, - }) + }, false) } // RemoveSourceIP 删除某个源IP diff --git a/internal/firewalls/nftables/set.go b/internal/firewalls/nftables/set.go index 3f5deee..2204c53 100644 --- a/internal/firewalls/nftables/set.go +++ b/internal/firewalls/nftables/set.go @@ -56,7 +56,7 @@ func (this *Set) Name() string { return this.rawSet.Name } -func (this *Set) AddElement(key []byte, options *ElementOptions) error { +func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error { var rawElement = nft.SetElement{ Key: key, } @@ -73,7 +73,7 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error { err = this.conn.Commit() if err != nil { // retry if exists - if strings.Contains(err.Error(), "file exists") { + if overwrite && strings.Contains(err.Error(), "file exists") { deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ { Key: key, @@ -93,16 +93,16 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error { return err } -func (this *Set) AddIPElement(ip string, options *ElementOptions) error { +func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool) error { var ipObj = net.ParseIP(ip) if ipObj == nil { return errors.New("invalid ip '" + ip + "'") } if utils.IsIPv4(ip) { - return this.AddElement(ipObj.To4(), options) + return this.AddElement(ipObj.To4(), options, overwrite) } else { - return this.AddElement(ipObj.To16(), options) + return this.AddElement(ipObj.To16(), options, overwrite) } } diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go index 4fbfabd..d698c34 100644 --- a/internal/waf/action_record_ip.go +++ b/internal/waf/action_record_ip.go @@ -48,7 +48,7 @@ func init() { const maxItems = 512 // 每次上传的最大IP数 for { - var pbItems = []*pb.CreateIPItemsRequest_IPItem{} + var pbItemMap = map[string]*pb.CreateIPItemsRequest_IPItem{} // ip => IPItem func() { for { @@ -63,7 +63,7 @@ func init() { reason = "触发WAF规则自动加入" } - pbItems = append(pbItems, &pb.CreateIPItemsRequest_IPItem{ + pbItemMap[task.ip] = &pb.CreateIPItemsRequest_IPItem{ IpListId: task.listId, IpFrom: task.ip, IpTo: "", @@ -77,9 +77,9 @@ func init() { SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId, SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId, SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId, - }) + } - if len(pbItems) >= maxItems { + if len(pbItemMap) >= maxItems { return } default: @@ -88,7 +88,11 @@ func init() { } }() - if len(pbItems) > 0 { + if len(pbItemMap) > 0 { + var pbItems = []*pb.CreateIPItemsRequest_IPItem{} + for _, pbItem := range pbItemMap { + pbItems = append(pbItems, pbItem) + } _, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems}) if err != nil { remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error()) diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index ac04788..067dab8 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -6,6 +6,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/conns" "github.com/TeaOSLab/EdgeNode/internal/firewalls" + "github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/iwind/TeaGo/types" "sync" @@ -33,6 +34,9 @@ type IPList struct { id uint64 locker sync.RWMutex + + lastIP string // 加入到 recordIPTaskChan 之前尽可能去重 + lastTime int64 } // NewIPList 获取新对象 @@ -101,26 +105,29 @@ func (this *IPList) RecordIP(ipType string, } // 加入队列等待上传 - select { - case recordIPTaskChan <- &recordIPTask{ - ip: ip, - listId: firewallconfigs.GlobalListId, - expiresAt: expiresAt, - level: firewallconfigs.DefaultEventLevel, - serverId: scopeServerId, - sourceServerId: serverId, - sourceHTTPFirewallPolicyId: policyId, - sourceHTTPFirewallRuleGroupId: groupId, - sourceHTTPFirewallRuleSetId: setId, - reason: reason, - }: - default: + if this.lastIP != ip || utils.UnixTime()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ { + select { + case recordIPTaskChan <- &recordIPTask{ + ip: ip, + listId: firewallconfigs.GlobalListId, + expiresAt: expiresAt, + level: firewallconfigs.DefaultEventLevel, + serverId: scopeServerId, + sourceServerId: serverId, + sourceHTTPFirewallPolicyId: policyId, + sourceHTTPFirewallRuleGroupId: groupId, + sourceHTTPFirewallRuleSetId: setId, + reason: reason, + }: + this.lastIP = ip + this.lastTime = utils.UnixTime() + default: + } - } - - // 使用本地防火墙 - if useLocalFirewall && expiresAt > 0 { - firewalls.DropTemporaryTo(ip, expiresAt) + // 使用本地防火墙 + if useLocalFirewall && expiresAt > 0 { + firewalls.DropTemporaryTo(ip, expiresAt) + } } // 关闭此IP相关连接