diff --git a/internal/nodes/client_listener.go b/internal/nodes/client_listener.go index c1044c9..916e456 100644 --- a/internal/nodes/client_listener.go +++ b/internal/nodes/client_listener.go @@ -8,6 +8,7 @@ import ( "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/waf" "net" + "time" ) // ClientListener 客户端网络监听 @@ -42,10 +43,28 @@ func (this *ClientListener) Accept() (net.Conn, error) { ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err == nil { canGoNext, _ := iplibrary.AllowIP(ip, 0) - var beingDenied = !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) && - waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) + if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) { + expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) + if ok { + var timeout = expiresAt - time.Now().Unix() + if timeout > 0 { + canGoNext = false - if !canGoNext || beingDenied { + if timeout > 3600 { + timeout = 3600 + } + + // 使用本地防火墙延长封禁 + var fw = firewalls.Firewall() + if fw != nil && !fw.IsMock() { + // 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险 + _ = fw.DropSourceIP(ip, int(timeout), true) + } + } + } + } + + if !canGoNext { tcpConn, ok := conn.(*net.TCPConn) if ok { _ = tcpConn.SetLinger(0) @@ -53,14 +72,6 @@ func (this *ClientListener) Accept() (net.Conn, error) { _ = conn.Close() - // 使用本地防火墙延长封禁 - if beingDenied { - var fw = firewalls.Firewall() - if fw != nil && !fw.IsMock() { - _ = fw.DropSourceIP(ip, 120, true) - } - } - return this.Accept() } } diff --git a/internal/utils/expires/list.go b/internal/utils/expires/list.go index 96ab4c9..e965a95 100644 --- a/internal/utils/expires/list.go +++ b/internal/utils/expires/list.go @@ -77,6 +77,12 @@ func (this *List) Remove(itemId uint64) { this.removeItem(itemId) } +func (this *List) ExpiresAt(itemId uint64) int64 { + this.locker.Lock() + defer this.locker.Unlock() + return this.itemsMap[itemId] +} + func (this *List) GC(timestamp int64) ItemMap { if this.lastTimestamp > timestamp+1 { return nil diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 77a0f6a..b053e48 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -68,6 +68,14 @@ func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serv var id = this.nextId() this.expireList.Add(id, expiresAt) this.locker.Lock() + + // 删除以前 + oldId, ok := this.ipMap[ip] + if ok { + delete(this.idMap, oldId) + this.expireList.Remove(oldId) + } + this.ipMap[ip] = id this.idMap[id] = ip this.locker.Unlock() @@ -117,7 +125,7 @@ func (this *IPList) RecordIP(ipType string, } } - // 关闭所有连接 + // 关闭此IP相关连接 conns.SharedMap.CloseIPConns(ip) } } @@ -139,13 +147,52 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, return ok } +// ContainsExpires 判断是否有某个IP,并返回过期时间 +func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) (expiresAt int64, ok bool) { + switch scope { + case firewallconfigs.FirewallScopeGlobal: + ip = "*@" + ip + "@" + ipType + case firewallconfigs.FirewallScopeService: + ip = types.String(serverId) + "@" + ip + "@" + ipType + default: + ip = "*@" + ip + "@" + ipType + } + + this.locker.RLock() + id, ok := this.ipMap[ip] + if ok { + expiresAt = this.expireList.ExpiresAt(id) + } + this.locker.RUnlock() + return expiresAt, ok +} + // RemoveIP 删除IP func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) { this.locker.Lock() - delete(this.ipMap, "*@"+ip+"@"+IPTypeAll) - if serverId > 0 { - delete(this.ipMap, types.String(serverId)+"@"+ip+"@"+IPTypeAll) + + { + var key = "*@" + ip + "@" + IPTypeAll + id, ok := this.ipMap[key] + if ok { + delete(this.ipMap, key) + delete(this.idMap, id) + + this.expireList.Remove(id) + } } + + if serverId > 0 { + var key = types.String(serverId) + "@" + ip + "@" + IPTypeAll + id, ok := this.ipMap[key] + if ok { + delete(this.ipMap, key) + delete(this.idMap, id) + + this.expireList.Remove(id) + } + } + this.locker.Unlock() // 从本地防火墙中删除 diff --git a/internal/waf/ip_list_test.go b/internal/waf/ip_list_test.go index 01e570a..2b822da 100644 --- a/internal/waf/ip_list_test.go +++ b/internal/waf/ip_list_test.go @@ -6,6 +6,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/logs" + timeutil "github.com/iwind/TeaGo/utils/time" "runtime" "strconv" "testing" @@ -13,12 +14,26 @@ import ( ) func TestNewIPList(t *testing.T) { - list := NewIPList(IPListTypeDeny) + var list = NewIPList(IPListTypeDeny) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10) + + list.RemoveIP("127.0.0.1", 1, false) + + logs.PrintAsJSON(list.ipMap, t) + logs.PrintAsJSON(list.idMap, t) +} + +func TestIPList_Expire(t *testing.T) { + var list = NewIPList(IPListTypeDeny) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3) - list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6) var ticker = time.NewTicker(1 * time.Second) for range ticker.C { @@ -32,22 +47,39 @@ func TestNewIPList(t *testing.T) { } func TestIPList_Contains(t *testing.T) { - a := assert.NewAssertion(t) + var a = assert.NewAssertion(t) - list := NewIPList(IPListTypeDeny) + var list = NewIPList(IPListTypeDeny) for i := 0; i < 1_0000; i++ { list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } //list.RemoveIP("192.168.1.100") - a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")) - a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100")) + { + a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")) + } + { + a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100")) + } +} + +func TestIPList_ContainsExpires(t *testing.T) { + var list = NewIPList(IPListTypeDeny) + + for i := 0; i < 1_0000; i++ { + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + } + // list.RemoveIP("192.168.1.100", 1, false) + for _, ip := range []string{"192.168.1.100", "192.168.2.100"} { + expiresAt, ok := list.ContainsExpires(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip) + t.Log(ok, expiresAt, timeutil.FormatTime("Y-m-d H:i:s", expiresAt)) + } } func BenchmarkIPList_Add(b *testing.B) { runtime.GOMAXPROCS(1) - list := NewIPList(IPListTypeDeny) + var list = NewIPList(IPListTypeDeny) for i := 0; i < b.N; i++ { list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } @@ -57,7 +89,8 @@ func BenchmarkIPList_Add(b *testing.B) { func BenchmarkIPList_Has(b *testing.B) { runtime.GOMAXPROCS(1) - list := NewIPList(IPListTypeDeny) + var list = NewIPList(IPListTypeDeny) + b.ResetTimer() for i := 0; i < 1_0000; i++ { list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)