Websocket支持自定义响应Header

This commit is contained in:
GoEdgeLab
2022-09-23 14:21:53 +08:00
parent efe6cbc881
commit 579d5ab3e1
15 changed files with 71 additions and 45 deletions

View File

@@ -1611,9 +1611,7 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
}
// 处理自定义Response Header
func (this *HTTPRequest) processResponseHeaders(statusCode int) {
var responseHeader = this.writer.Header()
func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, statusCode int) {
// 删除/添加/替换Header
// TODO 实现AddTrailers
if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn {

View File

@@ -372,7 +372,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-None-Match
if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -384,7 +384,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-Modified-Since
if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true
@@ -393,7 +393,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return true
}
this.processResponseHeaders(reader.Status())
this.processResponseHeaders(this.writer.Header(), reader.Status())
this.addExpiresHeader(reader.ExpiresAt())
// 返回上级节点过期时间
@@ -422,7 +422,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if supportRange {
if len(rangeHeader) > 0 {
if fileSize == 0 {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -430,7 +430,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if len(ranges) == 0 {
ranges, ok = httpRequestParseRangeHeader(rangeHeader)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -439,7 +439,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -466,7 +466,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.status"] = "MISS"
if err == caches.ErrInvalidRange {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}

View File

@@ -57,7 +57,7 @@ func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage s
return "${" + varName + "}"
})
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))
@@ -110,7 +110,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, z
return "${" + varName + "}"
})
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent))

View File

@@ -187,7 +187,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(resp.StatusCode)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
// 准备
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)

View File

@@ -34,10 +34,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(u.Status)
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
@@ -81,10 +81,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(u.Status)
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true
@@ -104,10 +104,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
}
if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else {
this.processResponseHeaders(u.Status)
this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
}
return true

View File

@@ -60,11 +60,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(page.NewStatus)
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(status)
this.processResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, stat.Size(), status, true)
this.writer.WriteHeader(status)
}
@@ -99,11 +99,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
// 修改状态码
if page.NewStatus > 0 {
// 自定义响应Headers
this.processResponseHeaders(page.NewStatus)
this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus)
} else {
this.processResponseHeaders(status)
this.processResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, int64(len(content)), status, true)
this.writer.WriteHeader(status)
}

View File

@@ -12,7 +12,7 @@ func (this *HTTPRequest) doPlanExpires() {
this.tags = append(this.tags, "plan")
var statusCode = http.StatusNotFound
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))

View File

@@ -42,7 +42,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
}
newURL := "https://" + host + this.RawReq.RequestURI
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
http.Redirect(this.writer, this.RawReq, newURL, statusCode)
return true

View File

@@ -397,7 +397,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 响应Header
this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(resp.StatusCode)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
// 是否需要刷新
var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"

View File

@@ -30,10 +30,10 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
// 跳转
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
if this.rewriteRule.RedirectStatus > 0 {
this.processResponseHeaders(this.rewriteRule.RedirectStatus)
this.processResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
} else {
this.processResponseHeaders(http.StatusTemporaryRedirect)
this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
}
return true

View File

@@ -217,7 +217,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-None-Match
if this.requestHeader("If-None-Match") == eTag {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -225,7 +225,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-Modified-Since
if this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header
this.processResponseHeaders(http.StatusNotModified)
this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified)
return true
}
@@ -253,14 +253,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
var contentRange = this.RawReq.Header.Get("Range")
if len(contentRange) > 0 {
if fileSize == 0 {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
set, ok := httpRequestParseRangeHeader(contentRange)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -269,7 +269,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
for k, r := range ranges {
r2, ok := r.Convert(fileSize)
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -290,7 +290,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
}
// 自定义Header
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
// 在Range请求中不能缓存
if len(ranges) > 0 {
@@ -325,7 +325,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}
@@ -377,7 +377,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true
}
if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable)
this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true
}

View File

@@ -28,10 +28,10 @@ func (this *HTTPRequest) doShutdown() {
if len(shutdown.URL) == 0 {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status)
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
_, err := this.writer.WriteString("The site have been shutdown.")
@@ -59,10 +59,10 @@ func (this *HTTPRequest) doShutdown() {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status)
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}
buf := utils.BytePool1k.Get()
@@ -85,10 +85,10 @@ func (this *HTTPRequest) doShutdown() {
} else if shutdown.BodyType == shared.BodyTypeHTML {
// 自定义响应Headers
if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status)
this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status)
} else {
this.processResponseHeaders(http.StatusOK)
this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK)
}

View File

@@ -13,7 +13,7 @@ func (this *HTTPRequest) doTrafficLimit() {
this.tags = append(this.tags, "bandwidth")
var statusCode = 509
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode)
if len(config.NoticePageBody) != 0 {

View File

@@ -44,9 +44,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
// Header
if statusCode <= 0 {
this.processResponseHeaders(resp.StatusCode)
this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
} else {
this.processResponseHeaders(statusCode)
this.processResponseHeaders(this.writer.Header(), statusCode)
}
if supportVariables {

View File

@@ -1,6 +1,7 @@
package nodes
import (
"bufio"
"errors"
"github.com/TeaOSLab/EdgeNode/internal/utils"
"io"
@@ -82,6 +83,33 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}()
go func() {
// 读取第一个响应
resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq)
if err != nil {
_ = clientConn.Close()
_ = originConn.Close()
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
// 将响应写回客户端
err = resp.Write(clientConn)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
if resp.Body != nil {
_ = resp.Body.Close()
}
// 复制剩余的数据
var buf = utils.BytePool4k.Get()
defer utils.BytePool4k.Put(buf)
for {