OCSP支持过期时间

This commit is contained in:
刘祥超
2022-03-18 20:21:24 +08:00
parent a23a2ed826
commit e370944233
6 changed files with 85 additions and 40 deletions

View File

@@ -26,5 +26,5 @@ const (
ReportNodeVersion = "0.1.0" ReportNodeVersion = "0.1.0"
// SQLVersion SQL版本号 // SQLVersion SQL版本号
SQLVersion = "4" SQLVersion = "5"
) )

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"bytes"
"encoding/json" "encoding/json"
"errors" "errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils"
@@ -138,7 +139,18 @@ func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx,
if certId <= 0 { if certId <= 0 {
return errors.New("invalid certId") return errors.New("invalid certId")
} }
op := NewSSLCertOperator()
oldOne, err := this.Query(tx).Find()
if err != nil {
return err
}
if oldOne == nil {
return nil
}
var oldCert = oldOne.(*SSLCert)
var dataIsChanged = bytes.Compare(certData, []byte(oldCert.CertData)) != 0 || bytes.Compare(keyData, []byte(oldCert.KeyData)) != 0
var op = NewSSLCertOperator()
op.Id = certId op.Id = certId
op.IsOn = isOn op.IsOn = isOn
op.Name = name op.Name = name
@@ -169,7 +181,15 @@ func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx,
} }
op.CommonNames = commonNamesJSON op.CommonNames = commonNamesJSON
op.OcspIsUpdated = false // OCSP
if dataIsChanged {
op.OcspIsUpdated = 0
op.Ocsp = ""
op.OcspUpdatedAt = 0
op.OcspError = ""
op.OcspTries = 0
op.OcspExpiresAt = 0
}
err = this.Save(tx, op) err = this.Save(tx, op)
if err != nil { if err != nil {
@@ -209,7 +229,12 @@ func (this *SSLCertDAO) ComposeCertConfig(tx *dbs.Tx, certId int64, cacheMap *ut
config.ServerName = cert.ServerName config.ServerName = cert.ServerName
config.TimeBeginAt = int64(cert.TimeBeginAt) config.TimeBeginAt = int64(cert.TimeBeginAt)
config.TimeEndAt = int64(cert.TimeEndAt) config.TimeEndAt = int64(cert.TimeEndAt)
config.OCSP = []byte(cert.Ocsp)
// OCSP
if int64(cert.OcspExpiresAt) > time.Now().Unix() {
config.OCSP = []byte(cert.Ocsp)
config.OCSPExpiresAt = int64(cert.OcspExpiresAt)
}
config.OCSPError = cert.OcspError config.OCSPError = cert.OcspError
if IsNotNull(cert.DnsNames) { if IsNotNull(cert.DnsNames) {
@@ -374,15 +399,14 @@ func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) er
} }
// ListCertsToUpdateOCSP 查找需要更新OCSP的证书 // ListCertsToUpdateOCSP 查找需要更新OCSP的证书
func (this *SSLCertDAO) ListCertsToUpdateOCSP(tx *dbs.Tx, maxAge int, size int64) (result []*SSLCert, err error) { func (this *SSLCertDAO) ListCertsToUpdateOCSP(tx *dbs.Tx, maxTries int, size int64) (result []*SSLCert, err error) {
if maxAge <= 0 { var nowTime = time.Now().Unix()
maxAge = 7200
}
var query = this.Query(tx). var query = this.Query(tx).
State(SSLCertStateEnabled). State(SSLCertStateEnabled).
Where("ocspUpdatedAt<:timestamp"). Lt("ocspExpiresAt", nowTime+120). // 提前 N 秒钟准备更新
Param("timestamp", time.Now().Unix()-int64(maxAge)) Lt("ocspTries", maxTries).
Lt("timeBeginAt", nowTime).
Gt("timeEndAt", nowTime)
// TODO 需要排除没有被server使用的policy或许可以增加一个字段记录policy最近使用时间 // TODO 需要排除没有被server使用的policy或许可以增加一个字段记录policy最近使用时间
@@ -411,7 +435,7 @@ func (this *SSLCertDAO) ListCertsToUpdateOCSP(tx *dbs.Tx, maxAge int, size int64
func (this *SSLCertDAO) ListCertOCSPAfterVersion(tx *dbs.Tx, version int64, size int64) (result []*SSLCert, err error) { func (this *SSLCertDAO) ListCertOCSPAfterVersion(tx *dbs.Tx, version int64, size int64) (result []*SSLCert, err error) {
// 不需要判断ocsp是否为空 // 不需要判断ocsp是否为空
_, err = this.Query(tx). _, err = this.Query(tx).
Result("id", "ocsp", "ocspUpdatedVersion"). Result("id", "ocsp", "ocspUpdatedVersion", "ocspExpiresAt").
State(SSLCertStateEnabled). State(SSLCertStateEnabled).
Attr("ocspIsUpdated", 1). Attr("ocspIsUpdated", 1).
Gt("ocspUpdatedVersion", version). Gt("ocspUpdatedVersion", version).
@@ -441,7 +465,11 @@ func (this *SSLCertDAO) PrepareCertOCSPUpdating(tx *dbs.Tx, certId int64) error
} }
// UpdateCertOCSP 修改OCSP // UpdateCertOCSP 修改OCSP
func (this *SSLCertDAO) UpdateCertOCSP(tx *dbs.Tx, certId int64, ocsp []byte, errString string) error { func (this *SSLCertDAO) UpdateCertOCSP(tx *dbs.Tx, certId int64, ocsp []byte, expiresAt int64, hasErr bool, errString string) error {
if hasErr && len(errString) == 0 {
errString = "failed"
}
version, err := SharedSysLockerDAO.Increase(tx, "SSL_CERT_OCSP_VERSION", 1) version, err := SharedSysLockerDAO.Increase(tx, "SSL_CERT_OCSP_VERSION", 1)
if err != nil { if err != nil {
return err return err
@@ -456,14 +484,22 @@ func (this *SSLCertDAO) UpdateCertOCSP(tx *dbs.Tx, certId int64, ocsp []byte, er
errString = errString[:300] errString = errString[:300]
} }
err = this.Query(tx). var query = this.Query(tx).
Pk(certId). Pk(certId).
Set("ocsp", ocsp). Set("ocsp", ocsp).
Set("ocspError", errString). Set("ocspError", errString).
Set("ocspIsUpdated", true). Set("ocspIsUpdated", true).
Set("ocspUpdatedAt", time.Now().Unix()). Set("ocspUpdatedAt", time.Now().Unix()).
Set("ocspUpdatedVersion", version). Set("ocspUpdatedVersion", version).
UpdateQuickly() Set("ocspExpiresAt", expiresAt)
if hasErr {
query.Set("ocspTries", dbs.SQL("ocspTries+1"))
} else {
query.Set("ocspTries", 0)
}
err = query.UpdateQuickly()
if err != nil { if err != nil {
return err return err
} }
@@ -531,6 +567,7 @@ func (this *SSLCertDAO) ResetSSLCertsWithOCSPError(tx *dbs.Tx, certIds []int64)
Set("ocspIsUpdated", 0). Set("ocspIsUpdated", 0).
Set("ocspUpdatedAt", 0). Set("ocspUpdatedAt", 0).
Set("ocspError", ""). Set("ocspError", "").
Set("ocspTries", 0).
UpdateQuickly() UpdateQuickly()
if err != nil { if err != nil {
return err return err
@@ -548,6 +585,7 @@ func (this *SSLCertDAO) ResetAllSSLCertsWithOCSPError(tx *dbs.Tx) error {
Set("ocspIsUpdated", 0). Set("ocspIsUpdated", 0).
Set("ocspUpdatedAt", 0). Set("ocspUpdatedAt", 0).
Set("ocspError", ""). Set("ocspError", "").
Set("ocspTries", 0).
UpdateQuickly() UpdateQuickly()
} }

View File

@@ -25,9 +25,11 @@ type SSLCert struct {
NotifiedAt uint64 `field:"notifiedAt"` // 最后通知时间 NotifiedAt uint64 `field:"notifiedAt"` // 最后通知时间
Ocsp string `field:"ocsp"` // OCSP缓存 Ocsp string `field:"ocsp"` // OCSP缓存
OcspIsUpdated uint8 `field:"ocspIsUpdated"` // OCSP是否已更新 OcspIsUpdated uint8 `field:"ocspIsUpdated"` // OCSP是否已更新
OcspUpdatedAt uint32 `field:"ocspUpdatedAt"` // OCSP更新时间 OcspUpdatedAt uint64 `field:"ocspUpdatedAt"` // OCSP更新时间
OcspError string `field:"ocspError"` // OCSP更新错误 OcspError string `field:"ocspError"` // OCSP更新错误
OcspUpdatedVersion uint64 `field:"ocspUpdatedVersion"` // OCSP更新版本 OcspUpdatedVersion uint64 `field:"ocspUpdatedVersion"` // OCSP更新版本
OcspExpiresAt uint64 `field:"ocspExpiresAt"` // OCSP过期时间(UTC)
OcspTries uint32 `field:"ocspTries"` // OCSP尝试次数
} }
type SSLCertOperator struct { type SSLCertOperator struct {
@@ -57,6 +59,8 @@ type SSLCertOperator struct {
OcspUpdatedAt interface{} // OCSP更新时间 OcspUpdatedAt interface{} // OCSP更新时间
OcspError interface{} // OCSP更新错误 OcspError interface{} // OCSP更新错误
OcspUpdatedVersion interface{} // OCSP更新版本 OcspUpdatedVersion interface{} // OCSP更新版本
OcspExpiresAt interface{} // OCSP过期时间(UTC)
OcspTries interface{} // OCSP尝试次数
} }
func NewSSLCertOperator() *SSLCertOperator { func NewSSLCertOperator() *SSLCertOperator {

View File

@@ -300,7 +300,8 @@ func (this *SSLCertService) ListUpdatedSSLCertOCSP(ctx context.Context, req *pb.
for _, cert := range certs { for _, cert := range certs {
result = append(result, &pb.ListUpdatedSSLCertOCSPResponse_SSLCertOCSP{ result = append(result, &pb.ListUpdatedSSLCertOCSPResponse_SSLCertOCSP{
SslCertId: int64(cert.Id), SslCertId: int64(cert.Id),
Ocsp: []byte(cert.Ocsp), Data: []byte(cert.Ocsp),
ExpiresAt: int64(cert.OcspExpiresAt),
Version: int64(cert.OcspUpdatedVersion), Version: int64(cert.OcspUpdatedVersion),
}) })
} }

File diff suppressed because one or more lines are too long

View File

@@ -63,8 +63,8 @@ func (this *SSLCertUpdateOCSPTask) Loop(checkLock bool) error {
var tx *dbs.Tx var tx *dbs.Tx
// TODO 将来可以设置单次任务条数 // TODO 将来可以设置单次任务条数
var size int64 = 60 var size int64 = 60
var maxAge = 7200 var maxTries = 5
certs, err := models.SharedSSLCertDAO.ListCertsToUpdateOCSP(tx, maxAge, size) certs, err := models.SharedSSLCertDAO.ListCertsToUpdateOCSP(tx, maxTries, size)
if err != nil { if err != nil {
return errors.New("list certs failed: " + err.Error()) return errors.New("list certs failed: " + err.Error())
} }
@@ -82,14 +82,16 @@ func (this *SSLCertUpdateOCSPTask) Loop(checkLock bool) error {
} }
for _, cert := range certs { for _, cert := range certs {
ocspData, err := this.UpdateCertOCSP(cert) ocspData, expiresAt, err := this.UpdateCertOCSP(cert)
var errString = "" var errString = ""
var hasErr = false
if err != nil { if err != nil {
errString = err.Error() errString = err.Error()
hasErr = true
remotelogs.Warn("SSLCertUpdateOCSPTask", "update ocsp failed: "+errString) remotelogs.Warn("SSLCertUpdateOCSPTask", "update ocsp failed: "+errString)
} }
err = models.SharedSSLCertDAO.UpdateCertOCSP(tx, int64(cert.Id), ocspData, errString) err = models.SharedSSLCertDAO.UpdateCertOCSP(tx, int64(cert.Id), ocspData, expiresAt, hasErr, errString)
if err != nil { if err != nil {
return errors.New("update ocsp failed: " + err.Error()) return errors.New("update ocsp failed: " + err.Error())
} }
@@ -99,37 +101,37 @@ func (this *SSLCertUpdateOCSPTask) Loop(checkLock bool) error {
} }
// UpdateCertOCSP 更新单个证书OCSP // UpdateCertOCSP 更新单个证书OCSP
func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocspData []byte, err error) { func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocspData []byte, expiresAt int64, err error) {
if certOne.IsCA == 1 || len(certOne.CertData) == 0 || len(certOne.KeyData) == 0 { if certOne.IsCA == 1 || len(certOne.CertData) == 0 || len(certOne.KeyData) == 0 {
return return
} }
keyPair, err := tls.X509KeyPair([]byte(certOne.CertData), []byte(certOne.KeyData)) keyPair, err := tls.X509KeyPair([]byte(certOne.CertData), []byte(certOne.KeyData))
if err != nil { if err != nil {
return nil, errors.New("parse certificate failed: " + err.Error()) return nil, 0, errors.New("parse certificate failed: " + err.Error())
} }
if len(keyPair.Certificate) == 0 { if len(keyPair.Certificate) == 0 {
return nil, nil return nil, 0, nil
} }
var certData = keyPair.Certificate[0] var certData = keyPair.Certificate[0]
cert, err := x509.ParseCertificate(certData) cert, err := x509.ParseCertificate(certData)
if err != nil { if err != nil {
return nil, errors.New("parse certificate block failed: " + err.Error()) return nil, 0, errors.New("parse certificate block failed: " + err.Error())
} }
// 是否已过期 // 是否已过期
var now = time.Now() var now = time.Now()
if cert.NotBefore.After(now) || cert.NotAfter.Before(now) { if cert.NotBefore.After(now) || cert.NotAfter.Before(now) {
return nil, nil return nil, 0, nil
} }
if len(cert.IssuingCertificateURL) == 0 || len(cert.OCSPServer) == 0 { if len(cert.IssuingCertificateURL) == 0 || len(cert.OCSPServer) == 0 {
return nil, nil return nil, 0, nil
} }
if len(cert.DNSNames) == 0 { if len(cert.DNSNames) == 0 {
return nil, nil return nil, 0, nil
} }
var issuerURL = cert.IssuingCertificateURL[0] var issuerURL = cert.IssuingCertificateURL[0]
@@ -137,12 +139,12 @@ func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocsp
issuerReq, err := http.NewRequest(http.MethodGet, issuerURL, nil) issuerReq, err := http.NewRequest(http.MethodGet, issuerURL, nil)
if err != nil { if err != nil {
return nil, errors.New("request issuer certificate failed: " + err.Error()) return nil, 0, errors.New("request issuer certificate failed: " + err.Error())
} }
issuerReq.Header.Set("User-Agent", teaconst.ProductName+"/"+teaconst.Version) issuerReq.Header.Set("User-Agent", teaconst.ProductName+"/"+teaconst.Version)
issuerResp, err := this.httpClient.Do(issuerReq) issuerResp, err := this.httpClient.Do(issuerReq)
if err != nil { if err != nil {
return nil, errors.New("request issuer certificate failed: '" + issuerURL + "': " + err.Error()) return nil, 0, errors.New("request issuer certificate failed: '" + issuerURL + "': " + err.Error())
} }
defer func() { defer func() {
_ = issuerResp.Body.Close() _ = issuerResp.Body.Close()
@@ -150,29 +152,29 @@ func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocsp
issuerData, err := ioutil.ReadAll(issuerResp.Body) issuerData, err := ioutil.ReadAll(issuerResp.Body)
if err != nil { if err != nil {
return nil, errors.New("read issuer certificate failed: '" + issuerURL + "': " + err.Error()) return nil, 0, errors.New("read issuer certificate failed: '" + issuerURL + "': " + err.Error())
} }
issuerCert, err := x509.ParseCertificate(issuerData) issuerCert, err := x509.ParseCertificate(issuerData)
if err != nil { if err != nil {
return nil, errors.New("parse issuer certificate failed: '" + issuerURL + "': " + err.Error()) return nil, 0, errors.New("parse issuer certificate failed: '" + issuerURL + "': " + err.Error())
} }
buf, err := ocsp.CreateRequest(cert, issuerCert, &ocsp.RequestOptions{ buf, err := ocsp.CreateRequest(cert, issuerCert, &ocsp.RequestOptions{
Hash: crypto.SHA1, Hash: crypto.SHA1,
}) })
if err != nil { if err != nil {
return nil, errors.New("create ocsp request failed: " + err.Error()) return nil, 0, errors.New("create ocsp request failed: " + err.Error())
} }
ocspReq, err := http.NewRequest(http.MethodPost, ocspServerURL, bytes.NewBuffer(buf)) ocspReq, err := http.NewRequest(http.MethodPost, ocspServerURL, bytes.NewBuffer(buf))
if err != nil { if err != nil {
return nil, errors.New("request ocsp failed: " + err.Error()) return nil, 0, errors.New("request ocsp failed: " + err.Error())
} }
ocspReq.Header.Set("Content-Type", "application/ocsp-request") ocspReq.Header.Set("Content-Type", "application/ocsp-request")
ocspReq.Header.Set("Accept", "application/ocsp-response") ocspReq.Header.Set("Accept", "application/ocsp-response")
ocspResp, err := this.httpClient.Do(ocspReq) ocspResp, err := this.httpClient.Do(ocspReq)
if err != nil { if err != nil {
return nil, errors.New("request ocsp failed: '" + ocspServerURL + "': " + err.Error()) return nil, 0, errors.New("request ocsp failed: '" + ocspServerURL + "': " + err.Error())
} }
defer func() { defer func() {
@@ -181,17 +183,17 @@ func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocsp
respData, err := ioutil.ReadAll(ocspResp.Body) respData, err := ioutil.ReadAll(ocspResp.Body)
if err != nil { if err != nil {
return nil, errors.New("read ocsp failed: '" + ocspServerURL + "': " + err.Error()) return nil, 0, errors.New("read ocsp failed: '" + ocspServerURL + "': " + err.Error())
} }
ocspResult, err := ocsp.ParseResponse(respData, issuerCert) ocspResult, err := ocsp.ParseResponse(respData, issuerCert)
if err != nil { if err != nil {
return nil, errors.New("decode ocsp failed: " + err.Error()) return nil, 0, errors.New("decode ocsp failed: " + err.Error())
} }
// 只返回Good的ocsp // 只返回Good的ocsp
if ocspResult.Status == ocsp.Good { if ocspResult.Status == ocsp.Good {
return respData, nil return respData, ocspResult.NextUpdate.Unix(), nil
} }
return nil, nil return nil, 0, nil
} }