mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-30 04:10:25 +08:00
大幅提升域名匹配性能
This commit is contained in:
@@ -40,7 +40,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
||||
requestCall.CallResponseCallbacks(this.writer)
|
||||
if origin == nil {
|
||||
err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy")
|
||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||
this.write50x(err, http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -5,17 +5,13 @@ import (
|
||||
"errors"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"golang.org/x/net/http2"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type BaseListener struct {
|
||||
serversLocker sync.RWMutex
|
||||
namedServersLocker sync.RWMutex
|
||||
namedServers map[string]*NamedServer // 域名 => server
|
||||
|
||||
Group *serverconfigs.ServerAddressGroup
|
||||
|
||||
countActiveConnections int64 // 当前活跃的连接数
|
||||
@@ -23,14 +19,11 @@ type BaseListener struct {
|
||||
|
||||
// Init 初始化
|
||||
func (this *BaseListener) Init() {
|
||||
this.namedServers = map[string]*NamedServer{}
|
||||
}
|
||||
|
||||
// Reset 清除既有配置
|
||||
func (this *BaseListener) Reset() {
|
||||
this.namedServersLocker.Lock()
|
||||
this.namedServers = map[string]*NamedServer{}
|
||||
this.namedServersLocker.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// CountActiveListeners 获取当前活跃连接数
|
||||
@@ -67,7 +60,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
|
||||
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
|
||||
}
|
||||
return cert, nil
|
||||
},
|
||||
@@ -83,7 +76,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
return nil, err
|
||||
}
|
||||
if cert == nil {
|
||||
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
|
||||
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
|
||||
}
|
||||
return cert, nil
|
||||
},
|
||||
@@ -92,9 +85,6 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
||||
|
||||
// 根据域名匹配证书
|
||||
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
|
||||
this.serversLocker.RLock()
|
||||
defer this.serversLocker.RUnlock()
|
||||
|
||||
group := this.Group
|
||||
|
||||
if group == nil {
|
||||
@@ -108,9 +98,9 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
return nil, nil, errors.New("no tls server name matched")
|
||||
}
|
||||
|
||||
firstServer := group.FirstServer()
|
||||
firstServer := group.FirstTLSServer()
|
||||
if firstServer == nil {
|
||||
return nil, nil, errors.New("no server available")
|
||||
return nil, nil, errors.New("no tls server available")
|
||||
}
|
||||
sslConfig := firstServer.SSLPolicy()
|
||||
|
||||
@@ -119,14 +109,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
|
||||
}
|
||||
return nil, nil, errors.New("no tls server name found")
|
||||
|
||||
}
|
||||
|
||||
// 通过代理服务域名配置匹配
|
||||
server, _ := this.findNamedServer(domain)
|
||||
if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||
// 搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
for _, server := range group.Servers {
|
||||
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||
// TODO 需要思考这种情况下是否允许访问
|
||||
for _, server := range group.Servers() {
|
||||
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||
continue
|
||||
}
|
||||
@@ -136,7 +126,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil, errors.New("[proxy]no server found for '" + domain + "'")
|
||||
return nil, nil, errors.New("no server found for '" + domain + "'")
|
||||
}
|
||||
|
||||
// 证书是否匹配
|
||||
@@ -146,6 +136,10 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
||||
return sslConfig, cert, nil
|
||||
}
|
||||
|
||||
if len(sslConfig.Certs) == 0 {
|
||||
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id))
|
||||
}
|
||||
|
||||
return sslConfig, sslConfig.FirstCert(), nil
|
||||
}
|
||||
|
||||
@@ -173,11 +167,8 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
|
||||
}
|
||||
|
||||
// 如果没有找到,则匹配到第一个
|
||||
this.serversLocker.RLock()
|
||||
defer this.serversLocker.RUnlock()
|
||||
|
||||
group := this.Group
|
||||
currentServers := group.Servers
|
||||
currentServers := group.Servers()
|
||||
countServers := len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil, ""
|
||||
@@ -192,71 +183,27 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// 读取缓存
|
||||
this.namedServersLocker.RLock()
|
||||
namedServer, found := this.namedServers[name]
|
||||
if found {
|
||||
this.namedServersLocker.RUnlock()
|
||||
return namedServer.Server, namedServer.Name
|
||||
server := group.MatchServerName(name)
|
||||
if server != nil {
|
||||
return server, name
|
||||
}
|
||||
this.namedServersLocker.RUnlock()
|
||||
|
||||
this.serversLocker.RLock()
|
||||
defer this.serversLocker.RUnlock()
|
||||
|
||||
currentServers := group.Servers
|
||||
countServers := len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
// 只记录N个记录,防止内存耗尽
|
||||
maxNamedServers := 100_0000
|
||||
|
||||
// 是否严格匹配域名
|
||||
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
|
||||
|
||||
// 如果只有一个server,则默认为这个
|
||||
var currentServers = group.Servers()
|
||||
var countServers = len(currentServers)
|
||||
if countServers == 1 && !matchDomainStrictly {
|
||||
return currentServers[0], name
|
||||
}
|
||||
|
||||
// 精确查找
|
||||
for _, server := range currentServers {
|
||||
if server.MatchNameStrictly(name) {
|
||||
this.namedServersLocker.Lock()
|
||||
if len(this.namedServers) < maxNamedServers {
|
||||
this.namedServers[name] = &NamedServer{
|
||||
Name: name,
|
||||
Server: server,
|
||||
}
|
||||
}
|
||||
this.namedServersLocker.Unlock()
|
||||
return server, name
|
||||
}
|
||||
}
|
||||
|
||||
// 模糊查找
|
||||
for _, server := range currentServers {
|
||||
if matched := server.MatchName(name); matched {
|
||||
this.namedServersLocker.Lock()
|
||||
if len(this.namedServers) < maxNamedServers {
|
||||
this.namedServers[name] = &NamedServer{
|
||||
Name: name,
|
||||
Server: server,
|
||||
}
|
||||
}
|
||||
this.namedServersLocker.Unlock()
|
||||
return server, name
|
||||
}
|
||||
}
|
||||
|
||||
return nil, name
|
||||
}
|
||||
|
||||
// 使用CNAME来查找服务
|
||||
// TODO 防止单IP随机生成域名攻击
|
||||
func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.ServerConfig {
|
||||
func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig {
|
||||
if !sharedNodeConfig.SupportCNAME {
|
||||
return nil
|
||||
}
|
||||
@@ -266,25 +213,10 @@ func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.Serv
|
||||
return nil
|
||||
}
|
||||
|
||||
this.serversLocker.Lock()
|
||||
defer this.serversLocker.Unlock()
|
||||
|
||||
group := this.Group
|
||||
if group == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
currentServers := group.Servers
|
||||
countServers := len(currentServers)
|
||||
if countServers == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, server := range currentServers {
|
||||
if server.SupportCNAME && lists.ContainsString(server.AliasServerNames, realName) {
|
||||
return server
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return group.MatchServerCNAME(realName)
|
||||
}
|
||||
|
||||
36
internal/nodes/listener_base_test.go
Normal file
36
internal/nodes/listener_base_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package nodes
|
||||
|
||||
import (
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBaseListener_FindServer(t *testing.T) {
|
||||
sharedNodeConfig = &nodeconfigs.NodeConfig{}
|
||||
|
||||
var listener = &BaseListener{namedServers: map[string]*NamedServer{}}
|
||||
listener.Group = &serverconfigs.ServerAddressGroup{}
|
||||
for i := 0; i < 1_000_000; i++ {
|
||||
var server = &serverconfigs.ServerConfig{
|
||||
IsOn: true,
|
||||
Name: types.String(i) + ".hello.com",
|
||||
ServerNames: []*serverconfigs.ServerNameConfig{
|
||||
{Name: types.String(i) + ".hello.com"},
|
||||
},
|
||||
}
|
||||
_ = server.Init()
|
||||
listener.Group.Servers = append(listener.Group.Servers, server)
|
||||
}
|
||||
|
||||
var before = time.Now()
|
||||
defer func() {
|
||||
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||
}()
|
||||
|
||||
t.Log(listener.findNamedServerMatched("855555.hello.com"))
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
||||
|
||||
server, serverName := this.findNamedServer(domain)
|
||||
if server == nil {
|
||||
server = this.findServerWithCname(domain)
|
||||
server = this.findServerWithCNAME(domain)
|
||||
if server == nil {
|
||||
// 严格匹配域名模式下,我们拒绝用户访问
|
||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||
|
||||
@@ -75,7 +75,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||
}
|
||||
|
||||
originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String())
|
||||
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -164,7 +164,7 @@ func (this *TCPListener) Close() error {
|
||||
return this.Listener.Close()
|
||||
}
|
||||
|
||||
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
||||
}
|
||||
conn, err = OriginConnect(origin, remoteAddr)
|
||||
if err != nil {
|
||||
remotelogs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
continue
|
||||
} else {
|
||||
return
|
||||
|
||||
@@ -55,7 +55,7 @@ func (this *UDPListener) Serve() error {
|
||||
ok = false
|
||||
}
|
||||
if !ok {
|
||||
originConn, err := this.connectOrigin(this.reverseProxy, addr)
|
||||
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr)
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
||||
continue
|
||||
@@ -101,7 +101,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
||||
this.reverseProxy = firstServer.ReverseProxy
|
||||
}
|
||||
|
||||
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||
if reverseProxy == nil {
|
||||
return nil, errors.New("no reverse proxy config")
|
||||
}
|
||||
@@ -114,7 +114,7 @@ func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
||||
}
|
||||
conn, err = OriginConnect(origin, remoteAddr.String())
|
||||
if err != nil {
|
||||
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||
continue
|
||||
} else {
|
||||
// PROXY Protocol
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
package nodes
|
||||
|
||||
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||
|
||||
// 域名和服务映射
|
||||
type NamedServer struct {
|
||||
Name string // 匹配后的域名
|
||||
Server *serverconfigs.ServerConfig // 匹配后的服务配置
|
||||
}
|
||||
Reference in New Issue
Block a user