mirror of
https://github.com/TeaOSLab/EdgeNode.git
synced 2025-11-30 04:10:25 +08:00
大幅提升域名匹配性能
This commit is contained in:
@@ -40,7 +40,7 @@ func (this *HTTPRequest) doReverseProxy() {
|
|||||||
requestCall.CallResponseCallbacks(this.writer)
|
requestCall.CallResponseCallbacks(this.writer)
|
||||||
if origin == nil {
|
if origin == nil {
|
||||||
err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy")
|
err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy")
|
||||||
remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error())
|
||||||
this.write50x(err, http.StatusBadGateway)
|
this.write50x(err, http.StatusBadGateway)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,17 +5,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs"
|
||||||
|
"github.com/TeaOSLab/EdgeNode/internal/remotelogs"
|
||||||
"github.com/iwind/TeaGo/lists"
|
"github.com/iwind/TeaGo/lists"
|
||||||
"github.com/iwind/TeaGo/types"
|
"github.com/iwind/TeaGo/types"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type BaseListener struct {
|
type BaseListener struct {
|
||||||
serversLocker sync.RWMutex
|
|
||||||
namedServersLocker sync.RWMutex
|
|
||||||
namedServers map[string]*NamedServer // 域名 => server
|
|
||||||
|
|
||||||
Group *serverconfigs.ServerAddressGroup
|
Group *serverconfigs.ServerAddressGroup
|
||||||
|
|
||||||
countActiveConnections int64 // 当前活跃的连接数
|
countActiveConnections int64 // 当前活跃的连接数
|
||||||
@@ -23,14 +19,11 @@ type BaseListener struct {
|
|||||||
|
|
||||||
// Init 初始化
|
// Init 初始化
|
||||||
func (this *BaseListener) Init() {
|
func (this *BaseListener) Init() {
|
||||||
this.namedServers = map[string]*NamedServer{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset 清除既有配置
|
// Reset 清除既有配置
|
||||||
func (this *BaseListener) Reset() {
|
func (this *BaseListener) Reset() {
|
||||||
this.namedServersLocker.Lock()
|
|
||||||
this.namedServers = map[string]*NamedServer{}
|
|
||||||
this.namedServersLocker.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountActiveListeners 获取当前活跃连接数
|
// CountActiveListeners 获取当前活跃连接数
|
||||||
@@ -67,7 +60,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if cert == nil {
|
if cert == nil {
|
||||||
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
|
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
|
||||||
}
|
}
|
||||||
return cert, nil
|
return cert, nil
|
||||||
},
|
},
|
||||||
@@ -83,7 +76,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if cert == nil {
|
if cert == nil {
|
||||||
return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'")
|
return nil, errors.New("no ssl certs found for '" + info.ServerName + "'")
|
||||||
}
|
}
|
||||||
return cert, nil
|
return cert, nil
|
||||||
},
|
},
|
||||||
@@ -92,9 +85,6 @@ func (this *BaseListener) buildTLSConfig() *tls.Config {
|
|||||||
|
|
||||||
// 根据域名匹配证书
|
// 根据域名匹配证书
|
||||||
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
|
func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) {
|
||||||
this.serversLocker.RLock()
|
|
||||||
defer this.serversLocker.RUnlock()
|
|
||||||
|
|
||||||
group := this.Group
|
group := this.Group
|
||||||
|
|
||||||
if group == nil {
|
if group == nil {
|
||||||
@@ -108,9 +98,9 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
|||||||
return nil, nil, errors.New("no tls server name matched")
|
return nil, nil, errors.New("no tls server name matched")
|
||||||
}
|
}
|
||||||
|
|
||||||
firstServer := group.FirstServer()
|
firstServer := group.FirstTLSServer()
|
||||||
if firstServer == nil {
|
if firstServer == nil {
|
||||||
return nil, nil, errors.New("no server available")
|
return nil, nil, errors.New("no tls server available")
|
||||||
}
|
}
|
||||||
sslConfig := firstServer.SSLPolicy()
|
sslConfig := firstServer.SSLPolicy()
|
||||||
|
|
||||||
@@ -119,14 +109,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
|||||||
|
|
||||||
}
|
}
|
||||||
return nil, nil, errors.New("no tls server name found")
|
return nil, nil, errors.New("no tls server name found")
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 通过代理服务域名配置匹配
|
// 通过代理服务域名配置匹配
|
||||||
server, _ := this.findNamedServer(domain)
|
server, _ := this.findNamedServer(domain)
|
||||||
if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||||
// 搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
// 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配
|
||||||
for _, server := range group.Servers {
|
// TODO 需要思考这种情况下是否允许访问
|
||||||
|
for _, server := range group.Servers() {
|
||||||
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -136,7 +126,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, nil, errors.New("[proxy]no server found for '" + domain + "'")
|
return nil, nil, errors.New("no server found for '" + domain + "'")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 证书是否匹配
|
// 证书是否匹配
|
||||||
@@ -146,6 +136,10 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
|
|||||||
return sslConfig, cert, nil
|
return sslConfig, cert, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(sslConfig.Certs) == 0 {
|
||||||
|
remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id))
|
||||||
|
}
|
||||||
|
|
||||||
return sslConfig, sslConfig.FirstCert(), nil
|
return sslConfig, sslConfig.FirstCert(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,11 +167,8 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 如果没有找到,则匹配到第一个
|
// 如果没有找到,则匹配到第一个
|
||||||
this.serversLocker.RLock()
|
|
||||||
defer this.serversLocker.RUnlock()
|
|
||||||
|
|
||||||
group := this.Group
|
group := this.Group
|
||||||
currentServers := group.Servers
|
currentServers := group.Servers()
|
||||||
countServers := len(currentServers)
|
countServers := len(currentServers)
|
||||||
if countServers == 0 {
|
if countServers == 0 {
|
||||||
return nil, ""
|
return nil, ""
|
||||||
@@ -192,71 +183,27 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser
|
|||||||
return nil, ""
|
return nil, ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// 读取缓存
|
server := group.MatchServerName(name)
|
||||||
this.namedServersLocker.RLock()
|
if server != nil {
|
||||||
namedServer, found := this.namedServers[name]
|
return server, name
|
||||||
if found {
|
|
||||||
this.namedServersLocker.RUnlock()
|
|
||||||
return namedServer.Server, namedServer.Name
|
|
||||||
}
|
}
|
||||||
this.namedServersLocker.RUnlock()
|
|
||||||
|
|
||||||
this.serversLocker.RLock()
|
|
||||||
defer this.serversLocker.RUnlock()
|
|
||||||
|
|
||||||
currentServers := group.Servers
|
|
||||||
countServers := len(currentServers)
|
|
||||||
if countServers == 0 {
|
|
||||||
return nil, ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// 只记录N个记录,防止内存耗尽
|
|
||||||
maxNamedServers := 100_0000
|
|
||||||
|
|
||||||
// 是否严格匹配域名
|
// 是否严格匹配域名
|
||||||
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
|
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
|
||||||
|
|
||||||
// 如果只有一个server,则默认为这个
|
// 如果只有一个server,则默认为这个
|
||||||
|
var currentServers = group.Servers()
|
||||||
|
var countServers = len(currentServers)
|
||||||
if countServers == 1 && !matchDomainStrictly {
|
if countServers == 1 && !matchDomainStrictly {
|
||||||
return currentServers[0], name
|
return currentServers[0], name
|
||||||
}
|
}
|
||||||
|
|
||||||
// 精确查找
|
|
||||||
for _, server := range currentServers {
|
|
||||||
if server.MatchNameStrictly(name) {
|
|
||||||
this.namedServersLocker.Lock()
|
|
||||||
if len(this.namedServers) < maxNamedServers {
|
|
||||||
this.namedServers[name] = &NamedServer{
|
|
||||||
Name: name,
|
|
||||||
Server: server,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.namedServersLocker.Unlock()
|
|
||||||
return server, name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 模糊查找
|
|
||||||
for _, server := range currentServers {
|
|
||||||
if matched := server.MatchName(name); matched {
|
|
||||||
this.namedServersLocker.Lock()
|
|
||||||
if len(this.namedServers) < maxNamedServers {
|
|
||||||
this.namedServers[name] = &NamedServer{
|
|
||||||
Name: name,
|
|
||||||
Server: server,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
this.namedServersLocker.Unlock()
|
|
||||||
return server, name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, name
|
return nil, name
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用CNAME来查找服务
|
// 使用CNAME来查找服务
|
||||||
// TODO 防止单IP随机生成域名攻击
|
// TODO 防止单IP随机生成域名攻击
|
||||||
func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.ServerConfig {
|
func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig {
|
||||||
if !sharedNodeConfig.SupportCNAME {
|
if !sharedNodeConfig.SupportCNAME {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -266,25 +213,10 @@ func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.Serv
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
this.serversLocker.Lock()
|
|
||||||
defer this.serversLocker.Unlock()
|
|
||||||
|
|
||||||
group := this.Group
|
group := this.Group
|
||||||
if group == nil {
|
if group == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
currentServers := group.Servers
|
return group.MatchServerCNAME(realName)
|
||||||
countServers := len(currentServers)
|
|
||||||
if countServers == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, server := range currentServers {
|
|
||||||
if server.SupportCNAME && lists.ContainsString(server.AliasServerNames, realName) {
|
|
||||||
return server
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
36
internal/nodes/listener_base_test.go
Normal file
36
internal/nodes/listener_base_test.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||||
|
|
||||||
|
package nodes
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
|
||||||
|
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
||||||
|
"github.com/iwind/TeaGo/types"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBaseListener_FindServer(t *testing.T) {
|
||||||
|
sharedNodeConfig = &nodeconfigs.NodeConfig{}
|
||||||
|
|
||||||
|
var listener = &BaseListener{namedServers: map[string]*NamedServer{}}
|
||||||
|
listener.Group = &serverconfigs.ServerAddressGroup{}
|
||||||
|
for i := 0; i < 1_000_000; i++ {
|
||||||
|
var server = &serverconfigs.ServerConfig{
|
||||||
|
IsOn: true,
|
||||||
|
Name: types.String(i) + ".hello.com",
|
||||||
|
ServerNames: []*serverconfigs.ServerNameConfig{
|
||||||
|
{Name: types.String(i) + ".hello.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_ = server.Init()
|
||||||
|
listener.Group.Servers = append(listener.Group.Servers, server)
|
||||||
|
}
|
||||||
|
|
||||||
|
var before = time.Now()
|
||||||
|
defer func() {
|
||||||
|
t.Log(time.Since(before).Seconds()*1000, "ms")
|
||||||
|
}()
|
||||||
|
|
||||||
|
t.Log(listener.findNamedServerMatched("855555.hello.com"))
|
||||||
|
}
|
||||||
@@ -152,7 +152,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http.
|
|||||||
|
|
||||||
server, serverName := this.findNamedServer(domain)
|
server, serverName := this.findNamedServer(domain)
|
||||||
if server == nil {
|
if server == nil {
|
||||||
server = this.findServerWithCname(domain)
|
server = this.findServerWithCNAME(domain)
|
||||||
if server == nil {
|
if server == nil {
|
||||||
// 严格匹配域名模式下,我们拒绝用户访问
|
// 严格匹配域名模式下,我们拒绝用户访问
|
||||||
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error {
|
|||||||
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId())
|
||||||
}
|
}
|
||||||
|
|
||||||
originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String())
|
originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -164,7 +164,7 @@ func (this *TCPListener) Close() error {
|
|||||||
return this.Listener.Close()
|
return this.Listener.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) {
|
||||||
if reverseProxy == nil {
|
if reverseProxy == nil {
|
||||||
return nil, errors.New("no reverse proxy config")
|
return nil, errors.New("no reverse proxy config")
|
||||||
}
|
}
|
||||||
@@ -177,7 +177,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
|||||||
}
|
}
|
||||||
conn, err = OriginConnect(origin, remoteAddr)
|
conn, err = OriginConnect(origin, remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
remotelogs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (this *UDPListener) Serve() error {
|
|||||||
ok = false
|
ok = false
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
originConn, err := this.connectOrigin(this.reverseProxy, addr)
|
originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error())
|
||||||
continue
|
continue
|
||||||
@@ -101,7 +101,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) {
|
|||||||
this.reverseProxy = firstServer.ReverseProxy
|
this.reverseProxy = firstServer.ReverseProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
|
func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) {
|
||||||
if reverseProxy == nil {
|
if reverseProxy == nil {
|
||||||
return nil, errors.New("no reverse proxy config")
|
return nil, errors.New("no reverse proxy config")
|
||||||
}
|
}
|
||||||
@@ -114,7 +114,7 @@ func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC
|
|||||||
}
|
}
|
||||||
conn, err = OriginConnect(origin, remoteAddr.String())
|
conn, err = OriginConnect(origin, remoteAddr.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error())
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
// PROXY Protocol
|
// PROXY Protocol
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
package nodes
|
|
||||||
|
|
||||||
import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
|
|
||||||
|
|
||||||
// 域名和服务映射
|
|
||||||
type NamedServer struct {
|
|
||||||
Name string // 匹配后的域名
|
|
||||||
Server *serverconfigs.ServerConfig // 匹配后的服务配置
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user