WAF动作支持有效范围

This commit is contained in:
GoEdgeLab
2021-10-18 20:08:43 +08:00
parent 1a2681be03
commit c0ddfa2cf1
11 changed files with 73 additions and 35 deletions

View File

@@ -3,6 +3,7 @@
package nodes package nodes
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf"
"net" "net"
) )
@@ -24,7 +25,7 @@ func (this *TrafficListener) Accept() (net.Conn, error) {
// 是否在WAF名单中 // 是否在WAF名单中
ip, _, err := net.SplitHostPort(conn.RemoteAddr().String()) ip, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err == nil { if err == nil {
if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, ip) && waf.SharedIPBlackList.Contains(waf.IPTypeAll, ip) { if !waf.SharedIPWhiteList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) && waf.SharedIPBlackList.Contains(waf.IPTypeAll, firewallconfigs.FirewallScopeGlobal, 0, ip) {
defer func() { defer func() {
_ = conn.Close() _ = conn.Close()
}() }()

View File

@@ -2,5 +2,22 @@
package waf package waf
import (
"net/http"
)
type BaseAction struct { type BaseAction struct {
} }
// CloseConn 关闭连接
func (this *BaseAction) CloseConn(writer http.ResponseWriter) error {
// 断开连接
hijack, ok := writer.(http.Hijacker)
if ok {
conn, _, err := hijack.Hijack()
if err == nil {
return conn.Close()
}
}
return nil
}

View File

@@ -24,6 +24,7 @@ type BlockAction struct {
Body string `yaml:"body" json:"body"` // supports HTML Body string `yaml:"body" json:"body"` // supports HTML
URL string `yaml:"url" json:"url"` URL string `yaml:"url" json:"url"`
Timeout int32 `yaml:"timeout" json:"timeout"` Timeout int32 `yaml:"timeout" json:"timeout"`
Scope string `yaml:"scope" json:"scope"`
} }
func (this *BlockAction) Init(waf *WAF) error { func (this *BlockAction) Init(waf *WAF) error {
@@ -62,7 +63,7 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
if timeout <= 0 { if timeout <= 0 {
timeout = 60 // 默认封锁60秒 timeout = 60 // 默认封锁60秒
} }
SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), time.Now().Unix()+int64(timeout)) SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(timeout))
if writer != nil { if writer != nil {
// close the connection // close the connection

View File

@@ -23,6 +23,7 @@ type CaptchaAction struct {
Life int32 `yaml:"life" json:"life"` Life int32 `yaml:"life" json:"life"`
Language string `yaml:"language" json:"language"` // 语言zh-CN, en-US ... Language string `yaml:"language" json:"language"` // 语言zh-CN, en-US ...
AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单 AddToWhiteList bool `yaml:"addToWhiteList" json:"addToWhiteList"` // 是否加入到白名单
Scope string `yaml:"scope" json:"scope"`
} }
func (this *CaptchaAction) Init(waf *WAF) error { func (this *CaptchaAction) Init(waf *WAF) error {
@@ -43,7 +44,7 @@ func (this *CaptchaAction) WillChange() bool {
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// 是否在白名单中 // 是否在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true return true
} }

View File

@@ -21,6 +21,7 @@ type Get302Action struct {
BaseAction BaseAction
Life int32 `yaml:"life" json:"life"` Life int32 `yaml:"life" json:"life"`
Scope string `yaml:"scope" json:"scope"`
} }
func (this *Get302Action) Init(waf *WAF) error { func (this *Get302Action) Init(waf *WAF) error {
@@ -46,7 +47,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
} }
// 是否已经在白名单中 // 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true return true
} }
@@ -54,6 +55,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
"url": request.WAFRaw().URL.String(), "url": request.WAFRaw().URL.String(),
"timestamp": time.Now().Unix(), "timestamp": time.Now().Unix(),
"life": this.Life, "life": this.Life,
"scope": this.Scope,
"setId": set.Id, "setId": set.Id,
} }
info, err := utils.SimpleEncryptMap(m) info, err := utils.SimpleEncryptMap(m)
@@ -66,7 +68,7 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
// 关闭连接 // 关闭连接
if request.WAFRaw().ProtoMajor == 1 { if request.WAFRaw().ProtoMajor == 1 {
request.WAFClose() _ = this.CloseConn(writer)
} }
return true return true

View File

@@ -11,6 +11,7 @@ import (
type Post307Action struct { type Post307Action struct {
Life int32 `yaml:"life" json:"life"` Life int32 `yaml:"life" json:"life"`
Scope string `yaml:"scope" json:"scope"`
BaseAction BaseAction
} }
@@ -40,7 +41,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
} }
// 是否已经在白名单中 // 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, request.WAFRemoteIP()) { if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true return true
} }
@@ -54,7 +55,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
life = 600 // 默认10分钟 life = 600 // 默认10分钟
} }
var setId = m.GetString("setId") var setId = m.GetString("setId")
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life) SharedIPWhiteList.Add("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
return true return true
} }
} }
@@ -62,6 +63,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
var m = maps.Map{ var m = maps.Map{
"timestamp": time.Now().Unix(), "timestamp": time.Now().Unix(),
"life": this.Life, "life": this.Life,
"scope": this.Scope,
"setId": set.Id, "setId": set.Id,
"remoteIP": request.WAFRemoteIP(), "remoteIP": request.WAFRemoteIP(),
} }
@@ -82,7 +84,7 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect) http.Redirect(writer, request.WAFRaw(), request.WAFRaw().URL.String(), http.StatusTemporaryRedirect)
if request.WAFRaw().ProtoMajor == 1 { if request.WAFRaw().ProtoMajor == 1 {
request.WAFClose() _ = this.CloseConn(writer)
} }
return true return true

View File

@@ -58,6 +58,7 @@ type RecordIPAction struct {
IPListId int64 `yaml:"ipListId" json:"ipListId"` IPListId int64 `yaml:"ipListId" json:"ipListId"`
Level string `yaml:"level" json:"level"` Level string `yaml:"level" json:"level"`
Timeout int32 `yaml:"timeout" json:"timeout"` Timeout int32 `yaml:"timeout" json:"timeout"`
Scope string `yaml:"scope" json:"scope"`
} }
func (this *RecordIPAction) Init(waf *WAF) error { func (this *RecordIPAction) Init(waf *WAF) error {
@@ -78,11 +79,10 @@ func (this *RecordIPAction) WillChange() bool {
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) { func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (allow bool) {
// 是否在本地白名单中 // 是否在本地白名单中
if SharedIPWhiteList.Contains("set:"+set.Id, set.Id) { if SharedIPWhiteList.Contains("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true return true
} }
// 先加入本地的黑名单
timeout := this.Timeout timeout := this.Timeout
if timeout <= 0 { if timeout <= 0 {
timeout = 86400 // 1天 timeout = 86400 // 1天
@@ -94,14 +94,11 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
request.WAFClose() request.WAFClose()
SharedIPBlackList.Add(IPTypeAll, request.WAFRemoteIP(), expiredAt) // 先加入本地的黑名单
SharedIPBlackList.Add(IPTypeAll, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
} else { } else {
// 加入本地白名单 // 加入本地白名单
timeout := this.Timeout SharedIPWhiteList.Add("set:"+set.Id, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), expiredAt)
if timeout <= 0 {
timeout = 86400 // 1天
}
SharedIPWhiteList.Add("set:"+set.Id, request.WAFRemoteIP(), expiredAt)
} }
// 上报 // 上报

View File

@@ -143,7 +143,7 @@ func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, setId int64,
} }
// 加入到白名单 // 加入到白名单
SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), request.WAFRemoteIP(), time.Now().Unix()+int64(life)) // TODO SharedIPWhiteList.Add("set:"+strconv.FormatInt(setId, 10), actionConfig.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+int64(life))
http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther) http.Redirect(writer, request.WAFRaw(), originURL, http.StatusSeeOther)

View File

@@ -44,7 +44,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW
life = 600 // 默认10分钟 life = 600 // 默认10分钟
} }
setId := m.GetString("setId") setId := m.GetString("setId")
SharedIPWhiteList.Add("set:"+setId, request.WAFRemoteIP(), time.Now().Unix()+life) SharedIPWhiteList.Add("set:"+setId, m.GetString("scope"), request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life)
// 返回原始URL // 返回原始URL
var url = m.GetString("url") var url = m.GetString("url")

View File

@@ -3,7 +3,9 @@
package waf package waf
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeNode/internal/utils/expires" "github.com/TeaOSLab/EdgeNode/internal/utils/expires"
"github.com/iwind/TeaGo/types"
"sync" "sync"
"sync/atomic" "sync/atomic"
) )
@@ -43,8 +45,15 @@ func NewIPList() *IPList {
} }
// Add 添加IP // Add 添加IP
func (this *IPList) Add(ipType string, ip string, expiresAt int64) { func (this *IPList) Add(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string, expiresAt int64) {
ip = ip + "@" + ipType switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = types.String(serverId) + "@" + ip + "@" + ipType
}
var id = this.nextId() var id = this.nextId()
this.expireList.Add(id, expiresAt) this.expireList.Add(id, expiresAt)
@@ -55,8 +64,15 @@ func (this *IPList) Add(ipType string, ip string, expiresAt int64) {
} }
// Contains 判断是否有某个IP // Contains 判断是否有某个IP
func (this *IPList) Contains(ipType string, ip string) bool { func (this *IPList) Contains(ipType string, scope firewallconfigs.FirewallScope, serverId int64, ip string) bool {
ip = ip + "@" + ipType switch scope {
case firewallconfigs.FirewallScopeGlobal:
ip = "*@" + ip + "@" + ipType
case firewallconfigs.FirewallScopeService:
ip = types.String(serverId) + "@" + ip + "@" + ipType
default:
ip = types.String(serverId) + "@" + ip + "@" + ipType
}
this.locker.RLock() this.locker.RLock()
defer this.locker.RUnlock() defer this.locker.RUnlock()

View File

@@ -3,6 +3,7 @@
package waf package waf
import ( import (
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
"runtime" "runtime"
@@ -13,11 +14,11 @@ import (
func TestNewIPList(t *testing.T) { func TestNewIPList(t *testing.T) {
list := NewIPList() list := NewIPList()
list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix())
list.Add(IPTypeAll, "127.0.0.2", time.Now().Unix()+1) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.2", time.Now().Unix()+1)
list.Add(IPTypeAll, "127.0.0.1", time.Now().Unix()+2) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.1", time.Now().Unix()+2)
list.Add(IPTypeAll, "127.0.0.3", time.Now().Unix()+3) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.3", time.Now().Unix()+3)
list.Add(IPTypeAll, "127.0.0.10", time.Now().Unix()+10) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "127.0.0.10", time.Now().Unix()+10)
var ticker = time.NewTicker(1 * time.Second) var ticker = time.NewTicker(1 * time.Second)
for range ticker.C { for range ticker.C {
@@ -36,10 +37,10 @@ func TestIPList_Contains(t *testing.T) {
list := NewIPList() list := NewIPList()
for i := 0; i < 1_0000; i++ { for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
} }
a.IsTrue(list.Contains(IPTypeAll, "192.168.1.100")) a.IsTrue(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100"))
a.IsFalse(list.Contains(IPTypeAll, "192.168.2.100")) a.IsFalse(list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.2.100"))
} }
func BenchmarkIPList_Add(b *testing.B) { func BenchmarkIPList_Add(b *testing.B) {
@@ -47,7 +48,7 @@ func BenchmarkIPList_Add(b *testing.B) {
list := NewIPList() list := NewIPList()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
} }
b.Log(len(list.ipMap)) b.Log(len(list.ipMap))
} }
@@ -58,10 +59,10 @@ func BenchmarkIPList_Has(b *testing.B) {
list := NewIPList() list := NewIPList()
for i := 0; i < 1_0000; i++ { for i := 0; i < 1_0000; i++ {
list.Add(IPTypeAll, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600) list.Add(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1."+strconv.Itoa(i), time.Now().Unix()+3600)
} }
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
list.Contains(IPTypeAll, "192.168.1.100") list.Contains(IPTypeAll, firewallconfigs.FirewallScopeGlobal, 1, "192.168.1.100")
} }
} }