mirror of
				https://github.com/TeaOSLab/EdgeNode.git
				synced 2025-11-04 16:00:25 +08:00 
			
		
		
		
	优化WAF中IP名单
This commit is contained in:
		@@ -8,6 +8,7 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
 | 
				
			||||||
	"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
						"github.com/TeaOSLab/EdgeNode/internal/waf"
 | 
				
			||||||
	"net"
 | 
						"net"
 | 
				
			||||||
 | 
						"time"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// ClientListener 客户端网络监听
 | 
					// ClientListener 客户端网络监听
 | 
				
			||||||
@@ -42,10 +43,28 @@ func (this *ClientListener) Accept() (net.Conn, error) {
 | 
				
			|||||||
	ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
 | 
						ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
 | 
				
			||||||
	if err == nil {
 | 
						if err == nil {
 | 
				
			||||||
		canGoNext, _ := iplibrary.AllowIP(ip, 0)
 | 
							canGoNext, _ := iplibrary.AllowIP(ip, 0)
 | 
				
			||||||
		var beingDenied = !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
 | 
							if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
 | 
				
			||||||
			waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
 | 
								expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
 | 
				
			||||||
 | 
								if ok {
 | 
				
			||||||
 | 
									var timeout = expiresAt - time.Now().Unix()
 | 
				
			||||||
 | 
									if timeout > 0 {
 | 
				
			||||||
 | 
										canGoNext = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		if !canGoNext || beingDenied {
 | 
										if timeout > 3600 {
 | 
				
			||||||
 | 
											timeout = 3600
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
										// 使用本地防火墙延长封禁
 | 
				
			||||||
 | 
										var fw = firewalls.Firewall()
 | 
				
			||||||
 | 
										if fw != nil && !fw.IsMock() {
 | 
				
			||||||
 | 
											// 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险
 | 
				
			||||||
 | 
											_ = fw.DropSourceIP(ip, int(timeout), true)
 | 
				
			||||||
 | 
										}
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							if !canGoNext {
 | 
				
			||||||
			tcpConn, ok := conn.(*net.TCPConn)
 | 
								tcpConn, ok := conn.(*net.TCPConn)
 | 
				
			||||||
			if ok {
 | 
								if ok {
 | 
				
			||||||
				_ = tcpConn.SetLinger(0)
 | 
									_ = tcpConn.SetLinger(0)
 | 
				
			||||||
@@ -53,14 +72,6 @@ func (this *ClientListener) Accept() (net.Conn, error) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
			_ = conn.Close()
 | 
								_ = conn.Close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
			// 使用本地防火墙延长封禁
 | 
					 | 
				
			||||||
			if beingDenied {
 | 
					 | 
				
			||||||
				var fw = firewalls.Firewall()
 | 
					 | 
				
			||||||
				if fw != nil && !fw.IsMock() {
 | 
					 | 
				
			||||||
					_ = fw.DropSourceIP(ip, 120, true)
 | 
					 | 
				
			||||||
				}
 | 
					 | 
				
			||||||
			}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
			return this.Accept()
 | 
								return this.Accept()
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -77,6 +77,12 @@ func (this *List) Remove(itemId uint64) {
 | 
				
			|||||||
	this.removeItem(itemId)
 | 
						this.removeItem(itemId)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (this *List) ExpiresAt(itemId uint64) int64 {
 | 
				
			||||||
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
						defer this.locker.Unlock()
 | 
				
			||||||
 | 
						return this.itemsMap[itemId]
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (this *List) GC(timestamp int64) ItemMap {
 | 
					func (this *List) GC(timestamp int64) ItemMap {
 | 
				
			||||||
	if this.lastTimestamp > timestamp+1 {
 | 
						if this.lastTimestamp > timestamp+1 {
 | 
				
			||||||
		return nil
 | 
							return nil
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -68,6 +68,14 @@ func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serv
 | 
				
			|||||||
	var id = this.nextId()
 | 
						var id = this.nextId()
 | 
				
			||||||
	this.expireList.Add(id, expiresAt)
 | 
						this.expireList.Add(id, expiresAt)
 | 
				
			||||||
	this.locker.Lock()
 | 
						this.locker.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 删除以前
 | 
				
			||||||
 | 
						oldId, ok := this.ipMap[ip]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							delete(this.idMap, oldId)
 | 
				
			||||||
 | 
							this.expireList.Remove(oldId)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	this.ipMap[ip] = id
 | 
						this.ipMap[ip] = id
 | 
				
			||||||
	this.idMap[id] = ip
 | 
						this.idMap[id] = ip
 | 
				
			||||||
	this.locker.Unlock()
 | 
						this.locker.Unlock()
 | 
				
			||||||
@@ -117,7 +125,7 @@ func (this *IPList) RecordIP(ipType string,
 | 
				
			|||||||
			}
 | 
								}
 | 
				
			||||||
		}
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
		// 关闭所有连接
 | 
							// 关闭此IP相关连接
 | 
				
			||||||
		conns.SharedMap.CloseIPConns(ip)
 | 
							conns.SharedMap.CloseIPConns(ip)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
@@ -139,13 +147,52 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope,
 | 
				
			|||||||
	return ok
 | 
						return ok
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// ContainsExpires 判断是否有某个IP,并返回过期时间
 | 
				
			||||||
 | 
					func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) (expiresAt int64, ok bool) {
 | 
				
			||||||
 | 
						switch scope {
 | 
				
			||||||
 | 
						case firewallconfigs.FirewallScopeGlobal:
 | 
				
			||||||
 | 
							ip = "*@" + ip + "@" + ipType
 | 
				
			||||||
 | 
						case firewallconfigs.FirewallScopeService:
 | 
				
			||||||
 | 
							ip = types.String(serverId) + "@" + ip + "@" + ipType
 | 
				
			||||||
 | 
						default:
 | 
				
			||||||
 | 
							ip = "*@" + ip + "@" + ipType
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						this.locker.RLock()
 | 
				
			||||||
 | 
						id, ok := this.ipMap[ip]
 | 
				
			||||||
 | 
						if ok {
 | 
				
			||||||
 | 
							expiresAt = this.expireList.ExpiresAt(id)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						this.locker.RUnlock()
 | 
				
			||||||
 | 
						return expiresAt, ok
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// RemoveIP 删除IP
 | 
					// RemoveIP 删除IP
 | 
				
			||||||
func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
 | 
					func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
 | 
				
			||||||
	this.locker.Lock()
 | 
						this.locker.Lock()
 | 
				
			||||||
	delete(this.ipMap, "*@"+ip+"@"+IPTypeAll)
 | 
					
 | 
				
			||||||
	if serverId > 0 {
 | 
						{
 | 
				
			||||||
		delete(this.ipMap, types.String(serverId)+"@"+ip+"@"+IPTypeAll)
 | 
							var key = "*@" + ip + "@" + IPTypeAll
 | 
				
			||||||
 | 
							id, ok := this.ipMap[key]
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								delete(this.ipMap, key)
 | 
				
			||||||
 | 
								delete(this.idMap, id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								this.expireList.Remove(id)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						if serverId > 0 {
 | 
				
			||||||
 | 
							var key = types.String(serverId) + "@" + ip + "@" + IPTypeAll
 | 
				
			||||||
 | 
							id, ok := this.ipMap[key]
 | 
				
			||||||
 | 
							if ok {
 | 
				
			||||||
 | 
								delete(this.ipMap, key)
 | 
				
			||||||
 | 
								delete(this.idMap, id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								this.expireList.Remove(id)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	this.locker.Unlock()
 | 
						this.locker.Unlock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// 从本地防火墙中删除
 | 
						// 从本地防火墙中删除
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -6,6 +6,7 @@ import (
 | 
				
			|||||||
	"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
						"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
 | 
				
			||||||
	"github.com/iwind/TeaGo/assert"
 | 
						"github.com/iwind/TeaGo/assert"
 | 
				
			||||||
	"github.com/iwind/TeaGo/logs"
 | 
						"github.com/iwind/TeaGo/logs"
 | 
				
			||||||
 | 
						timeutil "github.com/iwind/TeaGo/utils/time"
 | 
				
			||||||
	"runtime"
 | 
						"runtime"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"testing"
 | 
						"testing"
 | 
				
			||||||
@@ -13,12 +14,26 @@ import (
 | 
				
			|||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestNewIPList(t *testing.T) {
 | 
					func TestNewIPList(t *testing.T) {
 | 
				
			||||||
	list := NewIPList(IPListTypeDeny)
 | 
						var list = NewIPList(IPListTypeDeny)
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3)
 | 
				
			||||||
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						list.RemoveIP("127.0.0.1", 1, false)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						logs.PrintAsJSON(list.ipMap, t)
 | 
				
			||||||
 | 
						logs.PrintAsJSON(list.idMap, t)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIPList_Expire(t *testing.T) {
 | 
				
			||||||
 | 
						var list = NewIPList(IPListTypeDeny)
 | 
				
			||||||
	list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
 | 
				
			||||||
	list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
 | 
				
			||||||
	list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
 | 
				
			||||||
	list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
 | 
				
			||||||
	list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
 | 
						list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	var ticker = time.NewTicker(1 * time.Second)
 | 
						var ticker = time.NewTicker(1 * time.Second)
 | 
				
			||||||
	for range ticker.C {
 | 
						for range ticker.C {
 | 
				
			||||||
@@ -32,22 +47,39 @@ func TestNewIPList(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestIPList_Contains(t *testing.T) {
 | 
					func TestIPList_Contains(t *testing.T) {
 | 
				
			||||||
	a := assert.NewAssertion(t)
 | 
						var a = assert.NewAssertion(t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	list := NewIPList(IPListTypeDeny)
 | 
						var list = NewIPList(IPListTypeDeny)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i := 0; i < 1_0000; i++ {
 | 
						for i := 0; i < 1_0000; i++ {
 | 
				
			||||||
		list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
							list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	//list.RemoveIP("192.168.1.100")
 | 
						//list.RemoveIP("192.168.1.100")
 | 
				
			||||||
	a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
 | 
						{
 | 
				
			||||||
	a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
 | 
							a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						{
 | 
				
			||||||
 | 
							a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIPList_ContainsExpires(t *testing.T) {
 | 
				
			||||||
 | 
						var list = NewIPList(IPListTypeDeny)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						for i := 0; i < 1_0000; i++ {
 | 
				
			||||||
 | 
							list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						// list.RemoveIP("192.168.1.100", 1, false)
 | 
				
			||||||
 | 
						for _, ip := range []string{"192.168.1.100", "192.168.2.100"} {
 | 
				
			||||||
 | 
							expiresAt, ok := list.ContainsExpires(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip)
 | 
				
			||||||
 | 
							t.Log(ok, expiresAt, timeutil.FormatTime("Y-m-d H:i:s", expiresAt))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BenchmarkIPList_Add(b *testing.B) {
 | 
					func BenchmarkIPList_Add(b *testing.B) {
 | 
				
			||||||
	runtime.GOMAXPROCS(1)
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	list := NewIPList(IPListTypeDeny)
 | 
						var list = NewIPList(IPListTypeDeny)
 | 
				
			||||||
	for i := 0; i < b.N; i++ {
 | 
						for i := 0; i < b.N; i++ {
 | 
				
			||||||
		list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
							list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -57,7 +89,8 @@ func BenchmarkIPList_Add(b *testing.B) {
 | 
				
			|||||||
func BenchmarkIPList_Has(b *testing.B) {
 | 
					func BenchmarkIPList_Has(b *testing.B) {
 | 
				
			||||||
	runtime.GOMAXPROCS(1)
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	list := NewIPList(IPListTypeDeny)
 | 
						var list = NewIPList(IPListTypeDeny)
 | 
				
			||||||
 | 
						b.ResetTimer()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	for i := 0; i < 1_0000; i++ {
 | 
						for i := 0; i < 1_0000; i++ {
 | 
				
			||||||
		list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
							list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user