diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index 9094601..be51dd2 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -1611,9 +1611,7 @@ func (this *HTTPRequest) fixRequestHeader(header http.Header) { } // 处理自定义Response Header -func (this *HTTPRequest) processResponseHeaders(statusCode int) { - var responseHeader = this.writer.Header() - +func (this *HTTPRequest) processResponseHeaders(responseHeader http.Header, statusCode int) { // 删除/添加/替换Header // TODO 实现AddTrailers if this.web.ResponseHeaderPolicy != nil && this.web.ResponseHeaderPolicy.IsOn { diff --git a/internal/nodes/http_request_cache.go b/internal/nodes/http_request_cache.go index 04ac34f..fb995a0 100644 --- a/internal/nodes/http_request_cache.go +++ b/internal/nodes/http_request_cache.go @@ -372,7 +372,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { // 支持 If-None-Match if !this.isLnRequest && !isPartialCache && len(eTag) > 0 && this.requestHeader("If-None-Match") == eTag { // 自定义Header - this.processResponseHeaders(http.StatusNotModified) + this.processResponseHeaders(this.writer.Header(), http.StatusNotModified) this.addExpiresHeader(reader.ExpiresAt()) this.writer.WriteHeader(http.StatusNotModified) this.isCached = true @@ -384,7 +384,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { // 支持 If-Modified-Since if !this.isLnRequest && !isPartialCache && len(modifiedTime) > 0 && this.requestHeader("If-Modified-Since") == modifiedTime { // 自定义Header - this.processResponseHeaders(http.StatusNotModified) + this.processResponseHeaders(this.writer.Header(), http.StatusNotModified) this.addExpiresHeader(reader.ExpiresAt()) this.writer.WriteHeader(http.StatusNotModified) this.isCached = true @@ -393,7 +393,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { return true } - this.processResponseHeaders(reader.Status()) + this.processResponseHeaders(this.writer.Header(), reader.Status()) this.addExpiresHeader(reader.ExpiresAt()) // 返回上级节点过期时间 @@ -422,7 +422,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { if supportRange { if len(rangeHeader) > 0 { if fileSize == 0 { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } @@ -430,7 +430,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { if len(ranges) == 0 { ranges, ok = httpRequestParseRangeHeader(rangeHeader) if !ok { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } @@ -439,7 +439,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { for k, r := range ranges { r2, ok := r.Convert(fileSize) if !ok { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } @@ -466,7 +466,7 @@ func (this *HTTPRequest) doCacheRead(useStale bool) (shouldStop bool) { this.varMapping["cache.status"] = "MISS" if err == caches.ErrInvalidRange { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } diff --git a/internal/nodes/http_request_error.go b/internal/nodes/http_request_error.go index d32cec2..e487ada 100644 --- a/internal/nodes/http_request_error.go +++ b/internal/nodes/http_request_error.go @@ -57,7 +57,7 @@ func (this *HTTPRequest) writeCode(statusCode int, enMessage string, zhMessage s return "${" + varName + "}" }) - this.processResponseHeaders(statusCode) + this.processResponseHeaders(this.writer.Header(), statusCode) this.writer.WriteHeader(statusCode) _, _ = this.writer.Write([]byte(pageContent)) @@ -110,7 +110,7 @@ func (this *HTTPRequest) write50x(err error, statusCode int, enMessage string, z return "${" + varName + "}" }) - this.processResponseHeaders(statusCode) + this.processResponseHeaders(this.writer.Header(), statusCode) this.writer.WriteHeader(statusCode) _, _ = this.writer.Write([]byte(pageContent)) diff --git a/internal/nodes/http_request_fastcgi.go b/internal/nodes/http_request_fastcgi.go index c85f4eb..520af44 100644 --- a/internal/nodes/http_request_fastcgi.go +++ b/internal/nodes/http_request_fastcgi.go @@ -187,7 +187,7 @@ func (this *HTTPRequest) doFastcgi() (shouldStop bool) { // 响应Header this.writer.AddHeaders(resp.Header) - this.processResponseHeaders(resp.StatusCode) + this.processResponseHeaders(this.writer.Header(), resp.StatusCode) // 准备 this.writer.Prepare(resp, resp.ContentLength, resp.StatusCode, true) diff --git a/internal/nodes/http_request_host_redirect.go b/internal/nodes/http_request_host_redirect.go index 8cec3e9..781d264 100644 --- a/internal/nodes/http_request_host_redirect.go +++ b/internal/nodes/http_request_host_redirect.go @@ -34,10 +34,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) { } if u.Status <= 0 { - this.processResponseHeaders(http.StatusTemporaryRedirect) + this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect) http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect) } else { - this.processResponseHeaders(u.Status) + this.processResponseHeaders(this.writer.Header(), u.Status) http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status) } return true @@ -81,10 +81,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) { } if u.Status <= 0 { - this.processResponseHeaders(http.StatusTemporaryRedirect) + this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect) http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect) } else { - this.processResponseHeaders(u.Status) + this.processResponseHeaders(this.writer.Header(), u.Status) http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status) } return true @@ -104,10 +104,10 @@ func (this *HTTPRequest) doHostRedirect() (blocked bool) { } if u.Status <= 0 { - this.processResponseHeaders(http.StatusTemporaryRedirect) + this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect) http.Redirect(this.RawWriter, this.RawReq, afterURL, http.StatusTemporaryRedirect) } else { - this.processResponseHeaders(u.Status) + this.processResponseHeaders(this.writer.Header(), u.Status) http.Redirect(this.RawWriter, this.RawReq, afterURL, u.Status) } return true diff --git a/internal/nodes/http_request_page.go b/internal/nodes/http_request_page.go index 6380788..6a68abd 100644 --- a/internal/nodes/http_request_page.go +++ b/internal/nodes/http_request_page.go @@ -60,11 +60,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) { // 修改状态码 if page.NewStatus > 0 { // 自定义响应Headers - this.processResponseHeaders(page.NewStatus) + this.processResponseHeaders(this.writer.Header(), page.NewStatus) this.writer.Prepare(nil, stat.Size(), page.NewStatus, true) this.writer.WriteHeader(page.NewStatus) } else { - this.processResponseHeaders(status) + this.processResponseHeaders(this.writer.Header(), status) this.writer.Prepare(nil, stat.Size(), status, true) this.writer.WriteHeader(status) } @@ -99,11 +99,11 @@ func (this *HTTPRequest) doPage(status int) (shouldStop bool) { // 修改状态码 if page.NewStatus > 0 { // 自定义响应Headers - this.processResponseHeaders(page.NewStatus) + this.processResponseHeaders(this.writer.Header(), page.NewStatus) this.writer.Prepare(nil, int64(len(content)), page.NewStatus, true) this.writer.WriteHeader(page.NewStatus) } else { - this.processResponseHeaders(status) + this.processResponseHeaders(this.writer.Header(), status) this.writer.Prepare(nil, int64(len(content)), status, true) this.writer.WriteHeader(status) } diff --git a/internal/nodes/http_request_plan_expires.go b/internal/nodes/http_request_plan_expires.go index 2adae5d..13c77f8 100644 --- a/internal/nodes/http_request_plan_expires.go +++ b/internal/nodes/http_request_plan_expires.go @@ -12,7 +12,7 @@ func (this *HTTPRequest) doPlanExpires() { this.tags = append(this.tags, "plan") var statusCode = http.StatusNotFound - this.processResponseHeaders(statusCode) + this.processResponseHeaders(this.writer.Header(), statusCode) this.writer.WriteHeader(statusCode) _, _ = this.writer.WriteString(this.Format(serverconfigs.DefaultPlanExpireNoticePageBody)) diff --git a/internal/nodes/http_request_redirect_https.go b/internal/nodes/http_request_redirect_https.go index d3da6b4..3415d85 100644 --- a/internal/nodes/http_request_redirect_https.go +++ b/internal/nodes/http_request_redirect_https.go @@ -42,7 +42,7 @@ func (this *HTTPRequest) doRedirectToHTTPS(redirectToHTTPSConfig *serverconfigs. } newURL := "https://" + host + this.RawReq.RequestURI - this.processResponseHeaders(statusCode) + this.processResponseHeaders(this.writer.Header(), statusCode) http.Redirect(this.writer, this.RawReq, newURL, statusCode) return true diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 85b4c1c..30f6e3b 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -397,7 +397,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId // 响应Header this.writer.AddHeaders(resp.Header) - this.processResponseHeaders(resp.StatusCode) + this.processResponseHeaders(this.writer.Header(), resp.StatusCode) // 是否需要刷新 var shouldAutoFlush = this.reverseProxy.AutoFlush || this.RawReq.Header.Get("Accept") == "text/event-stream" diff --git a/internal/nodes/http_request_rewrite.go b/internal/nodes/http_request_rewrite.go index 0d1e0b2..a51a2a1 100644 --- a/internal/nodes/http_request_rewrite.go +++ b/internal/nodes/http_request_rewrite.go @@ -30,10 +30,10 @@ func (this *HTTPRequest) doRewrite() (shouldShop bool) { // 跳转 if this.rewriteRule.Mode == serverconfigs.HTTPRewriteModeRedirect { if this.rewriteRule.RedirectStatus > 0 { - this.processResponseHeaders(this.rewriteRule.RedirectStatus) + this.processResponseHeaders(this.writer.Header(), this.rewriteRule.RedirectStatus) http.Redirect(this.writer, this.RawReq, this.rewriteReplace, this.rewriteRule.RedirectStatus) } else { - this.processResponseHeaders(http.StatusTemporaryRedirect) + this.processResponseHeaders(this.writer.Header(), http.StatusTemporaryRedirect) http.Redirect(this.writer, this.RawReq, this.rewriteReplace, http.StatusTemporaryRedirect) } return true diff --git a/internal/nodes/http_request_root.go b/internal/nodes/http_request_root.go index 66f5eb6..2a39524 100644 --- a/internal/nodes/http_request_root.go +++ b/internal/nodes/http_request_root.go @@ -217,7 +217,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { // 支持 If-None-Match if this.requestHeader("If-None-Match") == eTag { // 自定义Header - this.processResponseHeaders(http.StatusNotModified) + this.processResponseHeaders(this.writer.Header(), http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified) return true } @@ -225,7 +225,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { // 支持 If-Modified-Since if this.requestHeader("If-Modified-Since") == modifiedTime { // 自定义Header - this.processResponseHeaders(http.StatusNotModified) + this.processResponseHeaders(this.writer.Header(), http.StatusNotModified) this.writer.WriteHeader(http.StatusNotModified) return true } @@ -253,14 +253,14 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { var contentRange = this.RawReq.Header.Get("Range") if len(contentRange) > 0 { if fileSize == 0 { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } set, ok := httpRequestParseRangeHeader(contentRange) if !ok { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } @@ -269,7 +269,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { for k, r := range ranges { r2, ok := r.Convert(fileSize) if !ok { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } @@ -290,7 +290,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { } // 自定义Header - this.processResponseHeaders(http.StatusOK) + this.processResponseHeaders(this.writer.Header(), http.StatusOK) // 在Range请求中不能缓存 if len(ranges) > 0 { @@ -325,7 +325,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { return true } if !ok { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } @@ -377,7 +377,7 @@ func (this *HTTPRequest) doRoot() (isBreak bool) { return true } if !ok { - this.processResponseHeaders(http.StatusRequestedRangeNotSatisfiable) + this.processResponseHeaders(this.writer.Header(), http.StatusRequestedRangeNotSatisfiable) this.writer.WriteHeader(http.StatusRequestedRangeNotSatisfiable) return true } diff --git a/internal/nodes/http_request_shutdown.go b/internal/nodes/http_request_shutdown.go index 6ad1820..9ba5416 100644 --- a/internal/nodes/http_request_shutdown.go +++ b/internal/nodes/http_request_shutdown.go @@ -28,10 +28,10 @@ func (this *HTTPRequest) doShutdown() { if len(shutdown.URL) == 0 { // 自定义响应Headers if shutdown.Status > 0 { - this.processResponseHeaders(shutdown.Status) + this.processResponseHeaders(this.writer.Header(), shutdown.Status) this.writer.WriteHeader(shutdown.Status) } else { - this.processResponseHeaders(http.StatusOK) + this.processResponseHeaders(this.writer.Header(), http.StatusOK) this.writer.WriteHeader(http.StatusOK) } _, err := this.writer.WriteString("The site have been shutdown.") @@ -59,10 +59,10 @@ func (this *HTTPRequest) doShutdown() { // 自定义响应Headers if shutdown.Status > 0 { - this.processResponseHeaders(shutdown.Status) + this.processResponseHeaders(this.writer.Header(), shutdown.Status) this.writer.WriteHeader(shutdown.Status) } else { - this.processResponseHeaders(http.StatusOK) + this.processResponseHeaders(this.writer.Header(), http.StatusOK) this.writer.WriteHeader(http.StatusOK) } buf := utils.BytePool1k.Get() @@ -85,10 +85,10 @@ func (this *HTTPRequest) doShutdown() { } else if shutdown.BodyType == shared.BodyTypeHTML { // 自定义响应Headers if shutdown.Status > 0 { - this.processResponseHeaders(shutdown.Status) + this.processResponseHeaders(this.writer.Header(), shutdown.Status) this.writer.WriteHeader(shutdown.Status) } else { - this.processResponseHeaders(http.StatusOK) + this.processResponseHeaders(this.writer.Header(), http.StatusOK) this.writer.WriteHeader(http.StatusOK) } diff --git a/internal/nodes/http_request_traffic_limit.go b/internal/nodes/http_request_traffic_limit.go index 362f7e7..66d4837 100644 --- a/internal/nodes/http_request_traffic_limit.go +++ b/internal/nodes/http_request_traffic_limit.go @@ -13,7 +13,7 @@ func (this *HTTPRequest) doTrafficLimit() { this.tags = append(this.tags, "bandwidth") var statusCode = 509 - this.processResponseHeaders(statusCode) + this.processResponseHeaders(this.writer.Header(), statusCode) this.writer.WriteHeader(statusCode) if len(config.NoticePageBody) != 0 { diff --git a/internal/nodes/http_request_url.go b/internal/nodes/http_request_url.go index bb43a20..52c696b 100644 --- a/internal/nodes/http_request_url.go +++ b/internal/nodes/http_request_url.go @@ -44,9 +44,9 @@ func (this *HTTPRequest) doURL(method string, url string, host string, statusCod // Header if statusCode <= 0 { - this.processResponseHeaders(resp.StatusCode) + this.processResponseHeaders(this.writer.Header(), resp.StatusCode) } else { - this.processResponseHeaders(statusCode) + this.processResponseHeaders(this.writer.Header(), statusCode) } if supportVariables { diff --git a/internal/nodes/http_request_websocket.go b/internal/nodes/http_request_websocket.go index 2f3c69d..f9141ff 100644 --- a/internal/nodes/http_request_websocket.go +++ b/internal/nodes/http_request_websocket.go @@ -1,6 +1,7 @@ package nodes import ( + "bufio" "errors" "github.com/TeaOSLab/EdgeNode/internal/utils" "io" @@ -82,6 +83,33 @@ func (this *HTTPRequest) doWebsocket(requestHost string, isLastRetry bool) (shou }() go func() { + // 读取第一个响应 + resp, err := http.ReadResponse(bufio.NewReader(originConn), this.RawReq) + if err != nil { + _ = clientConn.Close() + _ = originConn.Close() + return + } + + this.processResponseHeaders(resp.Header, resp.StatusCode) + + // 将响应写回客户端 + err = resp.Write(clientConn) + if err != nil { + if resp.Body != nil { + _ = resp.Body.Close() + } + + _ = clientConn.Close() + _ = originConn.Close() + return + } + + if resp.Body != nil { + _ = resp.Body.Close() + } + + // 复制剩余的数据 var buf = utils.BytePool4k.Get() defer utils.BytePool4k.Put(buf) for {