diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index c9e02dd..5bd3149 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -4,6 +4,7 @@ import ( "context" "crypto/tls" "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" @@ -52,13 +53,27 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig, originAddr string, proxyProtocol *serverconfigs.ProxyProtocolConfig, - followRedirects bool, - host string) (rawClient *http.Client, err error) { + followRedirects bool) (rawClient *http.Client, err error) { if origin.Addr == nil { return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } - var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + host + if req.RawReq.URL == nil { + err = errors.New("invalid request url") + return + } + var originHost = req.RawReq.URL.Host + var urlPort = req.RawReq.URL.Port() + if len(urlPort) == 0 { + if req.RawReq.URL.Scheme == "http" { + urlPort = "80" + } else { + urlPort = "443" + } + } + originHost = configutils.QuoteIP(originHost) + ":" + urlPort + + var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + originHost // if we are under available ProxyProtocol, we add client ip to key to make every client unique var isProxyProtocol = false @@ -67,6 +82,11 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, isProxyProtocol = true } + // follow redirects + if followRedirects { + rawKey += "@follow" + } + var key = xxhash.Sum64String(rawKey) var isLnRequest = origin.Id == 0 @@ -146,17 +166,24 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, var transport = &HTTPClientTransport{ Transport: &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // 普通的连接 + DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + var realAddr = originAddr + + // for redirections + if originHost != addr { + realAddr = addr + } + + // connect conn, dialErr := (&net.Dialer{ Timeout: connectionTimeout, KeepAlive: 1 * time.Minute, - }).DialContext(ctx, network, originAddr) + }).DialContext(ctx, network, realAddr) if dialErr != nil { return nil, dialErr } - // 处理PROXY protocol + // handle PROXY protocol proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol) if proxyErr != nil { return nil, proxyErr @@ -187,16 +214,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, CheckRedirect: func(targetReq *http.Request, via []*http.Request) error { // 是否跟随 if followRedirects { - var schemeIsSame = true - for _, r := range via { - if r.URL.Scheme != targetReq.URL.Scheme { - schemeIsSame = false - break - } - } - if schemeIsSame { - return nil - } + return nil } return http.ErrUseLastResponse diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 6b65599..f9dce1e 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -275,7 +275,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId } // 获取请求客户端 - client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects, requestHost) + client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects) if err != nil { remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error()) this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)