diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index 2503247..15b89fb 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -104,39 +104,41 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, } } - var transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - // 支持TOA的连接 - conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout) - if conn != nil || err != nil { - return conn, err - } + var transport = &HTTPClientTransport{ + Transport: &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + // 支持TOA的连接 + conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout) + if conn != nil || err != nil { + return conn, err + } - // 普通的连接 - conn, err = (&net.Dialer{ - Timeout: connectionTimeout, - KeepAlive: 1 * time.Minute, - }).DialContext(ctx, network, originAddr) - if err != nil { - return nil, err - } + // 普通的连接 + conn, err = (&net.Dialer{ + Timeout: connectionTimeout, + KeepAlive: 1 * time.Minute, + }).DialContext(ctx, network, originAddr) + if err != nil { + return nil, err + } - // 处理PROXY protocol - err = this.handlePROXYProtocol(conn, req, proxyProtocol) - if err != nil { - return nil, err - } + // 处理PROXY protocol + err = this.handlePROXYProtocol(conn, req, proxyProtocol) + if err != nil { + return nil, err + } - return conn, nil + return conn, nil + }, + MaxIdleConns: 0, + MaxIdleConnsPerHost: idleConns, + MaxConnsPerHost: maxConnections, + IdleConnTimeout: idleTimeout, + ExpectContinueTimeout: 1 * time.Second, + TLSHandshakeTimeout: 3 * time.Second, + TLSClientConfig: tlsConfig, + Proxy: nil, }, - MaxIdleConns: 0, - MaxIdleConnsPerHost: idleConns, - MaxConnsPerHost: maxConnections, - IdleConnTimeout: idleTimeout, - ExpectContinueTimeout: 1 * time.Second, - TLSHandshakeTimeout: 3 * time.Second, - TLSClientConfig: tlsConfig, - Proxy: nil, } rawClient = &http.Client{ diff --git a/internal/nodes/http_client_transport.go b/internal/nodes/http_client_transport.go new file mode 100644 index 0000000..e37ec7b --- /dev/null +++ b/internal/nodes/http_client_transport.go @@ -0,0 +1,26 @@ +// Copyright 2022 Liuxiangchao iwind.liu@gmail.com. All rights reserved. Official site: https://goedge.cn . + +package nodes + +import ( + "net/http" +) + +const emptyHTTPLocation = "/$EmptyHTTPLocation$" + +type HTTPClientTransport struct { + *http.Transport +} + +func (this *HTTPClientTransport) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := this.Transport.RoundTrip(req) + if err != nil { + return resp, err + } + + // 检查在跳转相关状态中Location是否存在 + if httpStatusIsRedirect(resp.StatusCode) && len(resp.Header.Get("Location")) == 0 { + resp.Header.Set("Location", emptyHTTPLocation) + } + return resp, nil +} diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index d72c28d..34199be 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -282,16 +282,22 @@ func (this *HTTPRequest) doReverseProxy() { // 替换Location中的源站地址 var locationHeader = resp.Header.Get("Location") if len(locationHeader) > 0 { - locationURL, err := url.Parse(locationHeader) - if err == nil && locationURL.Host != this.ReqHost && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) { - locationURL.Host = this.ReqHost - if this.IsHTTP { - locationURL.Scheme = "http" - } else if this.IsHTTPS { - locationURL.Scheme = "https" - } + // 空Location处理 + if locationHeader == emptyHTTPLocation { + resp.Header.Del("Location") + } else { + // 自动修正Location中的源站地址 + locationURL, err := url.Parse(locationHeader) + if err == nil && locationURL.Host != this.ReqHost && (locationURL.Host == originAddr || strings.HasPrefix(originAddr, locationURL.Host+":")) { + locationURL.Host = this.ReqHost + if this.IsHTTP { + locationURL.Scheme = "http" + } else if this.IsHTTPS { + locationURL.Scheme = "https" + } - resp.Header.Set("Location", locationURL.String()) + resp.Header.Set("Location", locationURL.String()) + } } }