diff --git a/internal/firewalls/firewall_firewalld.go b/internal/firewalls/firewall_firewalld.go index 8789d20..65b4f7d 100644 --- a/internal/firewalls/firewall_firewalld.go +++ b/internal/firewalls/firewall_firewalld.go @@ -18,7 +18,7 @@ type Firewalld struct { func NewFirewalld() *Firewalld { var firewalld = &Firewalld{ - cmdQueue: make(chan *exec.Cmd, 2048), + cmdQueue: make(chan *exec.Cmd, 4096), } path, err := exec.LookPath("firewall-cmd") diff --git a/internal/iplibrary/list_utils.go b/internal/iplibrary/list_utils.go index 7ac5607..cdb0a5d 100644 --- a/internal/iplibrary/list_utils.go +++ b/internal/iplibrary/list_utils.go @@ -7,6 +7,7 @@ import ( ) // AllowIP 检查IP是否被允许访问 +// 如果一个IP不在任何名单中,则允许访问 func AllowIP(ip string, serverId int64) bool { var ipLong = utils.IP2Long(ip) if ipLong == 0 { @@ -40,6 +41,17 @@ func AllowIP(ip string, serverId int64) bool { return true } +// IsInWhiteList 检查IP是否在白名单中 +func IsInWhiteList(ip string) bool { + var ipLong = utils.IP2Long(ip) + if ipLong == 0 { + return false + } + + // check white lists + return GlobalWhiteIPList.Contains(ipLong) +} + // AllowIPStrings 检查一组IP是否被允许访问 func AllowIPStrings(ipStrings []string, serverId int64) bool { if len(ipStrings) == 0 { diff --git a/internal/nodes/client_conn.go b/internal/nodes/client_conn.go index c971478..33f2392 100644 --- a/internal/nodes/client_conn.go +++ b/internal/nodes/client_conn.go @@ -4,9 +4,16 @@ package nodes import ( "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" teaconst "github.com/TeaOSLab/EdgeNode/internal/const" + "github.com/TeaOSLab/EdgeNode/internal/iplibrary" "github.com/TeaOSLab/EdgeNode/internal/ratelimit" + "github.com/TeaOSLab/EdgeNode/internal/ttlcache" + "github.com/TeaOSLab/EdgeNode/internal/utils" + "github.com/TeaOSLab/EdgeNode/internal/waf" + "github.com/iwind/TeaGo/types" "net" + "os" "sync" "sync/atomic" "time" @@ -17,8 +24,9 @@ type ClientConn struct { once sync.Once globalLimiter *ratelimit.Counter - isTLS bool - hasRead bool + isTLS bool + hasDeadline bool + hasRead bool BaseClientConn } @@ -38,9 +46,9 @@ func NewClientConn(conn net.Conn, isTLS bool, quickClose bool, globalLimiter *ra func (this *ClientConn) Read(b []byte) (n int, err error) { if this.isTLS { - if !this.hasRead { + if !this.hasDeadline { _ = this.rawConn.SetReadDeadline(time.Now().Add(time.Duration(nodeconfigs.DefaultTLSHandshakeTimeout) * time.Second)) // TODO 握手超时时间可以设置 - this.hasRead = true + this.hasDeadline = true defer func() { _ = this.rawConn.SetReadDeadline(time.Time{}) }() @@ -50,7 +58,21 @@ func (this *ClientConn) Read(b []byte) (n int, err error) { n, err = this.rawConn.Read(b) if n > 0 { atomic.AddUint64(&teaconst.InTrafficBytes, uint64(n)) + this.hasRead = true } + + // SYN Flood检测 + var synFloodConfig = sharedNodeConfig.SYNFloodConfig() + if synFloodConfig != nil && synFloodConfig.IsOn { + if err != nil && os.IsTimeout(err) { + if !this.hasRead { + this.checkSYNFlood() + } + } else { + this.resetSYNFlood() + } + } + return } @@ -99,3 +121,32 @@ func (this *ClientConn) SetReadDeadline(t time.Time) error { func (this *ClientConn) SetWriteDeadline(t time.Time) error { return this.rawConn.SetWriteDeadline(t) } + +func (this *ClientConn) resetSYNFlood() { + // 为了不影响性能,暂时不清除状态 + //ttlcache.SharedCache.Delete("SYN_FLOOD:" + this.RawIP()) +} + +func (this *ClientConn) checkSYNFlood() { + var synFloodConfig = sharedNodeConfig.SYNFloodConfig() + if synFloodConfig == nil || !synFloodConfig.IsOn { + return + } + + var ip = this.RawIP() + if len(ip) > 0 && !iplibrary.IsInWhiteList(ip) && (!synFloodConfig.IgnoreLocal || !utils.IsLocalIP(ip)) { + var timestamp = (utils.UnixTime()/60)*60 + 60 + var result = ttlcache.SharedCache.IncreaseInt64("SYN_FLOOD:"+ip, 1, timestamp) + var minAttempts = synFloodConfig.MinAttempts + if minAttempts < 3 { + minAttempts = 3 + } + if result >= int64(minAttempts) { + var timeout = synFloodConfig.TimeoutSeconds + if timeout <= 0 { + timeout = 600 + } + waf.SharedIPBlackList.RecordIP(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip, time.Now().Unix()+int64(timeout), 0, true, 0, 0, "疑似SYN Flood攻击,当前1分钟"+types.String(result)+"次空连接") + } + } +} diff --git a/internal/nodes/client_conn_base.go b/internal/nodes/client_conn_base.go index aad784e..aa386ea 100644 --- a/internal/nodes/client_conn_base.go +++ b/internal/nodes/client_conn_base.go @@ -36,3 +36,8 @@ func (this *BaseClientConn) Bind(serverId int64, remoteAddr string, maxConnsPerS return sharedClientConnLimiter.Add(this.rawConn.RemoteAddr().String(), serverId, remoteAddr, maxConnsPerServer, maxConnsPerIP) } +// RawIP 原本IP +func (this *BaseClientConn) RawIP() string { + ip, _, _ := net.SplitHostPort(this.rawConn.RemoteAddr().String()) + return ip +} diff --git a/internal/nodes/waf_manager.go b/internal/nodes/waf_manager.go index 09a2e74..b622bcc 100644 --- a/internal/nodes/waf_manager.go +++ b/internal/nodes/waf_manager.go @@ -65,6 +65,7 @@ func (this *WAFManager) convertWAF(policy *firewallconfigs.HTTPFirewallPolicy) ( Name: policy.Name, Mode: policy.Mode, UseLocalFirewall: policy.UseLocalFirewall, + SYNFlood: policy.SYNFlood, } // inbound diff --git a/internal/ttlcache/cache.go b/internal/ttlcache/cache.go index 84a0d14..f7a10ff 100644 --- a/internal/ttlcache/cache.go +++ b/internal/ttlcache/cache.go @@ -5,6 +5,8 @@ import ( "time" ) +var SharedCache = NewCache() + // Cache TTL缓存 // 最大的缓存时间为30 * 86400 // Piece数据结构: diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index c041cb4..fffade1 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -64,7 +64,7 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque timeout = 60 // 默认封锁60秒 } - SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, waf.UseLocalFirewall, group.Id, set.Id) + SharedIPBlackList.RecordIP(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout), waf.Id, waf.UseLocalFirewall, group.Id, set.Id, "") if writer != nil { // close the connection diff --git a/internal/waf/action_post_307.go b/internal/waf/action_post_307.go index c494dde..c8c7c76 100644 --- a/internal/waf/action_post_307.go +++ b/internal/waf/action_post_307.go @@ -56,7 +56,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req life = 600 // 默认10分钟 } var setId = m.GetString("setId") - SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId")) + SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "") return true } } diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go index a04c983..ecf2fe8 100644 --- a/internal/waf/action_record_ip.go +++ b/internal/waf/action_record_ip.go @@ -22,6 +22,8 @@ type recordIPTask struct { level string serverId int64 + reason string + sourceServerId int64 sourceHTTPFirewallPolicyId int64 sourceHTTPFirewallRuleGroupId int64 @@ -44,12 +46,16 @@ func init() { if strings.Contains(task.ip, ":") { ipType = "ipv6" } + var reason = task.reason + if len(reason) == 0 { + reason = "触发WAF规则自动加入" + } _, err = rpcClient.IPItemRPC().CreateIPItem(rpcClient.Context(), &pb.CreateIPItemRequest{ IpListId: task.listId, IpFrom: task.ip, IpTo: "", ExpiredAt: task.expiredAt, - Reason: "触发WAF规则自动加入", + Reason: reason, Type: ipType, EventLevel: task.level, ServerId: task.serverId, diff --git a/internal/waf/captcha_validator.go b/internal/waf/captcha_validator.go index da5fb92..775db5b 100644 --- a/internal/waf/captcha_validator.go +++ b/internal/waf/captcha_validator.go @@ -153,7 +153,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int } // 加入到白名单 - SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, false, groupId, setId) + SharedIPWhiteList.RecordIP("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life), policyId, false, groupId, setId, "") http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther) diff --git a/internal/waf/get302_validator.go b/internal/waf/get302_validator.go index 4721a7f..1b16a43 100644 --- a/internal/waf/get302_validator.go +++ b/internal/waf/get302_validator.go @@ -44,7 +44,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW life = 600 // 默认10分钟 } setId := m.GetString("setId") - SharedIPWhiteList.RecordIP("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId")) + SharedIPWhiteList.RecordIP("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "") // 返回原始URL var url = m.GetString("url") diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 1807af4..b2f529b 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -81,7 +81,8 @@ func (this *IPList) RecordIP(ipType string, policyId int64, useLocalFirewall bool, groupId int64, - setId int64) { + setId int64, + reason string) { this.Add(ipType, scope, serverId, ip, expiresAt) if this.listType == IPListTypeDeny { @@ -97,6 +98,7 @@ func (this *IPList) RecordIP(ipType string, sourceHTTPFirewallPolicyId: policyId, sourceHTTPFirewallRuleGroupId: groupId, sourceHTTPFirewallRuleSetId: setId, + reason: reason, }: default: diff --git a/internal/waf/waf.go b/internal/waf/waf.go index 2a69a22..6c4a40f 100644 --- a/internal/waf/waf.go +++ b/internal/waf/waf.go @@ -15,14 +15,15 @@ import ( ) type WAF struct { - Id int64 `yaml:"id" json:"id"` - IsOn bool `yaml:"isOn" json:"isOn"` - Name string `yaml:"name" json:"name"` - Inbound []*RuleGroup `yaml:"inbound" json:"inbound"` - Outbound []*RuleGroup `yaml:"outbound" json:"outbound"` - CreatedVersion string `yaml:"createdVersion" json:"createdVersion"` - Mode firewallconfigs.FirewallMode `yaml:"mode" json:"mode"` - UseLocalFirewall bool `yaml:"useLocalFirewall" json:"useLocalFirewall"` + Id int64 `yaml:"id" json:"id"` + IsOn bool `yaml:"isOn" json:"isOn"` + Name string `yaml:"name" json:"name"` + Inbound []*RuleGroup `yaml:"inbound" json:"inbound"` + Outbound []*RuleGroup `yaml:"outbound" json:"outbound"` + CreatedVersion string `yaml:"createdVersion" json:"createdVersion"` + Mode firewallconfigs.FirewallMode `yaml:"mode" json:"mode"` + UseLocalFirewall bool `yaml:"useLocalFirewall" json:"useLocalFirewall"` + SYNFlood *firewallconfigs.SYNFloodConfig `yaml:"synFlood" json:"synFlood"` DefaultBlockAction *BlockAction