From 7130154bc839eb8ecbf6f46b8c5c275c2a550ebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Wed, 17 Apr 2024 20:38:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=9B=9E=E6=BA=90=E8=B7=9F?= =?UTF-8?q?=E9=9A=8F=E6=97=A0=E6=B3=95=E8=B7=A8=E4=B8=8D=E5=90=8C=E5=8D=8F?= =?UTF-8?q?=E8=AE=AE=E3=80=81=E4=B8=8D=E5=90=8C=E6=9C=8D=E5=8A=A1=E5=99=A8?= =?UTF-8?q?=E5=9C=B0=E5=9D=80=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/http_client_pool.go | 52 +++++++++++++------- internal/nodes/http_request_reverse_proxy.go | 2 +- 2 files changed, 36 insertions(+), 18 deletions(-) diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index c9e02dd..5bd3149 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" @@ -52,13 +53,27 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig, originAddr string, proxyProtocol *serverconfigs.ProxyProtocolConfig, - followRedirects bool, - host string) (rawClient *http.Client, err error) { + followRedirects bool) (rawClient *http.Client, err error) { if origin.Addr == nil { return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } - var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + host + if req.RawReq.URL == nil { + err = errors.New("invalid request url") + return + } + var originHost = req.RawReq.URL.Host + var urlPort = req.RawReq.URL.Port() + if len(urlPort) == 0 { + if req.RawReq.URL.Scheme == "http" { + urlPort = "80" + } else { + urlPort = "443" + } + } + originHost = configutils.QuoteIP(originHost) + ":" + urlPort + + var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + originHost // if we are under available ProxyProtocol, we add client ip to key to make every client unique var isProxyProtocol = false @@ -67,6 +82,11 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, isProxyProtocol = true } + // follow redirects + if followRedirects { + rawKey += "@follow" + } + var key = xxhash.Sum64String(rawKey) var isLnRequest = origin.Id == 0 @@ -146,17 +166,24 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, var transport = &HTTPClientTransport{ Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // 普通的连接 + DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + var realAddr = originAddr + + // for redirections + if originHost != addr { + realAddr = addr + } + + // connect conn, dialErr := (&net.Dialer{ Timeout: connectionTimeout, KeepAlive: 1 * time.Minute, - }).DialContext(ctx, network, originAddr) + }).DialContext(ctx, network, realAddr) if dialErr != nil { return nil, dialErr } - // 处理PROXY protocol + // handle PROXY protocol proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol) if proxyErr != nil { return nil, proxyErr @@ -187,16 +214,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, CheckRedirect: func(targetReq *http.Request, via []*http.Request) error { // 是否跟随 if followRedirects { - var schemeIsSame = true - for _, r := range via { - if r.URL.Scheme != targetReq.URL.Scheme { - schemeIsSame = false - break - } - } - if schemeIsSame { - return nil - } + return nil } return http.ErrUseLastResponse diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 6b65599..f9dce1e 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -275,7 +275,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId } // 获取请求客户端 - client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects, requestHost) + client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects) if err != nil { remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error()) this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)