WAF允许动作默认跳过所有规则

This commit is contained in:
刘祥超
2024-01-20 20:54:41 +08:00
parent 7d11b3c63b
commit 095c381ae5
22 changed files with 558 additions and 161 deletions

View File

@@ -67,8 +67,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
// 当前服务的独立设置
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
if blocked {
blockedRequest, breakChecking := this.checkWAFRequest(this.web.FirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, false)
if blockedRequest {
return true
}
if breakChecking {
@@ -78,8 +78,8 @@ func (this *HTTPRequest) doWAFRequest() (blocked bool) {
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
if blocked {
blockedRequest, breakChecking := this.checkWAFRequest(this.ReqServer.HTTPFirewallPolicy, forceLog, forceLogRequestBody, forceLogRegionDenying, this.web.FirewallRef.IgnoreGlobalRules)
if blockedRequest {
return true
}
if breakChecking {
@@ -266,8 +266,11 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return
}
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
result, err := w.MatchRequest(this, this.writer, this.web.FirewallRef.DefaultCaptchaType)
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true
}
if err != nil {
@@ -277,28 +280,28 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir
return
}
if ruleSet != nil {
if result.Set != nil {
if forceLog {
this.forceLog = true
}
if ruleSet.HasSpecialActions() {
if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(result.Set.Id)
if ruleSet.HasAttackActions() {
if result.Set.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
}
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
}
return !goNext, false
return !result.GoNext, breakChecking
}
// call response waf
@@ -316,23 +319,26 @@ func (this *HTTPRequest) doWAFResponse(resp *http.Response) (blocked bool) {
}
if this.web.FirewallPolicy != nil && this.web.FirewallPolicy.IsOn {
blocked = this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
if blocked {
blockedRequest, breakChecking := this.checkWAFResponse(this.web.FirewallPolicy, resp, forceLog, forceLogRequestBody, false)
if blockedRequest {
return true
}
if breakChecking {
return
}
}
// 公用的防火墙设置
if this.ReqServer.HTTPFirewallPolicy != nil && this.ReqServer.HTTPFirewallPolicy.IsOn {
blocked = this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
if blocked {
blockedRequest, _ := this.checkWAFResponse(this.ReqServer.HTTPFirewallPolicy, resp, forceLog, forceLogRequestBody, this.web.FirewallRef.IgnoreGlobalRules)
if blockedRequest {
return true
}
}
return
}
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool) {
func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFirewallPolicy, resp *http.Response, forceLog bool, logRequestBody bool, ignoreRules bool) (blocked bool, breakChecking bool) {
if firewallPolicy == nil || !firewallPolicy.IsOn || !firewallPolicy.Outbound.IsOn || firewallPolicy.Mode == firewallconfigs.FirewallModeBypass {
return
}
@@ -347,8 +353,11 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return
}
goNext, hasRequestBody, ruleGroup, ruleSet, err := w.MatchResponse(this, resp, this.writer)
if forceLog && logRequestBody && hasRequestBody && ruleSet != nil && ruleSet.HasAttackActions() {
result, err := w.MatchResponse(this, resp, this.writer)
if result.IsAllowed && (len(result.AllowScope) == 0 || result.AllowScope == waf.AllowScopeGlobal) {
breakChecking = true
}
if forceLog && logRequestBody && result.HasRequestBody && result.Set != nil && result.Set.HasAttackActions() {
this.wafHasRequestBody = true
}
if err != nil {
@@ -358,28 +367,28 @@ func (this *HTTPRequest) checkWAFResponse(firewallPolicy *firewallconfigs.HTTPFi
return
}
if ruleSet != nil {
if result.Set != nil {
if forceLog {
this.forceLog = true
}
if ruleSet.HasSpecialActions() {
if result.Set.HasSpecialActions() {
this.firewallPolicyId = firewallPolicy.Id
this.firewallRuleGroupId = types.Int64(ruleGroup.Id)
this.firewallRuleSetId = types.Int64(ruleSet.Id)
this.firewallRuleGroupId = types.Int64(result.Group.Id)
this.firewallRuleSetId = types.Int64(result.Set.Id)
if ruleSet.HasAttackActions() {
if result.Set.HasAttackActions() {
this.isAttack = true
}
// 添加统计
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, ruleSet.Actions)
stats.SharedHTTPRequestStatManager.AddFirewallRuleGroupId(this.ReqServer.Id, this.firewallRuleGroupId, result.Set.Actions)
}
this.firewallActions = append(ruleSet.ActionCodes(), firewallPolicy.Mode)
this.firewallActions = append(result.Set.ActionCodes(), firewallPolicy.Mode)
}
return !goNext
return !result.GoNext, breakChecking
}
// WAFRaw 原始请求

View File

@@ -5,8 +5,18 @@ import (
"net/http"
)
type AllowScope = string
const (
AllowScopeGroup AllowScope = "group"
AllowScopeServer AllowScope = "server"
AllowScopeGlobal AllowScope = "global"
)
type AllowAction struct {
BaseAction
Scope AllowScope `yaml:"scope" json:"scope"`
}
func (this *AllowAction) Init(waf *WAF) error {
@@ -25,7 +35,12 @@ func (this *AllowAction) WillChange() bool {
return true
}
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *AllowAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// do nothing
return true, false
return PerformResult{
ContinueRequest: true,
GoNextGroup: this.Scope == AllowScopeGroup,
IsAllowed: true,
AllowScope: this.Scope,
}
}

View File

@@ -61,7 +61,7 @@ func (this *BlockAction) WillChange() bool {
return true
}
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// 加入到黑名单
var timeout = this.Timeout
if timeout <= 0 {
@@ -93,14 +93,14 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
req, err := http.NewRequest(http.MethodGet, this.URL, nil)
if err != nil {
logs.Error(err)
return false, false
return PerformResult{}
}
req.Header.Set("User-Agent", teaconst.GlobalProductName+"/"+teaconst.Version)
resp, err := httpClient.Do(req)
if err != nil {
logs.Error(err)
return false, false
return PerformResult{}
}
defer func() {
_ = resp.Body.Close()
@@ -124,11 +124,11 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
data, err := os.ReadFile(path)
if err != nil {
logs.Error(err)
return false, false
return PerformResult{}
}
_, _ = writer.Write(data)
}
return false, false
return PerformResult{}
}
if len(this.Body) > 0 {
_, _ = writer.Write([]byte(this.Body))
@@ -137,5 +137,5 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
}
}
return false, false
return PerformResult{}
}

View File

@@ -123,10 +123,12 @@ func (this *CaptchaAction) WillChange() bool {
return true
}
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
// 是否在白名单中
if SharedIPWhiteList.Contains(wafutils.ComposeIPType(set.Id, req), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
var refURL = req.WAFRaw().URL.String()
@@ -153,7 +155,9 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
info, err := utils.SimpleEncryptMap(captchaConfig)
if err != nil {
remotelogs.Error("WAF_CAPTCHA_ACTION", "encode captcha config failed: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
}
}
// 占用一次失败次数
@@ -163,5 +167,5 @@ func (this *CaptchaAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
req.ProcessResponseHeaders(writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(writer, req.WAFRaw(), CaptchaPath+"?info="+url.QueryEscape(info), http.StatusTemporaryRedirect)
return false, false
return PerformResult{}
}

View File

@@ -41,15 +41,19 @@ func (this *Get302Action) WillChange() bool {
return true
}
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
// 仅限于Get
if request.WAFRaw().Method != http.MethodGet {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
// 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
var m = maps.Map{
@@ -64,7 +68,9 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
info, err := utils.SimpleEncryptMap(m)
if err != nil {
remotelogs.Error("WAF_GET_302_ACTION", "encode info failed: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
}
}
request.DisableStat()
@@ -75,6 +81,6 @@ func (this *Get302Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
if ok {
flusher.Flush()
}
return false, false
return PerformResult{}
}

View File

@@ -29,20 +29,29 @@ func (this *GoGroupAction) WillChange() bool {
return true
}
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
func (this *GoGroupAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
if nextGroup == nil || !nextGroup.IsOn {
return true, true
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
b, _, nextSet, err := nextGroup.MatchRequest(request)
if err != nil {
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
if !b {
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
return nextSet.PerformActions(waf, nextGroup, request, writer)

View File

@@ -30,23 +30,35 @@ func (this *GoSetAction) WillChange() bool {
return true
}
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
nextGroup := waf.FindRuleGroup(types.Int64(this.GroupId))
func (this *GoSetAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var nextGroup = waf.FindRuleGroup(types.Int64(this.GroupId))
if nextGroup == nil || !nextGroup.IsOn {
return true, true
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
nextSet := nextGroup.FindRuleSet(types.Int64(this.SetId))
var nextSet = nextGroup.FindRuleSet(types.Int64(this.SetId))
if nextSet == nil || !nextSet.IsOn {
return true, true
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
b, _, err := nextSet.MatchRequest(request)
if err != nil {
remotelogs.Error("WAF", "GO_GROUP_ACTION: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
if !b {
return true, false
return PerformResult{
ContinueRequest: true,
GoNextSet: true,
}
}
return nextSet.PerformActions(waf, nextGroup, request, writer)
}

View File

@@ -27,5 +27,5 @@ type ActionInterface interface {
WillChange() bool
// Perform the action
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool)
Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult
}

View File

@@ -42,15 +42,19 @@ func (this *JSCookieAction) WillChange() bool {
return true
}
func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req requests.Request, writer http.ResponseWriter) PerformResult {
// 是否在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, req.WAFServerId(), req.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
nodeConfig, err := nodeconfigs.SharedNodeConfig()
if err != nil {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
var life = this.Life
@@ -69,7 +73,9 @@ func (this *JSCookieAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
var timestamp = pieces[0]
var sum = pieces[2]
if types.Int64(timestamp) >= time.Now().Unix()-int64(life) && fmt.Sprintf("%x", md5.Sum([]byte(timestamp+"@"+types.String(set.Id)+"@"+nodeConfig.NodeId))) == sum {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
}
}
@@ -103,7 +109,7 @@ window.location.reload();
// 记录失败次数
this.increaseFails(req, waf.Id, group.Id, set.Id)
return false, false
return PerformResult{}
}
func (this *JSCookieAction) increaseFails(req requests.Request, policyId int64, groupId int64, setId int64) (goNext bool) {

View File

@@ -25,6 +25,8 @@ func (this *LogAction) WillChange() bool {
return false
}
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
return true, false
func (this *LogAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
return PerformResult{
ContinueRequest: true,
}
}

View File

@@ -76,7 +76,7 @@ func (this *NotifyAction) WillChange() bool {
}
// Perform the action
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
select {
case notifyChan <- &notifyTask{
ServerId: request.WAFServerId(),
@@ -89,5 +89,7 @@ func (this *NotifyAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, requ
}
return true, false
return PerformResult{
ContinueRequest: true,
}
}

View File

@@ -45,9 +45,9 @@ func (this *PageAction) WillChange() bool {
}
// Perform the action
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
if writer == nil {
return
return PerformResult{}
}
request.ProcessResponseHeaders(writer.Header(), this.Status)
@@ -73,5 +73,5 @@ func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reques
}
_, _ = writer.Write([]byte(request.Format(body)))
return false, false
return PerformResult{}
}

View File

@@ -34,17 +34,21 @@ func (this *Post307Action) WillChange() bool {
return true
}
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var cookieName = "WAF_VALIDATOR_ID"
// 仅限于POST
if request.WAFRaw().Method != http.MethodPost {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
// 是否已经在白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
// 判断是否有Cookie
@@ -58,7 +62,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
}
var setId = types.String(m.GetInt64("setId"))
SharedIPWhiteList.RecordIP("set:"+setId, this.Scope, request.WAFServerId(), request.WAFRemoteIP(), time.Now().Unix()+life, m.GetInt64("policyId"), false, m.GetInt64("groupId"), m.GetInt64("setId"), "")
return true, false
return PerformResult{
ContinueRequest: true,
}
}
}
@@ -74,7 +80,9 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
info, err := utils.SimpleEncryptMap(m)
if err != nil {
remotelogs.Error("WAF_POST_307_ACTION", "encode info failed: "+err.Error())
return true, false
return PerformResult{
ContinueRequest: true,
}
}
// 清空请求内容
@@ -101,5 +109,5 @@ func (this *Post307Action) Perform(waf *WAF, group *RuleGroup, set *RuleSet, req
flusher.Flush()
}
return false, false
return PerformResult{}
}

View File

@@ -132,7 +132,7 @@ func (this *RecordIPAction) WillChange() bool {
return this.Type == "black"
}
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
var ipListId = this.IPListId
if ipListId <= 0 {
ipListId = firewallconfigs.GlobalListId
@@ -143,7 +143,11 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
// 是否在本地白名单中
if SharedIPWhiteList.Contains("set:"+types.String(set.Id), this.Scope, request.WAFServerId(), request.WAFRemoteIP()) {
return true, false
return PerformResult{
ContinueRequest: true,
IsAllowed: true,
AllowScope: AllowScopeGlobal,
}
}
var timeout = this.Timeout
@@ -200,5 +204,10 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
}
}
return this.Type != "black", false
var isWhite = this.Type != "black"
return PerformResult{
ContinueRequest: isWhite,
IsAllowed: isWhite,
AllowScope: AllowScopeGlobal,
}
}

View File

@@ -35,10 +35,10 @@ func (this *RedirectAction) WillChange() bool {
}
// Perform the action
func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Location", this.URL)
writer.WriteHeader(this.Status)
return false, false
return PerformResult{}
}

View File

@@ -27,6 +27,8 @@ func (this *TagAction) WillChange() bool {
return false
}
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
return true, true
func (this *TagAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) PerformResult {
return PerformResult{
ContinueRequest: true,
}
}

22
internal/waf/results.go Normal file
View File

@@ -0,0 +1,22 @@
// Copyright 2024 GoEdge CDN goedge.cdn@gmail.com. All rights reserved. Official site: https://goedge.cn .
package waf
// PerformResult action performing result
type PerformResult struct {
ContinueRequest bool
GoNextGroup bool
GoNextSet bool
IsAllowed bool
AllowScope AllowScope
}
// MatchResult request match result
type MatchResult struct {
GoNext bool
HasRequestBody bool
Group *RuleGroup
Set *RuleSet
IsAllowed bool
AllowScope AllowScope
}

View File

@@ -34,6 +34,9 @@ type RuleSet struct {
actionCodes []string
actionInstances []ActionInterface
hasAllowActions bool
allowScope string
hasRules bool
}
@@ -62,6 +65,12 @@ func (this *RuleSet) Init(waf *WAF) error {
// action codes
var actionCodes = []string{}
for _, action := range this.Actions {
if action.Code == ActionAllow {
this.hasAllowActions = true
if action.Options != nil {
this.allowScope = action.Options.GetString("scope")
}
}
if !lists.ContainsString(actionCodes, action.Code) {
actionCodes = append(actionCodes, action.Code)
}
@@ -141,19 +150,37 @@ func (this *RuleSet) ActionCodes() []string {
return this.actionCodes
}
func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Request, writer http.ResponseWriter) PerformResult {
if len(waf.Mode) != 0 && waf.Mode != firewallconfigs.FirewallModeDefend {
return true, false
return PerformResult{
ContinueRequest: true,
}
}
var isAllowed = this.hasAllowActions
var allowScope = this.allowScope
var continueRequest bool
var goNextGroup bool
var goNextSet bool
// 先执行allow
for _, instance := range this.actionInstances {
if !instance.WillChange() {
continueRequest = req.WAFOnAction(instance)
if !continueRequest {
return false, false
return PerformResult{
IsAllowed: isAllowed,
AllowScope: allowScope,
}
}
var performResult = instance.Perform(waf, group, this, req, writer)
continueRequest = performResult.ContinueRequest
goNextSet = performResult.GoNextSet
if performResult.IsAllowed {
isAllowed = true
allowScope = performResult.AllowScope
goNextGroup = performResult.GoNextGroup
}
_, goNextSet = instance.Perform(waf, group, this, req, writer)
}
}
@@ -163,13 +190,36 @@ func (this *RuleSet) PerformActions(waf *WAF, group *RuleGroup, req requests.Req
if instance.WillChange() {
continueRequest = req.WAFOnAction(instance)
if !continueRequest {
return false, false
return PerformResult{
IsAllowed: isAllowed,
AllowScope: allowScope,
}
}
var performResult = instance.Perform(waf, group, this, req, writer)
continueRequest = performResult.ContinueRequest
goNextSet = performResult.GoNextSet
if performResult.IsAllowed {
isAllowed = true
allowScope = performResult.AllowScope
goNextGroup = performResult.GoNextGroup
}
return PerformResult{
ContinueRequest: performResult.ContinueRequest,
GoNextGroup: goNextGroup,
GoNextSet: performResult.GoNextSet,
IsAllowed: isAllowed,
AllowScope: allowScope,
}
return instance.Perform(waf, group, this, req, writer)
}
}
return true, goNextSet
return PerformResult{
ContinueRequest: true,
GoNextGroup: goNextGroup,
GoNextSet: goNextSet,
IsAllowed: isAllowed,
AllowScope: allowScope,
}
}
func (this *RuleSet) MatchRequest(req requests.Request) (b bool, hasRequestBody bool, err error) {

View File

@@ -6,6 +6,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/cespare/xxhash"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
"net/http"
"regexp"
"runtime"
@@ -74,6 +75,52 @@ func TestRuleSet_MatchRequest2(t *testing.T) {
a.IsTrue(set.MatchRequest(req))
}
func TestRuleSet_MatchRequest_Allow(t *testing.T) {
var a = assert.NewAssertion(t)
var set = waf.NewRuleSet()
set.Connector = waf.RuleConnectorOr
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "hello",
},
}
set.Actions = []*waf.ActionConfig{
{
Code: "allow",
Options: maps.Map{
"scope": waf.AllowScopeGroup,
},
},
}
var wafInstance = waf.NewWAF()
err := set.Init(wafInstance)
if err != nil {
t.Fatal(err)
}
rawReq, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
if err != nil {
t.Fatal(err)
}
var req = requests.NewTestRequest(rawReq)
b, _, err := set.MatchRequest(req)
if err != nil {
t.Fatal(err)
}
a.IsTrue(b)
var result = set.PerformActions(wafInstance, &waf.RuleGroup{}, req, nil)
a.IsTrue(result.IsAllowed)
t.Log("scope:", result.AllowScope)
}
func BenchmarkRuleSet_MatchRequest(b *testing.B) {
runtime.GOMAXPROCS(1)

View File

@@ -52,18 +52,18 @@ func Test_Template2(t *testing.T) {
}
now := time.Now()
goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
t.Log(time.Since(now).Seconds()*1000, "ms")
if goNext {
if result.GoNext {
t.Log("ok")
return
}
logs.PrintAsJSON(set, t)
logs.PrintAsJSON(result.Set, t)
}
func BenchmarkTemplate(b *testing.B) {
@@ -84,7 +84,7 @@ func BenchmarkTemplate(b *testing.B) {
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, _, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
_, _ = wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
}
})
}
@@ -103,13 +103,13 @@ func testTemplate1010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "1010")
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "1010")
} else {
t.Log("break at:", id)
}
@@ -125,13 +125,13 @@ func testTemplate1010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNil(result)
if result != nil {
a.IsTrue(result.Code == "1010")
a.IsNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "1010")
}
}
}
@@ -192,13 +192,13 @@ func testTemplate2001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
req.Header.Add("Content-Type", writer.FormDataContentType())
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "2001")
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "2001")
}
}
@@ -207,13 +207,13 @@ func testTemplate3001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "3001")
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "3001")
}
}
@@ -222,13 +222,13 @@ func testTemplate4001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "4001")
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "4001")
}
}
@@ -238,13 +238,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "5001")
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "5001")
}
}
@@ -253,13 +253,13 @@ func testTemplate5001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "5001")
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "5001")
}
}
}
@@ -271,13 +271,13 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(result.Code == "6001")
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(result.Set.Code == "6001")
}
}
@@ -286,11 +286,11 @@ func testTemplate6001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
if err != nil {
t.Fatal(err)
}
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
a.IsNotNil(result.Set)
}
}
@@ -325,13 +325,13 @@ func testTemplate7010(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err)
}
req.Header.Set("User-Agent", testUserAgent)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(lists.ContainsAny([]string{"7010"}, result.Code))
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(lists.ContainsAny([]string{"7010"}, result.Set.Code))
} else {
t.Log("break:", id)
}
@@ -423,13 +423,13 @@ func testTemplate20001(a *assert.Assertion, t *testing.T, template *waf.WAF) {
t.Fatal(err)
}
req.Header.Set("User-Agent", bot)
_, _, _, result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := template.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
a.IsNotNil(result)
if result != nil {
a.IsTrue(lists.ContainsAny([]string{"20001"}, result.Code))
a.IsNotNil(result.Set)
if result.Set != nil {
a.IsTrue(lists.ContainsAny([]string{"20001"}, result.Set.Code))
} else {
t.Log("break:", bot)
}

View File

@@ -39,7 +39,8 @@ type WAF struct {
func NewWAF() *WAF {
return &WAF{
IsOn: true,
IsOn: true,
actionMap: map[int64]ActionInterface{},
}
}
@@ -243,9 +244,11 @@ func (this *WAF) MoveOutboundRuleGroup(fromIndex int, toIndex int) {
this.Outbound = result
}
func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) (goNext bool, hasRequestBody bool, resultGroup *RuleGroup, resultSet *RuleSet, err error) {
func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter, defaultCaptchaType firewallconfigs.ServerCaptchaType) (result MatchResult, err error) {
if !this.hasInboundRules {
return true, hasRequestBody, nil, nil, nil
return MatchResult{
GoNext: true,
}, nil
}
// validate captcha
@@ -266,51 +269,87 @@ func (this *WAF) MatchRequest(req requests.Request, writer http.ResponseWriter,
}
// match rules
var hasRequestBody bool
for _, group := range this.Inbound {
if !group.IsOn {
continue
}
b, hasCheckedRequestBody, set, err := group.MatchRequest(req)
b, hasCheckedRequestBody, set, matchErr := group.MatchRequest(req)
if hasCheckedRequestBody {
hasRequestBody = true
}
if err != nil {
return true, hasRequestBody, nil, nil, err
if matchErr != nil {
return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, matchErr
}
if b {
continueRequest, goNextSet := set.PerformActions(this, group, req, writer)
if !goNextSet {
return continueRequest, hasRequestBody, group, set, nil
var performResult = set.PerformActions(this, group, req, writer)
if !performResult.GoNextSet {
if performResult.GoNextGroup {
continue
}
return MatchResult{
GoNext: performResult.ContinueRequest,
HasRequestBody: hasRequestBody,
Group: group,
Set: set,
IsAllowed: performResult.IsAllowed,
AllowScope: performResult.AllowScope,
}, nil
}
}
}
return true, hasRequestBody, nil, nil, nil
return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, nil
}
func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (goNext bool, hasRequestBody bool, resultGroup *RuleGroup, resultSet *RuleSet, err error) {
func (this *WAF) MatchResponse(req requests.Request, rawResp *http.Response, writer http.ResponseWriter) (result MatchResult, err error) {
if !this.hasOutboundRules {
return true, hasRequestBody, nil, nil, nil
return MatchResult{
GoNext: true,
}, nil
}
resp := requests.NewResponse(rawResp)
var hasRequestBody bool
var resp = requests.NewResponse(rawResp)
for _, group := range this.Outbound {
if !group.IsOn {
continue
}
b, hasCheckedRequestBody, set, err := group.MatchResponse(req, resp)
b, hasCheckedRequestBody, set, matchErr := group.MatchResponse(req, resp)
if hasCheckedRequestBody {
hasRequestBody = true
}
if err != nil {
return true, hasRequestBody, nil, nil, err
if matchErr != nil {
return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, matchErr
}
if b {
continueRequest, goNextSet := set.PerformActions(this, group, req, writer)
if !goNextSet {
return continueRequest, hasRequestBody, group, set, nil
var performResult = set.PerformActions(this, group, req, writer)
if !performResult.GoNextSet {
if performResult.GoNextGroup {
continue
}
return MatchResult{
GoNext: performResult.ContinueRequest,
HasRequestBody: hasRequestBody,
Group: group,
Set: set,
IsAllowed: performResult.IsAllowed,
AllowScope: performResult.AllowScope,
}, nil
}
}
}
return true, hasRequestBody, nil, nil, nil
return MatchResult{
GoNext: true,
HasRequestBody: hasRequestBody,
}, nil
}
// Save to file path

View File

@@ -5,6 +5,7 @@ import (
"github.com/TeaOSLab/EdgeNode/internal/waf"
"github.com/TeaOSLab/EdgeNode/internal/waf/requests"
"github.com/iwind/TeaGo/assert"
"github.com/iwind/TeaGo/maps"
"net/http"
"testing"
)
@@ -44,7 +45,7 @@ func TestWAF_MatchRequest(t *testing.T) {
if err != nil {
t.Fatal(err)
}
goNext, _, _, set, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
@@ -52,6 +53,160 @@ func TestWAF_MatchRequest(t *testing.T) {
t.Log("not match")
return
}
t.Log("goNext:", goNext, "set:", set.Name)
a.IsFalse(goNext)
t.Log("goNext:", result.GoNext, "set:", set.Name)
a.IsFalse(result.GoNext)
}
func TestWAF_MatchRequest_Allow(t *testing.T) {
var a = assert.NewAssertion(t)
var wafInstance = waf.NewWAF()
{
var set = waf.NewRuleSet()
set.Id = 1
set.Name = "set1"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "hello",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "global",
})
var group = waf.NewRuleGroup()
group.Id = 1
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
{
var set = waf.NewRuleSet()
set.Id = 2
set.Name = "set2"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "he",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "global",
})
var group = waf.NewRuleGroup()
group.Id = 2
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
errs := wafInstance.Init()
if len(errs) > 0 {
t.Fatal(errs[0])
}
req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
if err != nil {
t.Fatal(err)
}
result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
if result.Set == nil {
t.Log("not match")
return
}
t.Log("goNext:", result.GoNext, "set:", result.Set.Name)
a.IsTrue(result.Set.Id == 1)
a.IsTrue(result.GoNext)
a.IsTrue(result.IsAllowed)
a.IsTrue(result.AllowScope == "global")
}
func TestWAF_MatchRequest_Allow2(t *testing.T) {
var a = assert.NewAssertion(t)
var wafInstance = waf.NewWAF()
{
var set = waf.NewRuleSet()
set.Id = 1
set.Name = "set1"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "hello",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "group",
})
var group = waf.NewRuleGroup()
group.Id = 1
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
{
var set = waf.NewRuleSet()
set.Id = 2
set.Name = "set2"
set.Connector = waf.RuleConnectorAnd
set.Rules = []*waf.Rule{
{
Param: "${requestPath}",
Operator: waf.RuleOperatorMatch,
Value: "he",
},
}
set.AddAction(waf.ActionAllow, maps.Map{
"scope": "global",
})
var group = waf.NewRuleGroup()
group.Id = 2
group.AddRuleSet(set)
group.IsInbound = true
wafInstance.AddRuleGroup(group)
}
errs := wafInstance.Init()
if len(errs) > 0 {
t.Fatal(errs[0])
}
req, err := http.NewRequest(http.MethodGet, "http://teaos.cn/hello?name=lu&age=20", nil)
if err != nil {
t.Fatal(err)
}
result, err := wafInstance.MatchRequest(requests.NewTestRequest(req), nil, firewallconfigs.ServerCaptchaTypeNone)
if err != nil {
t.Fatal(err)
}
if result.Set == nil {
t.Log("not match")
return
}
t.Log("goNext:", result.GoNext, "set:", result.Set.Name)
a.IsTrue(result.Set.Id == 2)
a.IsTrue(result.GoNext)
a.IsTrue(result.IsAllowed)
a.IsTrue(result.AllowScope == "global")
}