WAF在输出内容时也加入自定义的响应报头

This commit is contained in:
GoEdgeLab
2023-06-11 10:46:20 +08:00
parent 4501e9c15d
commit f7dc03cbfb
23 changed files with 64 additions and 48 deletions

View File

@@ -1699,8 +1699,8 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
}
}
// 处理自定义Response Header
func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, statusCode int) {
// ProcessResponseHeaders 处理自定义Response Header
func (this *HTTPRequest) ProcessResponseHeaders(responseHeader http.Header, statusCode int) {
// 删除/添加/替换Header
// TODO 实现AddTrailers
if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn {

View File

@@ -375,7 +375,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-None-Match
if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -387,7 +387,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-Modified-Since
if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -396,7 +396,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return true
}
this.processResponseHeaders(this.writer.Header(), reader.Status())
this.ProcessResponseHeaders(this.writer.Header(), reader.Status())
this.addExpiresHeader(reader.ExpiresAt())
// 返回上级节点过期时间
@@ -425,7 +425,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if supportRange {
if len(rangeHeader) > 0 {
if fileSize == 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -433,7 +433,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if len(ranges) == 0 {
ranges, ok = httpRequestParseRangeHeader(rangeHeader)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -442,7 +442,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -472,7 +472,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.status"] = "MISS"
if err == caches.ErrInvalidRange {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}

View File

@@ -57,7 +57,7 @@ func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage s
return "${" + varName + "}"
})
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))
@@ -110,7 +110,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, z
return "${" + varName + "}"
})
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))

View File

@@ -197,7 +197,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
// 准备
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)

View File

@@ -54,7 +54,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
}
@@ -96,7 +96,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
} else { // 精准匹配
@@ -119,7 +119,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
}
@@ -155,7 +155,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
// 参数
var qIndex = strings.Index(this.uri, "?")
@@ -211,7 +211,7 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
return false
}
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, status)
return true
}

View File

@@ -87,11 +87,11 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, stat.Size(), status, true)
this.writer.WriteHeader(status)
}
@@ -126,11 +126,11 @@ func (this *HTTPRequest) doPageLookup(pages []*serverconfigs.HTTPPageConfig, sta
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.ProcessResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(this.writer.Header(), status)
this.ProcessResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, int64(len(content)), status, true)
this.writer.WriteHeader(status)
}

View File

@@ -12,7 +12,7 @@ func (this *HTTPRequest) doPlanExpires() {
this.tags = append(this.tags, "plan")
var statusCode = http.StatusNotFound
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))

View File

@@ -42,7 +42,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
}
var newURL = "https://" + host + this.RawReq.RequestURI
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
return true

View File

@@ -451,7 +451,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
// 是否需要刷新
var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"

View File

@@ -30,10 +30,10 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
// 跳转
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
if this.rewriteRule.RedirectStatus > 0 {
this.processResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
this.ProcessResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
}
return true

View File

@@ -217,7 +217,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-None-Match
if this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -225,7 +225,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-Modified-Since
if this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -253,14 +253,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
var contentRange = this.RawReq.Header.Get("Range")
if len(contentRange) > 0 {
if fileSize == 0 {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
set, ok := httpRequestParseRangeHeader(contentRange)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -269,7 +269,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -290,7 +290,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
}
// 自定义Header
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
// 在Range请求中不能缓存
if len(ranges) > 0 {
@@ -325,7 +325,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -377,7 +377,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}

View File

@@ -28,10 +28,10 @@ func (this *HTTPRequest) doShutdown() {
if len(shutdown.URL) == 0 {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
_, err := this.writer.WriteString("The site have been shutdown.")
@@ -59,10 +59,10 @@ func (this *HTTPRequest) doShutdown() {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
buf := utils.BytePool1k.Get()
@@ -85,10 +85,10 @@ func (this *HTTPRequest) doShutdown() {
} else if shutdown.BodyType == shared.BodyTypeHTML {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.ProcessResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.ProcessResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}

View File

@@ -13,7 +13,7 @@ func (this *HTTPRequest) doTrafficLimit() {
this.tags = append(this.tags, "bandwidth")
var statusCode = 509
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
if len(config.NoticePageBody) != 0 {

View File

@@ -44,9 +44,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
// Header
if statusCode <= 0 {
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
this.ProcessResponseHeaders(this.writer.Header(), resp.StatusCode)
} else {
this.processResponseHeaders(this.writer.Header(), statusCode)
this.ProcessResponseHeaders(this.writer.Header(), statusCode)
}
if supportVariables {

View File

@@ -137,7 +137,7 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
this.ProcessResponseHeaders(resp.Header, resp.StatusCode)
this.writer.statusCode = resp.StatusCode
// 将响应写回客户端

View File

@@ -843,7 +843,7 @@ func (this *HTTPWriter) WriteHeader(statusCode int) {
// Send 直接发送内容,并终止请求
func (this *HTTPWriter) Send(status int, body string) {
this.req.processResponseHeaders(this.Header(), status)
this.req.ProcessResponseHeaders(this.Header(), status)
// content-length
_, hasContentLength := this.Header()["Content-Length"]

View File

@@ -82,8 +82,10 @@ func (this *BlockAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, reque
// output response
if this.StatusCode > 0 {
request.ProcessResponseHeaders(writer.Header(), this.StatusCode)
writer.WriteHeader(this.StatusCode)
} else {
request.ProcessResponseHeaders(writer.Header(), http.StatusForbidden)
writer.WriteHeader(http.StatusForbidden)
}
if len(this.URL) > 0 {

View File

@@ -36,6 +36,7 @@ func (this *PageAction) WillChange() bool {
// Perform the action
func (this *PageAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(this.Status)
_, _ = writer.Write([]byte(request.Format(this.Body)))

View File

@@ -146,6 +146,7 @@ func (this *RecordIPAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, re
var expiresAt = time.Now().Unix() + int64(timeout)
if this.Type == "black" {
request.ProcessResponseHeaders(writer.Header(), http.StatusForbidden)
writer.WriteHeader(http.StatusForbidden)
request.WAFClose()

View File

@@ -36,6 +36,7 @@ func (this *RedirectAction) WillChange() bool {
// Perform the action
func (this *RedirectAction) Perform(waf *WAF, group *RuleGroup, set *RuleSet, request requests.Request, writer http.ResponseWriter) (continueRequest bool, goNextSet bool) {
request.ProcessResponseHeaders(writer.Header(), this.Status)
writer.Header().Set("Location", this.URL)
writer.WriteHeader(this.Status)

View File

@@ -26,6 +26,7 @@ func NewCaptchaValidator() *CaptchaValidator {
func (this *CaptchaValidator) Run(req requests.Request, writer http.ResponseWriter) {
var info = req.WAFRaw().URL.Query().Get("info")
if len(info) == 0 {
req.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
@@ -183,8 +184,7 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Req
}
}
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = writer.Write([]byte(`<!DOCTYPE html>
var msgHTML = `<!DOCTYPE html>
<html>
<head>
<title>` + msgTitle + `</title>
@@ -206,7 +206,13 @@ func (this *CaptchaValidator) show(actionConfig *CaptchaAction, req requests.Req
</head>
<body>` + body + `
</body>
</html>`))
</html>`
req.ProcessResponseHeaders(writer.Header(), http.StatusOK)
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.Header().Set("Content-Length", types.String(len(msgHTML)))
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte(msgHTML))
}
func (this *CaptchaValidator) validate(actionConfig *CaptchaAction, policyId int64, groupId int64, setId int64, originURL string, req requests.Request, writer http.ResponseWriter) (allow bool) {

View File

@@ -22,6 +22,7 @@ func NewGet302Validator() *Get302Validator {
func (this *Get302Validator) Run(request requests.Request, writer http.ResponseWriter) {
var info = request.WAFRaw().URL.Query().Get("info")
if len(info) == 0 {
request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return
@@ -34,6 +35,7 @@ func (this *Get302Validator) Run(request requests.Request, writer http.ResponseW
var timestamp = m.GetInt64("timestamp")
if time.Now().Unix()-timestamp > 5 { // 超过5秒认为失效
request.ProcessResponseHeaders(writer.Header(), http.StatusBadRequest)
writer.WriteHeader(http.StatusBadRequest)
_, _ = writer.Write([]byte("invalid request"))
return

View File

@@ -38,6 +38,9 @@ type Request interface {
// Format 格式化变量
Format(string) string
// ProcessResponseHeaders 处理响应Header
ProcessResponseHeaders(headers http.Header, status int)
// DisableAccessLog 在当前请求中不使用访问日志
DisableAccessLog()
}