动态更新OCSP,防止过期

This commit is contained in:
刘祥超
2022-03-18 17:08:51 +08:00
parent 06c9c9403b
commit 3f9c250dff
10 changed files with 195 additions and 78 deletions

View File

@@ -26,5 +26,5 @@ const (
ReportNodeVersion = "0.1.0" ReportNodeVersion = "0.1.0"
// SQLVersion SQL版本号 // SQLVersion SQL版本号
SQLVersion = "3" SQLVersion = "4"
) )

View File

@@ -53,7 +53,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "", 10, timeutil.Format("Ymd"), 0, 0, 0, false, false, 0, 0, 0, false, 0, "", "", "") accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "", 10, timeutil.Format("Ymd"), "", "", 0, 0, 0, false, false, 0, 0, 0, false, 0, "", "", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -80,7 +80,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page(t *testing.T) {
times := 0 // 防止循环次数太多 times := 0 // 防止循环次数太多
for { for {
before := time.Now() before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd"), 0, 0, 0, false, false, 0, 0, 0, false, 0, "", "", "") accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd"), "", "", 0, 0, 0, false, false, 0, 0, 0, false, 0, "", "", "")
cost := time.Since(before).Seconds() cost := time.Since(before).Seconds()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -111,7 +111,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Reverse(t *testing.T) {
} }
before := time.Now() before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), 0, 0, 0, true, false, 0, 0, 0, false, 0, "", "", "") accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, "16023261176446590001000000000000003500000004", 2, timeutil.Format("Ymd"), "", "", 0, 0, 0, true, false, 0, 0, 0, false, 0, "", "", "")
cost := time.Since(before).Seconds() cost := time.Since(before).Seconds()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@@ -136,7 +136,7 @@ func TestHTTPAccessLogDAO_ListAccessLogs_Page_NotExists(t *testing.T) {
times := 0 // 防止循环次数太多 times := 0 // 防止循环次数太多
for { for {
before := time.Now() before := time.Now()
accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), 0, 0, 0, false, false, 0, 0, 0, false, 0, "", "", "") accessLogs, requestId, hasMore, err := SharedHTTPAccessLogDAO.ListAccessLogs(tx, lastRequestId, 2, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 1)), "", "", 0, 0, 0, false, false, 0, 0, 0, false, 0, "", "", "")
cost := time.Since(before).Seconds() cost := time.Since(before).Seconds()
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)

View File

@@ -948,6 +948,13 @@ func (this *NodeDAO) ComposeNodeConfig(tx *dbs.Tx, nodeId int64, cacheMap *utils
} }
} }
// OCSP
ocspVersion, err := SharedSSLCertDAO.FindCertOCSPLatestVersion(tx)
if err != nil {
return nil, err
}
config.OCSPVersion = ocspVersion
return config, nil return config, nil
} }

View File

@@ -374,18 +374,64 @@ func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) er
} }
// ListCertsToUpdateOCSP 查找需要更新OCSP的证书 // ListCertsToUpdateOCSP 查找需要更新OCSP的证书
func (this *SSLCertDAO) ListCertsToUpdateOCSP(tx *dbs.Tx, size int64) (result []*SSLCert, err error) { func (this *SSLCertDAO) ListCertsToUpdateOCSP(tx *dbs.Tx, maxAge int, size int64) (result []*SSLCert, err error) {
if maxAge <= 0 {
maxAge = 7200
}
_, err = this.Query(tx). _, err = this.Query(tx).
State(SSLCertStateEnabled). State(SSLCertStateEnabled).
Attr("ocspIsUpdated", false). Where("ocspUpdatedAt<:timestamp").
Param("timestamp", time.Now().Unix()-int64(maxAge)).
// TODO 需要排除没有被server使用的policy或许可以增加一个字段记录policy最近使用时间
Where("JSON_CONTAINS((SELECT JSON_ARRAYAGG(JSON_EXTRACT(certs, '$[*].certId')) FROM edgeSSLPolicies WHERE state=1 AND ocspIsOn=1 AND certs IS NOT NULL), CAST(id AS CHAR))").
Asc("ocspUpdatedAt").
Limit(size). Limit(size).
Slice(&result). Slice(&result).
FindAll() FindAll()
return return
} }
// UpdateCertOSCP 修改OCSP // ListCertOCSPAfterVersion 列出某个版本后的OCSP
func (this *SSLCertDAO) UpdateCertOSCP(tx *dbs.Tx, certId int64, ocsp []byte, errString string) error { func (this *SSLCertDAO) ListCertOCSPAfterVersion(tx *dbs.Tx, version int64, size int64) (result []*SSLCert, err error) {
// 不需要判断ocsp是否为空
_, err = this.Query(tx).
Result("id", "ocsp", "ocspUpdatedVersion").
State(SSLCertStateEnabled).
Attr("ocspIsUpdated", 1).
Gt("ocspUpdatedVersion", version).
Asc("ocspUpdatedVersion").
Limit(size).
Slice(&result).
FindAll()
return
}
// FindCertOCSPLatestVersion 获取OCSP最新版本
func (this *SSLCertDAO) FindCertOCSPLatestVersion(tx *dbs.Tx) (int64, error) {
return this.Query(tx).
Result("ocspUpdatedVersion").
Desc("ocspUpdatedVersion").
Limit(1).
FindInt64Col(0)
}
// PrepareCertOCSPUpdating 更新OCSP更新时间以便于准备更新相当于锁定
func (this *SSLCertDAO) PrepareCertOCSPUpdating(tx *dbs.Tx, certId int64) error {
return this.Query(tx).
Pk(certId).
Set("ocspUpdatedAt", time.Now().Unix()).
UpdateQuickly()
}
// UpdateCertOCSP 修改OCSP
func (this *SSLCertDAO) UpdateCertOCSP(tx *dbs.Tx, certId int64, ocsp []byte, errString string) error {
version, err := SharedSysLockerDAO.Increase(tx, "SSL_CERT_OCSP_VERSION", 1)
if err != nil {
return err
}
if ocsp == nil { if ocsp == nil {
ocsp = []byte{} ocsp = []byte{}
} }
@@ -395,17 +441,20 @@ func (this *SSLCertDAO) UpdateCertOSCP(tx *dbs.Tx, certId int64, ocsp []byte, er
errString = errString[:300] errString = errString[:300]
} }
err := this.Query(tx). err = this.Query(tx).
Pk(certId). Pk(certId).
Set("ocsp", ocsp). Set("ocsp", ocsp).
Set("ocspError", errString). Set("ocspError", errString).
Set("ocspIsUpdated", true). Set("ocspIsUpdated", true).
Set("ocspUpdatedAt", time.Now().Unix()).
Set("ocspUpdatedVersion", version).
UpdateQuickly() UpdateQuickly()
if err != nil { if err != nil {
return err return err
} }
return this.NotifyUpdate(tx, certId) // 注意:这里不通知更新,避免频繁的更新导致服务不稳定
return nil
} }
// CountAllSSLCertsWithOCSPError 计算有OCSP错误的证书数量 // CountAllSSLCertsWithOCSPError 计算有OCSP错误的证书数量
@@ -465,6 +514,7 @@ func (this *SSLCertDAO) ResetSSLCertsWithOCSPError(tx *dbs.Tx, certIds []int64)
err := this.Query(tx). err := this.Query(tx).
Pk(certId). Pk(certId).
Set("ocspIsUpdated", 0). Set("ocspIsUpdated", 0).
Set("ocspUpdatedAt", 0).
Set("ocspError", ""). Set("ocspError", "").
UpdateQuickly() UpdateQuickly()
if err != nil { if err != nil {
@@ -481,6 +531,7 @@ func (this *SSLCertDAO) ResetAllSSLCertsWithOCSPError(tx *dbs.Tx) error {
Attr("ocspIsUpdated", 1). Attr("ocspIsUpdated", 1).
Where("LENGTH(ocspError)>0"). Where("LENGTH(ocspError)>0").
Set("ocspIsUpdated", 0). Set("ocspIsUpdated", 0).
Set("ocspUpdatedAt", 0).
Set("ocspError", ""). Set("ocspError", "").
UpdateQuickly() UpdateQuickly()
} }

View File

@@ -1,5 +1,19 @@
package models package models_test
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
timeutil "github.com/iwind/TeaGo/utils/time"
"testing"
) )
func TestSSLCertDAO_ListCertsToUpdateOCSP(t *testing.T) {
var dao = models.NewSSLCertDAO()
certs, err := dao.ListCertsToUpdateOCSP(nil, 3600, 10)
if err != nil {
t.Fatal(err)
}
for _, cert := range certs {
t.Log(cert.Id, cert.Name, "updatedAt:", cert.OcspUpdatedAt, timeutil.FormatTime("Y-m-d H:i:s", int64(cert.OcspUpdatedAt)), cert.OcspIsUpdated)
}
}

View File

@@ -25,7 +25,9 @@ type SSLCert struct {
NotifiedAt uint64 `field:"notifiedAt"` // 最后通知时间 NotifiedAt uint64 `field:"notifiedAt"` // 最后通知时间
Ocsp string `field:"ocsp"` // OCSP缓存 Ocsp string `field:"ocsp"` // OCSP缓存
OcspIsUpdated uint8 `field:"ocspIsUpdated"` // OCSP是否已更新 OcspIsUpdated uint8 `field:"ocspIsUpdated"` // OCSP是否已更新
OcspUpdatedAt uint32 `field:"ocspUpdatedAt"` // OCSP更新时间
OcspError string `field:"ocspError"` // OCSP更新错误 OcspError string `field:"ocspError"` // OCSP更新错误
OcspUpdatedVersion uint64 `field:"ocspUpdatedVersion"` // OCSP更新版本
} }
type SSLCertOperator struct { type SSLCertOperator struct {
@@ -52,7 +54,9 @@ type SSLCertOperator struct {
NotifiedAt interface{} // 最后通知时间 NotifiedAt interface{} // 最后通知时间
Ocsp interface{} // OCSP缓存 Ocsp interface{} // OCSP缓存
OcspIsUpdated interface{} // OCSP是否已更新 OcspIsUpdated interface{} // OCSP是否已更新
OcspUpdatedAt interface{} // OCSP更新时间
OcspError interface{} // OCSP更新错误 OcspError interface{} // OCSP更新错误
OcspUpdatedVersion interface{} // OCSP更新版本
} }
func NewSSLCertOperator() *SSLCertOperator { func NewSSLCertOperator() *SSLCertOperator {

View File

@@ -282,3 +282,30 @@ func (this *SSLCertService) ResetAllSSLCertsWithOCSPError(ctx context.Context, r
} }
return this.Success() return this.Success()
} }
// ListUpdatedSSLCertOCSP 读取证书的OCSP
func (this *SSLCertService) ListUpdatedSSLCertOCSP(ctx context.Context, req *pb.ListUpdatedSSLCertOCSPRequest) (*pb.ListUpdatedSSLCertOCSPResponse, error) {
_, err := this.ValidateNode(ctx)
if err != nil {
return nil, err
}
var tx = this.NullTx()
certs, err := models.SharedSSLCertDAO.ListCertOCSPAfterVersion(tx, req.Version, int64(req.Size))
if err != nil {
return nil, err
}
var result = []*pb.ListUpdatedSSLCertOCSPResponse_SSLCertOCSP{}
for _, cert := range certs {
result = append(result, &pb.ListUpdatedSSLCertOCSPResponse_SSLCertOCSP{
SslCertId: int64(cert.Id),
Ocsp: []byte(cert.Ocsp),
Version: int64(cert.OcspUpdatedVersion),
})
}
return &pb.ListUpdatedSSLCertOCSPResponse{
SslCertOCSP: result,
}, nil
}

File diff suppressed because one or more lines are too long

View File

@@ -30,24 +30,27 @@ func init() {
type SSLCertUpdateOCSPTask struct { type SSLCertUpdateOCSPTask struct {
ticker *time.Ticker ticker *time.Ticker
httpClient *http.Client
} }
func NewSSLCertUpdateOCSPTask() *SSLCertUpdateOCSPTask { func NewSSLCertUpdateOCSPTask() *SSLCertUpdateOCSPTask {
return &SSLCertUpdateOCSPTask{ return &SSLCertUpdateOCSPTask{
ticker: time.NewTicker(1 * time.Minute), ticker: time.NewTicker(1 * time.Minute),
httpClient: utils.SharedHttpClient(5 * time.Second),
} }
} }
func (this *SSLCertUpdateOCSPTask) Start() { func (this *SSLCertUpdateOCSPTask) Start() {
for range this.ticker.C { for range this.ticker.C {
err := this.Loop() err := this.Loop(true)
if err != nil { if err != nil {
remotelogs.Error("SSLCertUpdateOCSPTask", err.Error()) remotelogs.Error("SSLCertUpdateOCSPTask", err.Error())
} }
} }
} }
func (this *SSLCertUpdateOCSPTask) Loop() error { func (this *SSLCertUpdateOCSPTask) Loop(checkLock bool) error {
if checkLock {
ok, err := models.SharedSysLockerDAO.Lock(nil, "ssl_cert_update_ocsp_task", 60-1) // 假设执行时间为1秒 ok, err := models.SharedSysLockerDAO.Lock(nil, "ssl_cert_update_ocsp_task", 60-1) // 假设执行时间为1秒
if err != nil { if err != nil {
return err return err
@@ -55,18 +58,29 @@ func (this *SSLCertUpdateOCSPTask) Loop() error {
if !ok { if !ok {
return nil return nil
} }
}
var tx *dbs.Tx var tx *dbs.Tx
// TODO 将来可以设置单次任务条数 // TODO 将来可以设置单次任务条数
var size int64 = 20 var size int64 = 60
certs, err := models.SharedSSLCertDAO.ListCertsToUpdateOCSP(tx, size) var maxAge = 7200
certs, err := models.SharedSSLCertDAO.ListCertsToUpdateOCSP(tx, maxAge, size)
if err != nil { if err != nil {
return errors.New("list certs failed: " + err.Error()) return errors.New("list certs failed: " + err.Error())
} }
if len(certs) == 0 { if len(certs) == 0 {
return nil return nil
} }
// 锁定
for _, cert := range certs {
err := models.SharedSSLCertDAO.PrepareCertOCSPUpdating(tx, int64(cert.Id))
if err != nil {
return errors.New("prepare cert ocsp updating failed: " + err.Error())
}
}
for _, cert := range certs { for _, cert := range certs {
ocspData, err := this.UpdateCertOCSP(cert) ocspData, err := this.UpdateCertOCSP(cert)
var errString = "" var errString = ""
@@ -75,7 +89,7 @@ func (this *SSLCertUpdateOCSPTask) Loop() error {
remotelogs.Warn("SSLCertUpdateOCSPTask", "update ocsp failed: "+errString) remotelogs.Warn("SSLCertUpdateOCSPTask", "update ocsp failed: "+errString)
} }
err = models.SharedSSLCertDAO.UpdateCertOSCP(tx, int64(cert.Id), ocspData, errString) err = models.SharedSSLCertDAO.UpdateCertOCSP(tx, int64(cert.Id), ocspData, errString)
if err != nil { if err != nil {
return errors.New("update ocsp failed: " + err.Error()) return errors.New("update ocsp failed: " + err.Error())
} }
@@ -84,6 +98,7 @@ func (this *SSLCertUpdateOCSPTask) Loop() error {
return nil return nil
} }
// UpdateCertOCSP 更新单个证书OCSP
func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocspData []byte, err error) { func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocspData []byte, err error) {
if certOne.IsCA == 1 || len(certOne.CertData) == 0 || len(certOne.KeyData) == 0 { if certOne.IsCA == 1 || len(certOne.CertData) == 0 || len(certOne.KeyData) == 0 {
return return
@@ -120,13 +135,12 @@ func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocsp
var issuerURL = cert.IssuingCertificateURL[0] var issuerURL = cert.IssuingCertificateURL[0]
var ocspServerURL = cert.OCSPServer[0] var ocspServerURL = cert.OCSPServer[0]
var httpClient = utils.SharedHttpClient(5 * time.Second)
issuerReq, err := http.NewRequest(http.MethodGet, issuerURL, nil) issuerReq, err := http.NewRequest(http.MethodGet, issuerURL, nil)
if err != nil { if err != nil {
return nil, errors.New("request issuer certificate failed: " + err.Error()) return nil, errors.New("request issuer certificate failed: " + err.Error())
} }
issuerReq.Header.Set("User-Agent", teaconst.ProductName+"/"+teaconst.Version) issuerReq.Header.Set("User-Agent", teaconst.ProductName+"/"+teaconst.Version)
issuerResp, err := httpClient.Do(issuerReq) issuerResp, err := this.httpClient.Do(issuerReq)
if err != nil { if err != nil {
return nil, errors.New("request issuer certificate failed: '" + issuerURL + "': " + err.Error()) return nil, errors.New("request issuer certificate failed: '" + issuerURL + "': " + err.Error())
} }
@@ -156,7 +170,7 @@ func (this *SSLCertUpdateOCSPTask) UpdateCertOCSP(certOne *models.SSLCert) (ocsp
ocspReq.Header.Set("Content-Type", "application/ocsp-request") ocspReq.Header.Set("Content-Type", "application/ocsp-request")
ocspReq.Header.Set("Accept", "application/ocsp-response") ocspReq.Header.Set("Accept", "application/ocsp-response")
ocspResp, err := httpClient.Do(ocspReq) ocspResp, err := this.httpClient.Do(ocspReq)
if err != nil { if err != nil {
return nil, errors.New("request ocsp failed: '" + ocspServerURL + "': " + err.Error()) return nil, errors.New("request ocsp failed: '" + ocspServerURL + "': " + err.Error())
} }

View File

@@ -12,7 +12,7 @@ func TestSSLCertUpdateOCSPTask_Loop(t *testing.T) {
dbs.NotifyReady() dbs.NotifyReady()
var task = tasks.NewSSLCertUpdateOCSPTask() var task = tasks.NewSSLCertUpdateOCSPTask()
err := task.Loop() err := task.Loop(false)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }