diff --git a/pkg/serverconfigs/sslconfigs/ssl_cert_config_test.go b/pkg/serverconfigs/sslconfigs/ssl_cert_config_test.go new file mode 100644 index 0000000..6220c3e --- /dev/null +++ b/pkg/serverconfigs/sslconfigs/ssl_cert_config_test.go @@ -0,0 +1,18 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package sslconfigs + +import ( + "github.com/iwind/TeaGo/assert" + "testing" +) + +func TestSSLCertConfig_MatchDomain(t *testing.T) { + var a = assert.NewAssertion(t) + + var cert = &SSLCertConfig{ + DNSNames: []string{"a.com", "b.com"}, + } + a.IsTrue(cert.MatchDomain("a.com")) + a.IsFalse(cert.MatchDomain("z.com")) +} diff --git a/pkg/serverconfigs/sslconfigs/ssl_policy.go b/pkg/serverconfigs/sslconfigs/ssl_policy.go index e0bf29f..750206d 100644 --- a/pkg/serverconfigs/sslconfigs/ssl_policy.go +++ b/pkg/serverconfigs/sslconfigs/ssl_policy.go @@ -3,15 +3,17 @@ package sslconfigs import ( "crypto/tls" "crypto/x509" + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + "golang.org/x/net/http2" ) -// TLS Version +// TLSVersion TLS Version type TLSVersion = string -// Cipher Suites +// TLSCipherSuite Cipher Suites type TLSCipherSuite = string -// SSL配置 +// SSLPolicy SSL配置 type SSLPolicy struct { Id int64 `yaml:"id" json:"id"` // ID IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启 @@ -35,16 +37,25 @@ type SSLPolicy struct { cipherSuites []uint16 clientCAPool *x509.CertPool + + tlsConfig *tls.Config } -// 校验配置 +// Init 校验配置 func (this *SSLPolicy) Init() error { + this.nameMapping = map[string]*tls.Certificate{} + // certs + var certs = []tls.Certificate{} for _, cert := range this.Certs { err := cert.Init() if err != nil { return err } + certs = append(certs, *cert.CertObject()) + for _, dnsName := range cert.DNSNames { + this.nameMapping[dnsName] = cert.CertObject() + } } // CA certs @@ -53,6 +64,10 @@ func (this *SSLPolicy) Init() error { if err != nil { return err } + certs = append(certs, *cert.CertObject()) + for _, dnsName := range cert.DNSNames { + this.nameMapping[dnsName] = cert.CertObject() + } } // min version @@ -69,30 +84,56 @@ func (this *SSLPolicy) Init() error { } } + // tls config + this.tlsConfig = &tls.Config{} + cipherSuites := this.TLSCipherSuites() + if !this.CipherSuitesIsOn || len(cipherSuites) == 0 { + cipherSuites = nil + } + + nextProto := []string{} + if this.HTTP2Enabled { + nextProto = []string{http2.NextProtoTLS} + } + this.tlsConfig = &tls.Config{ + Certificates: certs, + MinVersion: this.TLSMinVersion(), + CipherSuites: cipherSuites, + GetCertificate: nil, + ClientAuth: GoSSLClientAuthType(this.ClientAuthType), + ClientCAs: this.CAPool(), + NextProtos: nextProto, + } + return nil } -// 取得最小版本 +// TLSMinVersion 取得最小版本 func (this *SSLPolicy) TLSMinVersion() uint16 { return this.minVersion } -// 套件 +// TLSCipherSuites 套件 func (this *SSLPolicy) TLSCipherSuites() []uint16 { return this.cipherSuites } -// 校验是否匹配某个域名 +// MatchDomain 校验是否匹配某个域名 func (this *SSLPolicy) MatchDomain(domain string) (cert *tls.Certificate, ok bool) { - for _, cert := range this.Certs { - if cert.MatchDomain(domain) { - return cert.CertObject(), true + cert, ok = this.nameMapping[domain] + if ok { + return + } + + for name, cert := range this.nameMapping { + if configutils.MatchDomain(name, domain) { + return cert, true } } return nil, false } -// 取得第一个证书 +// FirstCert 取得第一个证书 func (this *SSLPolicy) FirstCert() *tls.Certificate { for _, cert := range this.Certs { return cert.CertObject() @@ -100,7 +141,11 @@ func (this *SSLPolicy) FirstCert() *tls.Certificate { return nil } -// CA证书Pool,用于TLS对客户端进行认证 +// CAPool CA证书Pool,用于TLS对客户端进行认证 func (this *SSLPolicy) CAPool() *x509.CertPool { return this.clientCAPool } + +func (this *SSLPolicy) TLSConfig() *tls.Config { + return this.tlsConfig +} diff --git a/pkg/serverconfigs/sslconfigs/ssl_policy_test.go b/pkg/serverconfigs/sslconfigs/ssl_policy_test.go new file mode 100644 index 0000000..87cb870 --- /dev/null +++ b/pkg/serverconfigs/sslconfigs/ssl_policy_test.go @@ -0,0 +1,33 @@ +// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved. + +package sslconfigs + +import ( + "github.com/iwind/TeaGo/assert" + "testing" +) + +func TestSSLPolicy_MatchDomain(t *testing.T) { + var a = assert.NewAssertion(t) + + var policy = &SSLPolicy{} + policy.Certs = []*SSLCertConfig{ + { + Id: 1, + DNSNames: []string{"a.com", "b.com"}, + }, + { + Id: 2, + DNSNames: []string{"c.com", "d.com"}, + }, + { + Id: 3, + DNSNames: []string{"e.com", "f.com"}, + }, + } + + { + _, ok := policy.MatchDomain("c.com") + a.IsTrue(ok) + } +}