diff --git a/internal/nodes/http_request_waf.go b/internal/nodes/http_request_waf.go index c07fde5..fa57929 100644 --- a/internal/nodes/http_request_waf.go +++ b/internal/nodes/http_request_waf.go @@ -173,7 +173,15 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir if !regionConfig.IsAllowedCountry(result.CountryId(), result.ProvinceId()) { this.firewallPolicyId = firewallPolicy.Id - this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问") + if len(regionConfig.CountryHTML) > 0 { + this.writer.Header().Set("Content-Type", "text/html; charset=utf-8") + this.writer.Header().Set("Content-Length", types.String(len(regionConfig.CountryHTML))) + this.writer.WriteHeader(http.StatusForbidden) + _, _ = this.writer.Write([]byte(regionConfig.CountryHTML)) + } else { + this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问") + } + this.writer.Flush() this.writer.Close() @@ -193,7 +201,14 @@ func (this *HTTPRequest) checkWAFRequest(firewallPolicy *firewallconfigs.HTTPFir if !regionConfig.IsAllowedProvince(result.CountryId(), result.ProvinceId()) { this.firewallPolicyId = firewallPolicy.Id - this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问") + if len(regionConfig.ProvinceHTML) > 0 { + this.writer.Header().Set("Content-Type", "text/html; charset=utf-8") + this.writer.Header().Set("Content-Length", types.String(len(regionConfig.ProvinceHTML))) + this.writer.WriteHeader(http.StatusForbidden) + _, _ = this.writer.Write([]byte(regionConfig.ProvinceHTML)) + } else { + this.writeCode(http.StatusForbidden, "The region has been denied.", "当前区域禁止访问") + } this.writer.Flush() this.writer.Close()