[反向代理]增加请求主机名类型选择

This commit is contained in:
刘祥超
2020-11-30 22:27:50 +08:00
parent 021ef3dd84
commit 49d3c1b586
3 changed files with 28 additions and 20 deletions

View File

@@ -36,17 +36,12 @@ func NewHTTPClientPool() *HTTPClientPool {
} }
// 根据地址获取客户端 // 根据地址获取客户端
func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.OriginConfig) (rawClient *http.Client, realAddr string, err error) { func (this *HTTPClientPool) Client(origin *serverconfigs.OriginConfig, originAddr string) (rawClient *http.Client, err error) {
if origin.Addr == nil { if origin.Addr == nil {
return nil, "", errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")") return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
} }
key := origin.UniqueKey() key := origin.UniqueKey() + "@" + originAddr
originAddr := origin.Addr.PickAddress()
if origin.Addr.HostHasVariables() {
originAddr = req.Format(originAddr)
}
key += "@" + originAddr
this.locker.Lock() this.locker.Lock()
defer this.locker.Unlock() defer this.locker.Unlock()
@@ -54,7 +49,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
client, found := this.clientsMap[key] client, found := this.clientsMap[key]
if found { if found {
client.UpdateAccessTime() client.UpdateAccessTime()
return client.RawClient(), originAddr, nil return client.RawClient(), nil
} }
maxConnections := origin.MaxConns maxConnections := origin.MaxConns
@@ -128,7 +123,7 @@ func (this *HTTPClientPool) Client(req *HTTPRequest, origin *serverconfigs.Origi
this.clientsMap[key] = NewHTTPClient(rawClient) this.clientsMap[key] = NewHTTPClient(rawClient)
return rawClient, originAddr, nil return rawClient, nil
} }
// 清理不使用的Client // 清理不使用的Client

View File

@@ -21,18 +21,18 @@ func TestHTTPClientPool_Client(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
{ {
client, addr, err := pool.Client(nil, origin) client, err := pool.Client(origin, origin.Addr.PickAddress())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log("addr:", addr, "client:", client) t.Log("client:", client)
} }
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
client, addr, err := pool.Client(nil, origin) client, err := pool.Client(origin, origin.Addr.PickAddress())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
t.Log("addr:", addr, "client:", client) t.Log("client:", client)
} }
} }
} }
@@ -53,7 +53,7 @@ func TestHTTPClientPool_cleanClients(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
t.Log("get", i) t.Log("get", i)
_, _, _ = pool.Client(nil, origin) _, _ = pool.Client(origin, origin.Addr.PickAddress())
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
} }
} }
@@ -73,6 +73,6 @@ func BenchmarkHTTPClientPool_Client(b *testing.B) {
pool := NewHTTPClientPool() pool := NewHTTPClientPool()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_, _, _ = pool.Client(nil, origin) _, _ = pool.Client(origin, origin.Addr.PickAddress())
} }
} }

View File

@@ -3,6 +3,7 @@ package nodes
import ( import (
"context" "context"
"errors" "errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/logs" "github.com/TeaOSLab/EdgeNode/internal/logs"
"github.com/TeaOSLab/EdgeNode/internal/utils" "github.com/TeaOSLab/EdgeNode/internal/utils"
@@ -22,7 +23,11 @@ func (this *HTTPRequest) doReverseProxy() {
stripPrefix := this.reverseProxy.StripPrefix stripPrefix := this.reverseProxy.StripPrefix
requestURI := this.reverseProxy.RequestURI requestURI := this.reverseProxy.RequestURI
requestURIHasVariables := this.reverseProxy.RequestURIHasVariables() requestURIHasVariables := this.reverseProxy.RequestURIHasVariables()
requestHost := this.reverseProxy.RequestHost
var requestHost = ""
if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeCustomized {
requestHost = this.reverseProxy.RequestHost
}
requestHostHasVariables := this.reverseProxy.RequestHostHasVariables() requestHostHasVariables := this.reverseProxy.RequestHostHasVariables()
// 源站 // 源站
@@ -91,6 +96,13 @@ func (this *HTTPRequest) doReverseProxy() {
this.uri = utils.CleanPath(this.uri) this.uri = utils.CleanPath(this.uri)
} }
// 获取源站地址
originAddr := origin.Addr.PickAddress()
if origin.Addr.HostHasVariables() {
originAddr = this.Format(originAddr)
}
this.originAddr = originAddr
// RequestHost // RequestHost
if len(requestHost) > 0 { if len(requestHost) > 0 {
if requestHostHasVariables { if requestHostHasVariables {
@@ -99,6 +111,9 @@ func (this *HTTPRequest) doReverseProxy() {
this.RawReq.Host = this.reverseProxy.RequestHost this.RawReq.Host = this.reverseProxy.RequestHost
} }
this.RawReq.URL.Host = this.RawReq.Host this.RawReq.URL.Host = this.RawReq.Host
} else if this.reverseProxy.RequestHostType == serverconfigs.RequestHostTypeOrigin {
this.RawReq.Host = originAddr
this.RawReq.URL.Host = this.RawReq.Host
} else { } else {
this.RawReq.URL.Host = this.Host this.RawReq.URL.Host = this.Host
} }
@@ -125,15 +140,13 @@ func (this *HTTPRequest) doReverseProxy() {
} }
// 获取请求客户端 // 获取请求客户端
client, addr, err := SharedHTTPClientPool.Client(this, origin) client, err := SharedHTTPClientPool.Client(origin, originAddr)
if err != nil { if err != nil {
logs.Error("REQUEST_REVERSE_PROXY", err.Error()) logs.Error("REQUEST_REVERSE_PROXY", err.Error())
this.write502(err) this.write502(err)
return return
} }
this.originAddr = addr
// 开始请求 // 开始请求
resp, err := client.Do(this.RawReq) resp, err := client.Do(this.RawReq)
if err != nil { if err != nil {