修复回源跟随无法跨不同协议、不同服务器地址的问题

This commit is contained in:
GoEdgeLab
2024-04-17 20:38:00 +08:00
parent cbf024b3ed
commit ddcc56b288
2 changed files with 36 additions and 18 deletions

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/utils/fasttime"
@@ -52,13 +53,27 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
origin *serverconfigs.OriginConfig,
originAddr string,
proxyProtocol *serverconfigs.ProxyProtocolConfig,
followRedirects bool,
host string) (rawClient *http.Client, err error) {
followRedirects bool) (rawClient *http.Client, err error) {
if origin.Addr == nil {
return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
}
var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + host
if req.RawReq.URL == nil {
err = errors.New("invalid request url")
return
}
var originHost = req.RawReq.URL.Host
var urlPort = req.RawReq.URL.Port()
if len(urlPort) == 0 {
if req.RawReq.URL.Scheme == "http" {
urlPort = "80"
} else {
urlPort = "443"
}
}
originHost = configutils.QuoteIP(originHost) + ":" + urlPort
var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + originHost
// if we are under available ProxyProtocol, we add client ip to key to make every client unique
var isProxyProtocol = false
@@ -67,6 +82,11 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
isProxyProtocol = true
}
// follow redirects
if followRedirects {
rawKey += "@follow"
}
var key = xxhash.Sum64String(rawKey)
var isLnRequest = origin.Id == 0
@@ -146,17 +166,24 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
var transport = &HTTPClientTransport{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 普通的连接
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
var realAddr = originAddr
// for redirections
if originHost != addr {
realAddr = addr
}
// connect
conn, dialErr := (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
}).DialContext(ctx, network, originAddr)
}).DialContext(ctx, network, realAddr)
if dialErr != nil {
return nil, dialErr
}
// 处理PROXY protocol
// handle PROXY protocol
proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol)
if proxyErr != nil {
return nil, proxyErr
@@ -187,16 +214,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest,
CheckRedirect: func(targetReq *http.Request, via []*http.Request) error {
// 是否跟随
if followRedirects {
var schemeIsSame = true
for _, r := range via {
if r.URL.Scheme != targetReq.URL.Scheme {
schemeIsSame = false
break
}
}
if schemeIsSame {
return nil
}
return nil
}
return http.ErrUseLastResponse

View File

@@ -275,7 +275,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId
}
// 获取请求客户端
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects, requestHost)
client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects)
if err != nil {
remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error())
this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)