diff --git a/internal/nodes/http_client.go b/internal/nodes/http_client.go index 8de35b9..771bdfe 100644 --- a/internal/nodes/http_client.go +++ b/internal/nodes/http_client.go @@ -7,15 +7,17 @@ import ( // HTTPClient HTTP客户端 type HTTPClient struct { - rawClient *http.Client - accessAt int64 + rawClient *http.Client + accessAt int64 + isProxyProtocol bool } // NewHTTPClient 获取新客户端对象 -func NewHTTPClient(rawClient *http.Client) *HTTPClient { +func NewHTTPClient(rawClient *http.Client, isProxyProtocol bool) *HTTPClient { return &HTTPClient{ - rawClient: rawClient, - accessAt: fasttime.Now().Unix(), + rawClient: rawClient, + accessAt: fasttime.Now().Unix(), + isProxyProtocol: isProxyProtocol, } } @@ -34,6 +36,11 @@ func (this *HTTPClient) AccessTime() int64 { return this.accessAt } +// IsProxyProtocol 判断是否为PROXY Protocol +func (this *HTTPClient) IsProxyProtocol() bool { + return this.isProxyProtocol +} + // Close 关闭 func (this *HTTPClient) Close() { this.rawClient.CloseIdleConnections() diff --git a/internal/nodes/http_client_pool.go b/internal/nodes/http_client_pool.go index 4463c55..c9e02dd 100644 --- a/internal/nodes/http_client_pool.go +++ b/internal/nodes/http_client_pool.go @@ -7,6 +7,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/utils/fasttime" + "github.com/cespare/xxhash/v2" "github.com/pires/go-proxyproto" "golang.org/x/net/http2" "net" @@ -25,7 +26,7 @@ const httpClientProxyProtocolTag = "@ProxyProtocol@" // HTTPClientPool 客户端池 type HTTPClientPool struct { - clientsMap map[string]*HTTPClient // backend key => client + clientsMap map[uint64]*HTTPClient // origin key => client cleanTicker *time.Ticker @@ -36,7 +37,7 @@ type HTTPClientPool struct { func NewHTTPClientPool() *HTTPClientPool { var pool = &HTTPClientPool{ cleanTicker: time.NewTicker(1 * time.Hour), - clientsMap: map[string]*HTTPClient{}, + clientsMap: map[uint64]*HTTPClient{}, } goman.New(func() { @@ -51,20 +52,23 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig, originAddr string, proxyProtocol *serverconfigs.ProxyProtocolConfig, - followRedirects bool) (rawClient *http.Client, err error) { + followRedirects bool, + host string) (rawClient *http.Client, err error) { if origin.Addr == nil { return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") } - var key = origin.UniqueKey() + "@" + originAddr + var rawKey = origin.UniqueKey() + "@" + originAddr + "@" + host // if we are under available ProxyProtocol, we add client ip to key to make every client unique var isProxyProtocol = false if proxyProtocol != nil && proxyProtocol.IsOn { - key += httpClientProxyProtocolTag + req.requestRemoteAddr(true) + rawKey += httpClientProxyProtocolTag + req.requestRemoteAddr(true) isProxyProtocol = true } + var key = xxhash.Sum64String(rawKey) + var isLnRequest = origin.Id == 0 this.locker.RLock() @@ -144,18 +148,18 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { // 普通的连接 - conn, err := (&net.Dialer{ + conn, dialErr := (&net.Dialer{ Timeout: connectionTimeout, KeepAlive: 1 * time.Minute, }).DialContext(ctx, network, originAddr) - if err != nil { - return nil, err + if dialErr != nil { + return nil, dialErr } // 处理PROXY protocol - err = this.handlePROXYProtocol(conn, req, proxyProtocol) - if err != nil { - return nil, err + proxyErr := this.handlePROXYProtocol(conn, req, proxyProtocol) + if proxyErr != nil { + return nil, proxyErr } return NewOriginConn(conn), nil @@ -199,7 +203,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, }, } - this.clientsMap[key] = NewHTTPClient(rawClient) + this.clientsMap[key] = NewHTTPClient(rawClient, isProxyProtocol) return rawClient, nil } @@ -209,14 +213,14 @@ func (this *HTTPClientPool) cleanClients() { for range this.cleanTicker.C { var nowTime = fasttime.Now().Unix() - var expiredKeys = []string{} + var expiredKeys []uint64 var expiredClients = []*HTTPClient{} // lookup expired clients this.locker.RLock() for k, client := range this.clientsMap { if client.AccessTime() < nowTime-86400 || - (strings.Contains(k, httpClientProxyProtocolTag) && client.AccessTime() < nowTime-3600) { // 超过 N 秒没有调用就关闭 + (client.IsProxyProtocol() && client.AccessTime() < nowTime-3600) { // 超过 N 秒没有调用就关闭 expiredKeys = append(expiredKeys, k) expiredClients = append(expiredClients, client) } diff --git a/internal/nodes/http_client_pool_test.go b/internal/nodes/http_client_pool_test.go index 0b6ddeb..ebb7fbc 100644 --- a/internal/nodes/http_client_pool_test.go +++ b/internal/nodes/http_client_pool_test.go @@ -23,14 +23,14 @@ func TestHTTPClientPool_Client(t *testing.T) { t.Fatal(err) } { - client, err := pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false) + client, err := pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false, "example.com") if err != nil { t.Fatal(err) } t.Log("client:", client) } for i := 0; i < 10; i++ { - client, err := pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false) + client, err := pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false, "example.com") if err != nil { t.Fatal(err) } @@ -54,7 +54,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) { for i := 0; i < 10; i++ { t.Log("get", i) - _, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false) + _, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false, "example.com") if testutils.IsSingleTesting() { time.Sleep(1 * time.Second) @@ -79,6 +79,6 @@ func BenchmarkHTTPClientPool_Client(b *testing.B) { var pool = NewHTTPClientPool() for i := 0; i < b.N; i++ { - _, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false) + _, _ = pool.Client(nil, origin, origin.Addr.PickAddress(), nil, false, "example.com") } } diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index f9dce1e..6b65599 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -275,7 +275,7 @@ func (this *HTTPRequest) doOriginRequest(failedOriginIds []int64, failedLnNodeId } // 获取请求客户端 - client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects) + client, err := SharedHTTPClientPool.Client(this, origin, originAddr, this.reverseProxy.ProxyProtocol, this.reverseProxy.FollowRedirects, requestHost) if err != nil { remotelogs.ErrorServer("HTTP_REQUEST_REVERSE_PROXY", this.URL()+": Create client failed: "+err.Error()) this.write50x(err, http.StatusBadGateway, "Failed to create origin site client", "构造源站客户端失败", true)