mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 07:40:56 +08:00 
			
		
		
		
	优化IP名单上传程序
This commit is contained in:
		@@ -546,7 +546,7 @@ func (this *DDoSProtectionManager) updateAllowIPList(allIPList []string) error {
 | 
				
			|||||||
				_, ok := oldMap[ip]
 | 
									_, ok := oldMap[ip]
 | 
				
			||||||
				if !ok {
 | 
									if !ok {
 | 
				
			||||||
					// 不存在则添加
 | 
										// 不存在则添加
 | 
				
			||||||
					err = set.AddIPElement(ip, nil)
 | 
										err = set.AddIPElement(ip, nil, false)
 | 
				
			||||||
					if err != nil {
 | 
										if err != nil {
 | 
				
			||||||
						return errors.New("add ip '" + ip + "' failed: " + err.Error())
 | 
											return errors.New("add ip '" + ip + "' failed: " + err.Error())
 | 
				
			||||||
					}
 | 
										}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -335,14 +335,14 @@ func (this *NFTablesFirewall) AllowSourceIP(ip string) error {
 | 
				
			|||||||
		if this.allowIPv6Set == nil {
 | 
							if this.allowIPv6Set == nil {
 | 
				
			||||||
			return errors.New("ipv6 ip set is nil")
 | 
								return errors.New("ipv6 ip set is nil")
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
		return this.allowIPv6Set.AddElement(data.To16(), nil)
 | 
							return this.allowIPv6Set.AddElement(data.To16(), nil, false)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ipv4
 | 
						// ipv4
 | 
				
			||||||
	if this.allowIPv4Set == nil {
 | 
						if this.allowIPv4Set == nil {
 | 
				
			||||||
		return errors.New("ipv4 ip set is nil")
 | 
							return errors.New("ipv4 ip set is nil")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return this.allowIPv4Set.AddElement(data.To4(), nil)
 | 
						return this.allowIPv4Set.AddElement(data.To4(), nil, false)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// RejectSourceIP 拒绝某个源IP连接
 | 
					// RejectSourceIP 拒绝某个源IP连接
 | 
				
			||||||
@@ -388,7 +388,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
		return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
 | 
							return this.denyIPv6Set.AddElement(data.To16(), &nftables.ElementOptions{
 | 
				
			||||||
			Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
								Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
				
			||||||
		})
 | 
							}, false)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// ipv4
 | 
						// ipv4
 | 
				
			||||||
@@ -397,7 +397,7 @@ func (this *NFTablesFirewall) DropSourceIP(ip string, timeoutSeconds int, async
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
 | 
						return this.denyIPv4Set.AddElement(data.To4(), &nftables.ElementOptions{
 | 
				
			||||||
		Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
							Timeout: time.Duration(timeoutSeconds) * time.Second,
 | 
				
			||||||
	})
 | 
						}, false)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// RemoveSourceIP 删除某个源IP
 | 
					// RemoveSourceIP 删除某个源IP
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -56,7 +56,7 @@ func (this *Set) Name() string {
 | 
				
			|||||||
	return this.rawSet.Name
 | 
						return this.rawSet.Name
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Set) AddElement(key []byte, options *ElementOptions) error {
 | 
					func (this *Set) AddElement(key []byte, options *ElementOptions, overwrite bool) error {
 | 
				
			||||||
	var rawElement = nft.SetElement{
 | 
						var rawElement = nft.SetElement{
 | 
				
			||||||
		Key: key,
 | 
							Key: key,
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -73,7 +73,7 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
 | 
				
			|||||||
	err = this.conn.Commit()
 | 
						err = this.conn.Commit()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		// retry if exists
 | 
							// retry if exists
 | 
				
			||||||
		if strings.Contains(err.Error(), "file exists") {
 | 
							if overwrite && strings.Contains(err.Error(), "file exists") {
 | 
				
			||||||
			deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
 | 
								deleteErr := this.conn.Raw().SetDeleteElements(this.rawSet, []nft.SetElement{
 | 
				
			||||||
				{
 | 
									{
 | 
				
			||||||
					Key: key,
 | 
										Key: key,
 | 
				
			||||||
@@ -93,16 +93,16 @@ func (this *Set) AddElement(key []byte, options *ElementOptions) error {
 | 
				
			|||||||
	return err
 | 
						return err
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *Set) AddIPElement(ip string, options *ElementOptions) error {
 | 
					func (this *Set) AddIPElement(ip string, options *ElementOptions, overwrite bool) error {
 | 
				
			||||||
	var ipObj = net.ParseIP(ip)
 | 
						var ipObj = net.ParseIP(ip)
 | 
				
			||||||
	if ipObj == nil {
 | 
						if ipObj == nil {
 | 
				
			||||||
		return errors.New("invalid ip '" + ip + "'")
 | 
							return errors.New("invalid ip '" + ip + "'")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if utils.IsIPv4(ip) {
 | 
						if utils.IsIPv4(ip) {
 | 
				
			||||||
		return this.AddElement(ipObj.To4(), options)
 | 
							return this.AddElement(ipObj.To4(), options, overwrite)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		return this.AddElement(ipObj.To16(), options)
 | 
							return this.AddElement(ipObj.To16(), options, overwrite)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -48,7 +48,7 @@ func init() {
 | 
				
			|||||||
			const maxItems = 512 // 每次上传的最大IP数
 | 
								const maxItems = 512 // 每次上传的最大IP数
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			for {
 | 
								for {
 | 
				
			||||||
				var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
 | 
									var pbItemMap = map[string]*pb.CreateIPItemsRequest_IPItem{} // ip => IPItem
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				func() {
 | 
									func() {
 | 
				
			||||||
					for {
 | 
										for {
 | 
				
			||||||
@@ -63,7 +63,7 @@ func init() {
 | 
				
			|||||||
								reason = "触发WAF规则自动加入"
 | 
													reason = "触发WAF规则自动加入"
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							pbItems = append(pbItems, &pb.CreateIPItemsRequest_IPItem{
 | 
												pbItemMap[task.ip] = &pb.CreateIPItemsRequest_IPItem{
 | 
				
			||||||
								IpListId:                      task.listId,
 | 
													IpListId:                      task.listId,
 | 
				
			||||||
								IpFrom:                        task.ip,
 | 
													IpFrom:                        task.ip,
 | 
				
			||||||
								IpTo:                          "",
 | 
													IpTo:                          "",
 | 
				
			||||||
@@ -77,9 +77,9 @@ func init() {
 | 
				
			|||||||
								SourceHTTPFirewallPolicyId:    task.sourceHTTPFirewallPolicyId,
 | 
													SourceHTTPFirewallPolicyId:    task.sourceHTTPFirewallPolicyId,
 | 
				
			||||||
								SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
 | 
													SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
 | 
				
			||||||
								SourceHTTPFirewallRuleSetId:   task.sourceHTTPFirewallRuleSetId,
 | 
													SourceHTTPFirewallRuleSetId:   task.sourceHTTPFirewallRuleSetId,
 | 
				
			||||||
							})
 | 
												}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
							if len(pbItems) >= maxItems {
 | 
												if len(pbItemMap) >= maxItems {
 | 
				
			||||||
								return
 | 
													return
 | 
				
			||||||
							}
 | 
												}
 | 
				
			||||||
						default:
 | 
											default:
 | 
				
			||||||
@@ -88,7 +88,11 @@ func init() {
 | 
				
			|||||||
					}
 | 
										}
 | 
				
			||||||
				}()
 | 
									}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
				if len(pbItems) > 0 {
 | 
									if len(pbItemMap) > 0 {
 | 
				
			||||||
 | 
										var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
 | 
				
			||||||
 | 
										for _, pbItem := range pbItemMap {
 | 
				
			||||||
 | 
											pbItems = append(pbItems, pbItem)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
					_, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems})
 | 
										_, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems})
 | 
				
			||||||
					if err != nil {
 | 
										if err != nil {
 | 
				
			||||||
						remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
 | 
											remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/conns"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/firewalls"
 | 
				
			||||||
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
	"sync"
 | 
						"sync"
 | 
				
			||||||
@@ -33,6 +34,9 @@ type IPList struct {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	id     uint64
 | 
						id     uint64
 | 
				
			||||||
	locker sync.RWMutex
 | 
						locker sync.RWMutex
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						lastIP   string // 加入到 recordIPTaskChan 之前尽可能去重
 | 
				
			||||||
 | 
						lastTime int64
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// NewIPList 获取新对象
 | 
					// NewIPList 获取新对象
 | 
				
			||||||
@@ -101,26 +105,29 @@ func (this *IPList) RecordIP(ipType string,
 | 
				
			|||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 加入队列等待上传
 | 
							// 加入队列等待上传
 | 
				
			||||||
		select {
 | 
							if this.lastIP != ip || utils.UnixTime()-this.lastTime > 3 /** 3秒外才允许重复添加 **/ {
 | 
				
			||||||
		case recordIPTaskChan <- &recordIPTask{
 | 
								select {
 | 
				
			||||||
			ip:                            ip,
 | 
								case recordIPTaskChan <- &recordIPTask{
 | 
				
			||||||
			listId:                        firewallconfigs.GlobalListId,
 | 
									ip:                            ip,
 | 
				
			||||||
			expiresAt:                     expiresAt,
 | 
									listId:                        firewallconfigs.GlobalListId,
 | 
				
			||||||
			level:                         firewallconfigs.DefaultEventLevel,
 | 
									expiresAt:                     expiresAt,
 | 
				
			||||||
			serverId:                      scopeServerId,
 | 
									level:                         firewallconfigs.DefaultEventLevel,
 | 
				
			||||||
			sourceServerId:                serverId,
 | 
									serverId:                      scopeServerId,
 | 
				
			||||||
			sourceHTTPFirewallPolicyId:    policyId,
 | 
									sourceServerId:                serverId,
 | 
				
			||||||
			sourceHTTPFirewallRuleGroupId: groupId,
 | 
									sourceHTTPFirewallPolicyId:    policyId,
 | 
				
			||||||
			sourceHTTPFirewallRuleSetId:   setId,
 | 
									sourceHTTPFirewallRuleGroupId: groupId,
 | 
				
			||||||
			reason:                        reason,
 | 
									sourceHTTPFirewallRuleSetId:   setId,
 | 
				
			||||||
		}:
 | 
									reason:                        reason,
 | 
				
			||||||
		default:
 | 
								}:
 | 
				
			||||||
 | 
									this.lastIP = ip
 | 
				
			||||||
 | 
									this.lastTime = utils.UnixTime()
 | 
				
			||||||
 | 
								default:
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		}
 | 
								// 使用本地防火墙
 | 
				
			||||||
 | 
								if useLocalFirewall && expiresAt > 0 {
 | 
				
			||||||
		// 使用本地防火墙
 | 
									firewalls.DropTemporaryTo(ip, expiresAt)
 | 
				
			||||||
		if useLocalFirewall && expiresAt > 0 {
 | 
								}
 | 
				
			||||||
			firewalls.DropTemporaryTo(ip, expiresAt)
 | 
					 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 关闭此IP相关连接
 | 
							// 关闭此IP相关连接
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user