diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index c5eba0d..4611de8 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -174,8 +174,15 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir if !regionConfig.IsAllowedCountry(result.CountryId(), result.ProvinceId()) { this.firewallPolicyId = firewallPolicy.Id + var promptHTML string if len(regionConfig.CountryHTML) > 0 { - var formattedHTML = this.Format(regionConfig.CountryHTML) + promptHTML = regionConfig.CountryHTML + } else if this.ReqServer != nil && this.ReqServer.HTTPFirewallPolicy != nil && len(this.ReqServer.HTTPFirewallPolicy.DenyCountryHTML) > 0 { + promptHTML = this.ReqServer.HTTPFirewallPolicy.DenyCountryHTML + } + + if len(promptHTML) > 0 { + var formattedHTML = this.Format(promptHTML) this.writer.Header().Set("Content-Type", "text/html; charset=utf-8") this.writer.Header().Set("Content-Length", types.String(len(formattedHTML))) this.writer.WriteHeader(http.StatusForbidden) @@ -203,8 +210,15 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir if !regionConfig.IsAllowedProvince(result.CountryId(), result.ProvinceId()) { this.firewallPolicyId = firewallPolicy.Id + var promptHTML string if len(regionConfig.ProvinceHTML) > 0 { - var formattedHTML = this.Format(regionConfig.ProvinceHTML) + promptHTML = regionConfig.ProvinceHTML + } else if this.ReqServer != nil && this.ReqServer.HTTPFirewallPolicy != nil && len(this.ReqServer.HTTPFirewallPolicy.DenyProvinceHTML) > 0 { + promptHTML = this.ReqServer.HTTPFirewallPolicy.DenyProvinceHTML + } + + if len(promptHTML) > 0 { + var formattedHTML = this.Format(promptHTML) this.writer.Header().Set("Content-Type", "text/html; charset=utf-8") this.writer.Header().Set("Content-Length", types.String(len(formattedHTML))) this.writer.WriteHeader(http.StatusForbidden)