diff --git a/internal/nodes/listener.go b/internal/nodes/listener.go index cc680ff..147b6ce 100644 --- a/internal/nodes/listener.go +++ b/internal/nodes/listener.go @@ -57,27 +57,27 @@ func (this *Listener) Listen() error { case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6: this.listener = &HTTPListener{ BaseListener: BaseListener{Group: this.group}, - Listener: netListener, + Listener: netListener, } case serverconfigs.ProtocolTCP, serverconfigs.ProtocolTCP4, serverconfigs.ProtocolTCP6: this.listener = &TCPListener{ - Group: this.group, - Listener: netListener, + BaseListener: BaseListener{Group: this.group}, + Listener: netListener, } case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6: this.listener = &TCPListener{ - Group: this.group, - Listener: netListener, + BaseListener: BaseListener{Group: this.group}, + Listener: netListener, } case serverconfigs.ProtocolUnix: this.listener = &UnixListener{ - Group: this.group, - Listener: netListener, + BaseListener: BaseListener{Group: this.group}, + Listener: netListener, } case serverconfigs.ProtocolUDP: this.listener = &UDPListener{ - Group: this.group, - Listener: netListener, + BaseListener: BaseListener{Group: this.group}, + Listener: netListener, } default: return errors.New("unknown protocol '" + protocol.String() + "'") diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index aec7f9d..96b951b 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -88,10 +88,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C group := this.Group + if group == nil { + return nil, nil, errors.New("no configure found") + } + // 如果域名为空,则取第一个 // 通常域名为空是因为是直接通过IP访问的 if len(domain) == 0 { - if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { + if group.IsHTTPS() && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { return nil, nil, errors.New("no tls server name matched") } @@ -106,10 +110,11 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C } return nil, nil, errors.New("no tls server name found") + } // 通过代理服务域名配置匹配 - server, _ := this.findNamedServer(group, domain) + server, _ := this.findNamedServer(domain) if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { // 搜索所有的Server,通过SSL证书内容中的DNSName匹配 for _, server := range group.Servers { @@ -136,7 +141,12 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C } // 根据域名来查找匹配的域名 -func (this *BaseListener) findNamedServer(group *serverconfigs.ServerGroup, name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { +func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { + group := this.Group + if group == nil { + return nil, "" + } + // 读取缓存 this.namedServersLocker.RLock() namedServer, found := this.namedServers[name] @@ -159,7 +169,7 @@ func (this *BaseListener) findNamedServer(group *serverconfigs.ServerGroup, name maxNamedServers := 10240 // 是否严格匹配域名 - matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly + matchDomainStrictly := group.IsHTTPS() && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly // 如果只有一个server,则默认为这个 if countServers == 1 && !matchDomainStrictly { diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index 43ebb26..247fd4e 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -115,7 +115,7 @@ func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http domain = reqHost } - server, serverName := this.findNamedServer(this.Group, domain) + server, serverName := this.findNamedServer(domain) if server == nil { // 严格匹配域名模式下,我们拒绝用户访问 if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { diff --git a/internal/nodes/listener_manager.go b/internal/nodes/listener_manager.go index 93c5a2e..daf7d6d 100644 --- a/internal/nodes/listener_manager.go +++ b/internal/nodes/listener_manager.go @@ -56,11 +56,13 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error { } // 停掉老的 - for _, listener := range this.listenersMap { + for listenerKey, listener := range this.listenersMap { addr := listener.FullAddr() if !lists.ContainsString(groupAddrs, addr) { logs.Println("[LISTENER_MANAGER]close '" + addr + "'") _ = listener.Close() + + delete(this.listenersMap, listenerKey) } } diff --git a/internal/nodes/listener_tcp.go b/internal/nodes/listener_tcp.go index ee70d57..429c5d2 100644 --- a/internal/nodes/listener_tcp.go +++ b/internal/nodes/listener_tcp.go @@ -1,6 +1,7 @@ package nodes import ( + "crypto/tls" "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/iwind/TeaGo/logs" @@ -10,13 +11,17 @@ import ( type TCPListener struct { BaseListener - Group *serverconfigs.ServerGroup Listener net.Listener } func (this *TCPListener) Serve() error { + listener := this.Listener + if this.Group.IsTLS() { + listener = tls.NewListener(listener, this.buildTLSConfig()) + } + for { - conn, err := this.Listener.Accept() + conn, err := listener.Accept() if err != nil { break } diff --git a/internal/nodes/listener_udp.go b/internal/nodes/listener_udp.go index 4d4537b..be05e70 100644 --- a/internal/nodes/listener_udp.go +++ b/internal/nodes/listener_udp.go @@ -8,7 +8,6 @@ import ( type UDPListener struct { BaseListener - Group *serverconfigs.ServerGroup Listener net.Listener } diff --git a/internal/nodes/listener_unix.go b/internal/nodes/listener_unix.go index f0d9b9a..787c593 100644 --- a/internal/nodes/listener_unix.go +++ b/internal/nodes/listener_unix.go @@ -8,7 +8,6 @@ import ( type UnixListener struct { BaseListener - Group *serverconfigs.ServerGroup Listener net.Listener }