From 240698ff0de91038181e9b43395301a9a36f7de1 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Mon, 14 Mar 2022 15:07:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=9B=9E=E6=BA=90=E8=B7=9F?= =?UTF-8?q?=E9=9A=8F=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/http_client_pool.go | 15 ++++++++++++++- internal/nodes/http_request_reverse_proxy.go | 2 +- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index 9e12d3d..3e43cdc 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -42,7 +42,7 @@ func NewHTTPClientPool() *HTTPClientPool { } // Client 根据地址获取客户端 -func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig, originAddr string, proxyProtocol *serverconfigs.ProxyProtocolConfig) (rawClient *http.Client, err error) { +func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig, originAddr string, proxyProtocol *serverconfigs.ProxyProtocolConfig, followRedirects bool) (rawClient *http.Client, err error) { if origin.Addr == nil { return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } @@ -139,6 +139,19 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi Timeout: readTimeout, Transport: transport, CheckRedirect: func(req *http.Request, via []*http.Request) error { + if followRedirects { + var schemeIsSame = true + for _, r := range via { + if r.URL.Scheme != req.URL.Scheme { + schemeIsSame = false + break + } + } + if schemeIsSame { + return nil + } + } + return http.ErrUseLastResponse }, } diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index c4d7f38..3872478 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -159,7 +159,7 @@ func (this *HTTPRequest) doReverseProxy() { } // 获取请求客户端 - client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol) + client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects) if err != nil { remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error()) this.write50x(err, http.StatusBadGateway, true)