diff --git a/internal/apis/api_node.go b/internal/apis/api_node.go index d4205957..3378f6a8 100644 --- a/internal/apis/api_node.go +++ b/internal/apis/api_node.go @@ -72,6 +72,7 @@ func (this *APINode) listenRPC() error { pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{}) pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{}) pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{}) + pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{}) err = rpcServer.Serve(listener) if err != nil { return errors.New("[API]start rpc failed: " + err.Error()) diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 6f965ea5..b879c576 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -133,7 +133,12 @@ func (this *ServerDAO) CreateServer(adminId int64, userId int64, serverType serv serverId = types.Int64(op.Id) - _, err = this.RenewServerConfig(serverId) + _, err = this.RenewServerConfig(serverId, false) + if err != nil { + return serverId, err + } + + err = this.createEvent() if err != nil { return serverId, err } @@ -157,7 +162,7 @@ func (this *ServerDAO) UpdateServerBasic(serverId int64, name string, descriptio return err } - _, err = this.RenewServerConfig(serverId) + _, err = this.RenewServerConfig(serverId, false) if err != nil { return err } @@ -166,7 +171,7 @@ func (this *ServerDAO) UpdateServerBasic(serverId int64, name string, descriptio } // 修改服务配置 -func (this *ServerDAO) UpdateServerConfig(serverId int64, configJSON []byte) (isChanged bool, err error) { +func (this *ServerDAO) UpdateServerConfig(serverId int64, configJSON []byte, updateMd5 bool) (isChanged bool, err error) { if serverId <= 0 { return false, errors.New("serverId should not be smaller than 0") } @@ -195,7 +200,9 @@ func (this *ServerDAO) UpdateServerConfig(serverId int64, configJSON []byte) (is op.Config = JSONBytes(configJSON) op.Version = dbs.SQL("version+1") - op.ConfigMd5 = newConfigMd5 + if updateMd5 { + op.ConfigMd5 = newConfigMd5 + } _, err = this.Save(op) return true, err } @@ -216,7 +223,7 @@ func (this *ServerDAO) UpdateServerHTTP(serverId int64, config []byte) error { return err } - _, err = this.RenewServerConfig(serverId) + _, err = this.RenewServerConfig(serverId, false) if err != nil { return err } @@ -225,22 +232,22 @@ func (this *ServerDAO) UpdateServerHTTP(serverId int64, config []byte) error { } // 修改HTTPS配置 -func (this *ServerDAO) UpdateServerHTTPS(serverId int64, config []byte) error { +func (this *ServerDAO) UpdateServerHTTPS(serverId int64, httpsJSON []byte) error { if serverId <= 0 { return errors.New("serverId should not be smaller than 0") } - if len(config) == 0 { - config = []byte("null") + if len(httpsJSON) == 0 { + httpsJSON = []byte("null") } _, err := this.Query(). Pk(serverId). - Set("https", string(config)). + Set("https", string(httpsJSON)). Update() if err != nil { return err } - _, err = this.RenewServerConfig(serverId) + _, err = this.RenewServerConfig(serverId, false) if err != nil { return err } @@ -516,7 +523,7 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve } // SSL - if httpsConfig.SSLPolicyRef != nil { + if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(httpsConfig.SSLPolicyRef.SSLPolicyId) if err != nil { return nil, err @@ -614,7 +621,7 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve } // 更新服务的Config配置 -func (this *ServerDAO) RenewServerConfig(serverId int64) (isChanged bool, err error) { +func (this *ServerDAO) RenewServerConfig(serverId int64, updateMd5 bool) (isChanged bool, err error) { serverConfig, err := this.ComposeServerConfig(serverId) if err != nil { return false, err @@ -623,7 +630,7 @@ func (this *ServerDAO) RenewServerConfig(serverId int64) (isChanged bool, err er if err != nil { return false, err } - return this.UpdateServerConfig(serverId, data) + return this.UpdateServerConfig(serverId, data, updateMd5) } // 根据条件获取反向代理配置 diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index 5c00d755..b6c24263 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -31,6 +31,20 @@ func NewSSLCertDAO() *SSLCertDAO { var SharedSSLCertDAO = NewSSLCertDAO() +// 初始化 +func (this *SSLCertDAO) Init() { + this.DAOObject.Init() + this.DAOObject.OnUpdate(func() error { + return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + }) + this.DAOObject.OnInsert(func() error { + return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + }) + this.DAOObject.OnDelete(func() error { + return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + }) +} + // 启用条目 func (this *SSLCertDAO) EnableSSLCert(id int64) error { _, err := this.Query(). diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index 138de08b..ae262e33 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -2,10 +2,12 @@ package models import ( "encoding/json" + "errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" "strconv" ) @@ -29,6 +31,20 @@ func NewSSLPolicyDAO() *SSLPolicyDAO { var SharedSSLPolicyDAO = NewSSLPolicyDAO() +// 初始化 +func (this *SSLPolicyDAO) Init() { + this.DAOObject.Init() + this.DAOObject.OnUpdate(func() error { + return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + }) + this.DAOObject.OnInsert(func() error { + return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + }) + this.DAOObject.OnDelete(func() error { + return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) + }) +} + // 启用条目 func (this *SSLPolicyDAO) EnableSSLPolicy(id int64) error { _, err := this.Query(). @@ -97,7 +113,30 @@ func (this *SSLPolicyDAO) ComposePolicyConfig(policyId int64) (*sslconfigs.SSLPo } } + // client CA certs + if IsNotNull(policy.ClientCACerts) { + refs := []*sslconfigs.SSLCertRef{} + err = json.Unmarshal([]byte(policy.ClientCACerts), &refs) + if err != nil { + return nil, err + } + if len(refs) > 0 { + for _, ref := range refs { + certConfig, err := SharedSSLCertDAO.ComposeCertConfig(ref.CertId) + if err != nil { + return nil, err + } + if certConfig == nil { + continue + } + config.ClientCARefs = append(config.ClientCARefs, ref) + config.ClientCACerts = append(config.ClientCACerts, certConfig) + } + } + } + // cipher suites + config.CipherSuitesIsOn = policy.CipherSuitesIsOn == 1 if IsNotNull(policy.CipherSuites) { cipherSuites := []string{} err = json.Unmarshal([]byte(policy.CipherSuites), &cipherSuites) @@ -140,3 +179,76 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(certId int64) (polic } return policyIds, nil } + +// 创建Policy +func (this *SSLPolicyDAO) CreatePolicy(http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) (int64, error) { + op := NewSSLPolicyOperator() + op.State = SSLPolicyStateEnabled + op.IsOn = true + op.Http2Enabled = http2Enabled + op.MinVersion = minVersion + + if len(certsJSON) > 0 { + op.Certs = certsJSON + } + if len(hstsJSON) > 0 { + op.Hsts = hstsJSON + } + + op.ClientAuthType = clientAuthType + if len(clientCACertsJSON) > 0 { + op.ClientCACerts = clientCACertsJSON + } + + op.CipherSuitesIsOn = cipherSuitesIsOn + if len(cipherSuites) > 0 { + cipherSuitesJSON, err := json.Marshal(cipherSuites) + if err != nil { + return 0, err + } + op.CipherSuites = cipherSuitesJSON + } + _, err := this.Save(op) + if err != nil { + return 0, err + } + return types.Int64(op.Id), nil +} + +// 修改Policy +// 创建Policy +func (this *SSLPolicyDAO) UpdatePolicy(policyId int64, http2Enabled bool, minVersion string, certsJSON []byte, hstsJSON []byte, clientAuthType int32, clientCACertsJSON []byte, cipherSuitesIsOn bool, cipherSuites []string) error { + if policyId <= 0 { + return errors.New("invalid policyId") + } + + op := NewSSLPolicyOperator() + op.Id = policyId + op.Http2Enabled = http2Enabled + op.MinVersion = minVersion + + if len(certsJSON) > 0 { + op.Certs = certsJSON + } + if len(hstsJSON) > 0 { + op.Hsts = hstsJSON + } + + op.ClientAuthType = clientAuthType + if len(clientCACertsJSON) > 0 { + op.ClientCACerts = clientCACertsJSON + } + + op.CipherSuitesIsOn = cipherSuitesIsOn + if len(cipherSuites) > 0 { + cipherSuitesJSON, err := json.Marshal(cipherSuites) + if err != nil { + return err + } + op.CipherSuites = cipherSuitesJSON + } else { + op.CipherSuites = "[]" + } + _, err := this.Save(op) + return err +} diff --git a/internal/db/models/ssl_policy_model.go b/internal/db/models/ssl_policy_model.go index c3837b85..7b3f0936 100644 --- a/internal/db/models/ssl_policy_model.go +++ b/internal/db/models/ssl_policy_model.go @@ -2,33 +2,37 @@ package models // type SSLPolicy struct { - Id uint32 `field:"id"` // ID - AdminId uint32 `field:"adminId"` // 管理员ID - UserId uint32 `field:"userId"` // 用户ID - IsOn uint8 `field:"isOn"` // 是否启用 - Certs string `field:"certs"` // 证书列表 - ClientAuthType uint32 `field:"clientAuthType"` // 客户端认证类型 - MinVersion string `field:"minVersion"` // 支持的SSL最小版本 - CipherSuites string `field:"cipherSuites"` // 加密算法套件 - Hsts string `field:"hsts"` // HSTS设置 - Http2Enabled uint8 `field:"http2Enabled"` // 是否启用HTTP/2 - State uint8 `field:"state"` // 状态 - CreatedAt uint64 `field:"createdAt"` // 创建时间 + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + IsOn uint8 `field:"isOn"` // 是否启用 + Certs string `field:"certs"` // 证书列表 + ClientCACerts string `field:"clientCACerts"` // 客户端证书 + ClientAuthType uint32 `field:"clientAuthType"` // 客户端认证类型 + MinVersion string `field:"minVersion"` // 支持的SSL最小版本 + CipherSuitesIsOn uint8 `field:"cipherSuitesIsOn"` // 是否自定义加密算法套件 + CipherSuites string `field:"cipherSuites"` // 加密算法套件 + Hsts string `field:"hsts"` // HSTS设置 + Http2Enabled uint8 `field:"http2Enabled"` // 是否启用HTTP/2 + State uint8 `field:"state"` // 状态 + CreatedAt uint64 `field:"createdAt"` // 创建时间 } type SSLPolicyOperator struct { - Id interface{} // ID - AdminId interface{} // 管理员ID - UserId interface{} // 用户ID - IsOn interface{} // 是否启用 - Certs interface{} // 证书列表 - ClientAuthType interface{} // 客户端认证类型 - MinVersion interface{} // 支持的SSL最小版本 - CipherSuites interface{} // 加密算法套件 - Hsts interface{} // HSTS设置 - Http2Enabled interface{} // 是否启用HTTP/2 - State interface{} // 状态 - CreatedAt interface{} // 创建时间 + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + IsOn interface{} // 是否启用 + Certs interface{} // 证书列表 + ClientCACerts interface{} // 客户端证书 + ClientAuthType interface{} // 客户端认证类型 + MinVersion interface{} // 支持的SSL最小版本 + CipherSuitesIsOn interface{} // 是否自定义加密算法套件 + CipherSuites interface{} // 加密算法套件 + Hsts interface{} // HSTS设置 + Http2Enabled interface{} // 是否启用HTTP/2 + State interface{} // 状态 + CreatedAt interface{} // 创建时间 } func NewSSLPolicyOperator() *SSLPolicyOperator { diff --git a/internal/db/models/sys_event_types.go b/internal/db/models/sys_event_types.go index 9458f9f7..3e88f265 100644 --- a/internal/db/models/sys_event_types.go +++ b/internal/db/models/sys_event_types.go @@ -38,7 +38,7 @@ func (this *ServerChangeEvent) Run() error { return err } for _, serverId := range serverIds { - isChanged, err := SharedServerDAO.RenewServerConfig(serverId) + isChanged, err := SharedServerDAO.RenewServerConfig(serverId, true) if err != nil { return err } diff --git a/internal/rpc/services/service_ssl_policy.go b/internal/rpc/services/service_ssl_policy.go new file mode 100644 index 00000000..311c4fea --- /dev/null +++ b/internal/rpc/services/service_ssl_policy.go @@ -0,0 +1,64 @@ +package services + +import ( + "context" + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +type SSLPolicyService struct { +} + +// 创建Policy +func (this *SSLPolicyService) CreateSSLPolicy(ctx context.Context, req *pb.CreateSSLPolicyRequest) (*pb.CreateSSLPolicyResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + policyId, err := models.SharedSSLPolicyDAO.CreatePolicy(req.Http2Enabled, req.MinVersion, req.CertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) + if err != nil { + return nil, err + } + + return &pb.CreateSSLPolicyResponse{SslPolicyId: policyId}, nil +} + +// 修改Policy +func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.UpdateSSLPolicyRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedSSLPolicyDAO.UpdatePolicy(req.SslPolicyId, req.Http2Enabled, req.MinVersion, req.CertsJSON, req.HstsJSON, req.ClientAuthType, req.ClientCACertsJSON, req.CipherSuitesIsOn, req.CipherSuites) + if err != nil { + return nil, err + } + + return rpcutils.RPCUpdateSuccess() +} + +// 查找Policy +func (this *SSLPolicyService) FindEnabledSSLPolicyConfig(ctx context.Context, req *pb.FindEnabledSSLPolicyConfigRequest) (*pb.FindEnabledSSLPolicyConfigResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + config, err := models.SharedSSLPolicyDAO.ComposePolicyConfig(req.SslPolicyId) + if err != nil { + return nil, err + } + configJSON, err := json.Marshal(config) + if err != nil { + return nil, err + } + + return &pb.FindEnabledSSLPolicyConfigResponse{SslPolicyJSON: configJSON}, nil +}