mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2026-01-04 22:55:48 +08:00
当SNI无法读取到ServerName时,尝试使用节点IP搜索网站
This commit is contained in:
@@ -46,7 +46,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
}
|
||||
}
|
||||
|
||||
tlsPolicy, _, err := this.matchSSL(this.helloServerName(clientInfo))
|
||||
tlsPolicy, _, err := this.matchSSL(this.helloServerNames(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -69,7 +69,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
}
|
||||
}
|
||||
|
||||
tlsPolicy, cert, err := this.matchSSL(this.helloServerName(clientInfo))
|
||||
tlsPolicy, cert, err := this.matchSSL(this.helloServerNames(clientInfo))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
}
|
||||
|
||||
// 根据域名匹配证书
|
||||
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
|
||||
func (this *BaseListener) matchSSL(domains []string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
|
||||
var group = this.Group
|
||||
|
||||
if group == nil {
|
||||
@@ -99,7 +99,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
|
||||
// 如果域名为空,则取第一个
|
||||
// 通常域名为空是因为是直接通过IP访问的
|
||||
if len(domain) == 0 {
|
||||
if len(domains) == 0 {
|
||||
if group.IsHTTPS() && globalServerConfig != nil && globalServerConfig.HTTPAll.MatchDomainStrictly {
|
||||
return nil, nil, errors.New("no tls server name matched")
|
||||
}
|
||||
@@ -116,9 +116,25 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
}
|
||||
return nil, nil, errors.New("no tls server name found")
|
||||
}
|
||||
var firstDomain = domains[0]
|
||||
|
||||
// 通过网站域名配置匹配
|
||||
server, _ := this.findNamedServer(domain)
|
||||
var server *serverconfigs.ServerConfig
|
||||
var matchedDomain string
|
||||
for _, domain := range domains {
|
||||
server, _ = this.findNamedServer(domain, true)
|
||||
if server != nil {
|
||||
matchedDomain = domain
|
||||
break
|
||||
}
|
||||
}
|
||||
if server == nil {
|
||||
server, _ = this.findNamedServer(firstDomain, false)
|
||||
if server != nil {
|
||||
matchedDomain = firstDomain
|
||||
}
|
||||
}
|
||||
|
||||
if server == nil {
|
||||
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
// 此功能仅为了兼容以往版本(v1.0.4),不应该作为常态启用
|
||||
@@ -127,14 +143,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
|
||||
continue
|
||||
}
|
||||
cert, ok := searchingServer.SSLPolicy().MatchDomain(domain)
|
||||
cert, ok := searchingServer.SSLPolicy().MatchDomain(firstDomain)
|
||||
if ok {
|
||||
return searchingServer.SSLPolicy(), cert, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("no server found for '" + domain + "'")
|
||||
return nil, nil, errors.New("no server found for '" + firstDomain + "'")
|
||||
}
|
||||
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
@@ -144,32 +160,32 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
if searchingServer.SSLPolicy() == nil || !searchingServer.SSLPolicy().IsOn {
|
||||
continue
|
||||
}
|
||||
cert, ok := searchingServer.SSLPolicy().MatchDomain(domain)
|
||||
cert, ok := searchingServer.SSLPolicy().MatchDomain(matchedDomain)
|
||||
if ok {
|
||||
return searchingServer.SSLPolicy(), cert, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("no cert found for '" + domain + "'")
|
||||
return nil, nil, errors.New("no cert found for '" + matchedDomain + "'")
|
||||
}
|
||||
|
||||
// 证书是否匹配
|
||||
var sslConfig = server.SSLPolicy()
|
||||
cert, ok := sslConfig.MatchDomain(domain)
|
||||
cert, ok := sslConfig.MatchDomain(matchedDomain)
|
||||
if ok {
|
||||
return sslConfig, cert, nil
|
||||
}
|
||||
|
||||
if len(sslConfig.Certs) == 0 {
|
||||
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id), "", nil)
|
||||
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+matchedDomain+"', server id: "+types.String(server.Id), "", nil)
|
||||
}
|
||||
|
||||
return sslConfig, sslConfig.FirstCert(), nil
|
||||
}
|
||||
|
||||
// 根据域名来查找匹配的域名
|
||||
func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
|
||||
func (this *BaseListener) findNamedServer(name string, exactly bool) (serverConfig *serverconfigs.ServerConfig, serverName string) {
|
||||
serverConfig, serverName = this.findNamedServerMatched(name)
|
||||
if serverConfig != nil {
|
||||
return
|
||||
@@ -198,14 +214,18 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
|
||||
return
|
||||
}
|
||||
|
||||
// 如果没有找到,则匹配到第一个
|
||||
var group = this.Group
|
||||
var currentServers = group.Servers()
|
||||
var countServers = len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil, ""
|
||||
if !exactly {
|
||||
// 如果没有找到,则匹配到第一个
|
||||
var group = this.Group
|
||||
var currentServers = group.Servers()
|
||||
var countServers = len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
return currentServers[0], name
|
||||
}
|
||||
return currentServers[0], name
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 严格查找域名
|
||||
@@ -234,16 +254,23 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
||||
}
|
||||
|
||||
// 从Hello信息中获取服务名称
|
||||
func (this *BaseListener) helloServerName(clientInfo *tls.ClientHelloInfo) string {
|
||||
var serverName = clientInfo.ServerName
|
||||
if len(serverName) == 0 && clientInfo.Conn != nil {
|
||||
func (this *BaseListener) helloServerNames(clientInfo *tls.ClientHelloInfo) (serverNames []string) {
|
||||
if len(clientInfo.ServerName) != 0 {
|
||||
serverNames = append(serverNames, clientInfo.ServerName)
|
||||
return
|
||||
}
|
||||
|
||||
if clientInfo.Conn != nil {
|
||||
var localAddr = clientInfo.Conn.LocalAddr()
|
||||
if localAddr != nil {
|
||||
tcpAddr, ok := localAddr.(*net.TCPAddr)
|
||||
if ok {
|
||||
serverName = tcpAddr.IP.String()
|
||||
serverNames = append(serverNames, tcpAddr.IP.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
return serverName
|
||||
|
||||
serverNames = append(serverNames, sharedNodeConfig.IPAddresses...)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
domain = reqHost
|
||||
}
|
||||
|
||||
server, serverName := this.findNamedServer(domain)
|
||||
server, serverName := this.findNamedServer(domain, false)
|
||||
if server == nil {
|
||||
if server == nil {
|
||||
// 增加默认的一个服务
|
||||
|
||||
Reference in New Issue
Block a user