大幅提升域名匹配性能

This commit is contained in:
刘祥超
2021-11-15 16:57:18 +08:00
parent 5a6ead1dd7
commit 04271d77c2
7 changed files with 66 additions and 107 deletions

View File

@@ -40,7 +40,7 @@ func (this *HTTPRequest) doReverseProxy() {
requestCall.CallResponseCallbacks(this.writer) requestCall.CallResponseCallbacks(this.writer)
if origin == nil { if origin == nil {
err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy") err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy")
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error()) remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error())
this.write50x(err, http.StatusBadGateway) this.write50x(err, http.StatusBadGateway)
return return
} }

View File

@@ -5,17 +5,13 @@ import (
"errors" "errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"sync"
) )
type BaseListener struct { type BaseListener struct {
serversLocker sync.RWMutex
namedServersLocker sync.RWMutex
namedServers map[string]*NamedServer // 域名 => server
Group *serverconfigs.ServerAddressGroup Group *serverconfigs.ServerAddressGroup
countActiveConnections int64 // 当前活跃的连接数 countActiveConnections int64 // 当前活跃的连接数
@@ -23,14 +19,11 @@ type BaseListener struct {
// Init 初始化 // Init 初始化
func (this *BaseListener) Init() { func (this *BaseListener) Init() {
this.namedServers = map[string]*NamedServer{}
} }
// Reset 清除既有配置 // Reset 清除既有配置
func (this *BaseListener) Reset() { func (this *BaseListener) Reset() {
this.namedServersLocker.Lock()
this.namedServers = map[string]*NamedServer{}
this.namedServersLocker.Unlock()
} }
// CountActiveListeners 获取当前活跃连接数 // CountActiveListeners 获取当前活跃连接数
@@ -67,7 +60,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return nil, err return nil, err
} }
if cert == nil { if cert == nil {
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'") return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
} }
return cert, nil return cert, nil
}, },
@@ -83,7 +76,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return nil, err return nil, err
} }
if cert == nil { if cert == nil {
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'") return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
} }
return cert, nil return cert, nil
}, },
@@ -92,9 +85,6 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
// 根据域名匹配证书 // 根据域名匹配证书
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) { func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
this.serversLocker.RLock()
defer this.serversLocker.RUnlock()
group := this.Group group := this.Group
if group == nil { if group == nil {
@@ -108,9 +98,9 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
return nil, nil, errors.New("no tls server name matched") return nil, nil, errors.New("no tls server name matched")
} }
firstServer := group.FirstServer() firstServer := group.FirstTLSServer()
if firstServer == nil { if firstServer == nil {
return nil, nil, errors.New("no server available") return nil, nil, errors.New("no tls server available")
} }
sslConfig := firstServer.SSLPolicy() sslConfig := firstServer.SSLPolicy()
@@ -119,14 +109,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
} }
return nil, nil, errors.New("no tls server name found") return nil, nil, errors.New("no tls server name found")
} }
// 通过代理服务域名配置匹配 // 通过代理服务域名配置匹配
server, _ := this.findNamedServer(domain) server, _ := this.findNamedServer(domain)
if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
// 搜索所有的Server通过SSL证书内容中的DNSName匹配 // 找不到或者此时的服务没有配置证书,需要搜索所有的Server通过SSL证书内容中的DNSName匹配
for _, server := range group.Servers { // TODO 需要思考这种情况下是否允许访问
for _, server := range group.Servers() {
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
continue continue
} }
@@ -136,7 +126,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
} }
} }
return nil, nil, errors.New("[proxy]no server found for '" + domain + "'") return nil, nil, errors.New("no server found for '" + domain + "'")
} }
// 证书是否匹配 // 证书是否匹配
@@ -146,6 +136,10 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
return sslConfig, cert, nil return sslConfig, cert, nil
} }
if len(sslConfig.Certs) == 0 {
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id))
}
return sslConfig, sslConfig.FirstCert(), nil return sslConfig, sslConfig.FirstCert(), nil
} }
@@ -173,11 +167,8 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
} }
// 如果没有找到,则匹配到第一个 // 如果没有找到,则匹配到第一个
this.serversLocker.RLock()
defer this.serversLocker.RUnlock()
group := this.Group group := this.Group
currentServers := group.Servers currentServers := group.Servers()
countServers := len(currentServers) countServers := len(currentServers)
if countServers == 0 { if countServers == 0 {
return nil, "" return nil, ""
@@ -192,71 +183,27 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
return nil, "" return nil, ""
} }
// 读取缓存 server := group.MatchServerName(name)
this.namedServersLocker.RLock() if server != nil {
namedServer, found := this.namedServers[name] return server, name
if found {
this.namedServersLocker.RUnlock()
return namedServer.Server, namedServer.Name
} }
this.namedServersLocker.RUnlock()
this.serversLocker.RLock()
defer this.serversLocker.RUnlock()
currentServers := group.Servers
countServers := len(currentServers)
if countServers == 0 {
return nil, ""
}
// 只记录N个记录防止内存耗尽
maxNamedServers := 100_0000
// 是否严格匹配域名 // 是否严格匹配域名
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
// 如果只有一个server则默认为这个 // 如果只有一个server则默认为这个
var currentServers = group.Servers()
var countServers = len(currentServers)
if countServers == 1 && !matchDomainStrictly { if countServers == 1 && !matchDomainStrictly {
return currentServers[0], name return currentServers[0], name
} }
// 精确查找
for _, server := range currentServers {
if server.MatchNameStrictly(name) {
this.namedServersLocker.Lock()
if len(this.namedServers) < maxNamedServers {
this.namedServers[name] = &NamedServer{
Name: name,
Server: server,
}
}
this.namedServersLocker.Unlock()
return server, name
}
}
// 模糊查找
for _, server := range currentServers {
if matched := server.MatchName(name); matched {
this.namedServersLocker.Lock()
if len(this.namedServers) < maxNamedServers {
this.namedServers[name] = &NamedServer{
Name: name,
Server: server,
}
}
this.namedServersLocker.Unlock()
return server, name
}
}
return nil, name return nil, name
} }
// 使用CNAME来查找服务 // 使用CNAME来查找服务
// TODO 防止单IP随机生成域名攻击 // TODO 防止单IP随机生成域名攻击
func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.ServerConfig { func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig {
if !sharedNodeConfig.SupportCNAME { if !sharedNodeConfig.SupportCNAME {
return nil return nil
} }
@@ -266,25 +213,10 @@ func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.Serv
return nil return nil
} }
this.serversLocker.Lock()
defer this.serversLocker.Unlock()
group := this.Group group := this.Group
if group == nil { if group == nil {
return nil return nil
} }
currentServers := group.Servers return group.MatchServerCNAME(realName)
countServers := len(currentServers)
if countServers == 0 {
return nil
}
for _, server := range currentServers {
if server.SupportCNAME && lists.ContainsString(server.AliasServerNames, realName) {
return server
}
}
return nil
} }

View File

@@ -0,0 +1,36 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/types"
"testing"
"time"
)
func TestBaseListener_FindServer(t *testing.T) {
sharedNodeConfig = &nodeconfigs.NodeConfig{}
var listener = &BaseListener{namedServers: map[string]*NamedServer{}}
listener.Group = &serverconfigs.ServerAddressGroup{}
for i := 0; i < 1_000_000; i++ {
var server = &serverconfigs.ServerConfig{
IsOn: true,
Name: types.String(i) + ".hello.com",
ServerNames: []*serverconfigs.ServerNameConfig{
{Name: types.String(i) + ".hello.com"},
},
}
_ = server.Init()
listener.Group.Servers = append(listener.Group.Servers, server)
}
var before = time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
t.Log(listener.findNamedServerMatched("855555.hello.com"))
}

View File

@@ -152,7 +152,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
server, serverName := this.findNamedServer(domain) server, serverName := this.findNamedServer(domain)
if server == nil { if server == nil {
server = this.findServerWithCname(domain) server = this.findServerWithCNAME(domain)
if server == nil { if server == nil {
// 严格匹配域名模式下,我们拒绝用户访问 // 严格匹配域名模式下,我们拒绝用户访问
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {

View File

@@ -75,7 +75,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
} }
originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String()) originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
if err != nil { if err != nil {
return err return err
} }
@@ -164,7 +164,7 @@ func (this *TCPListener) Close() error {
return this.Listener.Close() return this.Listener.Close()
} }
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
if reverseProxy == nil { if reverseProxy == nil {
return nil, errors.New("no reverse proxy config") return nil, errors.New("no reverse proxy config")
} }
@@ -177,7 +177,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
} }
conn, err = OriginConnect(origin, remoteAddr) conn, err = OriginConnect(origin, remoteAddr)
if err != nil { if err != nil {
remotelogs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
continue continue
} else { } else {
return return

View File

@@ -55,7 +55,7 @@ func (this *UDPListener) Serve() error {
ok = false ok = false
} }
if !ok { if !ok {
originConn, err := this.connectOrigin(this.reverseProxy, addr) originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr)
if err != nil { if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error()) remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
continue continue
@@ -101,7 +101,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.reverseProxy = firstServer.ReverseProxy this.reverseProxy = firstServer.ReverseProxy
} }
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) { func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
if reverseProxy == nil { if reverseProxy == nil {
return nil, errors.New("no reverse proxy config") return nil, errors.New("no reverse proxy config")
} }
@@ -114,7 +114,7 @@ func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
} }
conn, err = OriginConnect(origin, remoteAddr.String()) conn, err = OriginConnect(origin, remoteAddr.String())
if err != nil { if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
continue continue
} else { } else {
// PROXY Protocol // PROXY Protocol

View File

@@ -1,9 +0,0 @@
package nodes
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
// 域名和服务映射
type NamedServer struct {
Name string // 匹配后的域名
Server *serverconfigs.ServerConfig // 匹配后的服务配置
}