mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-08 03:00:27 +08:00
优化WAF黑名单处理
This commit is contained in:
29
internal/firewalls/utils.go
Normal file
29
internal/firewalls/utils.go
Normal file
@@ -0,0 +1,29 @@
|
||||
// Copyright 2023 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn .
|
||||
|
||||
package firewalls
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// DropTemporaryTo 使用本地防火墙临时拦截IP数据包
|
||||
func DropTemporaryTo(ip string, expiresAt int64) {
|
||||
if expiresAt <= 1 {
|
||||
return
|
||||
}
|
||||
|
||||
var timeout = expiresAt - time.Now().Unix()
|
||||
if timeout < 1 {
|
||||
return
|
||||
}
|
||||
if timeout > 3600 {
|
||||
timeout = 3600
|
||||
}
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
var fw = Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险
|
||||
_ = fw.DropSourceIP(ip, int(timeout), true)
|
||||
}
|
||||
}
|
||||
@@ -72,6 +72,25 @@ func (this *IPList) Contains(ip uint64) bool {
|
||||
return item != nil
|
||||
}
|
||||
|
||||
// ContainsExpires 判断是否包含某个IP
|
||||
func (this *IPList) ContainsExpires(ip uint64) (expiresAt int64, ok bool) {
|
||||
this.locker.RLock()
|
||||
if len(this.allItemsMap) > 0 {
|
||||
this.locker.RUnlock()
|
||||
return 0, true
|
||||
}
|
||||
|
||||
var item = this.lookupIP(ip)
|
||||
|
||||
this.locker.RUnlock()
|
||||
|
||||
if item == nil {
|
||||
return
|
||||
}
|
||||
|
||||
return item.ExpiredAt, true
|
||||
}
|
||||
|
||||
// ContainsIPStrings 是否包含一组IP中的任意一个,并返回匹配的第一个Item
|
||||
func (this *IPList) ContainsIPStrings(ipStrings []string) (item *IPItem, found bool) {
|
||||
if len(ipStrings) == 0 {
|
||||
@@ -155,7 +174,7 @@ func (this *IPList) addItem(item *IPItem, sortable bool) {
|
||||
this.locker.Unlock()
|
||||
}
|
||||
|
||||
// 对列表进行排序
|
||||
// 对列表进行排序
|
||||
func (this *IPList) sortItems() {
|
||||
sort.Slice(this.sortedItems, func(i, j int) bool {
|
||||
var item1 = this.sortedItems[i]
|
||||
|
||||
@@ -10,50 +10,54 @@ import (
|
||||
|
||||
// AllowIP 检查IP是否被允许访问
|
||||
// 如果一个IP不在任何名单中,则允许访问
|
||||
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool) {
|
||||
func AllowIP(ip string, serverId int64) (canGoNext bool, inAllowList bool, expiresAt int64) {
|
||||
if !Tea.IsTesting() { // 如果在测试环境,我们不加入一些白名单,以便于可以在本地和局域网正常测试
|
||||
// 放行lo
|
||||
if ip == "127.0.0.1" || ip == "::1" {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
|
||||
// check node
|
||||
nodeConfig, err := nodeconfigs.SharedNodeConfig()
|
||||
if err == nil && nodeConfig.IPIsAutoAllowed(ip) {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
}
|
||||
|
||||
var ipLong = utils.IP2Long(ip)
|
||||
if ipLong == 0 {
|
||||
return false, false
|
||||
return false, false, 0
|
||||
}
|
||||
|
||||
// check white lists
|
||||
if GlobalWhiteIPList.Contains(ipLong) {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindWhiteList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return true, true
|
||||
return true, true, 0
|
||||
}
|
||||
}
|
||||
|
||||
// check black lists
|
||||
if GlobalBlackIPList.Contains(ipLong) {
|
||||
return false, false
|
||||
expiresAt, ok := GlobalBlackIPList.ContainsExpires(ipLong)
|
||||
if ok {
|
||||
return false, false, expiresAt
|
||||
}
|
||||
|
||||
if serverId > 0 {
|
||||
var list = SharedServerListManager.FindBlackList(serverId, false)
|
||||
if list != nil && list.Contains(ipLong) {
|
||||
return false, false
|
||||
if list != nil {
|
||||
expiresAt, ok = list.ContainsExpires(ipLong)
|
||||
if ok {
|
||||
return false, false, expiresAt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true, false
|
||||
return true, false, 0
|
||||
}
|
||||
|
||||
// IsInWhiteList 检查IP是否在白名单中
|
||||
@@ -73,7 +77,7 @@ func AllowIPStrings(ipStrings []string, serverId int64) bool {
|
||||
return true
|
||||
}
|
||||
for _, ip := range ipStrings {
|
||||
isAllowed, _ := AllowIP(ip, serverId)
|
||||
isAllowed, _, _ := AllowIP(ip, serverId)
|
||||
if !isAllowed {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -4,6 +4,9 @@ package nodes
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/firewalls"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"net"
|
||||
)
|
||||
|
||||
@@ -61,6 +64,20 @@ func (this *BaseClientConn) SetServerId(serverId int64) {
|
||||
case *ClientConn:
|
||||
conn.SetServerId(serverId)
|
||||
}
|
||||
|
||||
// 检查服务相关IP黑名单
|
||||
if serverId > 0 && len(this.rawIP) > 0 {
|
||||
var list = iplibrary.SharedServerListManager.FindBlackList(serverId, false)
|
||||
if list != nil {
|
||||
expiresAt, ok := list.ContainsExpires(configutils.IPString2Long(this.rawIP))
|
||||
if ok {
|
||||
_ = this.rawConn.Close()
|
||||
if expiresAt > 0 {
|
||||
firewalls.DropTemporaryTo(this.rawIP, expiresAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ServerId 读取当前连接绑定的服务ID
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"github.com/TeaOSLab/EdgeNode/internal/iplibrary"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/waf"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClientListener 客户端网络监听
|
||||
@@ -43,24 +42,19 @@ func (this *ClientListener) Accept() (net.Conn, error) {
|
||||
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
var isInAllowList = false
|
||||
if err == nil {
|
||||
canGoNext, inAllowList := iplibrary.AllowIP(ip, 0)
|
||||
canGoNext, inAllowList, expiresAt := iplibrary.AllowIP(ip, 0)
|
||||
isInAllowList = inAllowList
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
var timeout = expiresAt - time.Now().Unix()
|
||||
if timeout > 0 {
|
||||
if !canGoNext {
|
||||
if expiresAt > 0 {
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
}
|
||||
} else {
|
||||
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
|
||||
expiresAt, ok := waf.SharedIPBlackList.ContainsExpires(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip)
|
||||
if ok {
|
||||
canGoNext = false
|
||||
|
||||
if timeout > 3600 {
|
||||
timeout = 3600
|
||||
}
|
||||
|
||||
// 使用本地防火墙延长封禁
|
||||
var fw = firewalls.Firewall()
|
||||
if fw != nil && !fw.IsMock() {
|
||||
// 这里 int(int64) 转换的前提是限制了 timeout <= 3600,否则将有整型溢出的风险
|
||||
_ = fw.DropSourceIP(ip, int(timeout), true)
|
||||
if expiresAt > 0 {
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func (this *HTTPRequest) doRequestLimit() (shouldStop bool) {
|
||||
// 是否在全局名单中
|
||||
_, isInAllowedList := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
|
||||
_, isInAllowedList, _ := iplibrary.AllowIP(this.RemoteAddr(), this.ReqServer.Id)
|
||||
if isInAllowedList {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
|
||||
}
|
||||
|
||||
// 是否在全局名单中
|
||||
canGoNext, isInAllowedList := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
|
||||
canGoNext, isInAllowedList, _ := iplibrary.AllowIP(remoteAddr, this.ReqServer.Id)
|
||||
if !canGoNext {
|
||||
this.disableLog = true
|
||||
this.Close()
|
||||
|
||||
@@ -119,13 +119,7 @@ func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64,
|
||||
|
||||
var countFails = ttlcache.SharedCache.IncreaseInt64(key, 1, time.Now().Unix()+300, true)
|
||||
if int(countFails) >= maxFails {
|
||||
var useLocalFirewall = false
|
||||
|
||||
if this.Scope == firewallconfigs.FirewallScopeGlobal {
|
||||
useLocalFirewall = true
|
||||
}
|
||||
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次")
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "JS_COOKIE验证连续失败超过"+types.String(maxFails)+"次")
|
||||
return false
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ type recordIPTask struct {
|
||||
sourceHTTPFirewallRuleSetId int64
|
||||
}
|
||||
|
||||
var recordIPTaskChan = make(chan *recordIPTask, 1024)
|
||||
var recordIPTaskChan = make(chan *recordIPTask, 2048)
|
||||
|
||||
func init() {
|
||||
if !teaconst.IsMain {
|
||||
@@ -45,32 +45,56 @@ func init() {
|
||||
return
|
||||
}
|
||||
|
||||
for task := range recordIPTaskChan {
|
||||
ipType := "ipv4"
|
||||
if strings.Contains(task.ip, ":") {
|
||||
ipType = "ipv6"
|
||||
}
|
||||
var reason = task.reason
|
||||
if len(reason) == 0 {
|
||||
reason = "触发WAF规则自动加入"
|
||||
}
|
||||
_, err = rpcClient.IPItemRPC.CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiresAt,
|
||||
Reason: reason,
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
ServerId: task.serverId,
|
||||
SourceNodeId: teaconst.NodeId,
|
||||
SourceServerId: task.sourceServerId,
|
||||
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
|
||||
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
|
||||
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId,
|
||||
})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
|
||||
const maxItems = 512 // 每次上传的最大IP数
|
||||
|
||||
for {
|
||||
var pbItems = []*pb.CreateIPItemsRequest_IPItem{}
|
||||
|
||||
func() {
|
||||
for {
|
||||
select {
|
||||
case task := <-recordIPTaskChan:
|
||||
var ipType = "ipv4"
|
||||
if strings.Contains(task.ip, ":") {
|
||||
ipType = "ipv6"
|
||||
}
|
||||
var reason = task.reason
|
||||
if len(reason) == 0 {
|
||||
reason = "触发WAF规则自动加入"
|
||||
}
|
||||
|
||||
pbItems = append(pbItems, &pb.CreateIPItemsRequest_IPItem{
|
||||
IpListId: task.listId,
|
||||
IpFrom: task.ip,
|
||||
IpTo: "",
|
||||
ExpiredAt: task.expiresAt,
|
||||
Reason: reason,
|
||||
Type: ipType,
|
||||
EventLevel: task.level,
|
||||
ServerId: task.serverId,
|
||||
SourceNodeId: teaconst.NodeId,
|
||||
SourceServerId: task.sourceServerId,
|
||||
SourceHTTPFirewallPolicyId: task.sourceHTTPFirewallPolicyId,
|
||||
SourceHTTPFirewallRuleGroupId: task.sourceHTTPFirewallRuleGroupId,
|
||||
SourceHTTPFirewallRuleSetId: task.sourceHTTPFirewallRuleSetId,
|
||||
})
|
||||
|
||||
if len(pbItems) >= maxItems {
|
||||
return
|
||||
}
|
||||
default:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if len(pbItems) > 0 {
|
||||
_, err = rpcClient.IPItemRPC.CreateIPItems(rpcClient.Context(), &pb.CreateIPItemsRequest{IpItems: pbItems})
|
||||
if err != nil {
|
||||
remotelogs.Error("WAF_RECORD_IP_ACTION", "create ip item failed: "+err.Error())
|
||||
}
|
||||
} else {
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -29,13 +29,7 @@ func CaptchaIncreaseFails(req requests.Request, actionConfig *CaptchaAction, pol
|
||||
}
|
||||
var countFails = ttlcache.SharedCache.IncreaseInt64(CaptchaCacheKey(req, pageCode), 1, time.Now().Unix()+300, true)
|
||||
if int(countFails) >= maxFails {
|
||||
var useLocalFirewall = false
|
||||
|
||||
if actionConfig.FailBlockScopeAll {
|
||||
useLocalFirewall = true
|
||||
}
|
||||
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, useLocalFirewall, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次")
|
||||
SharedIPBlackList.RecordIP(IPTypeAll, firewallconfigs.FirewallScopeService, req.WAFServerId(), req.WAFRemoteIP(), time.Now().Unix()+int64(failBlockTimeout), policyId, true, groupId, setId, "CAPTCHA验证连续失败超过"+types.String(maxFails)+"次")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var SharedIPWhiteList = NewIPList(IPListTypeAllow)
|
||||
@@ -95,6 +94,12 @@ func (this *IPList) RecordIP(ipType string,
|
||||
this.Add(ipType, scope, serverId, ip, expiresAt)
|
||||
|
||||
if this.listType == IPListTypeDeny {
|
||||
// 作用域
|
||||
var scopeServerId int64
|
||||
if scope == firewallconfigs.FirewallScopeService {
|
||||
scopeServerId = serverId
|
||||
}
|
||||
|
||||
// 加入队列等待上传
|
||||
select {
|
||||
case recordIPTaskChan <- &recordIPTask{
|
||||
@@ -102,7 +107,7 @@ func (this *IPList) RecordIP(ipType string,
|
||||
listId: firewallconfigs.GlobalListId,
|
||||
expiresAt: expiresAt,
|
||||
level: firewallconfigs.DefaultEventLevel,
|
||||
serverId: serverId,
|
||||
serverId: scopeServerId,
|
||||
sourceServerId: serverId,
|
||||
sourceHTTPFirewallPolicyId: policyId,
|
||||
sourceHTTPFirewallRuleGroupId: groupId,
|
||||
@@ -114,15 +119,8 @@ func (this *IPList) RecordIP(ipType string,
|
||||
}
|
||||
|
||||
// 使用本地防火墙
|
||||
if useLocalFirewall {
|
||||
var seconds = expiresAt - time.Now().Unix()
|
||||
if seconds > 0 {
|
||||
// 最大3600,防止误封时间过长
|
||||
if seconds > 3600 {
|
||||
seconds = 3600
|
||||
}
|
||||
_ = firewalls.Firewall().DropSourceIP(ip, int(seconds), true)
|
||||
}
|
||||
if useLocalFirewall && expiresAt > 0 {
|
||||
firewalls.DropTemporaryTo(ip, expiresAt)
|
||||
}
|
||||
|
||||
// 关闭此IP相关连接
|
||||
|
||||
Reference in New Issue
Block a user