mirror of
https://github.com/TeaOSLab/EdgeCommon.git
synced 2025-11-03 20:40:25 +08:00
优化代码
This commit is contained in:
18
pkg/serverconfigs/sslconfigs/ssl_cert_config_test.go
Normal file
18
pkg/serverconfigs/sslconfigs/ssl_cert_config_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package sslconfigs
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSSLCertConfig_MatchDomain(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
var cert = &SSLCertConfig{
|
||||
DNSNames: []string{"a.com", "b.com"},
|
||||
}
|
||||
a.IsTrue(cert.MatchDomain("a.com"))
|
||||
a.IsFalse(cert.MatchDomain("z.com"))
|
||||
}
|
||||
@@ -3,15 +3,17 @@ package sslconfigs
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"github.com/TeaOSLab/EdgeCommon/pkg/configutils"
|
||||
"golang.org/x/net/http2"
|
||||
)
|
||||
|
||||
// TLS Version
|
||||
// TLSVersion TLS Version
|
||||
type TLSVersion = string
|
||||
|
||||
// Cipher Suites
|
||||
// TLSCipherSuite Cipher Suites
|
||||
type TLSCipherSuite = string
|
||||
|
||||
// SSL配置
|
||||
// SSLPolicy SSL配置
|
||||
type SSLPolicy struct {
|
||||
Id int64 `yaml:"id" json:"id"` // ID
|
||||
IsOn bool `yaml:"isOn" json:"isOn"` // 是否开启
|
||||
@@ -35,16 +37,25 @@ type SSLPolicy struct {
|
||||
cipherSuites []uint16
|
||||
|
||||
clientCAPool *x509.CertPool
|
||||
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// 校验配置
|
||||
// Init 校验配置
|
||||
func (this *SSLPolicy) Init() error {
|
||||
this.nameMapping = map[string]*tls.Certificate{}
|
||||
|
||||
// certs
|
||||
var certs = []tls.Certificate{}
|
||||
for _, cert := range this.Certs {
|
||||
err := cert.Init()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certs = append(certs, *cert.CertObject())
|
||||
for _, dnsName := range cert.DNSNames {
|
||||
this.nameMapping[dnsName] = cert.CertObject()
|
||||
}
|
||||
}
|
||||
|
||||
// CA certs
|
||||
@@ -53,6 +64,10 @@ func (this *SSLPolicy) Init() error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certs = append(certs, *cert.CertObject())
|
||||
for _, dnsName := range cert.DNSNames {
|
||||
this.nameMapping[dnsName] = cert.CertObject()
|
||||
}
|
||||
}
|
||||
|
||||
// min version
|
||||
@@ -69,30 +84,56 @@ func (this *SSLPolicy) Init() error {
|
||||
}
|
||||
}
|
||||
|
||||
// tls config
|
||||
this.tlsConfig = &tls.Config{}
|
||||
cipherSuites := this.TLSCipherSuites()
|
||||
if !this.CipherSuitesIsOn || len(cipherSuites) == 0 {
|
||||
cipherSuites = nil
|
||||
}
|
||||
|
||||
nextProto := []string{}
|
||||
if this.HTTP2Enabled {
|
||||
nextProto = []string{http2.NextProtoTLS}
|
||||
}
|
||||
this.tlsConfig = &tls.Config{
|
||||
Certificates: certs,
|
||||
MinVersion: this.TLSMinVersion(),
|
||||
CipherSuites: cipherSuites,
|
||||
GetCertificate: nil,
|
||||
ClientAuth: GoSSLClientAuthType(this.ClientAuthType),
|
||||
ClientCAs: this.CAPool(),
|
||||
NextProtos: nextProto,
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 取得最小版本
|
||||
// TLSMinVersion 取得最小版本
|
||||
func (this *SSLPolicy) TLSMinVersion() uint16 {
|
||||
return this.minVersion
|
||||
}
|
||||
|
||||
// 套件
|
||||
// TLSCipherSuites 套件
|
||||
func (this *SSLPolicy) TLSCipherSuites() []uint16 {
|
||||
return this.cipherSuites
|
||||
}
|
||||
|
||||
// 校验是否匹配某个域名
|
||||
// MatchDomain 校验是否匹配某个域名
|
||||
func (this *SSLPolicy) MatchDomain(domain string) (cert *tls.Certificate, ok bool) {
|
||||
for _, cert := range this.Certs {
|
||||
if cert.MatchDomain(domain) {
|
||||
return cert.CertObject(), true
|
||||
cert, ok = this.nameMapping[domain]
|
||||
if ok {
|
||||
return
|
||||
}
|
||||
|
||||
for name, cert := range this.nameMapping {
|
||||
if configutils.MatchDomain(name, domain) {
|
||||
return cert, true
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 取得第一个证书
|
||||
// FirstCert 取得第一个证书
|
||||
func (this *SSLPolicy) FirstCert() *tls.Certificate {
|
||||
for _, cert := range this.Certs {
|
||||
return cert.CertObject()
|
||||
@@ -100,7 +141,11 @@ func (this *SSLPolicy) FirstCert() *tls.Certificate {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CA证书Pool,用于TLS对客户端进行认证
|
||||
// CAPool CA证书Pool,用于TLS对客户端进行认证
|
||||
func (this *SSLPolicy) CAPool() *x509.CertPool {
|
||||
return this.clientCAPool
|
||||
}
|
||||
|
||||
func (this *SSLPolicy) TLSConfig() *tls.Config {
|
||||
return this.tlsConfig
|
||||
}
|
||||
|
||||
33
pkg/serverconfigs/sslconfigs/ssl_policy_test.go
Normal file
33
pkg/serverconfigs/sslconfigs/ssl_policy_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Copyright 2021 Liuxiangchao iwind.liu@gmail.com. All rights reserved.
|
||||
|
||||
package sslconfigs
|
||||
|
||||
import (
|
||||
"github.com/iwind/TeaGo/assert"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSSLPolicy_MatchDomain(t *testing.T) {
|
||||
var a = assert.NewAssertion(t)
|
||||
|
||||
var policy = &SSLPolicy{}
|
||||
policy.Certs = []*SSLCertConfig{
|
||||
{
|
||||
Id: 1,
|
||||
DNSNames: []string{"a.com", "b.com"},
|
||||
},
|
||||
{
|
||||
Id: 2,
|
||||
DNSNames: []string{"c.com", "d.com"},
|
||||
},
|
||||
{
|
||||
Id: 3,
|
||||
DNSNames: []string{"e.com", "f.com"},
|
||||
},
|
||||
}
|
||||
|
||||
{
|
||||
_, ok := policy.MatchDomain("c.com")
|
||||
a.IsTrue(ok)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user