TLS支持默认SNI回源

This commit is contained in:
刘祥超
2022-08-03 11:04:58 +08:00
parent 8cfde43f5d
commit 9ad1c3a3c8

View File

@@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"errors" "errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
"github.com/TeaOSLab/EdgeNode/internal/goman" "github.com/TeaOSLab/EdgeNode/internal/goman"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/TeaOSLab/EdgeNode/internal/stats" "github.com/TeaOSLab/EdgeNode/internal/stats"
@@ -108,8 +109,9 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
// 记录域名排行 // 记录域名排行
tlsConn, ok := conn.(*tls.Conn) tlsConn, ok := conn.(*tls.Conn)
var recordStat = false var recordStat = false
var serverName = ""
if ok { if ok {
var serverName = tlsConn.ConnectionState().ServerName serverName = tlsConn.ConnectionState().ServerName
if len(serverName) > 0 { if len(serverName) > 0 {
// 统计 // 统计
stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) stats.SharedTrafficStatManager.Add(server.Id, serverName, 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
@@ -122,7 +124,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId()) stats.SharedTrafficStatManager.Add(server.Id, "", 0, 0, 1, 0, 0, 0, server.ShouldCheckTrafficLimit(), server.PlanId())
} }
originConn, err := this.connectOrigin(server.Id, server.ReverseProxy, conn.RemoteAddr().String()) originConn, err := this.connectOrigin(server.Id, serverName, server.ReverseProxy, conn.RemoteAddr().String())
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return err return err
@@ -219,26 +221,27 @@ func (this *TCPListener) Close() error {
} }
// 连接源站 // 连接源站
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { func (this *TCPListener) connectOrigin(serverId int64, requestHost string, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
if reverseProxy == nil { if reverseProxy == nil {
return nil, errors.New("no reverse proxy config") return nil, errors.New("no reverse proxy config")
} }
var requestCall = shared.NewRequestCall()
requestCall.Domain = requestHost
var retries = 3 var retries = 3
var addr string var addr string
for i := 0; i < retries; i++ { for i := 0; i < retries; i++ {
var origin = reverseProxy.NextOrigin(nil) var origin = reverseProxy.NextOrigin(requestCall)
if origin == nil { if origin == nil {
continue continue
} }
// 回源主机名 // 回源主机名
var requestHost = ""
if len(reverseProxy.RequestHost) > 0 {
requestHost = reverseProxy.RequestHost
}
if len(origin.RequestHost) > 0 { if len(origin.RequestHost) > 0 {
requestHost = origin.RequestHost requestHost = origin.RequestHost
} else if len(reverseProxy.RequestHost) > 0 {
requestHost = reverseProxy.RequestHost
} }
conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost) conn, addr, err = OriginConnect(origin, this.port, remoteAddr, requestHost)