优化IP名单上传程序

This commit is contained in:
刘祥超
2023-04-01 20:51:49 +08:00
parent 8988765cef
commit 888df02d0c
5 changed files with 45 additions and 34 deletions

View File

@@ -546,7 +546,7 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
_, ok := oldMap[ip] _, ok := oldMap[ip]
if !ok { if !ok {
// 不存在则添加 // 不存在则添加
err = set.AddIPElement(ip, nil) err = set.AddIPElement(ip, nil, false)
if err != nil { if err != nil {
return errors.New("add ip '" + ip + "' failed: " + err.Error()) return errors.New("add ip '" + ip + "' failed: " + err.Error())
} }

View File

@@ -335,14 +335,14 @@ func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
if this.allowIPv6Set == nil { if this.allowIPv6Set == nil {
return errors.New("ipv6 ip set is nil") return errors.New("ipv6 ip set is nil")
} }
return this.allowIPv6Set.AddElement(data.To16(), nil) return this.allowIPv6Set.AddElement(data.To16(), nil, false)
} }
// ipv4 // ipv4
if this.allowIPv4Set == nil { if this.allowIPv4Set == nil {
return errors.New("ipv4 ip set is nil") return errors.New("ipv4 ip set is nil")
} }
return this.allowIPv4Set.AddElement(data.To4(), nil) return this.allowIPv4Set.AddElement(data.To4(), nil, false)
} }
// RejectSourceIP 拒绝某个源IP连接 // RejectSourceIP 拒绝某个源IP连接
@@ -388,7 +388,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
} }
return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{ return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second, Timeout: time.Duration(timeoutSeconds) * time.Second,
}) }, false)
} }
// ipv4 // ipv4
@@ -397,7 +397,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
} }
return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{ return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
Timeout: time.Duration(timeoutSeconds) * time.Second, Timeout: time.Duration(timeoutSeconds) * time.Second,
}) }, false)
} }
// RemoveSourceIP 删除某个源IP // RemoveSourceIP 删除某个源IP

View File

@@ -56,7 +56,7 @@ func (this *Set) Name() string {
return this.rawSet.Name return this.rawSet.Name
} }
func (this *Set) AddElement(key []byte, options *ElementOptions) error { func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error {
var rawElement = nft.SetElement{ var rawElement = nft.SetElement{
Key: key, Key: key,
} }
@@ -73,7 +73,7 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
err = this.conn.Commit() err = this.conn.Commit()
if err != nil { if err != nil {
// retry if exists // retry if exists
if strings.Contains(err.Error(), "file exists") { if overwrite && strings.Contains(err.Error(), "file exists") {
deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{ deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
{ {
Key: key, Key: key,
@@ -93,16 +93,16 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
return err return err
} }
func (this *Set) AddIPElement(ip string, options *ElementOptions) error { func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool) error {
var ipObj = net.ParseIP(ip) var ipObj = net.ParseIP(ip)
if ipObj == nil { if ipObj == nil {
return errors.New("invalid ip '" + ip + "'") return errors.New("invalid ip '" + ip + "'")
} }
if utils.IsIPv4(ip) { if utils.IsIPv4(ip) {
return this.AddElement(ipObj.To4(), options) return this.AddElement(ipObj.To4(), options, overwrite)
} else { } else {
return this.AddElement(ipObj.To16(), options) return this.AddElement(ipObj.To16(), options, overwrite)
} }
} }

View File

@@ -48,7 +48,7 @@ func init() {
const maxItems = 512 // 每次上传的最大IP数 const maxItems = 512 // 每次上传的最大IP数
for { for {
var pbItems = []*pb.CreateIPItemsRequest_IPItem{} var pbItemMap = map[string]*pb.CreateIPItemsRequest_IPItem{} // ip => IPItem
func() { func() {
for { for {
@@ -63,7 +63,7 @@ func init() {
reason = "触发WAF规则自动加入" reason = "触发WAF规则自动加入"
} }
pbItems = append(pbItems, &pb.CreateIPItemsRequest_IPItem{ pbItemMap[task.ip] = &pb.CreateIPItemsRequest_IPItem{
IpListId: task.listId, IpListId: task.listId,
IpFrom: task.ip, IpFrom: task.ip,
IpTo: "", IpTo: "",
@@ -77,9 +77,9 @@ func init() {
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId, SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId, SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId, SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId,
}) }
if len(pbItems) >= maxItems { if len(pbItemMap) >= maxItems {
return return
} }
default: default:
@@ -88,7 +88,11 @@ func init() {
} }
}() }()
if len(pbItems) > 0 { if len(pbItemMap) > 0 {
var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
for _, pbItem := range pbItemMap {
pbItems = append(pbItems, pbItem)
}
_, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems}) _, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems})
if err != nil { if err != nil {
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error()) remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/conns" "github.com/TeaOSLab/EdgeNode/internal/conns"
"github.com/TeaOSLab/EdgeNode/internal/firewalls" "github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"sync" "sync"
@@ -33,6 +34,9 @@ type IPList struct {
id uint64 id uint64
locker sync.RWMutex locker sync.RWMutex
lastIP string // 加入到 recordIPTaskChan 之前尽可能去重
lastTime int64
} }
// NewIPList 获取新对象 // NewIPList 获取新对象
@@ -101,6 +105,7 @@ func (this *IPList) RecordIP(ipType string,
} }
// 加入队列等待上传 // 加入队列等待上传
if this.lastIP != ip || utils.UnixTime()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ {
select { select {
case recordIPTaskChan <- &recordIPTask{ case recordIPTaskChan <- &recordIPTask{
ip: ip, ip: ip,
@@ -114,14 +119,16 @@ func (this *IPList) RecordIP(ipType string,
sourceHTTPFirewallRuleSetId: setId, sourceHTTPFirewallRuleSetId: setId,
reason: reason, reason: reason,
}: }:
this.lastIP = ip
this.lastTime = utils.UnixTime()
default: default:
} }
// 使用本地防火墙 // 使用本地防火墙
if useLocalFirewall && expiresAt > 0 { if useLocalFirewall && expiresAt > 0 {
firewalls.DropTemporaryTo(ip, expiresAt) firewalls.DropTemporaryTo(ip, expiresAt)
} }
}
// 关闭此IP相关连接 // 关闭此IP相关连接
conns.SharedMap.CloseIPConns(ip) conns.SharedMap.CloseIPConns(ip)