diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index 5305ac9..2f0ced9 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -46,7 +46,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { } } - tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo)) + tlsPolicy, _, err := this.matchSSL(this.helloServerNames(clientInfo)) if err != nil { return nil, err } @@ -69,7 +69,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { } } - tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo)) + tlsPolicy, cert, err := this.matchSSL(this.helloServerNames(clientInfo)) if err != nil { return nil, err } @@ -85,7 +85,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { } // 根据域名匹配证书 -func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) { +func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) { var group = this.Group if group == nil { @@ -99,7 +99,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C // 如果域名为空,则取第一个 // 通常域名为空是因为是直接通过IP访问的 - if len(domain) == 0 { + if len(domains) == 0 { if group.IsHTTPS() && globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly { return nil, nil, errors.New("no tls server name matched") } @@ -116,9 +116,25 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C } return nil, nil, errors.New("no tls server name found") } + var firstDomain = domains[0] // 通过网站域名配置匹配 - server, _ := this.findNamedServer(domain) + var server *serverconfigs.ServerConfig + var matchedDomain string + for _, domain := range domains { + server, _ = this.findNamedServer(domain, true) + if server != nil { + matchedDomain = domain + break + } + } + if server == nil { + server, _ = this.findNamedServer(firstDomain, false) + if server != nil { + matchedDomain = firstDomain + } + } + if server == nil { // 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配 // 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用 @@ -127,14 +143,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn { continue } - cert, ok := searchingServer.SSLPolicy().MatchDomain(domain) + cert, ok := searchingServer.SSLPolicy().MatchDomain(firstDomain) if ok { return searchingServer.SSLPolicy(), cert, nil } } } - return nil, nil, errors.New("no server found for '" + domain + "'") + return nil, nil, errors.New("no server found for '" + firstDomain + "'") } if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { // 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配 @@ -144,32 +160,32 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn { continue } - cert, ok := searchingServer.SSLPolicy().MatchDomain(domain) + cert, ok := searchingServer.SSLPolicy().MatchDomain(matchedDomain) if ok { return searchingServer.SSLPolicy(), cert, nil } } } - return nil, nil, errors.New("no cert found for '" + domain + "'") + return nil, nil, errors.New("no cert found for '" + matchedDomain + "'") } // 证书是否匹配 var sslConfig = server.SSLPolicy() - cert, ok := sslConfig.MatchDomain(domain) + cert, ok := sslConfig.MatchDomain(matchedDomain) if ok { return sslConfig, cert, nil } if len(sslConfig.Certs) == 0 { - remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id), "", nil) + remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+matchedDomain+"', server id: "+types.String(server.Id), "", nil) } return sslConfig, sslConfig.FirstCert(), nil } // 根据域名来查找匹配的域名 -func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { +func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConfig *serverconfigs.ServerConfig, serverName string) { serverConfig, serverName = this.findNamedServerMatched(name) if serverConfig != nil { return @@ -198,14 +214,18 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf return } - // 如果没有找到,则匹配到第一个 - var group = this.Group - var currentServers = group.Servers() - var countServers = len(currentServers) - if countServers == 0 { - return nil, "" + if !exactly { + // 如果没有找到,则匹配到第一个 + var group = this.Group + var currentServers = group.Servers() + var countServers = len(currentServers) + if countServers == 0 { + return nil, "" + } + return currentServers[0], name } - return currentServers[0], name + + return } // 严格查找域名 @@ -234,16 +254,23 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser } // 从Hello信息中获取服务名称 -func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string { - var serverName = clientInfo.ServerName - if len(serverName) == 0 && clientInfo.Conn != nil { +func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (serverNames []string) { + if len(clientInfo.ServerName) != 0 { + serverNames = append(serverNames, clientInfo.ServerName) + return + } + + if clientInfo.Conn != nil { var localAddr = clientInfo.Conn.LocalAddr() if localAddr != nil { tcpAddr, ok := localAddr.(*net.TCPAddr) if ok { - serverName = tcpAddr.IP.String() + serverNames = append(serverNames, tcpAddr.IP.String()) } } } - return serverName + + serverNames = append(serverNames, sharedNodeConfig.IPAddresses...) + + return } diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index 92a6f59..a77d54a 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -162,7 +162,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http. domain = reqHost } - server, serverName := this.findNamedServer(domain) + server, serverName := this.findNamedServer(domain, false) if server == nil { if server == nil { // 增加默认的一个服务