优化代码

This commit is contained in:
刘祥超
2021-12-13 14:58:45 +08:00
parent abd5c6dbb1
commit 173175a248
3 changed files with 108 additions and 12 deletions

View File

@@ -0,0 +1,18 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package sslconfigs
import (
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestSSLCertConfig_MatchDomain(t *testing.T) {
var a = assert.NewAssertion(t)
var cert = &SSLCertConfig{
DNSNames: []string{"a.com", "b.com"},
}
a.IsTrue(cert.MatchDomain("a.com"))
a.IsFalse(cert.MatchDomain("z.com"))
}

View File

@@ -3,15 +3,17 @@ package sslconfigs
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
"golang.org/x/net/http2"
) )
// TLS Version // TLSVersion TLS Version
type TLSVersion = string type TLSVersion = string
// Cipher Suites // TLSCipherSuite Cipher Suites
type TLSCipherSuite = string type TLSCipherSuite = string
// SSL配置 // SSLPolicy SSL配置
type SSLPolicy struct { type SSLPolicy struct {
Id int64 `yaml:"id" json:"id"` // ID Id int64 `yaml:"id" json:"id"` // ID
IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启 IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启
@@ -35,16 +37,25 @@ type SSLPolicy struct {
cipherSuites []uint16 cipherSuites []uint16
clientCAPool *x509.CertPool clientCAPool *x509.CertPool
tlsConfig *tls.Config
} }
// 校验配置 // Init 校验配置
func (this *SSLPolicy) Init() error { func (this *SSLPolicy) Init() error {
this.nameMapping = map[string]*tls.Certificate{}
// certs // certs
var certs = []tls.Certificate{}
for _, cert := range this.Certs { for _, cert := range this.Certs {
err := cert.Init() err := cert.Init()
if err != nil { if err != nil {
return err return err
} }
certs = append(certs, *cert.CertObject())
for _, dnsName := range cert.DNSNames {
this.nameMapping[dnsName] = cert.CertObject()
}
} }
// CA certs // CA certs
@@ -53,6 +64,10 @@ func (this *SSLPolicy) Init() error {
if err != nil { if err != nil {
return err return err
} }
certs = append(certs, *cert.CertObject())
for _, dnsName := range cert.DNSNames {
this.nameMapping[dnsName] = cert.CertObject()
}
} }
// min version // min version
@@ -69,30 +84,56 @@ func (this *SSLPolicy) Init() error {
} }
} }
// tls config
this.tlsConfig = &tls.Config{}
cipherSuites := this.TLSCipherSuites()
if !this.CipherSuitesIsOn || len(cipherSuites) == 0 {
cipherSuites = nil
}
nextProto := []string{}
if this.HTTP2Enabled {
nextProto = []string{http2.NextProtoTLS}
}
this.tlsConfig = &tls.Config{
Certificates: certs,
MinVersion: this.TLSMinVersion(),
CipherSuites: cipherSuites,
GetCertificate: nil,
ClientAuth: GoSSLClientAuthType(this.ClientAuthType),
ClientCAs: this.CAPool(),
NextProtos: nextProto,
}
return nil return nil
} }
// 取得最小版本 // TLSMinVersion 取得最小版本
func (this *SSLPolicy) TLSMinVersion() uint16 { func (this *SSLPolicy) TLSMinVersion() uint16 {
return this.minVersion return this.minVersion
} }
// 套件 // TLSCipherSuites 套件
func (this *SSLPolicy) TLSCipherSuites() []uint16 { func (this *SSLPolicy) TLSCipherSuites() []uint16 {
return this.cipherSuites return this.cipherSuites
} }
// 校验是否匹配某个域名 // MatchDomain 校验是否匹配某个域名
func (this *SSLPolicy) MatchDomain(domain string) (cert *tls.Certificate, ok bool) { func (this *SSLPolicy) MatchDomain(domain string) (cert *tls.Certificate, ok bool) {
for _, cert := range this.Certs { cert, ok = this.nameMapping[domain]
if cert.MatchDomain(domain) { if ok {
return cert.CertObject(), true return
}
for name, cert := range this.nameMapping {
if configutils.MatchDomain(name, domain) {
return cert, true
} }
} }
return nil, false return nil, false
} }
// 取得第一个证书 // FirstCert 取得第一个证书
func (this *SSLPolicy) FirstCert() *tls.Certificate { func (this *SSLPolicy) FirstCert() *tls.Certificate {
for _, cert := range this.Certs { for _, cert := range this.Certs {
return cert.CertObject() return cert.CertObject()
@@ -100,7 +141,11 @@ func (this *SSLPolicy) FirstCert() *tls.Certificate {
return nil return nil
} }
// CA证书Pool用于TLS对客户端进行认证 // CAPool CA证书Pool用于TLS对客户端进行认证
func (this *SSLPolicy) CAPool() *x509.CertPool { func (this *SSLPolicy) CAPool() *x509.CertPool {
return this.clientCAPool return this.clientCAPool
} }
func (this *SSLPolicy) TLSConfig() *tls.Config {
return this.tlsConfig
}

View File

@@ -0,0 +1,33 @@
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
package sslconfigs
import (
"github.com/iwind/TeaGo/assert"
"testing"
)
func TestSSLPolicy_MatchDomain(t *testing.T) {
var a = assert.NewAssertion(t)
var policy = &SSLPolicy{}
policy.Certs = []*SSLCertConfig{
{
Id: 1,
DNSNames: []string{"a.com", "b.com"},
},
{
Id: 2,
DNSNames: []string{"c.com", "d.com"},
},
{
Id: 3,
DNSNames: []string{"e.com", "f.com"},
},
}
{
_, ok := policy.MatchDomain("c.com")
a.IsTrue(ok)
}
}