WAF允许动作默认跳过所有规则

This commit is contained in:
刘祥超
2024-01-20 20:54:41 +08:00
parent 7d11b3c63b
commit 095c381ae5
22 changed files with 558 additions and 161 deletions

View File

@@ -67,8 +67,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
// 当前服务的独立设置 // 当前服务的独立设置
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false) blockedRequest, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
if blocked { if blockedRequest {
return true return true
} }
if breakChecking { if breakChecking {
@@ -78,8 +78,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
// 公用的防火墙设置 // 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn { if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules) blockedRequest, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
if blocked { if blockedRequest {
return true return true
} }
if breakChecking { if breakChecking {
@@ -266,8 +266,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return return
} }
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType) result, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() { if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true this.wafHasRequestBody = true
} }
if err != nil { if err != nil {
@@ -277,28 +280,28 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return return
} }
if ruleSet != nil { if result.Set != nil {
if forceLog { if forceLog {
this.forceLog = true this.forceLog = true
} }
if ruleSet.HasSpecialActions() { if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id) this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id) this.firewallRuleSetId = types.Int64(result.Set.Id)
if ruleSet.HasAttackActions() { if result.Set.HasAttackActions() {
this.isAttack = true this.isAttack = true
} }
// 添加统计 // 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions) stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
} }
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode) this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
} }
return !goNext, false return !result.GoNext, breakChecking
} }
// call response waf // call response waf
@@ -316,23 +319,26 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
} }
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn { if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked = this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false) blockedRequest, breakChecking := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
if blocked { if blockedRequest {
return true return true
} }
if breakChecking {
return
}
} }
// 公用的防火墙设置 // 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn { if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked = this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules) blockedRequest, _ := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
if blocked { if blockedRequest {
return true return true
} }
} }
return return
} }
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool) { func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool, breakChecking bool) {
if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass { if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
return return
} }
@@ -347,8 +353,11 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return return
} }
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer) result, err := w.MatchResponse(this, resp, this.writer)
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() { if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true this.wafHasRequestBody = true
} }
if err != nil { if err != nil {
@@ -358,28 +367,28 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return return
} }
if ruleSet != nil { if result.Set != nil {
if forceLog { if forceLog {
this.forceLog = true this.forceLog = true
} }
if ruleSet.HasSpecialActions() { if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id) this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id) this.firewallRuleSetId = types.Int64(result.Set.Id)
if ruleSet.HasAttackActions() { if result.Set.HasAttackActions() {
this.isAttack = true this.isAttack = true
} }
// 添加统计 // 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions) stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
} }
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode) this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
} }
return !goNext return !result.GoNext, breakChecking
} }
// WAFRaw 原始请求 // WAFRaw 原始请求

View File

@@ -5,8 +5,18 @@ import (
"net/http" "net/http"
) )
type AllowScope = string
const (
AllowScopeGroup AllowScope = "group"
AllowScopeServer AllowScope = "server"
AllowScopeGlobal AllowScope = "global"
)
type AllowAction struct { type AllowAction struct {
BaseAction BaseAction
Scope AllowScope `yaml:"scope" json:"scope"`
} }
func (this *AllowAction) Init(waf *WAF) error { func (this *AllowAction) Init(waf *WAF) error {
@@ -25,7 +35,12 @@ func (this *AllowAction) WillChange() bool {
return true return true
} }
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// do nothing // do nothing
return true, false return PerformResult{
ContinueRequest: true,
GoNextGroup: this.Scope == AllowScopeGroup,
IsAllowed: true,
AllowScope: this.Scope,
}
} }

View File

@@ -61,7 +61,7 @@ func (this *BlockAction) WillChange() bool {
return true return true
} }
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// 加入到黑名单 // 加入到黑名单
var timeout = this.Timeout var timeout = this.Timeout
if timeout <= 0 { if timeout <= 0 {
@@ -93,14 +93,14 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
req, err := http.NewRequest(http.MethodGet, this.URL, nil) req, err := http.NewRequest(http.MethodGet, this.URL, nil)
if err != nil { if err != nil {
logs.Error(err) logs.Error(err)
return false, false return PerformResult{}
} }
req.Header.Set("User-Agent", teaconst.GlobalProductName+"/"+teaconst.Version) req.Header.Set("User-Agent", teaconst.GlobalProductName+"/"+teaconst.Version)
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
logs.Error(err) logs.Error(err)
return false, false return PerformResult{}
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = resp.Body.Close()
@@ -124,11 +124,11 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
logs.Error(err) logs.Error(err)
return false, false return PerformResult{}
} }
_, _ = writer.Write(data) _, _ = writer.Write(data)
} }
return false, false return PerformResult{}
} }
if len(this.Body) > 0 { if len(this.Body) > 0 {
_, _ = writer.Write([]byte(this.Body)) _, _ = writer.Write([]byte(this.Body))
@@ -137,5 +137,5 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
} }
} }
return false, false return PerformResult{}
} }

View File

@@ -123,10 +123,12 @@ func (this *CaptchaAction) WillChange() bool {
return true return true
} }
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
// 是否在白名单中 // 是否在白名单中
if SharedIPWhiteList.Contains(wafutils.ComposeIPType(set.Id, req), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) { if SharedIPWhiteList.Contains(wafutils.ComposeIPType(set.Id, req), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
var refURL = req.WAFRaw().URL.String() var refURL = req.WAFRaw().URL.String()
@@ -153,7 +155,9 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
info, err := utils.SimpleEncryptMap(captchaConfig) info, err := utils.SimpleEncryptMap(captchaConfig)
if err != nil { if err != nil {
remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error()) remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
return true, false return PerformResult{
ContinueRequest: true,
}
} }
// 占用一次失败次数 // 占用一次失败次数
@@ -163,5 +167,5 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect) req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect) http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
return false, false return PerformResult{}
} }

View File

@@ -41,15 +41,19 @@ func (this *Get302Action) WillChange() bool {
return true return true
} }
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// 仅限于Get // 仅限于Get
if request.WAFRaw().Method != http.MethodGet { if request.WAFRaw().Method != http.MethodGet {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
// 是否已经在白名单中 // 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
var m = maps.Map{ var m = maps.Map{
@@ -64,7 +68,9 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
info, err := utils.SimpleEncryptMap(m) info, err := utils.SimpleEncryptMap(m)
if err != nil { if err != nil {
remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error()) remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
return true, false return PerformResult{
ContinueRequest: true,
}
} }
request.DisableStat() request.DisableStat()
@@ -76,5 +82,5 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
flusher.Flush() flusher.Flush()
} }
return false, false return PerformResult{}
} }

View File

@@ -29,20 +29,29 @@ func (this *GoGroupAction) WillChange() bool {
return true return true
} }
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId)) var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
if nextGroup == nil || !nextGroup.IsOn { if nextGroup == nil || !nextGroup.IsOn {
return true, true return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
} }
b, _, nextSet, err := nextGroup.MatchRequest(request) b, _, nextSet, err := nextGroup.MatchRequest(request)
if err != nil { if err != nil {
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error()) remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
return true, false return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
} }
if !b { if !b {
return true, false return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
} }
return nextSet.PerformActions(waf, nextGroup, request, writer) return nextSet.PerformActions(waf, nextGroup, request, writer)

View File

@@ -30,23 +30,35 @@ func (this *GoSetAction) WillChange() bool {
return true return true
} }
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId)) var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
if nextGroup == nil || !nextGroup.IsOn { if nextGroup == nil || !nextGroup.IsOn {
return true, true return PerformResult{
ContinueRequest: true,
GoNextSet: true,
} }
nextSet := nextGroup.FindRuleSet(types.Int64(this.SetId)) }
var nextSet = nextGroup.FindRuleSet(types.Int64(this.SetId))
if nextSet == nil || !nextSet.IsOn { if nextSet == nil || !nextSet.IsOn {
return true, true return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
} }
b, _, err := nextSet.MatchRequest(request) b, _, err := nextSet.MatchRequest(request)
if err != nil { if err != nil {
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error()) remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
return true, false return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
} }
if !b { if !b {
return true, false return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
} }
return nextSet.PerformActions(waf, nextGroup, request, writer) return nextSet.PerformActions(waf, nextGroup, request, writer)
} }

View File

@@ -27,5 +27,5 @@ type ActionInterface interface {
WillChange() bool WillChange() bool
// Perform the action // Perform the action
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult
} }

View File

@@ -42,15 +42,19 @@ func (this *JSCookieAction) WillChange() bool {
return true return true
} }
func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
// 是否在白名单中 // 是否在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) { if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
nodeConfig, err := nodeconfigs.SharedNodeConfig() nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err != nil { if err != nil {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
var life = this.Life var life = this.Life
@@ -69,7 +73,9 @@ func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
var timestamp = pieces[0] var timestamp = pieces[0]
var sum = pieces[2] var sum = pieces[2]
if types.Int64(timestamp) >= time.Now().Unix()-int64(life) && fmt.Sprintf("%x", md5.Sum([]byte(timestamp+"@"+types.String(set.Id)+"@"+nodeConfig.NodeId))) == sum { if types.Int64(timestamp) >= time.Now().Unix()-int64(life) && fmt.Sprintf("%x", md5.Sum([]byte(timestamp+"@"+types.String(set.Id)+"@"+nodeConfig.NodeId))) == sum {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
} }
} }
@@ -103,7 +109,7 @@ window.location.reload();
// 记录失败次数 // 记录失败次数
this.increaseFails(req, waf.Id, group.Id, set.Id) this.increaseFails(req, waf.Id, group.Id, set.Id)
return false, false return PerformResult{}
} }
func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64) (goNext bool) { func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64) (goNext bool) {

View File

@@ -25,6 +25,8 @@ func (this *LogAction) WillChange() bool {
return false return false
} }
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
return true, false return PerformResult{
ContinueRequest: true,
}
} }

View File

@@ -76,7 +76,7 @@ func (this *NotifyAction) WillChange() bool {
} }
// Perform the action // Perform the action
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
select { select {
case notifyChan <- &notifyTask{ case notifyChan <- &notifyTask{
ServerId: request.WAFServerId(), ServerId: request.WAFServerId(),
@@ -89,5 +89,7 @@ func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
} }
return true, false return PerformResult{
ContinueRequest: true,
}
} }

View File

@@ -45,9 +45,9 @@ func (this *PageAction) WillChange() bool {
} }
// Perform the action // Perform the action
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
if writer == nil { if writer == nil {
return return PerformResult{}
} }
request.ProcessResponseHeaders(writer.Header(), this.Status) request.ProcessResponseHeaders(writer.Header(), this.Status)
@@ -73,5 +73,5 @@ func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reques
} }
_, _ = writer.Write([]byte(request.Format(body))) _, _ = writer.Write([]byte(request.Format(body)))
return false, false return PerformResult{}
} }

View File

@@ -34,17 +34,21 @@ func (this *Post307Action) WillChange() bool {
return true return true
} }
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var cookieName = "WAF_VALIDATOR_ID" var cookieName = "WAF_VALIDATOR_ID"
// 仅限于POST // 仅限于POST
if request.WAFRaw().Method != http.MethodPost { if request.WAFRaw().Method != http.MethodPost {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
// 是否已经在白名单中 // 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true, false return PerformResult{
ContinueRequest: true,
}
} }
// 判断是否有Cookie // 判断是否有Cookie
@@ -58,7 +62,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
} }
var setId = types.String(m.GetInt64("setId")) var setId = types.String(m.GetInt64("setId"))
SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "") SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "")
return true, false return PerformResult{
ContinueRequest: true,
}
} }
} }
@@ -74,7 +80,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
info, err := utils.SimpleEncryptMap(m) info, err := utils.SimpleEncryptMap(m)
if err != nil { if err != nil {
remotelogs.Error("WAF_POST_307_ACTION", "encode info failed: "+err.Error()) remotelogs.Error("WAF_POST_307_ACTION", "encode info failed: "+err.Error())
return true, false return PerformResult{
ContinueRequest: true,
}
} }
// 清空请求内容 // 清空请求内容
@@ -101,5 +109,5 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
flusher.Flush() flusher.Flush()
} }
return false, false return PerformResult{}
} }

View File

@@ -132,7 +132,7 @@ func (this *RecordIPAction) WillChange() bool {
return this.Type == "black" return this.Type == "black"
} }
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var ipListId = this.IPListId var ipListId = this.IPListId
if ipListId <= 0 { if ipListId <= 0 {
ipListId = firewallconfigs.GlobalListId ipListId = firewallconfigs.GlobalListId
@@ -143,7 +143,11 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
// 是否在本地白名单中 // 是否在本地白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) { if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true, false return PerformResult{
ContinueRequest: true,
IsAllowed: true,
AllowScope: AllowScopeGlobal,
}
} }
var timeout = this.Timeout var timeout = this.Timeout
@@ -200,5 +204,10 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
} }
} }
return this.Type != "black", false var isWhite = this.Type != "black"
return PerformResult{
ContinueRequest: isWhite,
IsAllowed: isWhite,
AllowScope: AllowScopeGlobal,
}
} }

View File

@@ -35,10 +35,10 @@ func (this *RedirectAction) WillChange() bool {
} }
// Perform the action // Perform the action
func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
request.ProcessResponseHeaders(writer.Header(), this.Status) request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Location", this.URL) writer.Header().Set("Location", this.URL)
writer.WriteHeader(this.Status) writer.WriteHeader(this.Status)
return false, false return PerformResult{}
} }

View File

@@ -27,6 +27,8 @@ func (this *TagAction) WillChange() bool {
return false return false
} }
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
return true, true return PerformResult{
ContinueRequest: true,
}
} }

22
internal/waf/results.go Normal file
View File

@@ -0,0 +1,22 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf
// PerformResult action performing result
type PerformResult struct {
ContinueRequest bool
GoNextGroup bool
GoNextSet bool
IsAllowed bool
AllowScope AllowScope
}
// MatchResult request match result
type MatchResult struct {
GoNext bool
HasRequestBody bool
Group *RuleGroup
Set *RuleSet
IsAllowed bool
AllowScope AllowScope
}

View File

@@ -34,6 +34,9 @@ type RuleSet struct {
actionCodes []string actionCodes []string
actionInstances []ActionInterface actionInstances []ActionInterface
hasAllowActions bool
allowScope string
hasRules bool hasRules bool
} }
@@ -62,6 +65,12 @@ func (this *RuleSet) Init(waf *WAF) error {
// action codes // action codes
var actionCodes = []string{} var actionCodes = []string{}
for _, action := range this.Actions { for _, action := range this.Actions {
if action.Code == ActionAllow {
this.hasAllowActions = true
if action.Options != nil {
this.allowScope = action.Options.GetString("scope")
}
}
if !lists.ContainsString(actionCodes, action.Code) { if !lists.ContainsString(actionCodes, action.Code) {
actionCodes = append(actionCodes, action.Code) actionCodes = append(actionCodes, action.Code)
} }
@@ -141,19 +150,37 @@ func (this *RuleSet) ActionCodes() []string {
return this.actionCodes return this.actionCodes
} }
func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) { func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) PerformResult {
if len(waf.Mode) != 0 && waf.Mode != firewallconfigs.FirewallModeDefend { if len(waf.Mode) != 0 && waf.Mode != firewallconfigs.FirewallModeDefend {
return true, false return PerformResult{
ContinueRequest: true,
} }
}
var isAllowed = this.hasAllowActions
var allowScope = this.allowScope
var continueRequest bool
var goNextGroup bool
var goNextSet bool
// 先执行allow // 先执行allow
for _, instance := range this.actionInstances { for _, instance := range this.actionInstances {
if !instance.WillChange() { if !instance.WillChange() {
continueRequest = req.WAFOnAction(instance) continueRequest = req.WAFOnAction(instance)
if !continueRequest { if !continueRequest {
return false, false return PerformResult{
IsAllowed: isAllowed,
AllowScope: allowScope,
}
}
var performResult = instance.Perform(waf, group, this, req, writer)
continueRequest = performResult.ContinueRequest
goNextSet = performResult.GoNextSet
if performResult.IsAllowed {
isAllowed = true
allowScope = performResult.AllowScope
goNextGroup = performResult.GoNextGroup
} }
_, goNextSet = instance.Perform(waf, group, this, req, writer)
} }
} }
@@ -163,13 +190,36 @@ func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Req
if instance.WillChange() { if instance.WillChange() {
continueRequest = req.WAFOnAction(instance) continueRequest = req.WAFOnAction(instance)
if !continueRequest { if !continueRequest {
return false, false return PerformResult{
IsAllowed: isAllowed,
AllowScope: allowScope,
}
}
var performResult = instance.Perform(waf, group, this, req, writer)
continueRequest = performResult.ContinueRequest
goNextSet = performResult.GoNextSet
if performResult.IsAllowed {
isAllowed = true
allowScope = performResult.AllowScope
goNextGroup = performResult.GoNextGroup
}
return PerformResult{
ContinueRequest: performResult.ContinueRequest,
GoNextGroup: goNextGroup,
GoNextSet: performResult.GoNextSet,
IsAllowed: isAllowed,
AllowScope: allowScope,
} }
return instance.Perform(waf, group, this, req, writer)
} }
} }
return true, goNextSet return PerformResult{
ContinueRequest: true,
GoNextGroup: goNextGroup,
GoNextSet: goNextSet,
IsAllowed: isAllowed,
AllowScope: allowScope,
}
} }
func (this *RuleSet) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, err error) { func (this *RuleSet) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, err error) {

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/cespare/xxhash" "github.com/cespare/xxhash"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
"net/http" "net/http"
"regexp" "regexp"
"runtime" "runtime"
@@ -74,6 +75,52 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
a.IsTrue(set.MatchRequest(req)) a.IsTrue(set.MatchRequest(req))
} }
func TestRuleSet_MatchRequest_Allow(t *testing.T) {
var a = assert.NewAssertion(t)
var set = waf.NewRuleSet()
set.Connector = waf.RuleConnectorOr
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "hello",
},
}
set.Actions = []*waf.ActionConfig{
{
Code: "allow",
Options: maps.Map{
"scope": waf.AllowScopeGroup,
},
},
}
var wafInstance = waf.NewWAF()
err := set.Init(wafInstance)
if err != nil {
t.Fatal(err)
}
rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
if err != nil {
t.Fatal(err)
}
var req = requests.NewTestRequest(rawReq)
b, _, err := set.MatchRequest(req)
if err != nil {
t.Fatal(err)
}
a.IsTrue(b)
var result = set.PerformActions(wafInstance, &waf.RuleGroup{}, req, nil)
a.IsTrue(result.IsAllowed)
t.Log("scope:", result.AllowScope)
}
func BenchmarkRuleSet_MatchRequest(b *testing.B) { func BenchmarkRuleSet_MatchRequest(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)

View File

@@ -52,18 +52,18 @@ func Test_Template2(t *testing.T) {
} }
now := time.Now() now := time.Now()
goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log(time.Since(now).Seconds()*1000, "ms") t.Log(time.Since(now).Seconds()*1000, "ms")
if goNext { if result.GoNext {
t.Log("ok") t.Log("ok")
return return
} }
logs.PrintAsJSON(set, t) logs.PrintAsJSON(result.Set, t)
} }
func BenchmarkTemplate(b *testing.B) { func BenchmarkTemplate(b *testing.B) {
@@ -84,7 +84,7 @@ func BenchmarkTemplate(b *testing.B) {
} }
req.Header.Set("User-Agent", testUserAgent) req.Header.Set("User-Agent", testUserAgent)
_, _, _, _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
} }
}) })
} }
@@ -103,13 +103,13 @@ func testTemplate1010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent) req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "1010") a.IsTrue(result.Set.Code == "1010")
} else { } else {
t.Log("break at:", id) t.Log("break at:", id)
} }
@@ -125,13 +125,13 @@ func testTemplate1010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent) req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNil(result) a.IsNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "1010") a.IsTrue(result.Set.Code == "1010")
} }
} }
} }
@@ -192,13 +192,13 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req.Header.Add("Content-Type", writer.FormDataContentType()) req.Header.Add("Content-Type", writer.FormDataContentType())
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "2001") a.IsTrue(result.Set.Code == "2001")
} }
} }
@@ -207,13 +207,13 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "3001") a.IsTrue(result.Set.Code == "3001")
} }
} }
@@ -222,13 +222,13 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "4001") a.IsTrue(result.Set.Code == "4001")
} }
} }
@@ -238,13 +238,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "5001") a.IsTrue(result.Set.Code == "5001")
} }
} }
@@ -253,13 +253,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "5001") a.IsTrue(result.Set.Code == "5001")
} }
} }
} }
@@ -271,13 +271,13 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent) req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(result.Code == "6001") a.IsTrue(result.Set.Code == "6001")
} }
} }
@@ -286,11 +286,11 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
} }
} }
@@ -325,13 +325,13 @@ func testTemplate7010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", testUserAgent) req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(lists.ContainsAny([]string{"7010"}, result.Code)) a.IsTrue(lists.ContainsAny([]string{"7010"}, result.Set.Code))
} else { } else {
t.Log("break:", id) t.Log("break:", id)
} }
@@ -423,13 +423,13 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err) t.Fatal(err)
} }
req.Header.Set("User-Agent", bot) req.Header.Set("User-Agent", bot)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.IsNotNil(result) a.IsNotNil(result.Set)
if result != nil { if result.Set != nil {
a.IsTrue(lists.ContainsAny([]string{"20001"}, result.Code)) a.IsTrue(lists.ContainsAny([]string{"20001"}, result.Set.Code))
} else { } else {
t.Log("break:", bot) t.Log("break:", bot)
} }

View File

@@ -40,6 +40,7 @@ type WAF struct {
func NewWAF() *WAF { func NewWAF() *WAF {
return &WAF{ return &WAF{
IsOn: true, IsOn: true,
actionMap: map[int64]ActionInterface{},
} }
} }
@@ -243,9 +244,11 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) {
this.Outbound = result this.Outbound = result
} }
func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) (goNext bool, hasRequestBody bool, resultGroup *RuleGroup, resultSet *RuleSet, err error) { func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) (result MatchResult, err error) {
if !this.hasInboundRules { if !this.hasInboundRules {
return true, hasRequestBody, nil, nil, nil return MatchResult{
GoNext: true,
}, nil
} }
// validate captcha // validate captcha
@@ -266,51 +269,87 @@ func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter,
} }
// match rules // match rules
var hasRequestBody bool
for _, group := range this.Inbound { for _, group := range this.Inbound {
if !group.IsOn { if !group.IsOn {
continue continue
} }
b, hasCheckedRequestBody, set, err := group.MatchRequest(req) b, hasCheckedRequestBody, set, matchErr := group.MatchRequest(req)
if hasCheckedRequestBody { if hasCheckedRequestBody {
hasRequestBody = true hasRequestBody = true
} }
if err != nil { if matchErr != nil {
return true, hasRequestBody, nil, nil, err return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, matchErr
} }
if b { if b {
continueRequest, goNextSet := set.PerformActions(this, group, req, writer) var performResult = set.PerformActions(this, group, req, writer)
if !goNextSet { if !performResult.GoNextSet {
return continueRequest, hasRequestBody, group, set, nil if performResult.GoNextGroup {
continue
}
return MatchResult{
GoNext: performResult.ContinueRequest,
HasRequestBody: hasRequestBody,
Group: group,
Set: set,
IsAllowed: performResult.IsAllowed,
AllowScope: performResult.AllowScope,
}, nil
} }
} }
} }
return true, hasRequestBody, nil, nil, nil return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, nil
} }
func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, hasRequestBody bool, resultGroup *RuleGroup, resultSet *RuleSet, err error) { func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (result MatchResult, err error) {
if !this.hasOutboundRules { if !this.hasOutboundRules {
return true, hasRequestBody, nil, nil, nil return MatchResult{
GoNext: true,
}, nil
} }
resp := requests.NewResponse(rawResp) var hasRequestBody bool
var resp = requests.NewResponse(rawResp)
for _, group := range this.Outbound { for _, group := range this.Outbound {
if !group.IsOn { if !group.IsOn {
continue continue
} }
b, hasCheckedRequestBody, set, err := group.MatchResponse(req, resp) b, hasCheckedRequestBody, set, matchErr := group.MatchResponse(req, resp)
if hasCheckedRequestBody { if hasCheckedRequestBody {
hasRequestBody = true hasRequestBody = true
} }
if err != nil { if matchErr != nil {
return true, hasRequestBody, nil, nil, err return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, matchErr
} }
if b { if b {
continueRequest, goNextSet := set.PerformActions(this, group, req, writer) var performResult = set.PerformActions(this, group, req, writer)
if !goNextSet { if !performResult.GoNextSet {
return continueRequest, hasRequestBody, group, set, nil if performResult.GoNextGroup {
continue
}
return MatchResult{
GoNext: performResult.ContinueRequest,
HasRequestBody: hasRequestBody,
Group: group,
Set: set,
IsAllowed: performResult.IsAllowed,
AllowScope: performResult.AllowScope,
}, nil
} }
} }
} }
return true, hasRequestBody, nil, nil, nil return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, nil
} }
// Save to file path // Save to file path

View File

@@ -5,6 +5,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf" "github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests" "github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert" "github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
"net/http" "net/http"
"testing" "testing"
) )
@@ -44,7 +45,7 @@ func TestWAF_MatchRequest(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone) result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -52,6 +53,160 @@ func TestWAF_MatchRequest(t *testing.T) {
t.Log("not match") t.Log("not match")
return return
} }
t.Log("goNext:", goNext, "set:", set.Name) t.Log("goNext:", result.GoNext, "set:", set.Name)
a.IsFalse(goNext) a.IsFalse(result.GoNext)
}
func TestWAF_MatchRequest_Allow(t *testing.T) {
var a = assert.NewAssertion(t)
var wafInstance = waf.NewWAF()
{
var set = waf.NewRuleSet()
set.Id = 1
set.Name = "set1"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "hello",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "global",
})
var group = waf.NewRuleGroup()
group.Id = 1
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
{
var set = waf.NewRuleSet()
set.Id = 2
set.Name = "set2"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "he",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "global",
})
var group = waf.NewRuleGroup()
group.Id = 2
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
errs := wafInstance.Init()
if len(errs) > 0 {
t.Fatal(errs[0])
}
req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
if err != nil {
t.Fatal(err)
}
result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
if result.Set == nil {
t.Log("not match")
return
}
t.Log("goNext:", result.GoNext, "set:", result.Set.Name)
a.IsTrue(result.Set.Id == 1)
a.IsTrue(result.GoNext)
a.IsTrue(result.IsAllowed)
a.IsTrue(result.AllowScope == "global")
}
func TestWAF_MatchRequest_Allow2(t *testing.T) {
var a = assert.NewAssertion(t)
var wafInstance = waf.NewWAF()
{
var set = waf.NewRuleSet()
set.Id = 1
set.Name = "set1"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "hello",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "group",
})
var group = waf.NewRuleGroup()
group.Id = 1
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
{
var set = waf.NewRuleSet()
set.Id = 2
set.Name = "set2"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "he",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "global",
})
var group = waf.NewRuleGroup()
group.Id = 2
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
errs := wafInstance.Init()
if len(errs) > 0 {
t.Fatal(errs[0])
}
req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
if err != nil {
t.Fatal(err)
}
result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
if result.Set == nil {
t.Log("not match")
return
}
t.Log("goNext:", result.GoNext, "set:", result.Set.Name)
a.IsTrue(result.Set.Id == 2)
a.IsTrue(result.GoNext)
a.IsTrue(result.IsAllowed)
a.IsTrue(result.AllowScope == "global")
} }