diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index 08c55816..7ae9bcb4 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -87,7 +87,20 @@ func (this *OriginDAO) FindOriginName(tx *dbs.Tx, id int64) (string, error) { } // CreateOrigin 创建源站 -func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32, domains []string) (originId int64, err error) { +func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, + adminId int64, + userId int64, + name string, + addrJSON string, + description string, + weight int32, isOn bool, + connTimeout *shared.TimeDuration, + readTimeout *shared.TimeDuration, + idleTimeout *shared.TimeDuration, + maxConns int32, + maxIdleConns int32, + certRef *sslconfigs.SSLCertRef, + domains []string) (originId int64, err error) { op := NewOriginOperator() op.AdminId = adminId op.UserId = userId @@ -133,6 +146,15 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, nam } op.Weight = weight + // cert + if certRef != nil { + certRefJSON, err := json.Marshal(certRef) + if err != nil { + return 0, err + } + op.Cert = certRefJSON + } + if len(domains) > 0 { domainsJSON, err := json.Marshal(domains) if err != nil { @@ -152,7 +174,20 @@ func (this *OriginDAO) CreateOrigin(tx *dbs.Tx, adminId int64, userId int64, nam } // UpdateOrigin 修改源站 -func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, addrJSON string, description string, weight int32, isOn bool, connTimeout *shared.TimeDuration, readTimeout *shared.TimeDuration, idleTimeout *shared.TimeDuration, maxConns int32, maxIdleConns int32, domains []string) error { +func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, + originId int64, + name string, + addrJSON string, + description string, + weight int32, + isOn bool, + connTimeout *shared.TimeDuration, + readTimeout *shared.TimeDuration, + idleTimeout *shared.TimeDuration, + maxConns int32, + maxIdleConns int32, + certRef *sslconfigs.SSLCertRef, + domains []string) error { if originId <= 0 { return errors.New("invalid originId") } @@ -201,6 +236,17 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, add op.IsOn = isOn op.Version = dbs.SQL("version+1") + // cert + if certRef != nil { + certRefJSON, err := json.Marshal(certRef) + if err != nil { + return err + } + op.Cert = certRefJSON + } else { + op.Cert = dbs.SQL("NULL") + } + if len(domains) > 0 { domainsJSON, err := json.Marshal(domains) if err != nil { diff --git a/internal/rpc/services/service_origin.go b/internal/rpc/services/service_origin.go index 2255bff0..57764e3e 100644 --- a/internal/rpc/services/service_origin.go +++ b/internal/rpc/services/service_origin.go @@ -7,6 +7,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" "github.com/iwind/TeaGo/maps" ) @@ -58,7 +59,20 @@ func (this *OriginService) CreateOrigin(ctx context.Context, req *pb.CreateOrigi } } - originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, req.Domains) + // cert + var certRef *sslconfigs.SSLCertRef + if len(req.CertRefJSON) > 0 { + certRef = &sslconfigs.SSLCertRef{} + err = json.Unmarshal(req.CertRefJSON, certRef) + if err != nil { + return nil, err + } + if certRef.CertId <= 0 { + certRef = nil + } + } + + originId, err := models.SharedOriginDAO.CreateOrigin(tx, adminId, userId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains) if err != nil { return nil, err } @@ -112,7 +126,20 @@ func (this *OriginService) UpdateOrigin(ctx context.Context, req *pb.UpdateOrigi } } - err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, req.Domains) + // cert + var certRef *sslconfigs.SSLCertRef + if len(req.CertRefJSON) > 0 { + certRef = &sslconfigs.SSLCertRef{} + err = json.Unmarshal(req.CertRefJSON, certRef) + if err != nil { + return nil, err + } + if certRef.CertId <= 0 { + certRef = nil + } + } + + err = models.SharedOriginDAO.UpdateOrigin(tx, req.OriginId, req.Name, string(addrMap.AsJSON()), req.Description, req.Weight, req.IsOn, connTimeout, readTimeout, idleTimeout, req.MaxConns, req.MaxIdleConns, certRef, req.Domains) if err != nil { return nil, err }