mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-03 23:20:25 +08:00
优化WAF中IP名单
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientListener 客户端网络监听
|
||||
@@ -42,10 +43,28 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
if err == nil {
|
||||
canGoNext, _ := iplibrary.AllowIP(ip, 0)
|
||||
var beingDenied = !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) &&
|
||||
waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
var timeout = expiresAt - time.Now().Unix()
|
||||
if timeout > 0 {
|
||||
canGoNext = false
|
||||
|
||||
if !canGoNext || beingDenied {
|
||||
if timeout > 3600 {
|
||||
timeout = 3600
|
||||
}
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险
|
||||
_ = fw.DropSourceIP(ip, int(timeout), true)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !canGoNext {
|
||||
tcpConn, ok := conn.(*net.TCPConn)
|
||||
if ok {
|
||||
_ = tcpConn.SetLinger(0)
|
||||
@@ -53,14 +72,6 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
|
||||
_ = conn.Close()
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
if beingDenied {
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
_ = fw.DropSourceIP(ip, 120, true)
|
||||
}
|
||||
}
|
||||
|
||||
return this.Accept()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +77,12 @@ func (this *List) Remove(itemId uint64) {
|
||||
this.removeItem(itemId)
|
||||
}
|
||||
|
||||
func (this *List) ExpiresAt(itemId uint64) int64 {
|
||||
this.locker.Lock()
|
||||
defer this.locker.Unlock()
|
||||
return this.itemsMap[itemId]
|
||||
}
|
||||
|
||||
func (this *List) GC(timestamp int64) ItemMap {
|
||||
if this.lastTimestamp > timestamp+1 {
|
||||
return nil
|
||||
|
||||
@@ -68,6 +68,14 @@ func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serv
|
||||
var id = this.nextId()
|
||||
this.expireList.Add(id, expiresAt)
|
||||
this.locker.Lock()
|
||||
|
||||
// 删除以前
|
||||
oldId, ok := this.ipMap[ip]
|
||||
if ok {
|
||||
delete(this.idMap, oldId)
|
||||
this.expireList.Remove(oldId)
|
||||
}
|
||||
|
||||
this.ipMap[ip] = id
|
||||
this.idMap[id] = ip
|
||||
this.locker.Unlock()
|
||||
@@ -117,7 +125,7 @@ func (this *IPList) RecordIP(ipType string,
|
||||
}
|
||||
}
|
||||
|
||||
// 关闭所有连接
|
||||
// 关闭此IP相关连接
|
||||
conns.SharedMap.CloseIPConns(ip)
|
||||
}
|
||||
}
|
||||
@@ -139,13 +147,52 @@ func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope,
|
||||
return ok
|
||||
}
|
||||
|
||||
// ContainsExpires 判断是否有某个IP,并返回过期时间
|
||||
func (this *IPList) ContainsExpires(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) (expiresAt int64, ok bool) {
|
||||
switch scope {
|
||||
case firewallconfigs.FirewallScopeGlobal:
|
||||
ip = "*@" + ip + "@" + ipType
|
||||
case firewallconfigs.FirewallScopeService:
|
||||
ip = types.String(serverId) + "@" + ip + "@" + ipType
|
||||
default:
|
||||
ip = "*@" + ip + "@" + ipType
|
||||
}
|
||||
|
||||
this.locker.RLock()
|
||||
id, ok := this.ipMap[ip]
|
||||
if ok {
|
||||
expiresAt = this.expireList.ExpiresAt(id)
|
||||
}
|
||||
this.locker.RUnlock()
|
||||
return expiresAt, ok
|
||||
}
|
||||
|
||||
// RemoveIP 删除IP
|
||||
func (this *IPList) RemoveIP(ip string, serverId int64, shouldExecute bool) {
|
||||
this.locker.Lock()
|
||||
delete(this.ipMap, "*@"+ip+"@"+IPTypeAll)
|
||||
if serverId > 0 {
|
||||
delete(this.ipMap, types.String(serverId)+"@"+ip+"@"+IPTypeAll)
|
||||
|
||||
{
|
||||
var key = "*@" + ip + "@" + IPTypeAll
|
||||
id, ok := this.ipMap[key]
|
||||
if ok {
|
||||
delete(this.ipMap, key)
|
||||
delete(this.idMap, id)
|
||||
|
||||
this.expireList.Remove(id)
|
||||
}
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var key = types.String(serverId) + "@" + ip + "@" + IPTypeAll
|
||||
id, ok := this.ipMap[key]
|
||||
if ok {
|
||||
delete(this.ipMap, key)
|
||||
delete(this.idMap, id)
|
||||
|
||||
this.expireList.Remove(id)
|
||||
}
|
||||
}
|
||||
|
||||
this.locker.Unlock()
|
||||
|
||||
// 从本地防火墙中删除
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"github.com/iwind/TeaGo/logs"
|
||||
timeutil "github.com/iwind/TeaGo/utils/time"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
@@ -13,12 +14,26 @@ import (
|
||||
)
|
||||
|
||||
func TestNewIPList(t *testing.T) {
|
||||
list := NewIPList(IPListTypeDeny)
|
||||
var list = NewIPList(IPListTypeDeny)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeService, 1, "127.0.0.3", time.Now().Unix()+3)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
|
||||
|
||||
list.RemoveIP("127.0.0.1", 1, false)
|
||||
|
||||
logs.PrintAsJSON(list.ipMap, t)
|
||||
logs.PrintAsJSON(list.idMap, t)
|
||||
}
|
||||
|
||||
func TestIPList_Expire(t *testing.T) {
|
||||
var list = NewIPList(IPListTypeDeny)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+6)
|
||||
|
||||
var ticker = time.NewTicker(1 * time.Second)
|
||||
for range ticker.C {
|
||||
@@ -32,22 +47,39 @@ func TestNewIPList(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPList_Contains(t *testing.T) {
|
||||
a := assert.NewAssertion(t)
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
list := NewIPList(IPListTypeDeny)
|
||||
var list = NewIPList(IPListTypeDeny)
|
||||
|
||||
for i := 0; i < 1_0000; i++ {
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
}
|
||||
//list.RemoveIP("192.168.1.100")
|
||||
a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
|
||||
a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
|
||||
{
|
||||
a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
|
||||
}
|
||||
{
|
||||
a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPList_ContainsExpires(t *testing.T) {
|
||||
var list = NewIPList(IPListTypeDeny)
|
||||
|
||||
for i := 0; i < 1_0000; i++ {
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
}
|
||||
// list.RemoveIP("192.168.1.100", 1, false)
|
||||
for _, ip := range []string{"192.168.1.100", "192.168.2.100"} {
|
||||
expiresAt, ok := list.ContainsExpires(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, ip)
|
||||
t.Log(ok, expiresAt, timeutil.FormatTime("Y-m-d H:i:s", expiresAt))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIPList_Add(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
list := NewIPList(IPListTypeDeny)
|
||||
var list = NewIPList(IPListTypeDeny)
|
||||
for i := 0; i < b.N; i++ {
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
}
|
||||
@@ -57,7 +89,8 @@ func BenchmarkIPList_Add(b *testing.B) {
|
||||
func BenchmarkIPList_Has(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
list := NewIPList(IPListTypeDeny)
|
||||
var list = NewIPList(IPListTypeDeny)
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < 1_0000; i++ {
|
||||
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
|
||||
|
||||
Reference in New Issue
Block a user