Websocket支持自定义响应Header

This commit is contained in:
GoEdgeLab
2022-09-23 14:21:53 +08:00
parent efe6cbc881
commit 579d5ab3e1
15 changed files with 71 additions and 45 deletions

View File

@@ -1611,9 +1611,7 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) {
} }
// 处理自定义Response Header // 处理自定义Response Header
func (this *HTTPRequest) processResponseHeaders(statusCode int) { func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, statusCode int) {
var responseHeader = this.writer.Header()
// 删除/添加/替换Header // 删除/添加/替换Header
// TODO 实现AddTrailers // TODO 实现AddTrailers
if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn { if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn {

View File

@@ -372,7 +372,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-None-Match // 支持 If-None-Match
if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag { if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag {
// 自定义Header // 自定义Header
this.processResponseHeaders(http.StatusNotModified) this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt()) this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true this.isCached = true
@@ -384,7 +384,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
// 支持 If-Modified-Since // 支持 If-Modified-Since
if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime { if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header // 自定义Header
this.processResponseHeaders(http.StatusNotModified) this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.addExpiresHeader(reader.ExpiresAt()) this.addExpiresHeader(reader.ExpiresAt())
this.writer.WriteHeader(http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified)
this.isCached = true this.isCached = true
@@ -393,7 +393,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
return true return true
} }
this.processResponseHeaders(reader.Status()) this.processResponseHeaders(this.writer.Header(), reader.Status())
this.addExpiresHeader(reader.ExpiresAt()) this.addExpiresHeader(reader.ExpiresAt())
// 返回上级节点过期时间 // 返回上级节点过期时间
@@ -422,7 +422,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if supportRange { if supportRange {
if len(rangeHeader) > 0 { if len(rangeHeader) > 0 {
if fileSize == 0 { if fileSize == 0 {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }
@@ -430,7 +430,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
if len(ranges) == 0 { if len(ranges) == 0 {
ranges, ok = httpRequestParseRangeHeader(rangeHeader) ranges, ok = httpRequestParseRangeHeader(rangeHeader)
if !ok { if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }
@@ -439,7 +439,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
for k, r := range ranges { for k, r := range ranges {
r2, ok := r.Convert(fileSize) r2, ok := r.Convert(fileSize)
if !ok { if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }
@@ -466,7 +466,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) {
this.varMapping["cache.status"] = "MISS" this.varMapping["cache.status"] = "MISS"
if err == caches.ErrInvalidRange { if err == caches.ErrInvalidRange {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }

View File

@@ -57,7 +57,7 @@ func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage s
return "${" + varName + "}" return "${" + varName + "}"
}) })
this.processResponseHeaders(statusCode) this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode) this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent)) _, _ = this.writer.Write([]byte(pageContent))
@@ -110,7 +110,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, z
return "${" + varName + "}" return "${" + varName + "}"
}) })
this.processResponseHeaders(statusCode) this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode) this.writer.WriteHeader(statusCode)
_, _ = this.writer.Write([]byte(pageContent)) _, _ = this.writer.Write([]byte(pageContent))

View File

@@ -187,7 +187,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) {
// 响应Header // 响应Header
this.writer.AddHeaders(resp.Header) this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(resp.StatusCode) this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
// 准备 // 准备
this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true) this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true)

View File

@@ -34,10 +34,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
} }
if u.Status <= 0 { if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect) this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect) http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else { } else {
this.processResponseHeaders(u.Status) this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status) http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
} }
return true return true
@@ -81,10 +81,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
} }
if u.Status <= 0 { if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect) this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect) http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else { } else {
this.processResponseHeaders(u.Status) this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status) http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
} }
return true return true
@@ -104,10 +104,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) {
} }
if u.Status <= 0 { if u.Status <= 0 {
this.processResponseHeaders(http.StatusTemporaryRedirect) this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect) http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect)
} else { } else {
this.processResponseHeaders(u.Status) this.processResponseHeaders(this.writer.Header(), u.Status)
http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status) http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status)
} }
return true return true

View File

@@ -60,11 +60,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
// 修改状态码 // 修改状态码
if page.NewStatus > 0 { if page.NewStatus > 0 {
// 自定义响应Headers // 自定义响应Headers
this.processResponseHeaders(page.NewStatus) this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, stat.Size(), page.NewStatus, true) this.writer.Prepare(nil, stat.Size(), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus) this.writer.WriteHeader(page.NewStatus)
} else { } else {
this.processResponseHeaders(status) this.processResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, stat.Size(), status, true) this.writer.Prepare(nil, stat.Size(), status, true)
this.writer.WriteHeader(status) this.writer.WriteHeader(status)
} }
@@ -99,11 +99,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) {
// 修改状态码 // 修改状态码
if page.NewStatus > 0 { if page.NewStatus > 0 {
// 自定义响应Headers // 自定义响应Headers
this.processResponseHeaders(page.NewStatus) this.processResponseHeaders(this.writer.Header(), page.NewStatus)
this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true) this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true)
this.writer.WriteHeader(page.NewStatus) this.writer.WriteHeader(page.NewStatus)
} else { } else {
this.processResponseHeaders(status) this.processResponseHeaders(this.writer.Header(), status)
this.writer.Prepare(nil, int64(len(content)), status, true) this.writer.Prepare(nil, int64(len(content)), status, true)
this.writer.WriteHeader(status) this.writer.WriteHeader(status)
} }

View File

@@ -12,7 +12,7 @@ func (this *HTTPRequest) doPlanExpires() {
this.tags = append(this.tags, "plan") this.tags = append(this.tags, "plan")
var statusCode = http.StatusNotFound var statusCode = http.StatusNotFound
this.processResponseHeaders(statusCode) this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode) this.writer.WriteHeader(statusCode)
_, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody)) _, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody))

View File

@@ -42,7 +42,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs.
} }
newURL := "https://" + host + this.RawReq.RequestURI newURL := "https://" + host + this.RawReq.RequestURI
this.processResponseHeaders(statusCode) this.processResponseHeaders(this.writer.Header(), statusCode)
http.Redirect(this.writer, this.RawReq, newURL, statusCode) http.Redirect(this.writer, this.RawReq, newURL, statusCode)
return true return true

View File

@@ -397,7 +397,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
// 响应Header // 响应Header
this.writer.AddHeaders(resp.Header) this.writer.AddHeaders(resp.Header)
this.processResponseHeaders(resp.StatusCode) this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
// 是否需要刷新 // 是否需要刷新
var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream" var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream"

View File

@@ -30,10 +30,10 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) {
// 跳转 // 跳转
if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect { if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect {
if this.rewriteRule.RedirectStatus > 0 { if this.rewriteRule.RedirectStatus > 0 {
this.processResponseHeaders(this.rewriteRule.RedirectStatus) this.processResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus) http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus)
} else { } else {
this.processResponseHeaders(http.StatusTemporaryRedirect) this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect)
http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect) http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect)
} }
return true return true

View File

@@ -217,7 +217,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-None-Match // 支持 If-None-Match
if this.requestHeader("If-None-Match") == eTag { if this.requestHeader("If-None-Match") == eTag {
// 自定义Header // 自定义Header
this.processResponseHeaders(http.StatusNotModified) this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified)
return true return true
} }
@@ -225,7 +225,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
// 支持 If-Modified-Since // 支持 If-Modified-Since
if this.requestHeader("If-Modified-Since") == modifiedTime { if this.requestHeader("If-Modified-Since") == modifiedTime {
// 自定义Header // 自定义Header
this.processResponseHeaders(http.StatusNotModified) this.processResponseHeaders(this.writer.Header(), http.StatusNotModified)
this.writer.WriteHeader(http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified)
return true return true
} }
@@ -253,14 +253,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
var contentRange = this.RawReq.Header.Get("Range") var contentRange = this.RawReq.Header.Get("Range")
if len(contentRange) > 0 { if len(contentRange) > 0 {
if fileSize == 0 { if fileSize == 0 {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }
set, ok := httpRequestParseRangeHeader(contentRange) set, ok := httpRequestParseRangeHeader(contentRange)
if !ok { if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }
@@ -269,7 +269,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
for k, r := range ranges { for k, r := range ranges {
r2, ok := r.Convert(fileSize) r2, ok := r.Convert(fileSize)
if !ok { if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }
@@ -290,7 +290,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
} }
// 自定义Header // 自定义Header
this.processResponseHeaders(http.StatusOK) this.processResponseHeaders(this.writer.Header(), http.StatusOK)
// 在Range请求中不能缓存 // 在Range请求中不能缓存
if len(ranges) > 0 { if len(ranges) > 0 {
@@ -325,7 +325,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true return true
} }
if !ok { if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }
@@ -377,7 +377,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) {
return true return true
} }
if !ok { if !ok {
this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable)
this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable)
return true return true
} }

View File

@@ -28,10 +28,10 @@ func (this *HTTPRequest) doShutdown() {
if len(shutdown.URL) == 0 { if len(shutdown.URL) == 0 {
// 自定义响应Headers // 自定义响应Headers
if shutdown.Status > 0 { if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status) this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status) this.writer.WriteHeader(shutdown.Status)
} else { } else {
this.processResponseHeaders(http.StatusOK) this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK) this.writer.WriteHeader(http.StatusOK)
} }
_, err := this.writer.WriteString("The site have been shutdown.") _, err := this.writer.WriteString("The site have been shutdown.")
@@ -59,10 +59,10 @@ func (this *HTTPRequest) doShutdown() {
// 自定义响应Headers // 自定义响应Headers
if shutdown.Status > 0 { if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status) this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status) this.writer.WriteHeader(shutdown.Status)
} else { } else {
this.processResponseHeaders(http.StatusOK) this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK) this.writer.WriteHeader(http.StatusOK)
} }
buf := utils.BytePool1k.Get() buf := utils.BytePool1k.Get()
@@ -85,10 +85,10 @@ func (this *HTTPRequest) doShutdown() {
} else if shutdown.BodyType == shared.BodyTypeHTML { } else if shutdown.BodyType == shared.BodyTypeHTML {
// 自定义响应Headers // 自定义响应Headers
if shutdown.Status > 0 { if shutdown.Status > 0 {
this.processResponseHeaders(shutdown.Status) this.processResponseHeaders(this.writer.Header(), shutdown.Status)
this.writer.WriteHeader(shutdown.Status) this.writer.WriteHeader(shutdown.Status)
} else { } else {
this.processResponseHeaders(http.StatusOK) this.processResponseHeaders(this.writer.Header(), http.StatusOK)
this.writer.WriteHeader(http.StatusOK) this.writer.WriteHeader(http.StatusOK)
} }

View File

@@ -13,7 +13,7 @@ func (this *HTTPRequest) doTrafficLimit() {
this.tags = append(this.tags, "bandwidth") this.tags = append(this.tags, "bandwidth")
var statusCode = 509 var statusCode = 509
this.processResponseHeaders(statusCode) this.processResponseHeaders(this.writer.Header(), statusCode)
this.writer.WriteHeader(statusCode) this.writer.WriteHeader(statusCode)
if len(config.NoticePageBody) != 0 { if len(config.NoticePageBody) != 0 {

View File

@@ -44,9 +44,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod
// Header // Header
if statusCode <= 0 { if statusCode <= 0 {
this.processResponseHeaders(resp.StatusCode) this.processResponseHeaders(this.writer.Header(), resp.StatusCode)
} else { } else {
this.processResponseHeaders(statusCode) this.processResponseHeaders(this.writer.Header(), statusCode)
} }
if supportVariables { if supportVariables {

View File

@@ -1,6 +1,7 @@
package nodes package nodes
import ( import (
"bufio"
"errors" "errors"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
"io" "io"
@@ -82,6 +83,33 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou
}() }()
go func() { go func() {
// 读取第一个响应
resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq)
if err != nil {
_ = clientConn.Close()
_ = originConn.Close()
return
}
this.processResponseHeaders(resp.Header, resp.StatusCode)
// 将响应写回客户端
err = resp.Write(clientConn)
if err != nil {
if resp.Body != nil {
_ = resp.Body.Close()
}
_ = clientConn.Close()
_ = originConn.Close()
return
}
if resp.Body != nil {
_ = resp.Body.Close()
}
// 复制剩余的数据
var buf = utils.BytePool4k.Get() var buf = utils.BytePool4k.Get()
defer utils.BytePool4k.Put(buf) defer utils.BytePool4k.Put(buf)
for { for {