大幅提升域名匹配性能

This commit is contained in:
刘祥超
2021-11-15 16:57:18 +08:00
parent 5a6ead1dd7
commit 04271d77c2
7 changed files with 66 additions and 107 deletions

View File

@@ -40,7 +40,7 @@ func (this *HTTPRequest) doReverseProxy() {
requestCall.CallResponseCallbacks(this.writer)
if origin == nil {
err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy")
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error())
this.write50x(err, http.StatusBadGateway)
return
}

View File

@@ -5,17 +5,13 @@ import (
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
"github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/types"
"golang.org/x/net/http2"
"sync"
)
type BaseListener struct {
serversLocker sync.RWMutex
namedServersLocker sync.RWMutex
namedServers map[string]*NamedServer // 域名 => server
Group *serverconfigs.ServerAddressGroup
countActiveConnections int64 // 当前活跃的连接数
@@ -23,14 +19,11 @@ type BaseListener struct {
// Init 初始化
func (this *BaseListener) Init() {
this.namedServers = map[string]*NamedServer{}
}
// Reset 清除既有配置
func (this *BaseListener) Reset() {
this.namedServersLocker.Lock()
this.namedServers = map[string]*NamedServer{}
this.namedServersLocker.Unlock()
}
// CountActiveListeners 获取当前活跃连接数
@@ -67,7 +60,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return nil, err
}
if cert == nil {
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
}
return cert, nil
},
@@ -83,7 +76,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
return nil, err
}
if cert == nil {
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
}
return cert, nil
},
@@ -92,9 +85,6 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
// 根据域名匹配证书
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
this.serversLocker.RLock()
defer this.serversLocker.RUnlock()
group := this.Group
if group == nil {
@@ -108,9 +98,9 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
return nil, nil, errors.New("no tls server name matched")
}
firstServer := group.FirstServer()
firstServer := group.FirstTLSServer()
if firstServer == nil {
return nil, nil, errors.New("no server available")
return nil, nil, errors.New("no tls server available")
}
sslConfig := firstServer.SSLPolicy()
@@ -119,14 +109,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
}
return nil, nil, errors.New("no tls server name found")
}
// 通过代理服务域名配置匹配
server, _ := this.findNamedServer(domain)
if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
// 搜索所有的Server通过SSL证书内容中的DNSName匹配
for _, server := range group.Servers {
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server通过SSL证书内容中的DNSName匹配
// TODO 需要思考这种情况下是否允许访问
for _, server := range group.Servers() {
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
continue
}
@@ -136,7 +126,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
}
}
return nil, nil, errors.New("[proxy]no server found for '" + domain + "'")
return nil, nil, errors.New("no server found for '" + domain + "'")
}
// 证书是否匹配
@@ -146,6 +136,10 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
return sslConfig, cert, nil
}
if len(sslConfig.Certs) == 0 {
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id))
}
return sslConfig, sslConfig.FirstCert(), nil
}
@@ -173,11 +167,8 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
}
// 如果没有找到,则匹配到第一个
this.serversLocker.RLock()
defer this.serversLocker.RUnlock()
group := this.Group
currentServers := group.Servers
currentServers := group.Servers()
countServers := len(currentServers)
if countServers == 0 {
return nil, ""
@@ -192,71 +183,27 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
return nil, ""
}
// 读取缓存
this.namedServersLocker.RLock()
namedServer, found := this.namedServers[name]
if found {
this.namedServersLocker.RUnlock()
return namedServer.Server, namedServer.Name
server := group.MatchServerName(name)
if server != nil {
return server, name
}
this.namedServersLocker.RUnlock()
this.serversLocker.RLock()
defer this.serversLocker.RUnlock()
currentServers := group.Servers
countServers := len(currentServers)
if countServers == 0 {
return nil, ""
}
// 只记录N个记录防止内存耗尽
maxNamedServers := 100_0000
// 是否严格匹配域名
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
// 如果只有一个server则默认为这个
var currentServers = group.Servers()
var countServers = len(currentServers)
if countServers == 1 && !matchDomainStrictly {
return currentServers[0], name
}
// 精确查找
for _, server := range currentServers {
if server.MatchNameStrictly(name) {
this.namedServersLocker.Lock()
if len(this.namedServers) < maxNamedServers {
this.namedServers[name] = &NamedServer{
Name: name,
Server: server,
}
}
this.namedServersLocker.Unlock()
return server, name
}
}
// 模糊查找
for _, server := range currentServers {
if matched := server.MatchName(name); matched {
this.namedServersLocker.Lock()
if len(this.namedServers) < maxNamedServers {
this.namedServers[name] = &NamedServer{
Name: name,
Server: server,
}
}
this.namedServersLocker.Unlock()
return server, name
}
}
return nil, name
}
// 使用CNAME来查找服务
// TODO 防止单IP随机生成域名攻击
func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.ServerConfig {
func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig {
if !sharedNodeConfig.SupportCNAME {
return nil
}
@@ -266,25 +213,10 @@ func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.Serv
return nil
}
this.serversLocker.Lock()
defer this.serversLocker.Unlock()
group := this.Group
if group == nil {
return nil
}
currentServers := group.Servers
countServers := len(currentServers)
if countServers == 0 {
return nil
}
for _, server := range currentServers {
if server.SupportCNAME && lists.ContainsString(server.AliasServerNames, realName) {
return server
}
}
return nil
return group.MatchServerCNAME(realName)
}

View File

@@ -0,0 +1,36 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package nodes
import (
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/types"
"testing"
"time"
)
func TestBaseListener_FindServer(t *testing.T) {
sharedNodeConfig = &nodeconfigs.NodeConfig{}
var listener = &BaseListener{namedServers: map[string]*NamedServer{}}
listener.Group = &serverconfigs.ServerAddressGroup{}
for i := 0; i < 1_000_000; i++ {
var server = &serverconfigs.ServerConfig{
IsOn: true,
Name: types.String(i) + ".hello.com",
ServerNames: []*serverconfigs.ServerNameConfig{
{Name: types.String(i) + ".hello.com"},
},
}
_ = server.Init()
listener.Group.Servers = append(listener.Group.Servers, server)
}
var before = time.Now()
defer func() {
t.Log(time.Since(before).Seconds()*1000, "ms")
}()
t.Log(listener.findNamedServerMatched("855555.hello.com"))
}

View File

@@ -152,7 +152,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
server, serverName := this.findNamedServer(domain)
if server == nil {
server = this.findServerWithCname(domain)
server = this.findServerWithCNAME(domain)
if server == nil {
// 严格匹配域名模式下,我们拒绝用户访问
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {

View File

@@ -75,7 +75,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
}
originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String())
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
if err != nil {
return err
}
@@ -164,7 +164,7 @@ func (this *TCPListener) Close() error {
return this.Listener.Close()
}
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
if reverseProxy == nil {
return nil, errors.New("no reverse proxy config")
}
@@ -177,7 +177,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
}
conn, err = OriginConnect(origin, remoteAddr)
if err != nil {
remotelogs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
continue
} else {
return

View File

@@ -55,7 +55,7 @@ func (this *UDPListener) Serve() error {
ok = false
}
if !ok {
originConn, err := this.connectOrigin(this.reverseProxy, addr)
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr)
if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
continue
@@ -101,7 +101,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
this.reverseProxy = firstServer.ReverseProxy
}
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
if reverseProxy == nil {
return nil, errors.New("no reverse proxy config")
}
@@ -114,7 +114,7 @@ func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
}
conn, err = OriginConnect(origin, remoteAddr.String())
if err != nil {
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
continue
} else {
// PROXY Protocol

View File

@@ -1,9 +0,0 @@
package nodes
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
// 域名和服务映射
type NamedServer struct {
Name string // 匹配后的域名
Server *serverconfigs.ServerConfig // 匹配后的服务配置
}