From 9ad1c3a3c8c9bed8630fb6e741b490ed0b46e7f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Wed, 3 Aug 2022 11:04:58 +0800 Subject: [PATCH] =?UTF-8?q?TLS=E6=94=AF=E6=8C=81=E9=BB=98=E8=AE=A4SNI?= =?UTF-8?q?=E5=9B=9E=E6=BA=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/listener_tcp.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index 38114ae..fbfbb54 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" "github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/stats" @@ -108,8 +109,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error { // 记录域名排行 tlsConn, ok := conn.(*tls.Conn) var recordStat = false + var serverName = "" if ok { - var serverName = tlsConn.ConnectionState().ServerName + serverName = tlsConn.ConnectionState().ServerName if len(serverName) > 0 { // 统计 stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) @@ -122,7 +124,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error { stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) } - originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String()) + originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String()) if err != nil { _ = conn.Close() return err @@ -219,26 +221,27 @@ func (this *TCPListener) Close() error { } // 连接源站 -func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { +func (this *TCPListener) connectOrigin(serverId int64, requestHost string, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { if reverseProxy == nil { return nil, errors.New("no reverse proxy config") } + var requestCall = shared.NewRequestCall() + requestCall.Domain = requestHost + var retries = 3 var addr string for i := 0; i < retries; i++ { - var origin = reverseProxy.NextOrigin(nil) + var origin = reverseProxy.NextOrigin(requestCall) if origin == nil { continue } // 回源主机名 - var requestHost = "" - if len(reverseProxy.RequestHost) > 0 { - requestHost = reverseProxy.RequestHost - } if len(origin.RequestHost) > 0 { requestHost = origin.RequestHost + } else if len(reverseProxy.RequestHost) > 0 { + requestHost = reverseProxy.RequestHost } conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost)