From dec5ba89549cf29e67b75b04a8f3add98333e1b2 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Mon, 15 Nov 2021 16:57:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=A7=E5=B9=85=E6=8F=90=E5=8D=87=E5=9F=9F?= =?UTF-8?q?=E5=90=8D=E5=8C=B9=E9=85=8D=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/nodes/http_request_reverse_proxy.go | 2 +- internal/nodes/listener_base.go | 112 ++++--------------- internal/nodes/listener_base_test.go | 36 ++++++ internal/nodes/listener_http.go | 2 +- internal/nodes/listener_tcp.go | 6 +- internal/nodes/listener_udp.go | 6 +- internal/nodes/named_server.go | 9 -- 7 files changed, 66 insertions(+), 107 deletions(-) create mode 100644 internal/nodes/listener_base_test.go delete mode 100644 internal/nodes/named_server.go diff --git a/internal/nodes/http_request_reverse_proxy.go b/internal/nodes/http_request_reverse_proxy.go index 5115db7..d467694 100644 --- a/internal/nodes/http_request_reverse_proxy.go +++ b/internal/nodes/http_request_reverse_proxy.go @@ -40,7 +40,7 @@ func (this *HTTPRequest) doReverseProxy() { requestCall.CallResponseCallbacks(this.writer) if origin == nil { err := errors.New(this.requestFullURL() + ": no available origin sites for reverse proxy") - remotelogs.Error("HTTP_REQUEST_REVERSE_PROXY", err.Error()) + remotelogs.ServerError(this.Server.Id, "HTTP_REQUEST_REVERSE_PROXY", err.Error()) this.write50x(err, http.StatusBadGateway) return } diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index 3b026fd..2336154 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -5,17 +5,13 @@ import ( "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" + "github.com/TeaOSLab/EdgeNode/internal/remotelogs" "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" "golang.org/x/net/http2" - "sync" ) type BaseListener struct { - serversLocker sync.RWMutex - namedServersLocker sync.RWMutex - namedServers map[string]*NamedServer // 域名 => server - Group *serverconfigs.ServerAddressGroup countActiveConnections int64 // 当前活跃的连接数 @@ -23,14 +19,11 @@ type BaseListener struct { // Init 初始化 func (this *BaseListener) Init() { - this.namedServers = map[string]*NamedServer{} } // Reset 清除既有配置 func (this *BaseListener) Reset() { - this.namedServersLocker.Lock() - this.namedServers = map[string]*NamedServer{} - this.namedServersLocker.Unlock() + } // CountActiveListeners 获取当前活跃连接数 @@ -67,7 +60,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { return nil, err } if cert == nil { - return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'") + return nil, errors.New("no ssl certs found for '" + info.ServerName + "'") } return cert, nil }, @@ -83,7 +76,7 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { return nil, err } if cert == nil { - return nil, errors.New("[proxy]no certs found for '" + info.ServerName + "'") + return nil, errors.New("no ssl certs found for '" + info.ServerName + "'") } return cert, nil }, @@ -92,9 +85,6 @@ func (this *BaseListener) buildTLSConfig() *tls.Config { // 根据域名匹配证书 func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) { - this.serversLocker.RLock() - defer this.serversLocker.RUnlock() - group := this.Group if group == nil { @@ -108,9 +98,9 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C return nil, nil, errors.New("no tls server name matched") } - firstServer := group.FirstServer() + firstServer := group.FirstTLSServer() if firstServer == nil { - return nil, nil, errors.New("no server available") + return nil, nil, errors.New("no tls server available") } sslConfig := firstServer.SSLPolicy() @@ -119,14 +109,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C } return nil, nil, errors.New("no tls server name found") - } // 通过代理服务域名配置匹配 server, _ := this.findNamedServer(domain) if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { - // 搜索所有的Server,通过SSL证书内容中的DNSName匹配 - for _, server := range group.Servers { + // 找不到或者此时的服务没有配置证书,需要搜索所有的Server,通过SSL证书内容中的DNSName匹配 + // TODO 需要思考这种情况下是否允许访问 + for _, server := range group.Servers() { if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { continue } @@ -136,7 +126,7 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C } } - return nil, nil, errors.New("[proxy]no server found for '" + domain + "'") + return nil, nil, errors.New("no server found for '" + domain + "'") } // 证书是否匹配 @@ -146,6 +136,10 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C return sslConfig, cert, nil } + if len(sslConfig.Certs) == 0 { + remotelogs.ServerError(server.Id, "BASE_LISTENER", "no ssl certs found for '"+domain+"', server id: "+types.String(server.Id)) + } + return sslConfig, sslConfig.FirstCert(), nil } @@ -173,11 +167,8 @@ func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconf } // 如果没有找到,则匹配到第一个 - this.serversLocker.RLock() - defer this.serversLocker.RUnlock() - group := this.Group - currentServers := group.Servers + currentServers := group.Servers() countServers := len(currentServers) if countServers == 0 { return nil, "" @@ -192,71 +183,27 @@ func (this *BaseListener) findNamedServerMatched(name string) (serverConfig *ser return nil, "" } - // 读取缓存 - this.namedServersLocker.RLock() - namedServer, found := this.namedServers[name] - if found { - this.namedServersLocker.RUnlock() - return namedServer.Server, namedServer.Name + server := group.MatchServerName(name) + if server != nil { + return server, name } - this.namedServersLocker.RUnlock() - - this.serversLocker.RLock() - defer this.serversLocker.RUnlock() - - currentServers := group.Servers - countServers := len(currentServers) - if countServers == 0 { - return nil, "" - } - - // 只记录N个记录,防止内存耗尽 - maxNamedServers := 100_0000 // 是否严格匹配域名 matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly // 如果只有一个server,则默认为这个 + var currentServers = group.Servers() + var countServers = len(currentServers) if countServers == 1 && !matchDomainStrictly { return currentServers[0], name } - // 精确查找 - for _, server := range currentServers { - if server.MatchNameStrictly(name) { - this.namedServersLocker.Lock() - if len(this.namedServers) < maxNamedServers { - this.namedServers[name] = &NamedServer{ - Name: name, - Server: server, - } - } - this.namedServersLocker.Unlock() - return server, name - } - } - - // 模糊查找 - for _, server := range currentServers { - if matched := server.MatchName(name); matched { - this.namedServersLocker.Lock() - if len(this.namedServers) < maxNamedServers { - this.namedServers[name] = &NamedServer{ - Name: name, - Server: server, - } - } - this.namedServersLocker.Unlock() - return server, name - } - } - return nil, name } // 使用CNAME来查找服务 // TODO 防止单IP随机生成域名攻击 -func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.ServerConfig { +func (this *BaseListener) findServerWithCNAME(domain string) *serverconfigs.ServerConfig { if !sharedNodeConfig.SupportCNAME { return nil } @@ -266,25 +213,10 @@ func (this *BaseListener) findServerWithCname(domain string) *serverconfigs.Serv return nil } - this.serversLocker.Lock() - defer this.serversLocker.Unlock() - group := this.Group if group == nil { return nil } - currentServers := group.Servers - countServers := len(currentServers) - if countServers == 0 { - return nil - } - - for _, server := range currentServers { - if server.SupportCNAME && lists.ContainsString(server.AliasServerNames, realName) { - return server - } - } - - return nil + return group.MatchServerCNAME(realName) } diff --git a/internal/nodes/listener_base_test.go b/internal/nodes/listener_base_test.go new file mode 100644 index 0000000..2b7fe1a --- /dev/null +++ b/internal/nodes/listener_base_test.go @@ -0,0 +1,36 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package nodes + +import ( + "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" + "github.com/iwind/TeaGo/types" + "testing" + "time" +) + +func TestBaseListener_FindServer(t *testing.T) { + sharedNodeConfig = &nodeconfigs.NodeConfig{} + + var listener = &BaseListener{namedServers: map[string]*NamedServer{}} + listener.Group = &serverconfigs.ServerAddressGroup{} + for i := 0; i < 1_000_000; i++ { + var server = &serverconfigs.ServerConfig{ + IsOn: true, + Name: types.String(i) + ".hello.com", + ServerNames: []*serverconfigs.ServerNameConfig{ + {Name: types.String(i) + ".hello.com"}, + }, + } + _ = server.Init() + listener.Group.Servers = append(listener.Group.Servers, server) + } + + var before = time.Now() + defer func() { + t.Log(time.Since(before).Seconds()*1000, "ms") + }() + + t.Log(listener.findNamedServerMatched("855555.hello.com")) +} diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index bfef438..2b00e03 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -152,7 +152,7 @@ func (this *HTTPListener) ServeHTTP(rawWriter http.ResponseWriter, rawReq *http. server, serverName := this.findNamedServer(domain) if server == nil { - server = this.findServerWithCname(domain) + server = this.findServerWithCNAME(domain) if server == nil { // 严格匹配域名模式下,我们拒绝用户访问 if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index f90d50a..32f6538 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -75,7 +75,7 @@ func (this *TCPListener) handleConn(conn net.Conn) error { stats.SharedTrafficStatManager.Add(firstServer.Id, "", 0, 0, 1, 0, 0, 0, firstServer.ShouldCheckTrafficLimit(), firstServer.PlanId()) } - originConn, err := this.connectOrigin(firstServer.ReverseProxy, conn.RemoteAddr().String()) + originConn, err := this.connectOrigin(firstServer.Id, firstServer.ReverseProxy, conn.RemoteAddr().String()) if err != nil { return err } @@ -164,7 +164,7 @@ func (this *TCPListener) Close() error { return this.Listener.Close() } -func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { +func (this *TCPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr string) (conn net.Conn, err error) { if reverseProxy == nil { return nil, errors.New("no reverse proxy config") } @@ -177,7 +177,7 @@ func (this *TCPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC } conn, err = OriginConnect(origin, remoteAddr) if err != nil { - remotelogs.Error("TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) + remotelogs.ServerError(serverId, "TCP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) continue } else { return diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index 6cd5b00..2292277 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -55,7 +55,7 @@ func (this *UDPListener) Serve() error { ok = false } if !ok { - originConn, err := this.connectOrigin(this.reverseProxy, addr) + originConn, err := this.connectOrigin(firstServer.Id, this.reverseProxy, addr) if err != nil { remotelogs.Error("UDP_LISTENER", "unable to connect to origin server: "+err.Error()) continue @@ -101,7 +101,7 @@ func (this *UDPListener) Reload(group *serverconfigs.ServerAddressGroup) { this.reverseProxy = firstServer.ReverseProxy } -func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) { +func (this *UDPListener) connectOrigin(serverId int64, reverseProxy *serverconfigs.ReverseProxyConfig, remoteAddr net.Addr) (conn net.Conn, err error) { if reverseProxy == nil { return nil, errors.New("no reverse proxy config") } @@ -114,7 +114,7 @@ func (this *UDPListener) connectOrigin(reverseProxy *serverconfigs.ReverseProxyC } conn, err = OriginConnect(origin, remoteAddr.String()) if err != nil { - remotelogs.Error("UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) + remotelogs.ServerError(serverId, "UDP_LISTENER", "unable to connect origin: "+origin.Addr.Host+":"+origin.Addr.PortRange+": "+err.Error()) continue } else { // PROXY Protocol diff --git a/internal/nodes/named_server.go b/internal/nodes/named_server.go deleted file mode 100644 index 18c894c..0000000 --- a/internal/nodes/named_server.go +++ /dev/null @@ -1,9 +0,0 @@ -package nodes - -import "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" - -// 域名和服务映射 -type NamedServer struct { - Name string // 匹配后的域名 - Server *serverconfigs.ServerConfig // 匹配后的服务配置 -}