From c0ddfa2cf19c770d7c19d4acf9571ea037a80bbb Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Mon, 18 Oct 2021 20:08:43 +0800 Subject: [PATCH] =?UTF-8?q?WAF=E5=8A=A8=E4=BD=9C=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E6=9C=89=E6=95=88=E8=8C=83=E5=9B=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/traffic_listener.go | 3 ++- internal/waf/action_base.go | 17 +++++++++++++++++ internal/waf/action_block.go | 3 ++- internal/waf/action_captcha.go | 3 ++- internal/waf/action_get_302.go | 8 +++++--- internal/waf/action_post_307.go | 10 ++++++---- internal/waf/action_record_ip.go | 13 +++++-------- internal/waf/captcha_validator.go | 2 +- internal/waf/get302_validator.go | 2 +- internal/waf/ip_list.go | 24 ++++++++++++++++++++---- internal/waf/ip_list_test.go | 23 ++++++++++++----------- 11 files changed, 73 insertions(+), 35 deletions(-) diff --git a/internal/nodes/traffic_listener.go b/internal/nodes/traffic_listener.go index c0be765..d68a285 100644 --- a/internal/nodes/traffic_listener.go +++ b/internal/nodes/traffic_listener.go @@ -3,6 +3,7 @@ package nodes import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/waf" "net" ) @@ -24,7 +25,7 @@ func (this *TrafficListener) Accept() (net.Conn, error) { // 是否在WAF名单中 ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) if err == nil { - if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackList.Contains(waf.IPTypeAll, ip) { + if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) && waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) { defer func() { _ = conn.Close() }() diff --git a/internal/waf/action_base.go b/internal/waf/action_base.go index d3a0026..b3cc227 100644 --- a/internal/waf/action_base.go +++ b/internal/waf/action_base.go @@ -2,5 +2,22 @@ package waf +import ( + "net/http" +) + type BaseAction struct { } + +// CloseConn 关闭连接 +func (this *BaseAction) CloseConn(writer http.ResponseWriter) error { + // 断开连接 + hijack, ok := writer.(http.Hijacker) + if ok { + conn, _, err := hijack.Hijack() + if err == nil { + return conn.Close() + } + } + return nil +} diff --git a/internal/waf/action_block.go b/internal/waf/action_block.go index e9ebdba..a367b7b 100644 --- a/internal/waf/action_block.go +++ b/internal/waf/action_block.go @@ -24,6 +24,7 @@ type BlockAction struct { Body string `yaml:"body" json:"body"` // supports HTML URL string `yaml:"url" json:"url"` Timeout int32 `yaml:"timeout" json:"timeout"` + Scope string `yaml:"scope" json:"scope"` } func (this *BlockAction) Init(waf *WAF) error { @@ -62,7 +63,7 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque if timeout <= 0 { timeout = 60 // 默认封锁60秒 } - SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(timeout)) + SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout)) if writer != nil { // close the connection diff --git a/internal/waf/action_captcha.go b/internal/waf/action_captcha.go index 4db432e..8414065 100644 --- a/internal/waf/action_captcha.go +++ b/internal/waf/action_captcha.go @@ -23,6 +23,7 @@ type CaptchaAction struct { Life int32 `yaml:"life" json:"life"` Language string `yaml:"language" json:"language"` // 语言,zh-CN, en-US ... AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单 + Scope string `yaml:"scope" json:"scope"` } func (this *CaptchaAction) Init(waf *WAF) error { @@ -43,7 +44,7 @@ func (this *CaptchaAction) WillChange() bool { func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { // 是否在白名单中 - if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { + if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { return true } diff --git a/internal/waf/action_get_302.go b/internal/waf/action_get_302.go index 9dffd05..ab230c2 100644 --- a/internal/waf/action_get_302.go +++ b/internal/waf/action_get_302.go @@ -20,7 +20,8 @@ const ( type Get302Action struct { BaseAction - Life int32 `yaml:"life" json:"life"` + Life int32 `yaml:"life" json:"life"` + Scope string `yaml:"scope" json:"scope"` } func (this *Get302Action) Init(waf *WAF) error { @@ -46,7 +47,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ } // 是否已经在白名单中 - if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { + if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { return true } @@ -54,6 +55,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ "url": request.WAFRaw().URL.String(), "timestamp": time.Now().Unix(), "life": this.Life, + "scope": this.Scope, "setId": set.Id, } info, err := utils.SimpleEncryptMap(m) @@ -66,7 +68,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ // 关闭连接 if request.WAFRaw().ProtoMajor == 1 { - request.WAFClose() + _ = this.CloseConn(writer) } return true diff --git a/internal/waf/action_post_307.go b/internal/waf/action_post_307.go index 4eb4393..49aa8ae 100644 --- a/internal/waf/action_post_307.go +++ b/internal/waf/action_post_307.go @@ -10,7 +10,8 @@ import ( ) type Post307Action struct { - Life int32 `yaml:"life" json:"life"` + Life int32 `yaml:"life" json:"life"` + Scope string `yaml:"scope" json:"scope"` BaseAction } @@ -40,7 +41,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req } // 是否已经在白名单中 - if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { + if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { return true } @@ -54,7 +55,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req life = 600 // 默认10分钟 } var setId = m.GetString("setId") - SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life) + SharedIPWhiteList.Add("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life) return true } } @@ -62,6 +63,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req var m = maps.Map{ "timestamp": time.Now().Unix(), "life": this.Life, + "scope": this.Scope, "setId": set.Id, "remoteIP": request.WAFRemoteIP(), } @@ -82,7 +84,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect) if request.WAFRaw().ProtoMajor == 1 { - request.WAFClose() + _ = this.CloseConn(writer) } return true diff --git a/internal/waf/action_record_ip.go b/internal/waf/action_record_ip.go index 58546b5..97dd148 100644 --- a/internal/waf/action_record_ip.go +++ b/internal/waf/action_record_ip.go @@ -58,6 +58,7 @@ type RecordIPAction struct { IPListId int64 `yaml:"ipListId" json:"ipListId"` Level string `yaml:"level" json:"level"` Timeout int32 `yaml:"timeout" json:"timeout"` + Scope string `yaml:"scope" json:"scope"` } func (this *RecordIPAction) Init(waf *WAF) error { @@ -78,11 +79,10 @@ func (this *RecordIPAction) WillChange() bool { func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { // 是否在本地白名单中 - if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) { + if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { return true } - // 先加入本地的黑名单 timeout := this.Timeout if timeout <= 0 { timeout = 86400 // 1天 @@ -94,14 +94,11 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re request.WAFClose() - SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt) + // 先加入本地的黑名单 + SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt) } else { // 加入本地白名单 - timeout := this.Timeout - if timeout <= 0 { - timeout = 86400 // 1天 - } - SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt) + SharedIPWhiteList.Add("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt) } // 上报 diff --git a/internal/waf/captcha_validator.go b/internal/waf/captcha_validator.go index 954bbdf..06b60d5 100644 --- a/internal/waf/captcha_validator.go +++ b/internal/waf/captcha_validator.go @@ -143,7 +143,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64, } // 加入到白名单 - SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO + SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther) diff --git a/internal/waf/get302_validator.go b/internal/waf/get302_validator.go index fbb562b..0228d28 100644 --- a/internal/waf/get302_validator.go +++ b/internal/waf/get302_validator.go @@ -44,7 +44,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW life = 600 // 默认10分钟 } setId := m.GetString("setId") - SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life) + SharedIPWhiteList.Add("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life) // 返回原始URL var url = m.GetString("url") diff --git a/internal/waf/ip_list.go b/internal/waf/ip_list.go index 5ade6ef..56d43d8 100644 --- a/internal/waf/ip_list.go +++ b/internal/waf/ip_list.go @@ -3,7 +3,9 @@ package waf import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeNode/internal/utils/expires" + "github.com/iwind/TeaGo/types" "sync" "sync/atomic" ) @@ -43,8 +45,15 @@ func NewIPList() *IPList { } // Add 添加IP -func (this *IPList) Add(ipType string, ip string, expiresAt int64) { - ip = ip + "@" + ipType +func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) { + switch scope { + case firewallconfigs.FirewallScopeGlobal: + ip = "*@" + ip + "@" + ipType + case firewallconfigs.FirewallScopeService: + ip = types.String(serverId) + "@" + ip + "@" + ipType + default: + ip = types.String(serverId) + "@" + ip + "@" + ipType + } var id = this.nextId() this.expireList.Add(id, expiresAt) @@ -55,8 +64,15 @@ func (this *IPList) Add(ipType string, ip string, expiresAt int64) { } // Contains 判断是否有某个IP -func (this *IPList) Contains(ipType string, ip string) bool { - ip = ip + "@" + ipType +func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool { + switch scope { + case firewallconfigs.FirewallScopeGlobal: + ip = "*@" + ip + "@" + ipType + case firewallconfigs.FirewallScopeService: + ip = types.String(serverId) + "@" + ip + "@" + ipType + default: + ip = types.String(serverId) + "@" + ip + "@" + ipType + } this.locker.RLock() defer this.locker.RUnlock() diff --git a/internal/waf/ip_list_test.go b/internal/waf/ip_list_test.go index c175f43..3da3e17 100644 --- a/internal/waf/ip_list_test.go +++ b/internal/waf/ip_list_test.go @@ -3,6 +3,7 @@ package waf import ( + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/logs" "runtime" @@ -13,11 +14,11 @@ import ( func TestNewIPList(t *testing.T) { list := NewIPList() - list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()) - list.Add(IPTypeAll, "127.0.0.2", time.Now().Unix()+1) - list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()+2) - list.Add(IPTypeAll, "127.0.0.3", time.Now().Unix()+3) - list.Add(IPTypeAll, "127.0.0.10", time.Now().Unix()+10) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10) var ticker = time.NewTicker(1 * time.Second) for range ticker.C { @@ -36,10 +37,10 @@ func TestIPList_Contains(t *testing.T) { list := NewIPList() for i := 0; i < 1_0000; i++ { - list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } - a.IsTrue(list.Contains(IPTypeAll, "192.168.1.100")) - a.IsFalse(list.Contains(IPTypeAll, "192.168.2.100")) + a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")) + a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100")) } func BenchmarkIPList_Add(b *testing.B) { @@ -47,7 +48,7 @@ func BenchmarkIPList_Add(b *testing.B) { list := NewIPList() for i := 0; i < b.N; i++ { - list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } b.Log(len(list.ipMap)) } @@ -58,10 +59,10 @@ func BenchmarkIPList_Has(b *testing.B) { list := NewIPList() for i := 0; i < 1_0000; i++ { - list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) + list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) } for i := 0; i < b.N; i++ { - list.Contains(IPTypeAll, "192.168.1.100") + list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100") } }