WAF动作支持有效范围

This commit is contained in:
GoEdgeLab
2021-10-18 20:08:43 +08:00
parent 1a2681be03
commit c0ddfa2cf1
11 changed files with 73 additions and 35 deletions

View File

@@ -3,6 +3,7 @@
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf"
"net"
)
@@ -24,7 +25,7 @@ func (this *TrafficListener) Accept() (net.Conn, error) {
// 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackList.Contains(waf.IPTypeAll, ip) {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) && waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
defer func() {
_ = conn.Close()
}()

View File

@@ -2,5 +2,22 @@
package waf
import (
"net/http"
)
type BaseAction struct {
}
// CloseConn 关闭连接
func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
// 断开连接
hijack, ok := writer.(http.Hijacker)
if ok {
conn, _, err := hijack.Hijack()
if err == nil {
return conn.Close()
}
}
return nil
}

View File

@@ -24,6 +24,7 @@ type BlockAction struct {
Body string `yaml:"body" json:"body"` // supports HTML
URL string `yaml:"url" json:"url"`
Timeout int32 `yaml:"timeout" json:"timeout"`
Scope string `yaml:"scope" json:"scope"`
}
func (this *BlockAction) Init(waf *WAF) error {
@@ -62,7 +63,7 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
if timeout <= 0 {
timeout = 60 // 默认封锁60秒
}
SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(timeout))
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout))
if writer != nil {
// close the connection

View File

@@ -23,6 +23,7 @@ type CaptchaAction struct {
Life int32 `yaml:"life" json:"life"`
Language string `yaml:"language" json:"language"` // 语言zh-CN, en-US ...
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
Scope string `yaml:"scope" json:"scope"`
}
func (this *CaptchaAction) Init(waf *WAF) error {
@@ -43,7 +44,7 @@ func (this *CaptchaAction) WillChange() bool {
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// 是否在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true
}

View File

@@ -21,6 +21,7 @@ type Get302Action struct {
BaseAction
Life int32 `yaml:"life" json:"life"`
Scope string `yaml:"scope" json:"scope"`
}
func (this *Get302Action) Init(waf *WAF) error {
@@ -46,7 +47,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
}
// 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true
}
@@ -54,6 +55,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
"url": request.WAFRaw().URL.String(),
"timestamp": time.Now().Unix(),
"life": this.Life,
"scope": this.Scope,
"setId": set.Id,
}
info, err := utils.SimpleEncryptMap(m)
@@ -66,7 +68,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
// 关闭连接
if request.WAFRaw().ProtoMajor == 1 {
request.WAFClose()
_ = this.CloseConn(writer)
}
return true

View File

@@ -11,6 +11,7 @@ import (
type Post307Action struct {
Life int32 `yaml:"life" json:"life"`
Scope string `yaml:"scope" json:"scope"`
BaseAction
}
@@ -40,7 +41,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
}
// 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) {
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true
}
@@ -54,7 +55,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
life = 600 // 默认10分钟
}
var setId = m.GetString("setId")
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
SharedIPWhiteList.Add("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
return true
}
}
@@ -62,6 +63,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
var m = maps.Map{
"timestamp": time.Now().Unix(),
"life": this.Life,
"scope": this.Scope,
"setId": set.Id,
"remoteIP": request.WAFRemoteIP(),
}
@@ -82,7 +84,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
if request.WAFRaw().ProtoMajor == 1 {
request.WAFClose()
_ = this.CloseConn(writer)
}
return true

View File

@@ -58,6 +58,7 @@ type RecordIPAction struct {
IPListId int64 `yaml:"ipListId" json:"ipListId"`
Level string `yaml:"level" json:"level"`
Timeout int32 `yaml:"timeout" json:"timeout"`
Scope string `yaml:"scope" json:"scope"`
}
func (this *RecordIPAction) Init(waf *WAF) error {
@@ -78,11 +79,10 @@ func (this *RecordIPAction) WillChange() bool {
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// 是否在本地白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) {
if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true
}
// 先加入本地的黑名单
timeout := this.Timeout
if timeout <= 0 {
timeout = 86400 // 1天
@@ -94,14 +94,11 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
request.WAFClose()
SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt)
// 先加入本地的黑名单
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
} else {
// 加入本地白名单
timeout := this.Timeout
if timeout <= 0 {
timeout = 86400 // 1天
}
SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt)
SharedIPWhiteList.Add("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
}
// 上报

View File

@@ -143,7 +143,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64,
}
// 加入到白名单
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life))
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)

View File

@@ -44,7 +44,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW
life = 600 // 默认10分钟
}
setId := m.GetString("setId")
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life)
SharedIPWhiteList.Add("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
// 返回原始URL
var url = m.GetString("url")

View File

@@ -3,7 +3,9 @@
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/iwind/TeaGo/types"
"sync"
"sync/atomic"
)
@@ -43,8 +45,15 @@ func NewIPList() *IPList {
}
// Add 添加IP
func (this *IPList) Add(ipType string, ip string, expiresAt int64) {
ip = ip + "@" + ipType
func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = types.String(serverId) + "@" + ip + "@" + ipType
}
var id = this.nextId()
this.expireList.Add(id, expiresAt)
@@ -55,8 +64,15 @@ func (this *IPList) Add(ipType string, ip string, expiresAt int64) {
}
// Contains 判断是否有某个IP
func (this *IPList) Contains(ipType string, ip string) bool {
ip = ip + "@" + ipType
func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = types.String(serverId) + "@" + ip + "@" + ipType
}
this.locker.RLock()
defer this.locker.RUnlock()

View File

@@ -3,6 +3,7 @@
package waf
import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs"
"runtime"
@@ -13,11 +14,11 @@ import (
func TestNewIPList(t *testing.T) {
list := NewIPList()
list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix())
list.Add(IPTypeAll, "127.0.0.2", time.Now().Unix()+1)
list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()+2)
list.Add(IPTypeAll, "127.0.0.3", time.Now().Unix()+3)
list.Add(IPTypeAll, "127.0.0.10", time.Now().Unix()+10)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
var ticker = time.NewTicker(1 * time.Second)
for range ticker.C {
@@ -36,10 +37,10 @@ func TestIPList_Contains(t *testing.T) {
list := NewIPList()
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
a.IsTrue(list.Contains(IPTypeAll, "192.168.1.100"))
a.IsFalse(list.Contains(IPTypeAll, "192.168.2.100"))
a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
}
func BenchmarkIPList_Add(b *testing.B) {
@@ -47,7 +48,7 @@ func BenchmarkIPList_Add(b *testing.B) {
list := NewIPList()
for i := 0; i < b.N; i++ {
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
b.Log(len(list.ipMap))
}
@@ -58,10 +59,10 @@ func BenchmarkIPList_Has(b *testing.B) {
list := NewIPList()
for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
}
for i := 0; i < b.N; i++ {
list.Contains(IPTypeAll, "192.168.1.100")
list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")
}
}