diff --git a/internal/nodes/http_request.go b/internal/nodes/http_request.go index bd014cd..fba1843 100644 --- a/internal/nodes/http_request.go +++ b/internal/nodes/http_request.go @@ -948,12 +948,12 @@ func (this *HTTPRequest) processResponseHeaders(statusCode int) { // HSTS if this.IsHTTPS && this.Server.HTTPS != nil && - this.Server.HTTPS.SSL != nil && - this.Server.HTTPS.SSL.IsOn && - this.Server.HTTPS.SSL.HSTS != nil && - this.Server.HTTPS.SSL.HSTS.IsOn && - this.Server.HTTPS.SSL.HSTS.Match(this.Host) { - responseHeader.Set(this.Server.HTTPS.SSL.HSTS.HeaderKey(), this.Server.HTTPS.SSL.HSTS.HeaderValue()) + this.Server.HTTPS.SSLPolicy != nil && + this.Server.HTTPS.SSLPolicy.IsOn && + this.Server.HTTPS.SSLPolicy.HSTS != nil && + this.Server.HTTPS.SSLPolicy.HSTS.IsOn && + this.Server.HTTPS.SSLPolicy.HSTS.Match(this.Host) { + responseHeader.Set(this.Server.HTTPS.SSLPolicy.HSTS.HeaderKey(), this.Server.HTTPS.SSLPolicy.HSTS.HeaderValue()) } } diff --git a/internal/nodes/listener.go b/internal/nodes/listener.go index d0e5752..cc680ff 100644 --- a/internal/nodes/listener.go +++ b/internal/nodes/listener.go @@ -51,12 +51,12 @@ func (this *Listener) Listen() error { switch protocol { case serverconfigs.ProtocolHTTP, serverconfigs.ProtocolHTTP4, serverconfigs.ProtocolHTTP6: this.listener = &HTTPListener{ - Group: this.group, - Listener: netListener, + BaseListener: BaseListener{Group: this.group}, + Listener: netListener, } case serverconfigs.ProtocolHTTPS, serverconfigs.ProtocolHTTPS4, serverconfigs.ProtocolHTTPS6: this.listener = &HTTPListener{ - Group: this.group, + BaseListener: BaseListener{Group: this.group}, Listener: netListener, } case serverconfigs.ProtocolTCP, serverconfigs.ProtocolTCP4, serverconfigs.ProtocolTCP6: diff --git a/internal/nodes/listener_base.go b/internal/nodes/listener_base.go index fa47013..aec7f9d 100644 --- a/internal/nodes/listener_base.go +++ b/internal/nodes/listener_base.go @@ -13,6 +13,8 @@ type BaseListener struct { serversLocker sync.RWMutex namedServersLocker sync.RWMutex namedServers map[string]*NamedServer // 域名 => server + + Group *serverconfigs.ServerGroup } // 初始化 @@ -28,22 +30,22 @@ func (this *BaseListener) Reset() { } // 构造TLS配置 -func (this *BaseListener) buildTLSConfig(group *serverconfigs.ServerGroup) *tls.Config { +func (this *BaseListener) buildTLSConfig() *tls.Config { return &tls.Config{ Certificates: nil, GetConfigForClient: func(info *tls.ClientHelloInfo) (config *tls.Config, e error) { - ssl, _, err := this.matchSSL(group, info.ServerName) + ssl, _, err := this.matchSSL(info.ServerName) if err != nil { return nil, err } cipherSuites := ssl.TLSCipherSuites() - if len(cipherSuites) == 0 { + if !ssl.CipherSuitesIsOn || len(cipherSuites) == 0 { cipherSuites = nil } nextProto := []string{} - if !ssl.HTTP2Disabled { + if ssl.HTTP2Enabled { nextProto = []string{http2.NextProtoTLS} } return &tls.Config{ @@ -51,7 +53,7 @@ func (this *BaseListener) buildTLSConfig(group *serverconfigs.ServerGroup) *tls. MinVersion: ssl.TLSMinVersion(), CipherSuites: cipherSuites, GetCertificate: func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) { - _, cert, err := this.matchSSL(group, info.ServerName) + _, cert, err := this.matchSSL(info.ServerName) if err != nil { return nil, err } @@ -67,7 +69,7 @@ func (this *BaseListener) buildTLSConfig(group *serverconfigs.ServerGroup) *tls. }, nil }, GetCertificate: func(info *tls.ClientHelloInfo) (certificate *tls.Certificate, e error) { - _, cert, err := this.matchSSL(group, info.ServerName) + _, cert, err := this.matchSSL(info.ServerName) if err != nil { return nil, err } @@ -80,10 +82,12 @@ func (this *BaseListener) buildTLSConfig(group *serverconfigs.ServerGroup) *tls. } // 根据域名匹配证书 -func (this *BaseListener) matchSSL(group *serverconfigs.ServerGroup, domain string) (*sslconfigs.SSLConfig, *tls.Certificate, error) { +func (this *BaseListener) matchSSL(domain string) (*sslconfigs.SSLPolicy, *tls.Certificate, error) { this.serversLocker.RLock() defer this.serversLocker.RUnlock() + group := this.Group + // 如果域名为空,则取第一个 // 通常域名为空是因为是直接通过IP访问的 if len(domain) == 0 { @@ -95,7 +99,7 @@ func (this *BaseListener) matchSSL(group *serverconfigs.ServerGroup, domain stri if firstServer == nil { return nil, nil, errors.New("no server available") } - sslConfig := firstServer.SSLConfig() + sslConfig := firstServer.SSLPolicy() if sslConfig != nil { return sslConfig, sslConfig.FirstCert(), nil @@ -106,15 +110,15 @@ func (this *BaseListener) matchSSL(group *serverconfigs.ServerGroup, domain stri // 通过代理服务域名配置匹配 server, _ := this.findNamedServer(group, domain) - if server == nil || server.SSLConfig() == nil || !server.SSLConfig().IsOn { + if server == nil || server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { // 搜索所有的Server,通过SSL证书内容中的DNSName匹配 for _, server := range group.Servers { - if server.SSLConfig() == nil || !server.SSLConfig().IsOn { + if server.SSLPolicy() == nil || !server.SSLPolicy().IsOn { continue } - cert, ok := server.SSLConfig().MatchDomain(domain) + cert, ok := server.SSLPolicy().MatchDomain(domain) if ok { - return server.SSLConfig(), cert, nil + return server.SSLPolicy(), cert, nil } } @@ -122,7 +126,7 @@ func (this *BaseListener) matchSSL(group *serverconfigs.ServerGroup, domain stri } // 证书是否匹配 - sslConfig := server.SSLConfig() + sslConfig := server.SSLPolicy() cert, ok := sslConfig.MatchDomain(domain) if ok { return sslConfig, cert, nil diff --git a/internal/nodes/listener_http.go b/internal/nodes/listener_http.go index bfdda62..43ebb26 100644 --- a/internal/nodes/listener_http.go +++ b/internal/nodes/listener_http.go @@ -13,7 +13,6 @@ import ( type HTTPListener struct { BaseListener - Group *serverconfigs.ServerGroup Listener net.Listener addr string @@ -49,7 +48,7 @@ func (this *HTTPListener) Serve() error { // HTTPS协议 if this.isHTTPS { - this.httpServer.TLSConfig = this.buildTLSConfig(this.Group) + this.httpServer.TLSConfig = this.buildTLSConfig() // support http/2 err := http2.ConfigureServer(this.httpServer, nil) @@ -76,10 +75,6 @@ func (this *HTTPListener) Close() error { func (this *HTTPListener) Reload(group *serverconfigs.ServerGroup) { this.Group = group - if this.isHTTPS { - this.httpServer.TLSConfig = this.buildTLSConfig(this.Group) - } - this.Reset() }