mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-10 04:20:27 +08:00
修复回源跟随无法跨不同协议、不同服务器地址的问题
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
"github.com/TeaOSLab/EdgeNode/internal/goman"
|
||||||
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
|
||||||
@@ -52,13 +53,27 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
|||||||
origin *serverconfigs.OriginConfig,
|
origin *serverconfigs.OriginConfig,
|
||||||
originAddr string,
|
originAddr string,
|
||||||
proxyProtocol *serverconfigs.ProxyProtocolConfig,
|
proxyProtocol *serverconfigs.ProxyProtocolConfig,
|
||||||
followRedirects bool,
|
followRedirects bool) (rawClient *http.Client, err error) {
|
||||||
host string) (rawClient *http.Client, err error) {
|
|
||||||
if origin.Addr == nil {
|
if origin.Addr == nil {
|
||||||
return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
|
return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
|
||||||
}
|
}
|
||||||
|
|
||||||
var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + host
|
if req.RawReq.URL == nil {
|
||||||
|
err = errors.New("invalid request url")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var originHost = req.RawReq.URL.Host
|
||||||
|
var urlPort = req.RawReq.URL.Port()
|
||||||
|
if len(urlPort) == 0 {
|
||||||
|
if req.RawReq.URL.Scheme == "http" {
|
||||||
|
urlPort = "80"
|
||||||
|
} else {
|
||||||
|
urlPort = "443"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
originHost = configutils.QuoteIP(originHost) + ":" + urlPort
|
||||||
|
|
||||||
|
var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + originHost
|
||||||
|
|
||||||
// if we are under available ProxyProtocol, we add client ip to key to make every client unique
|
// if we are under available ProxyProtocol, we add client ip to key to make every client unique
|
||||||
var isProxyProtocol = false
|
var isProxyProtocol = false
|
||||||
@@ -67,6 +82,11 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
|||||||
isProxyProtocol = true
|
isProxyProtocol = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// follow redirects
|
||||||
|
if followRedirects {
|
||||||
|
rawKey += "@follow"
|
||||||
|
}
|
||||||
|
|
||||||
var key = xxhash.Sum64String(rawKey)
|
var key = xxhash.Sum64String(rawKey)
|
||||||
|
|
||||||
var isLnRequest = origin.Id == 0
|
var isLnRequest = origin.Id == 0
|
||||||
@@ -146,17 +166,24 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
|||||||
|
|
||||||
var transport = &HTTPClientTransport{
|
var transport = &HTTPClientTransport{
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
// 普通的连接
|
var realAddr = originAddr
|
||||||
|
|
||||||
|
// for redirections
|
||||||
|
if originHost != addr {
|
||||||
|
realAddr = addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect
|
||||||
conn, dialErr := (&net.Dialer{
|
conn, dialErr := (&net.Dialer{
|
||||||
Timeout: connectionTimeout,
|
Timeout: connectionTimeout,
|
||||||
KeepAlive: 1 * time.Minute,
|
KeepAlive: 1 * time.Minute,
|
||||||
}).DialContext(ctx, network, originAddr)
|
}).DialContext(ctx, network, realAddr)
|
||||||
if dialErr != nil {
|
if dialErr != nil {
|
||||||
return nil, dialErr
|
return nil, dialErr
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理PROXY protocol
|
// handle PROXY protocol
|
||||||
proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol)
|
proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol)
|
||||||
if proxyErr != nil {
|
if proxyErr != nil {
|
||||||
return nil, proxyErr
|
return nil, proxyErr
|
||||||
@@ -187,16 +214,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
|
|||||||
CheckRedirect: func(targetReq *http.Request, via []*http.Request) error {
|
CheckRedirect: func(targetReq *http.Request, via []*http.Request) error {
|
||||||
// 是否跟随
|
// 是否跟随
|
||||||
if followRedirects {
|
if followRedirects {
|
||||||
var schemeIsSame = true
|
return nil
|
||||||
for _, r := range via {
|
|
||||||
if r.URL.Scheme != targetReq.URL.Scheme {
|
|
||||||
schemeIsSame = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if schemeIsSame {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return http.ErrUseLastResponse
|
return http.ErrUseLastResponse
|
||||||
|
|||||||
@@ -275,7 +275,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 获取请求客户端
|
// 获取请求客户端
|
||||||
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects, requestHost)
|
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
|
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
|
||||||
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
|
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)
|
||||||
|
|||||||
Reference in New Issue
Block a user