mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 07:40:56 +08:00 
			
		
		
		
	优化IP名单上传程序
This commit is contained in:
		@@ -546,7 +546,7 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
 | 
			
		||||
				_, ok := oldMap[ip]
 | 
			
		||||
				if !ok {
 | 
			
		||||
					// 不存在则添加
 | 
			
		||||
					err = set.AddIPElement(ip, nil)
 | 
			
		||||
					err = set.AddIPElement(ip, nil, false)
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						return errors.New("add ip '" + ip + "' failed: " + err.Error())
 | 
			
		||||
					}
 | 
			
		||||
 
 | 
			
		||||
@@ -335,14 +335,14 @@ func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
 | 
			
		||||
		if this.allowIPv6Set == nil {
 | 
			
		||||
			return errors.New("ipv6 ip set is nil")
 | 
			
		||||
		}
 | 
			
		||||
		return this.allowIPv6Set.AddElement(data.To16(), nil)
 | 
			
		||||
		return this.allowIPv6Set.AddElement(data.To16(), nil, false)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ipv4
 | 
			
		||||
	if this.allowIPv4Set == nil {
 | 
			
		||||
		return errors.New("ipv4 ip set is nil")
 | 
			
		||||
	}
 | 
			
		||||
	return this.allowIPv4Set.AddElement(data.To4(), nil)
 | 
			
		||||
	return this.allowIPv4Set.AddElement(data.To4(), nil, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RejectSourceIP 拒绝某个源IP连接
 | 
			
		||||
@@ -388,7 +388,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
 | 
			
		||||
		}
 | 
			
		||||
		return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
 | 
			
		||||
			Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
			
		||||
		})
 | 
			
		||||
		}, false)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ipv4
 | 
			
		||||
@@ -397,7 +397,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
 | 
			
		||||
	}
 | 
			
		||||
	return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
 | 
			
		||||
		Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
			
		||||
	})
 | 
			
		||||
	}, false)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveSourceIP 删除某个源IP
 | 
			
		||||
 
 | 
			
		||||
@@ -56,7 +56,7 @@ func (this *Set) Name() string {
 | 
			
		||||
	return this.rawSet.Name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
 | 
			
		||||
func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error {
 | 
			
		||||
	var rawElement = nft.SetElement{
 | 
			
		||||
		Key: key,
 | 
			
		||||
	}
 | 
			
		||||
@@ -73,7 +73,7 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
 | 
			
		||||
	err = this.conn.Commit()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		// retry if exists
 | 
			
		||||
		if strings.Contains(err.Error(), "file exists") {
 | 
			
		||||
		if overwrite && strings.Contains(err.Error(), "file exists") {
 | 
			
		||||
			deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
 | 
			
		||||
				{
 | 
			
		||||
					Key: key,
 | 
			
		||||
@@ -93,16 +93,16 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
 | 
			
		||||
	return err
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
 | 
			
		||||
func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool) error {
 | 
			
		||||
	var ipObj = net.ParseIP(ip)
 | 
			
		||||
	if ipObj == nil {
 | 
			
		||||
		return errors.New("invalid ip '" + ip + "'")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if utils.IsIPv4(ip) {
 | 
			
		||||
		return this.AddElement(ipObj.To4(), options)
 | 
			
		||||
		return this.AddElement(ipObj.To4(), options, overwrite)
 | 
			
		||||
	} else {
 | 
			
		||||
		return this.AddElement(ipObj.To16(), options)
 | 
			
		||||
		return this.AddElement(ipObj.To16(), options, overwrite)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -48,7 +48,7 @@ func init() {
 | 
			
		||||
			const maxItems = 512 // 每次上传的最大IP数
 | 
			
		||||
 | 
			
		||||
			for {
 | 
			
		||||
				var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
 | 
			
		||||
				var pbItemMap = map[string]*pb.CreateIPItemsRequest_IPItem{} // ip => IPItem
 | 
			
		||||
 | 
			
		||||
				func() {
 | 
			
		||||
					for {
 | 
			
		||||
@@ -63,7 +63,7 @@ func init() {
 | 
			
		||||
								reason = "触发WAF规则自动加入"
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							pbItems = append(pbItems, &pb.CreateIPItemsRequest_IPItem{
 | 
			
		||||
							pbItemMap[task.ip] = &pb.CreateIPItemsRequest_IPItem{
 | 
			
		||||
								IpListId:                      task.listId,
 | 
			
		||||
								IpFrom:                        task.ip,
 | 
			
		||||
								IpTo:                          "",
 | 
			
		||||
@@ -77,9 +77,9 @@ func init() {
 | 
			
		||||
								SourceHTTPFirewallPolicyId:    task.sourceHTTPFirewallPolicyId,
 | 
			
		||||
								SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
 | 
			
		||||
								SourceHTTPFirewallRuleSetId:   task.sourceHTTPFirewallRuleSetId,
 | 
			
		||||
							})
 | 
			
		||||
							}
 | 
			
		||||
 | 
			
		||||
							if len(pbItems) >= maxItems {
 | 
			
		||||
							if len(pbItemMap) >= maxItems {
 | 
			
		||||
								return
 | 
			
		||||
							}
 | 
			
		||||
						default:
 | 
			
		||||
@@ -88,7 +88,11 @@ func init() {
 | 
			
		||||
					}
 | 
			
		||||
				}()
 | 
			
		||||
 | 
			
		||||
				if len(pbItems) > 0 {
 | 
			
		||||
				if len(pbItemMap) > 0 {
 | 
			
		||||
					var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
 | 
			
		||||
					for _, pbItem := range pbItemMap {
 | 
			
		||||
						pbItems = append(pbItems, pbItem)
 | 
			
		||||
					}
 | 
			
		||||
					_, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems})
 | 
			
		||||
					if err != nil {
 | 
			
		||||
						remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
 | 
			
		||||
 
 | 
			
		||||
@@ -6,6 +6,7 @@ import (
 | 
			
		||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
			
		||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
	"sync"
 | 
			
		||||
@@ -33,6 +34,9 @@ type IPList struct {
 | 
			
		||||
 | 
			
		||||
	id     uint64
 | 
			
		||||
	locker sync.RWMutex
 | 
			
		||||
 | 
			
		||||
	lastIP   string // 加入到 recordIPTaskChan 之前尽可能去重
 | 
			
		||||
	lastTime int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// NewIPList 获取新对象
 | 
			
		||||
@@ -101,6 +105,7 @@ func (this *IPList) RecordIP(ipType string,
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 加入队列等待上传
 | 
			
		||||
		if this.lastIP != ip || utils.UnixTime()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ {
 | 
			
		||||
			select {
 | 
			
		||||
			case recordIPTaskChan <- &recordIPTask{
 | 
			
		||||
				ip:                            ip,
 | 
			
		||||
@@ -114,14 +119,16 @@ func (this *IPList) RecordIP(ipType string,
 | 
			
		||||
				sourceHTTPFirewallRuleSetId:   setId,
 | 
			
		||||
				reason:                        reason,
 | 
			
		||||
			}:
 | 
			
		||||
				this.lastIP = ip
 | 
			
		||||
				this.lastTime = utils.UnixTime()
 | 
			
		||||
			default:
 | 
			
		||||
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 使用本地防火墙
 | 
			
		||||
			if useLocalFirewall && expiresAt > 0 {
 | 
			
		||||
				firewalls.DropTemporaryTo(ip, expiresAt)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 关闭此IP相关连接
 | 
			
		||||
		conns.SharedMap.CloseIPConns(ip)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user