Files
EdgeNode/internal/nodes/http_client_pool.go
2023-06-23 11:43:02 +08:00

283 lines
7.0 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package nodes
import (
"context"
"crypto/tls"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/pires/go-proxyproto"
"golang.org/x/net/http2"
"net"
"net/http"
"runtime"
"strconv"
"strings"
"sync"
"time"
)
// SharedHTTPClientPool HTTP客户端池单例
var SharedHTTPClientPool = NewHTTPClientPool()
// HTTPClientPool 客户端池
type HTTPClientPool struct {
clientsMap map[string]*HTTPClient // backend key => client
cleanTicker *time.Ticker
locker sync.RWMutex
}
// NewHTTPClientPool 获取新对象
func NewHTTPClientPool() *HTTPClientPool {
var pool = &HTTPClientPool{
cleanTicker: time.NewTicker(1 * time.Hour),
clientsMap: map[string]*HTTPClient{},
}
goman.New(func() {
pool.cleanClients()
})
return pool
}
// Client 根据地址获取客户端
func (this *HTTPClientPool) Client(req *HTTPRequest,
origin *serverconfigs.OriginConfig,
originAddr string,
proxyProtocol *serverconfigs.ProxyProtocolConfig,
followRedirects bool) (rawClient *http.Client, err error) {
if origin.Addr == nil {
return nil, errors.New("origin addr should not be empty (originId:" + strconv.FormatInt(origin.Id, 10) + ")")
}
var key = origin.UniqueKey() + "@" + originAddr
var isLnRequest = origin.Id == 0
this.locker.RLock()
client, found := this.clientsMap[key]
this.locker.RUnlock()
if found {
client.UpdateAccessTime()
return client.RawClient(), nil
}
// 这里不能使用RLock避免因为并发生成多个同样的client实例
this.locker.Lock()
defer this.locker.Unlock()
// 再次查找
client, found = this.clientsMap[key]
if found {
client.UpdateAccessTime()
return client.RawClient(), nil
}
var maxConnections = origin.MaxConns
var connectionTimeout = origin.ConnTimeoutDuration()
var readTimeout = origin.ReadTimeoutDuration()
var idleTimeout = origin.IdleTimeoutDuration()
var idleConns = origin.MaxIdleConns
// 超时时间
if connectionTimeout <= 0 {
connectionTimeout = 15 * time.Second
}
if idleTimeout <= 0 {
idleTimeout = 2 * time.Minute
}
var numberCPU = runtime.NumCPU()
if numberCPU < 8 {
numberCPU = 8
}
if maxConnections <= 0 {
maxConnections = numberCPU * 64
}
if idleConns <= 0 {
idleConns = numberCPU * 16
}
// 可以判断为Ln节点请求
if isLnRequest {
maxConnections *= 8
idleConns *= 8
idleTimeout *= 4
} else if sharedNodeConfig != nil && sharedNodeConfig.Level > 1 {
// Ln节点可以适当增加连接数
maxConnections *= 2
idleConns *= 2
}
// TLS通讯
var tlsConfig = &tls.Config{
InsecureSkipVerify: true,
}
if origin.Cert != nil {
var obj = origin.Cert.CertObject()
if obj != nil {
tlsConfig.InsecureSkipVerify = false
tlsConfig.Certificates = []tls.Certificate{*obj}
if len(origin.Cert.ServerName) > 0 {
tlsConfig.ServerName = origin.Cert.ServerName
}
}
}
var transport = &HTTPClientTransport{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
// 支持TOA的连接
conn, err := this.handleTOA(req, ctx, network, originAddr, connectionTimeout)
if conn != nil || err != nil {
return conn, err
}
// 普通的连接
conn, err = (&net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
}).DialContext(ctx, network, originAddr)
if err != nil {
return nil, err
}
// 处理PROXY protocol
err = this.handlePROXYProtocol(conn, req, proxyProtocol)
if err != nil {
return nil, err
}
return conn, nil
},
MaxIdleConns: 0,
MaxIdleConnsPerHost: idleConns,
MaxConnsPerHost: maxConnections,
IdleConnTimeout: idleTimeout,
ExpectContinueTimeout: 1 * time.Second,
TLSHandshakeTimeout: 5 * time.Second,
TLSClientConfig: tlsConfig,
ReadBufferSize: 8 * 1024,
Proxy: nil,
},
}
// support http/2
if origin.HTTP2Enabled && origin.Addr != nil && origin.Addr.Protocol == serverconfigs.ProtocolHTTPS {
_ = http2.ConfigureTransport(transport.Transport)
}
rawClient = &http.Client{
Timeout: readTimeout,
Transport: transport,
CheckRedirect: func(targetReq *http.Request, via []*http.Request) error {
// 是否跟随
if followRedirects {
var schemeIsSame = true
for _, r := range via {
if r.URL.Scheme != targetReq.URL.Scheme {
schemeIsSame = false
break
}
}
if schemeIsSame {
return nil
}
}
return http.ErrUseLastResponse
},
}
this.clientsMap[key] = NewHTTPClient(rawClient)
return rawClient, nil
}
// 清理不使用的Client
func (this *HTTPClientPool) cleanClients() {
for range this.cleanTicker.C {
var nowTime = time.Now().Unix()
this.locker.Lock()
for k, client := range this.clientsMap {
if client.AccessTime() < nowTime+86400 { // 超过 N 秒没有调用就关闭
delete(this.clientsMap, k)
client.Close()
}
}
this.locker.Unlock()
}
}
// 支持TOA
func (this *HTTPClientPool) handleTOA(req *HTTPRequest, ctx context.Context, network string, originAddr string, connectionTimeout time.Duration) (net.Conn, error) {
// TODO 每个服务读取自身所属集群的TOA设置
var toaConfig = sharedTOAManager.Config()
if toaConfig != nil && toaConfig.IsOn {
var retries = 3
for i := 1; i <= retries; i++ {
var port = int(toaConfig.RandLocalPort())
// TODO 思考是否支持X-Real-IP/X-Forwarded-IP
err := sharedTOAManager.SendMsg("add:" + strconv.Itoa(port) + ":" + req.requestRemoteAddr(true))
if err != nil {
remotelogs.Error("TOA", "add failed: "+err.Error())
} else {
dialer := net.Dialer{
Timeout: connectionTimeout,
KeepAlive: 1 * time.Minute,
LocalAddr: &net.TCPAddr{
Port: port,
},
}
conn, err := dialer.DialContext(ctx, network, originAddr)
// TODO 需要在合适的时机删除TOA记录
if err == nil || i == retries {
return conn, err
}
}
}
}
return nil, nil
}
// 支持PROXY Protocol
func (this *HTTPClientPool) handlePROXYProtocol(conn net.Conn, req *HTTPRequest, proxyProtocol *serverconfigs.ProxyProtocolConfig) error {
if proxyProtocol != nil && proxyProtocol.IsOn && (proxyProtocol.Version == serverconfigs.ProxyProtocolVersion1 || proxyProtocol.Version == serverconfigs.ProxyProtocolVersion2) {
var remoteAddr = req.requestRemoteAddr(true)
var transportProtocol = proxyproto.TCPv4
if strings.Contains(remoteAddr, ":") {
transportProtocol = proxyproto.TCPv6
}
var destAddr = conn.RemoteAddr()
var reqConn = req.RawReq.Context().Value(HTTPConnContextKey)
if reqConn != nil {
destAddr = reqConn.(net.Conn).LocalAddr()
}
var header = proxyproto.Header{
Version: byte(proxyProtocol.Version),
Command: proxyproto.PROXY,
TransportProtocol: transportProtocol,
SourceAddr: &net.TCPAddr{
IP: net.ParseIP(remoteAddr),
Port: req.requestRemotePort(),
},
DestinationAddr: destAddr,
}
_, err := header.WriteTo(conn)
if err != nil {
_ = conn.Close()
return err
}
return nil
}
return nil
}