实现TLS配置

This commit is contained in:
GoEdgeLab
2020-10-01 16:51:24 +08:00
parent 04e7db9dd3
commit 1da53beee4
7 changed files with 34 additions and 19 deletions

View File

@@ -57,27 +57,27 @@ func (this *Listener) Listen() error {
case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6: case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6:
this.listener = &HTTPListener{ this.listener = &HTTPListener{
BaseListener: BaseListener{Group: this.group}, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,
} }
case serverconfigs.ProtocolTCP, serverconfigs.ProtocolTCP4, serverconfigs.ProtocolTCP6: case serverconfigs.ProtocolTCP, serverconfigs.ProtocolTCP4, serverconfigs.ProtocolTCP6:
this.listener = &TCPListener{ this.listener = &TCPListener{
Group: this.group, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,
} }
case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6: case serverconfigs.ProtocolTLS, serverconfigs.ProtocolTLS4, serverconfigs.ProtocolTLS6:
this.listener = &TCPListener{ this.listener = &TCPListener{
Group: this.group, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,
} }
case serverconfigs.ProtocolUnix: case serverconfigs.ProtocolUnix:
this.listener = &UnixListener{ this.listener = &UnixListener{
Group: this.group, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,
} }
case serverconfigs.ProtocolUDP: case serverconfigs.ProtocolUDP:
this.listener = &UDPListener{ this.listener = &UDPListener{
Group: this.group, BaseListener: BaseListener{Group: this.group},
Listener: netListener, Listener: netListener,
} }
default: default:
return errors.New("unknown protocol '" + protocol.String() + "'") return errors.New("unknown protocol '" + protocol.String() + "'")

View File

@@ -88,10 +88,14 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
group := this.Group group := this.Group
if group == nil {
return nil, nil, errors.New("no configure found")
}
// 如果域名为空,则取第一个 // 如果域名为空,则取第一个
// 通常域名为空是因为是直接通过IP访问的 // 通常域名为空是因为是直接通过IP访问的
if len(domain) == 0 { if len(domain) == 0 {
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { if group.IsHTTPS() && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {
return nil, nil, errors.New("no tls server name matched") return nil, nil, errors.New("no tls server name matched")
} }
@@ -106,10 +110,11 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
} }
return nil, nil, errors.New("no tls server name found") return nil, nil, errors.New("no tls server name found")
} }
// 通过代理服务域名配置匹配 // 通过代理服务域名配置匹配
server, _ := this.findNamedServer(group, domain) server, _ := this.findNamedServer(domain)
if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn {
// 搜索所有的Server通过SSL证书内容中的DNSName匹配 // 搜索所有的Server通过SSL证书内容中的DNSName匹配
for _, server := range group.Servers { for _, server := range group.Servers {
@@ -136,7 +141,12 @@ func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.C
} }
// 根据域名来查找匹配的域名 // 根据域名来查找匹配的域名
func (this *BaseListener) findNamedServer(group *serverconfigs.ServerGroup, name string) (serverConfig *serverconfigs.ServerConfig, serverName string) { func (this *BaseListener) findNamedServer(name string) (serverConfig *serverconfigs.ServerConfig, serverName string) {
group := this.Group
if group == nil {
return nil, ""
}
// 读取缓存 // 读取缓存
this.namedServersLocker.RLock() this.namedServersLocker.RLock()
namedServer, found := this.namedServers[name] namedServer, found := this.namedServers[name]
@@ -159,7 +169,7 @@ func (this *BaseListener) findNamedServer(group *serverconfigs.ServerGroup, name
maxNamedServers := 10240 maxNamedServers := 10240
// 是否严格匹配域名 // 是否严格匹配域名
matchDomainStrictly := sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly matchDomainStrictly := group.IsHTTPS() && sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly
// 如果只有一个server则默认为这个 // 如果只有一个server则默认为这个
if countServers == 1 && !matchDomainStrictly { if countServers == 1 && !matchDomainStrictly {

View File

@@ -115,7 +115,7 @@ func (this *HTTPListener) handleHTTP(rawWriter http.ResponseWriter, rawReq *http
domain = reqHost domain = reqHost
} }
server, serverName := this.findNamedServer(this.Group, domain) server, serverName := this.findNamedServer(domain)
if server == nil { if server == nil {
// 严格匹配域名模式下,我们拒绝用户访问 // 严格匹配域名模式下,我们拒绝用户访问
if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly { if sharedNodeConfig.GlobalConfig != nil && sharedNodeConfig.GlobalConfig.HTTPAll.MatchDomainStrictly {

View File

@@ -56,11 +56,13 @@ func (this *ListenerManager) Start(node *nodeconfigs.NodeConfig) error {
} }
// 停掉老的 // 停掉老的
for _, listener := range this.listenersMap { for listenerKey, listener := range this.listenersMap {
addr := listener.FullAddr() addr := listener.FullAddr()
if !lists.ContainsString(groupAddrs, addr) { if !lists.ContainsString(groupAddrs, addr) {
logs.Println("[LISTENER_MANAGER]close '" + addr + "'") logs.Println("[LISTENER_MANAGER]close '" + addr + "'")
_ = listener.Close() _ = listener.Close()
delete(this.listenersMap, listenerKey)
} }
} }

View File

@@ -1,6 +1,7 @@
package nodes package nodes
import ( import (
"crypto/tls"
"errors" "errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/iwind/TeaGo/logs" "github.com/iwind/TeaGo/logs"
@@ -10,13 +11,17 @@ import (
type TCPListener struct { type TCPListener struct {
BaseListener BaseListener
Group *serverconfigs.ServerGroup
Listener net.Listener Listener net.Listener
} }
func (this *TCPListener) Serve() error { func (this *TCPListener) Serve() error {
listener := this.Listener
if this.Group.IsTLS() {
listener = tls.NewListener(listener, this.buildTLSConfig())
}
for { for {
conn, err := this.Listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
break break
} }

View File

@@ -8,7 +8,6 @@ import (
type UDPListener struct { type UDPListener struct {
BaseListener BaseListener
Group *serverconfigs.ServerGroup
Listener net.Listener Listener net.Listener
} }

View File

@@ -8,7 +8,6 @@ import (
type UnixListener struct { type UnixListener struct {
BaseListener BaseListener
Group *serverconfigs.ServerGroup
Listener net.Listener Listener net.Listener
} }