优化WAF黑名单处理

This commit is contained in:
GoEdgeLab
2023-03-31 21:37:15 +08:00
parent 20c802c51d
commit e016029c8e
11 changed files with 157 additions and 84 deletions

View File

@@ -0,0 +1,29 @@
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
package firewalls
import (
"time"
)
// DropTemporaryTo 使用本地防火墙临时拦截IP数据包
func DropTemporaryTo(ip string, expiresAt int64) {
if expiresAt <= 1 {
return
}
var timeout = expiresAt - time.Now().Unix()
if timeout < 1 {
return
}
if timeout > 3600 {
timeout = 3600
}
// 使用本地防火墙延长封禁
var fw = Firewall()
if fw != nil && !fw.IsMock() {
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600否则将有整型溢出的风险
_ = fw.DropSourceIP(ip, int(timeout), true)
}
}

View File

@@ -72,6 +72,25 @@ func (this *IPList) Contains(ip uint64) bool {
return item != nil return item != nil
} }
// ContainsExpires 判断是否包含某个IP
func (this *IPList) ContainsExpires(ip uint64) (expiresAt int64, ok bool) {
this.locker.RLock()
if len(this.allItemsMap) > 0 {
this.locker.RUnlock()
return 0, true
}
var item = this.lookupIP(ip)
this.locker.RUnlock()
if item == nil {
return
}
return item.ExpiredAt, true
}
// ContainsIPStrings 是否包含一组IP中的任意一个并返回匹配的第一个Item // ContainsIPStrings 是否包含一组IP中的任意一个并返回匹配的第一个Item
func (this *IPList) ContainsIPStrings(ipStrings []string) (item *IPItem, found bool) { func (this *IPList) ContainsIPStrings(ipStrings []string) (item *IPItem, found bool) {
if len(ipStrings) == 0 { if len(ipStrings) == 0 {
@@ -155,7 +174,7 @@ func (this *IPList) addItem(item *IPItem, sortable bool) {
this.locker.Unlock() this.locker.Unlock()
} }
// 对列表进行排序 // 对列表进行排序
func (this *IPList) sortItems() { func (this *IPList) sortItems() {
sort.Slice(this.sortedItems, func(i, j int) bool { sort.Slice(this.sortedItems, func(i, j int) bool {
var item1 = this.sortedItems[i] var item1 = this.sortedItems[i]

View File

@@ -10,50 +10,54 @@ import (
// AllowIP 检查IP是否被允许访问 // AllowIP 检查IP是否被允许访问
// 如果一个IP不在任何名单中则允许访问 // 如果一个IP不在任何名单中则允许访问
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) { func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expiresAt int64) {
if !Tea.IsTesting() { // 如果在测试环境,我们不加入一些白名单,以便于可以在本地和局域网正常测试 if !Tea.IsTesting() { // 如果在测试环境,我们不加入一些白名单,以便于可以在本地和局域网正常测试
// 放行lo // 放行lo
if ip == "127.0.0.1" || ip == "::1" { if ip == "127.0.0.1" || ip == "::1" {
return true, true return true, true, 0
} }
// check node // check node
nodeConfig, err := nodeconfigs.SharedNodeConfig() nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err == nil && nodeConfig.IPIsAutoAllowed(ip) { if err == nil && nodeConfig.IPIsAutoAllowed(ip) {
return true, true return true, true, 0
} }
} }
var ipLong = utils.IP2Long(ip) var ipLong = utils.IP2Long(ip)
if ipLong == 0 { if ipLong == 0 {
return false, false return false, false, 0
} }
// check white lists // check white lists
if GlobalWhiteIPList.Contains(ipLong) { if GlobalWhiteIPList.Contains(ipLong) {
return true, true return true, true, 0
} }
if serverId > 0 { if serverId > 0 {
var list = SharedServerListManager.FindWhiteList(serverId, false) var list = SharedServerListManager.FindWhiteList(serverId, false)
if list != nil && list.Contains(ipLong) { if list != nil && list.Contains(ipLong) {
return true, true return true, true, 0
} }
} }
// check black lists // check black lists
if GlobalBlackIPList.Contains(ipLong) { expiresAt, ok := GlobalBlackIPList.ContainsExpires(ipLong)
return false, false if ok {
return false, false, expiresAt
} }
if serverId > 0 { if serverId > 0 {
var list = SharedServerListManager.FindBlackList(serverId, false) var list = SharedServerListManager.FindBlackList(serverId, false)
if list != nil && list.Contains(ipLong) { if list != nil {
return false, false expiresAt, ok = list.ContainsExpires(ipLong)
if ok {
return false, false, expiresAt
}
} }
} }
return true, false return true, false, 0
} }
// IsInWhiteList 检查IP是否在白名单中 // IsInWhiteList 检查IP是否在白名单中
@@ -73,7 +77,7 @@ func AllowIPStrings(ipStrings []string, serverId int64) bool {
return true return true
} }
for _, ip := range ipStrings { for _, ip := range ipStrings {
isAllowed, _ := AllowIP(ip, serverId) isAllowed, _, _ := AllowIP(ip, serverId)
if !isAllowed { if !isAllowed {
return false return false
} }

View File

@@ -4,6 +4,9 @@ package nodes
import ( import (
"crypto/tls" "crypto/tls"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"net" "net"
) )
@@ -61,6 +64,20 @@ func (this *BaseClientConn) SetServerId(serverId int64) {
case *ClientConn: case *ClientConn:
conn.SetServerId(serverId) conn.SetServerId(serverId)
} }
// 检查服务相关IP黑名单
if serverId > 0 && len(this.rawIP) > 0 {
var list = iplibrary.SharedServerListManager.FindBlackList(serverId, false)
if list != nil {
expiresAt, ok := list.ContainsExpires(configutils.IPString2Long(this.rawIP))
if ok {
_ = this.rawConn.Close()
if expiresAt > 0 {
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
}
}
}
}
} }
// ServerId 读取当前连接绑定的服务ID // ServerId 读取当前连接绑定的服务ID

View File

@@ -8,7 +8,6 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/iplibrary"
"github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf"
"net" "net"
"time"
) )
// ClientListener 客户端网络监听 // ClientListener 客户端网络监听
@@ -43,24 +42,19 @@ func (this *ClientListener) Accept() (net.Conn, error) {
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
var isInAllowList = false var isInAllowList = false
if err == nil { if err == nil {
canGoNext, inAllowList := iplibrary.AllowIP(ip, 0) canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
isInAllowList = inAllowList isInAllowList = inAllowList
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) { if !canGoNext {
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) if expiresAt > 0 {
if ok { firewalls.DropTemporaryTo(ip, expiresAt)
var timeout = expiresAt - time.Now().Unix() }
if timeout > 0 { } else {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
if ok {
canGoNext = false canGoNext = false
if expiresAt > 0 {
if timeout > 3600 { firewalls.DropTemporaryTo(ip, expiresAt)
timeout = 3600
}
// 使用本地防火墙延长封禁
var fw = firewalls.Firewall()
if fw != nil && !fw.IsMock() {
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600否则将有整型溢出的风险
_ = fw.DropSourceIP(ip, int(timeout), true)
} }
} }
} }

View File

@@ -9,7 +9,7 @@ import (
func (this *HTTPRequest) doRequestLimit() (shouldStop bool) { func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
// 是否在全局名单中 // 是否在全局名单中
_, isInAllowedList := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id) _, isInAllowedList, _ := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
if isInAllowedList { if isInAllowedList {
return false return false
} }

View File

@@ -35,7 +35,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
} }
// 是否在全局名单中 // 是否在全局名单中
canGoNext, isInAllowedList := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id) canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
if !canGoNext { if !canGoNext {
this.disableLog = true this.disableLog = true
this.Close() this.Close()

View File

@@ -119,13 +119,7 @@ func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64,
var countFails = ttlcache.SharedCache.IncreaseInt64(key, 1, time.Now().Unix()+300, true) var countFails = ttlcache.SharedCache.IncreaseInt64(key, 1, time.Now().Unix()+300, true)
if int(countFails) >= maxFails { if int(countFails) >= maxFails {
var useLocalFirewall = false SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次")
if this.Scope == firewallconfigs.FirewallScopeGlobal {
useLocalFirewall = true
}
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次")
return false return false
} }

View File

@@ -30,7 +30,7 @@ type recordIPTask struct {
sourceHTTPFirewallRuleSetId int64 sourceHTTPFirewallRuleSetId int64
} }
var recordIPTaskChan = make(chan *recordIPTask, 1024) var recordIPTaskChan = make(chan *recordIPTask, 2048)
func init() { func init() {
if !teaconst.IsMain { if !teaconst.IsMain {
@@ -45,32 +45,56 @@ func init() {
return return
} }
for task := range recordIPTaskChan { const maxItems = 512 // 每次上传的最大IP数
ipType := "ipv4"
if strings.Contains(task.ip, ":") { for {
ipType = "ipv6" var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
}
var reason = task.reason func() {
if len(reason) == 0 { for {
reason = "触发WAF规则自动加入" select {
} case task := <-recordIPTaskChan:
_, err = rpcClient.IPItemRPC.CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{ var ipType = "ipv4"
IpListId: task.listId, if strings.Contains(task.ip, ":") {
IpFrom: task.ip, ipType = "ipv6"
IpTo: "", }
ExpiredAt: task.expiresAt, var reason = task.reason
Reason: reason, if len(reason) == 0 {
Type: ipType, reason = "触发WAF规则自动加入"
EventLevel: task.level, }
ServerId: task.serverId,
SourceNodeId: teaconst.NodeId, pbItems = append(pbItems, &pb.CreateIPItemsRequest_IPItem{
SourceServerId: task.sourceServerId, IpListId: task.listId,
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId, IpFrom: task.ip,
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId, IpTo: "",
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId, ExpiredAt: task.expiresAt,
}) Reason: reason,
if err != nil { Type: ipType,
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error()) EventLevel: task.level,
ServerId: task.serverId,
SourceNodeId: teaconst.NodeId,
SourceServerId: task.sourceServerId,
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId,
})
if len(pbItems) >= maxItems {
return
}
default:
return
}
}
}()
if len(pbItems) > 0 {
_, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems})
if err != nil {
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
}
} else {
time.Sleep(1 * time.Second)
} }
} }
}) })

View File

@@ -29,13 +29,7 @@ func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, pol
} }
var countFails = ttlcache.SharedCache.IncreaseInt64(CaptchaCacheKey(req, pageCode), 1, time.Now().Unix()+300, true) var countFails = ttlcache.SharedCache.IncreaseInt64(CaptchaCacheKey(req, pageCode), 1, time.Now().Unix()+300, true)
if int(countFails) >= maxFails { if int(countFails) >= maxFails {
var useLocalFirewall = false SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次")
if actionConfig.FailBlockScopeAll {
useLocalFirewall = true
}
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次")
return false return false
} }
} }

View File

@@ -10,7 +10,6 @@ import (
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
) )
var SharedIPWhiteList = NewIPList(IPListTypeAllow) var SharedIPWhiteList = NewIPList(IPListTypeAllow)
@@ -95,6 +94,12 @@ func (this *IPList) RecordIP(ipType string,
this.Add(ipType, scope, serverId, ip, expiresAt) this.Add(ipType, scope, serverId, ip, expiresAt)
if this.listType == IPListTypeDeny { if this.listType == IPListTypeDeny {
// 作用域
var scopeServerId int64
if scope == firewallconfigs.FirewallScopeService {
scopeServerId = serverId
}
// 加入队列等待上传 // 加入队列等待上传
select { select {
case recordIPTaskChan <- &recordIPTask{ case recordIPTaskChan <- &recordIPTask{
@@ -102,7 +107,7 @@ func (this *IPList) RecordIP(ipType string,
listId: firewallconfigs.GlobalListId, listId: firewallconfigs.GlobalListId,
expiresAt: expiresAt, expiresAt: expiresAt,
level: firewallconfigs.DefaultEventLevel, level: firewallconfigs.DefaultEventLevel,
serverId: serverId, serverId: scopeServerId,
sourceServerId: serverId, sourceServerId: serverId,
sourceHTTPFirewallPolicyId: policyId, sourceHTTPFirewallPolicyId: policyId,
sourceHTTPFirewallRuleGroupId: groupId, sourceHTTPFirewallRuleGroupId: groupId,
@@ -114,15 +119,8 @@ func (this *IPList) RecordIP(ipType string,
} }
// 使用本地防火墙 // 使用本地防火墙
if useLocalFirewall { if useLocalFirewall && expiresAt > 0 {
var seconds = expiresAt - time.Now().Unix() firewalls.DropTemporaryTo(ip, expiresAt)
if seconds > 0 {
// 最大3600防止误封时间过长
if seconds > 3600 {
seconds = 3600
}
_ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true)
}
} }
// 关闭此IP相关连接 // 关闭此IP相关连接