diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index b2ae636..b780288 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -36,17 +36,12 @@ func NewHTTPClientPool() *HTTPClientPool { } // 根据地址获取客户端 -func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig) (rawClient *http.Client, realAddr string, err error) { +func (this *HTTPClientPool) Client(origin *serverconfigs.OriginConfig, originAddr string) (rawClient *http.Client, err error) { if origin.Addr == nil { - return nil, "", errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") + return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } - key := origin.UniqueKey() - originAddr := origin.Addr.PickAddress() - if origin.Addr.HostHasVariables() { - originAddr = req.Format(originAddr) - } - key += "@" + originAddr + key := origin.UniqueKey() + "@" + originAddr this.locker.Lock() defer this.locker.Unlock() @@ -54,7 +49,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi client, found := this.clientsMap[key] if found { client.UpdateAccessTime() - return client.RawClient(), originAddr, nil + return client.RawClient(), nil } maxConnections := origin.MaxConns @@ -128,7 +123,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi this.clientsMap[key] = NewHTTPClient(rawClient) - return rawClient, originAddr, nil + return rawClient, nil } // 清理不使用的Client diff --git a/internal/nodes/http_client_pool_test.go b/internal/nodes/http_client_pool_test.go index 01d8bd9..2207513 100644 --- a/internal/nodes/http_client_pool_test.go +++ b/internal/nodes/http_client_pool_test.go @@ -21,18 +21,18 @@ func TestHTTPClientPool_Client(t *testing.T) { t.Fatal(err) } { - client, addr, err := pool.Client(nil, origin) + client, err := pool.Client(origin, origin.Addr.PickAddress()) if err != nil { t.Fatal(err) } - t.Log("addr:", addr, "client:", client) + t.Log("client:", client) } for i := 0; i < 10; i++ { - client, addr, err := pool.Client(nil, origin) + client, err := pool.Client(origin, origin.Addr.PickAddress()) if err != nil { t.Fatal(err) } - t.Log("addr:", addr, "client:", client) + t.Log("client:", client) } } } @@ -53,7 +53,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) { for i := 0; i < 10; i++ { t.Log("get", i) - _, _, _ = pool.Client(nil, origin) + _, _ = pool.Client(origin, origin.Addr.PickAddress()) time.Sleep(1 * time.Second) } } @@ -73,6 +73,6 @@ func BenchmarkHTTPClientPool_Client(b *testing.B) { pool := NewHTTPClientPool() for i := 0; i < b.N; i++ { - _, _, _ = pool.Client(nil, origin) + _, _ = pool.Client(origin, origin.Addr.PickAddress()) } } diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 75d49d9..b4385a4 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -3,6 +3,7 @@ package nodes import ( "context" "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeNode/internal/logs" "github.com/TeaOSLab/EdgeNode/internal/utils" @@ -22,7 +23,11 @@ func (this *HTTPRequest) doReverseProxy() { stripPrefix := this.reverseProxy.StripPrefix requestURI := this.reverseProxy.RequestURI requestURIHasVariables := this.reverseProxy.RequestURIHasVariables() - requestHost := this.reverseProxy.RequestHost + + var requestHost = "" + if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeCustomized { + requestHost = this.reverseProxy.RequestHost + } requestHostHasVariables := this.reverseProxy.RequestHostHasVariables() // 源站 @@ -91,6 +96,13 @@ func (this *HTTPRequest) doReverseProxy() { this.uri = utils.CleanPath(this.uri) } + // 获取源站地址 + originAddr := origin.Addr.PickAddress() + if origin.Addr.HostHasVariables() { + originAddr = this.Format(originAddr) + } + this.originAddr = originAddr + // RequestHost if len(requestHost) > 0 { if requestHostHasVariables { @@ -99,6 +111,9 @@ func (this *HTTPRequest) doReverseProxy() { this.RawReq.Host = this.reverseProxy.RequestHost } this.RawReq.URL.Host = this.RawReq.Host + } else if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeOrigin { + this.RawReq.Host = originAddr + this.RawReq.URL.Host = this.RawReq.Host } else { this.RawReq.URL.Host = this.Host } @@ -125,15 +140,13 @@ func (this *HTTPRequest) doReverseProxy() { } // 获取请求客户端 - client, addr, err := SharedHTTPClientPool.Client(this, origin) + client, err := SharedHTTPClientPool.Client(origin, originAddr) if err != nil { logs.Error("REQUEST_REVERSE_PROXY", err.Error()) this.write502(err) return } - this.originAddr = addr - // 开始请求 resp, err := client.Do(this.RawReq) if err != nil {