diff --git a/internal/apis/api_node.go b/internal/apis/api_node.go index 9c5716e1..d4205957 100644 --- a/internal/apis/api_node.go +++ b/internal/apis/api_node.go @@ -71,6 +71,7 @@ func (this *APINode) listenRPC() error { pb.RegisterHTTPLocationServiceServer(rpcServer, &services.HTTPLocationService{}) pb.RegisterHTTPWebsocketServiceServer(rpcServer, &services.HTTPWebsocketService{}) pb.RegisterHTTPRewriteRuleServiceServer(rpcServer, &services.HTTPRewriteRuleService{}) + pb.RegisterSSLCertServiceServer(rpcServer, &services.SSLCertService{}) err = rpcServer.Serve(listener) if err != nil { return errors.New("[API]start rpc failed: " + err.Error()) diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index dd40dc8a..6f965ea5 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -10,6 +10,8 @@ import ( "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/types" + "strconv" + "strings" ) const ( @@ -512,6 +514,18 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve if err != nil { return nil, err } + + // SSL + if httpsConfig.SSLPolicyRef != nil { + sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(httpsConfig.SSLPolicyRef.SSLPolicyId) + if err != nil { + return nil, err + } + if sslPolicyConfig != nil { + httpsConfig.SSLPolicy = sslPolicyConfig + } + } + config.HTTPS = httpsConfig } @@ -532,6 +546,18 @@ func (this *ServerDAO) ComposeServerConfig(serverId int64) (*serverconfigs.Serve if err != nil { return nil, err } + + // SSL + if tlsConfig.SSLPolicyRef != nil { + sslPolicyConfig, err := SharedSSLPolicyDAO.ComposePolicyConfig(tlsConfig.SSLPolicyRef.SSLPolicyId) + if err != nil { + return nil, err + } + if sslPolicyConfig != nil { + tlsConfig.SSLPolicy = sslPolicyConfig + } + } + config.TLS = tlsConfig } @@ -617,6 +643,7 @@ func (this *ServerDAO) FindReverseProxyRef(serverId int64) (*serverconfigs.Rever return config, err } +// 查找Server对应的WebId func (this *ServerDAO) FindServerWebId(serverId int64) (int64, error) { webId, err := this.Query(). Pk(serverId). @@ -628,6 +655,42 @@ func (this *ServerDAO) FindServerWebId(serverId int64) (int64, error) { return int64(webId), nil } +// 计算使用SSL策略的所有服务数量 +func (this *ServerDAO) CountServersWithSSLPolicyIds(sslPolicyIds []int64) (count int64, err error) { + if len(sslPolicyIds) == 0 { + return + } + policyStringIds := []string{} + for _, policyId := range sslPolicyIds { + policyStringIds = append(policyStringIds, strconv.FormatInt(policyId, 10)) + } + return this.Query(). + State(ServerStateEnabled). + Where("(FIND_IN_SET(JSON_EXTRACT(https, '$.sslPolicyRef.sslPolicyId'), :policyIds) OR FIND_IN_SET(JSON_EXTRACT(tls, '$.sslPolicyRef.sslPolicyId'), :policyIds))"). + Param("policyIds", strings.Join(policyStringIds, ",")). + Count() +} + +// 查找使用SSL策略的所有服务 +func (this *ServerDAO) FindAllServersWithSSLPolicyIds(sslPolicyIds []int64) (result []*Server, err error) { + if len(sslPolicyIds) == 0 { + return + } + policyStringIds := []string{} + for _, policyId := range sslPolicyIds { + policyStringIds = append(policyStringIds, strconv.FormatInt(policyId, 10)) + } + _, err = this.Query(). + State(ServerStateEnabled). + Result("id", "name", "https", "tls", "isOn", "type"). + Where("(FIND_IN_SET(JSON_EXTRACT(https, '$.sslPolicyRef.sslPolicyId'), :policyIds) OR FIND_IN_SET(JSON_EXTRACT(tls, '$.sslPolicyRef.sslPolicyId'), :policyIds))"). + Param("policyIds", strings.Join(policyStringIds, ",")). + Slice(&result). + AscPk(). + FindAll() + return +} + // 创建事件 func (this *ServerDAO) createEvent() error { return SharedSysEventDAO.CreateEvent(NewServerChangeEvent()) diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go new file mode 100644 index 00000000..5c00d755 --- /dev/null +++ b/internal/db/models/ssl_cert_dao.go @@ -0,0 +1,242 @@ +package models + +import ( + "encoding/json" + "errors" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" + "time" +) + +const ( + SSLCertStateEnabled = 1 // 已启用 + SSLCertStateDisabled = 0 // 已禁用 +) + +type SSLCertDAO dbs.DAO + +func NewSSLCertDAO() *SSLCertDAO { + return dbs.NewDAO(&SSLCertDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeSSLCerts", + Model: new(SSLCert), + PkName: "id", + }, + }).(*SSLCertDAO) +} + +var SharedSSLCertDAO = NewSSLCertDAO() + +// 启用条目 +func (this *SSLCertDAO) EnableSSLCert(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", SSLCertStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *SSLCertDAO) DisableSSLCert(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", SSLCertStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *SSLCertDAO) FindEnabledSSLCert(id int64) (*SSLCert, error) { + result, err := this.Query(). + Pk(id). + Attr("state", SSLCertStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*SSLCert), err +} + +// 根据主键查找名称 +func (this *SSLCertDAO) FindSSLCertName(id int64) (string, error) { + return this.Query(). + Pk(id). + Result("name"). + FindStringCol("") +} + +// 创建证书 +func (this *SSLCertDAO) CreateCert(isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) (int64, error) { + op := NewSSLCertOperator() + op.State = SSLCertStateEnabled + op.IsOn = isOn + op.Name = name + op.Description = description + op.ServerName = serverName + op.IsCA = isCA + op.CertData = certData + op.KeyData = keyData + op.TimeBeginAt = timeBeginAt + op.TimeEndAt = timeEndAt + + dnsNamesJSON, err := json.Marshal(dnsNames) + if err != nil { + return 0, err + } + op.DnsNames = dnsNamesJSON + + commonNamesJSON, err := json.Marshal(commonNames) + if err != nil { + return 0, err + } + op.CommonNames = commonNamesJSON + + _, err = this.Save(op) + if err != nil { + return 0, err + } + return types.Int64(op.Id), nil +} + +// 修改证书 +func (this *SSLCertDAO) UpdateCert(certId int64, isOn bool, name string, description string, serverName string, isCA bool, certData []byte, keyData []byte, timeBeginAt int64, timeEndAt int64, dnsNames []string, commonNames []string) error { + if certId <= 0 { + return errors.New("invalid certId") + } + op := NewSSLCertOperator() + op.Id = certId + op.IsOn = isOn + op.Name = name + op.Description = description + op.ServerName = serverName + op.IsCA = isCA + op.CertData = certData + op.KeyData = keyData + op.TimeBeginAt = timeBeginAt + op.TimeEndAt = timeEndAt + + dnsNamesJSON, err := json.Marshal(dnsNames) + if err != nil { + return err + } + op.DnsNames = dnsNamesJSON + + commonNamesJSON, err := json.Marshal(commonNames) + if err != nil { + return err + } + op.CommonNames = commonNamesJSON + + _, err = this.Save(op) + return err +} + +// 组合配置 +func (this *SSLCertDAO) ComposeCertConfig(certId int64) (*sslconfigs.SSLCertConfig, error) { + cert, err := this.FindEnabledSSLCert(certId) + if err != nil { + return nil, err + } + if cert == nil { + return nil, nil + } + + config := &sslconfigs.SSLCertConfig{} + config.Id = int64(cert.Id) + config.IsOn = cert.IsOn == 1 + config.IsCA = cert.IsCA == 1 + config.Name = cert.Name + config.Description = cert.Description + config.CertData = []byte(cert.CertData) + config.KeyData = []byte(cert.KeyData) + config.ServerName = cert.ServerName + config.TimeBeginAt = int64(cert.TimeBeginAt) + config.TimeEndAt = int64(cert.TimeEndAt) + + if IsNotNull(cert.DnsNames) { + dnsNames := []string{} + err := json.Unmarshal([]byte(cert.DnsNames), &dnsNames) + if err != nil { + return nil, err + } + config.DNSNames = dnsNames + } + + if IsNotNull(cert.CommonNames) { + commonNames := []string{} + err := json.Unmarshal([]byte(cert.CommonNames), &commonNames) + if err != nil { + return nil, err + } + config.CommonNames = commonNames + } + + return config, nil +} + +// 计算符合条件的证书数量 +func (this *SSLCertDAO) CountCerts(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string) (int64, error) { + query := this.Query(). + State(SSLCertStateEnabled) + if isCA { + query.Attr("isCA", true) + } + if isAvailable { + query.Where("timeBeginAt<=UNIX_TIMESTAMP() AND timeEndAt>=UNIX_TIMESTAMP()") + } + if isExpired { + query.Where("timeEndAt 0 { + query.Where("timeEndAt>UNIX_TIMESTAMP() AND timeEndAt<:expiredAt"). + Param("expiredAt", time.Now().Unix()+expiringDays*86400) + } + if len(keyword) > 0 { + query.Where("(name LIKE :keyword OR description LIKE :keyword OR dnsNames LIKE :keyword OR commonNames LIKE :keyword)"). + Param("keyword", "%"+keyword+"%") + } + return query.Count() +} + +// 列出符合条件的证书 +func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, expiringDays int64, keyword string, offset int64, size int64) (certIds []int64, err error) { + query := this.Query(). + State(SSLCertStateEnabled) + if isCA { + query.Attr("isCA", true) + } + if isAvailable { + query.Where("timeBeginAt<=UNIX_TIMESTAMP() AND timeEndAt>=UNIX_TIMESTAMP()") + } + if isExpired { + query.Where("timeEndAt 0 { + query.Where("timeEndAt>UNIX_TIMESTAMP() AND timeEndAt<:expiredAt"). + Param("expiredAt", time.Now().Unix()+expiringDays*86400) + } + if len(keyword) > 0 { + query.Where("(name LIKE :keyword OR description LIKE :keyword OR dnsNames LIKE :keyword OR commonNames LIKE :keyword)"). + Param("keyword", "%"+keyword+"%") + } + + ones, err := query. + ResultPk(). + DescPk(). + Offset(offset). + Limit(size). + FindAll() + if err != nil { + return nil, err + } + + result := []int64{} + for _, one := range ones { + result = append(result, int64(one.(*SSLCert).Id)) + } + return result, nil +} diff --git a/internal/db/models/ssl_cert_dao_test.go b/internal/db/models/ssl_cert_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/ssl_cert_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/ssl_cert_group_dao.go b/internal/db/models/ssl_cert_group_dao.go new file mode 100644 index 00000000..61fe1397 --- /dev/null +++ b/internal/db/models/ssl_cert_group_dao.go @@ -0,0 +1,65 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" +) + +const ( + SSLCertGroupStateEnabled = 1 // 已启用 + SSLCertGroupStateDisabled = 0 // 已禁用 +) + +type SSLCertGroupDAO dbs.DAO + +func NewSSLCertGroupDAO() *SSLCertGroupDAO { + return dbs.NewDAO(&SSLCertGroupDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeSSLCertGroups", + Model: new(SSLCertGroup), + PkName: "id", + }, + }).(*SSLCertGroupDAO) +} + +var SharedSSLCertGroupDAO = NewSSLCertGroupDAO() + +// 启用条目 +func (this *SSLCertGroupDAO) EnableSSLCertGroup(id uint32) error { + _, err := this.Query(). + Pk(id). + Set("state", SSLCertGroupStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *SSLCertGroupDAO) DisableSSLCertGroup(id uint32) error { + _, err := this.Query(). + Pk(id). + Set("state", SSLCertGroupStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *SSLCertGroupDAO) FindEnabledSSLCertGroup(id uint32) (*SSLCertGroup, error) { + result, err := this.Query(). + Pk(id). + Attr("state", SSLCertGroupStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*SSLCertGroup), err +} + +// 根据主键查找名称 +func (this *SSLCertGroupDAO) FindSSLCertGroupName(id uint32) (string, error) { + return this.Query(). + Pk(id). + Result("name"). + FindStringCol("") +} diff --git a/internal/db/models/ssl_cert_group_dao_test.go b/internal/db/models/ssl_cert_group_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/ssl_cert_group_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/ssl_cert_group_model.go b/internal/db/models/ssl_cert_group_model.go new file mode 100644 index 00000000..6b055b22 --- /dev/null +++ b/internal/db/models/ssl_cert_group_model.go @@ -0,0 +1,26 @@ +package models + +// +type SSLCertGroup struct { + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + Name string `field:"name"` // 分组名 + Order uint32 `field:"order"` // 分组排序 + State uint8 `field:"state"` // 状态 + CreatedAt uint64 `field:"createdAt"` // 创建时间 +} + +type SSLCertGroupOperator struct { + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + Name interface{} // 分组名 + Order interface{} // 分组排序 + State interface{} // 状态 + CreatedAt interface{} // 创建时间 +} + +func NewSSLCertGroupOperator() *SSLCertGroupOperator { + return &SSLCertGroupOperator{} +} diff --git a/internal/db/models/ssl_cert_group_model_ext.go b/internal/db/models/ssl_cert_group_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/ssl_cert_group_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/db/models/ssl_cert_model.go b/internal/db/models/ssl_cert_model.go new file mode 100644 index 00000000..fe2c4d19 --- /dev/null +++ b/internal/db/models/ssl_cert_model.go @@ -0,0 +1,48 @@ +package models + +// SSL证书 +type SSLCert struct { + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + State uint8 `field:"state"` // 状态 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + UpdatedAt uint64 `field:"updatedAt"` // 修改时间 + IsOn uint8 `field:"isOn"` // 是否启用 + Name string `field:"name"` // 证书名 + Description string `field:"description"` // 描述 + CertData string `field:"certData"` // 证书内容 + KeyData string `field:"keyData"` // 密钥内容 + ServerName string `field:"serverName"` // 证书使用的主机名 + IsCA uint8 `field:"isCA"` // 是否为CA证书 + GroupIds string `field:"groupIds"` // 证书分组 + TimeBeginAt uint64 `field:"timeBeginAt"` // 开始时间 + TimeEndAt uint64 `field:"timeEndAt"` // 结束时间 + DnsNames string `field:"dnsNames"` // DNS名称列表 + CommonNames string `field:"commonNames"` // 发行单位列表 +} + +type SSLCertOperator struct { + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + State interface{} // 状态 + CreatedAt interface{} // 创建时间 + UpdatedAt interface{} // 修改时间 + IsOn interface{} // 是否启用 + Name interface{} // 证书名 + Description interface{} // 描述 + CertData interface{} // 证书内容 + KeyData interface{} // 密钥内容 + ServerName interface{} // 证书使用的主机名 + IsCA interface{} // 是否为CA证书 + GroupIds interface{} // 证书分组 + TimeBeginAt interface{} // 开始时间 + TimeEndAt interface{} // 结束时间 + DnsNames interface{} // DNS名称列表 + CommonNames interface{} // 发行单位列表 +} + +func NewSSLCertOperator() *SSLCertOperator { + return &SSLCertOperator{} +} diff --git a/internal/db/models/ssl_cert_model_ext.go b/internal/db/models/ssl_cert_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/ssl_cert_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go new file mode 100644 index 00000000..138de08b --- /dev/null +++ b/internal/db/models/ssl_policy_dao.go @@ -0,0 +1,142 @@ +package models + +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "strconv" +) + +const ( + SSLPolicyStateEnabled = 1 // 已启用 + SSLPolicyStateDisabled = 0 // 已禁用 +) + +type SSLPolicyDAO dbs.DAO + +func NewSSLPolicyDAO() *SSLPolicyDAO { + return dbs.NewDAO(&SSLPolicyDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeSSLPolicies", + Model: new(SSLPolicy), + PkName: "id", + }, + }).(*SSLPolicyDAO) +} + +var SharedSSLPolicyDAO = NewSSLPolicyDAO() + +// 启用条目 +func (this *SSLPolicyDAO) EnableSSLPolicy(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", SSLPolicyStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *SSLPolicyDAO) DisableSSLPolicy(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", SSLPolicyStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *SSLPolicyDAO) FindEnabledSSLPolicy(id int64) (*SSLPolicy, error) { + result, err := this.Query(). + Pk(id). + Attr("state", SSLPolicyStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*SSLPolicy), err +} + +// 组合配置 +func (this *SSLPolicyDAO) ComposePolicyConfig(policyId int64) (*sslconfigs.SSLPolicy, error) { + policy, err := this.FindEnabledSSLPolicy(policyId) + if err != nil { + return nil, err + } + if policy == nil { + return nil, nil + } + config := &sslconfigs.SSLPolicy{} + config.Id = int64(policy.Id) + config.IsOn = policy.IsOn == 1 + config.ClientAuthType = int(policy.ClientAuthType) + config.HTTP2Enabled = policy.Http2Enabled == 1 + config.MinVersion = policy.MinVersion + + // certs + if IsNotNull(policy.Certs) { + refs := []*sslconfigs.SSLCertRef{} + err = json.Unmarshal([]byte(policy.Certs), &refs) + if err != nil { + return nil, err + } + if len(refs) > 0 { + for _, ref := range refs { + certConfig, err := SharedSSLCertDAO.ComposeCertConfig(ref.CertId) + if err != nil { + return nil, err + } + if certConfig == nil { + continue + } + config.CertRefs = append(config.CertRefs, ref) + config.Certs = append(config.Certs, certConfig) + } + } + } + + // cipher suites + if IsNotNull(policy.CipherSuites) { + cipherSuites := []string{} + err = json.Unmarshal([]byte(policy.CipherSuites), &cipherSuites) + if err != nil { + return nil, err + } + config.CipherSuites = cipherSuites + } + + // hsts + if IsNotNull(policy.Hsts) { + hstsConfig := &sslconfigs.HSTSConfig{} + err = json.Unmarshal([]byte(policy.Hsts), hstsConfig) + if err != nil { + return nil, err + } + config.HSTS = hstsConfig + } + + return config, nil +} + +// 查询使用单个证书的所有策略ID +func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(certId int64) (policyIds []int64, err error) { + if certId <= 0 { + return + } + + ones, err := this.Query(). + State(SSLPolicyStateEnabled). + ResultPk(). + Where(`JSON_CONTAINS(certs, '{"certId": ` + strconv.FormatInt(certId, 10) + ` }')`). + Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + FindAll() + if err != nil { + return nil, err + } + for _, one := range ones { + policyIds = append(policyIds, int64(one.(*SSLPolicy).Id)) + } + return policyIds, nil +} diff --git a/internal/db/models/ssl_policy_dao_test.go b/internal/db/models/ssl_policy_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/ssl_policy_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/ssl_policy_model.go b/internal/db/models/ssl_policy_model.go new file mode 100644 index 00000000..c3837b85 --- /dev/null +++ b/internal/db/models/ssl_policy_model.go @@ -0,0 +1,36 @@ +package models + +// +type SSLPolicy struct { + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + IsOn uint8 `field:"isOn"` // 是否启用 + Certs string `field:"certs"` // 证书列表 + ClientAuthType uint32 `field:"clientAuthType"` // 客户端认证类型 + MinVersion string `field:"minVersion"` // 支持的SSL最小版本 + CipherSuites string `field:"cipherSuites"` // 加密算法套件 + Hsts string `field:"hsts"` // HSTS设置 + Http2Enabled uint8 `field:"http2Enabled"` // 是否启用HTTP/2 + State uint8 `field:"state"` // 状态 + CreatedAt uint64 `field:"createdAt"` // 创建时间 +} + +type SSLPolicyOperator struct { + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + IsOn interface{} // 是否启用 + Certs interface{} // 证书列表 + ClientAuthType interface{} // 客户端认证类型 + MinVersion interface{} // 支持的SSL最小版本 + CipherSuites interface{} // 加密算法套件 + Hsts interface{} // HSTS设置 + Http2Enabled interface{} // 是否启用HTTP/2 + State interface{} // 状态 + CreatedAt interface{} // 创建时间 +} + +func NewSSLPolicyOperator() *SSLPolicyOperator { + return &SSLPolicyOperator{} +} diff --git a/internal/db/models/ssl_policy_model_ext.go b/internal/db/models/ssl_policy_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/ssl_policy_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index f88f8a1f..7faca1bf 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -15,6 +15,7 @@ type ServerService struct { // 创建服务 func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServerRequest) (*pb.CreateServerResponse, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -35,6 +36,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe // 修改服务 func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.UpdateServerBasicRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -77,6 +79,7 @@ func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.Update // 修改HTTP服务 func (this *ServerService) UpdateServerHTTP(ctx context.Context, req *pb.UpdateServerHTTPRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -106,6 +109,7 @@ func (this *ServerService) UpdateServerHTTP(ctx context.Context, req *pb.UpdateS // 修改HTTPS服务 func (this *ServerService) UpdateServerHTTPS(ctx context.Context, req *pb.UpdateServerHTTPSRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -135,6 +139,7 @@ func (this *ServerService) UpdateServerHTTPS(ctx context.Context, req *pb.Update // 修改TCP服务 func (this *ServerService) UpdateServerTCP(ctx context.Context, req *pb.UpdateServerTCPRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -164,6 +169,7 @@ func (this *ServerService) UpdateServerTCP(ctx context.Context, req *pb.UpdateSe // 修改TLS服务 func (this *ServerService) UpdateServerTLS(ctx context.Context, req *pb.UpdateServerTLSRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -193,6 +199,7 @@ func (this *ServerService) UpdateServerTLS(ctx context.Context, req *pb.UpdateSe // 修改Unix服务 func (this *ServerService) UpdateServerUnix(ctx context.Context, req *pb.UpdateServerUnixRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -222,6 +229,7 @@ func (this *ServerService) UpdateServerUnix(ctx context.Context, req *pb.UpdateS // 修改UDP服务 func (this *ServerService) UpdateServerUDP(ctx context.Context, req *pb.UpdateServerUDPRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -251,6 +259,7 @@ func (this *ServerService) UpdateServerUDP(ctx context.Context, req *pb.UpdateSe // 修改Web服务 func (this *ServerService) UpdateServerWeb(ctx context.Context, req *pb.UpdateServerWebRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -280,6 +289,7 @@ func (this *ServerService) UpdateServerWeb(ctx context.Context, req *pb.UpdateSe // 修改反向代理服务 func (this *ServerService) UpdateServerReverseProxy(ctx context.Context, req *pb.UpdateServerReverseProxyRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -309,6 +319,7 @@ func (this *ServerService) UpdateServerReverseProxy(ctx context.Context, req *pb // 修改域名服务 func (this *ServerService) UpdateServerNames(ctx context.Context, req *pb.UpdateServerNamesRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -338,8 +349,8 @@ func (this *ServerService) UpdateServerNames(ctx context.Context, req *pb.Update // 计算服务数量 func (this *ServerService) CountAllEnabledServers(ctx context.Context, req *pb.CountAllEnabledServersRequest) (*pb.CountAllEnabledServersResponse, error) { - _ = req - _, _, err := rpcutils.ValidateRequest(ctx) + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err } @@ -353,6 +364,7 @@ func (this *ServerService) CountAllEnabledServers(ctx context.Context, req *pb.C // 列出单页服务 func (this *ServerService) ListEnabledServers(ctx context.Context, req *pb.ListEnabledServersRequest) (*pb.ListEnabledServersResponse, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -394,6 +406,7 @@ func (this *ServerService) ListEnabledServers(ctx context.Context, req *pb.ListE // 禁用某服务 func (this *ServerService) DisableServer(ctx context.Context, req *pb.DisableServerRequest) (*pb.DisableServerResponse, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -425,6 +438,7 @@ func (this *ServerService) DisableServer(ctx context.Context, req *pb.DisableSer // 查找单个服务 func (this *ServerService) FindEnabledServer(ctx context.Context, req *pb.FindEnabledServerRequest) (*pb.FindEnabledServerResponse, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -473,6 +487,7 @@ func (this *ServerService) FindEnabledServer(ctx context.Context, req *pb.FindEn // func (this *ServerService) FindEnabledServerType(ctx context.Context, req *pb.FindEnabledServerTypeRequest) (*pb.FindEnabledServerTypeResponse, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -488,6 +503,7 @@ func (this *ServerService) FindEnabledServerType(ctx context.Context, req *pb.Fi // 查找反向代理设置 func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Context, req *pb.FindAndInitServerReverseProxyConfigRequest) (*pb.FindAndInitServerReverseProxyConfigResponse, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -538,6 +554,7 @@ func (this *ServerService) FindAndInitServerReverseProxyConfig(ctx context.Conte // 初始化Web设置 func (this *ServerService) FindAndInitServerWebConfig(ctx context.Context, req *pb.FindAndInitServerWebConfigRequest) (*pb.FindAndInitServerWebConfigResponse, error) { + // 校验请求 _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) if err != nil { return nil, err @@ -566,3 +583,60 @@ func (this *ServerService) FindAndInitServerWebConfig(ctx context.Context, req * return &pb.FindAndInitServerWebConfigResponse{WebJSON: configJSON}, nil } + +// 计算使用某个SSL证书的服务数量 +func (this *ServerService) CountServersWithSSLCertId(ctx context.Context, req *pb.CountServersWithSSLCertIdRequest) (*pb.CountServersWithSSLCertIdResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.CertId) + if err != nil { + return nil, err + } + + if len(policyIds) == 0 { + return &pb.CountServersWithSSLCertIdResponse{Count: 0}, nil + } + + count, err := models.SharedServerDAO.CountServersWithSSLPolicyIds(policyIds) + if err != nil { + return nil, err + } + + return &pb.CountServersWithSSLCertIdResponse{Count: count}, nil +} + +// 查找使用某个SSL证书的所有服务 +func (this *ServerService) FindAllServersWithSSLCertId(ctx context.Context, req *pb.FindAllServersWithSSLCertIdRequest) (*pb.FindAllServersWithSSLCertIdResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(req.CertId) + if err != nil { + return nil, err + } + if len(policyIds) == 0 { + return &pb.FindAllServersWithSSLCertIdResponse{Servers: nil}, nil + } + + servers, err := models.SharedServerDAO.FindAllServersWithSSLPolicyIds(policyIds) + if err != nil { + return nil, err + } + result := []*pb.Server{} + for _, server := range servers { + result = append(result, &pb.Server{ + Id: int64(server.Id), + Name: server.Name, + IsOn: server.IsOn == 1, + Type: server.Type, + }) + } + return &pb.FindAllServersWithSSLCertIdResponse{Servers: result}, nil +} diff --git a/internal/rpc/services/service_ssl_cert.go b/internal/rpc/services/service_ssl_cert.go new file mode 100644 index 00000000..0f9c2209 --- /dev/null +++ b/internal/rpc/services/service_ssl_cert.go @@ -0,0 +1,133 @@ +package services + +import ( + "context" + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/sslconfigs" +) + +// SSL证书相关服务 +type SSLCertService struct { +} + +// 创建Cert +func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSLCertRequest) (*pb.CreateSSLCertResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + certId, err := models.SharedSSLCertDAO.CreateCert(req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) + if err != nil { + return nil, err + } + + return &pb.CreateSSLCertResponse{CertId: certId}, nil +} + +// 修改Cert +func (this *SSLCertService) UpdateSSLCert(ctx context.Context, req *pb.UpdateSSLCertRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedSSLCertDAO.UpdateCert(req.CertId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) + if err != nil { + return nil, err + } + + return rpcutils.RPCUpdateSuccess() +} + +// 查找证书配置 +func (this *SSLCertService) FindEnabledSSLCertConfig(ctx context.Context, req *pb.FindEnabledSSLCertConfigRequest) (*pb.FindEnabledSSLCertConfigResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + config, err := models.SharedSSLCertDAO.ComposeCertConfig(req.CertId) + if err != nil { + return nil, err + } + + configJSON, err := json.Marshal(config) + if err != nil { + return nil, err + } + return &pb.FindEnabledSSLCertConfigResponse{CertJSON: configJSON}, nil +} + +// 删除证书 +func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSLCertRequest) (*pb.RPCDeleteSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedSSLCertDAO.DisableSSLCert(req.CertId) + if err != nil { + return nil, err + } + + return rpcutils.RPCDeleteSuccess() +} + +// 计算匹配的Cert数量 +func (this *SSLCertService) CountSSLCerts(ctx context.Context, req *pb.CountSSLCertRequest) (*pb.CountSSLCertResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + count, err := models.SharedSSLCertDAO.CountCerts(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword) + if err != nil { + return nil, err + } + + return &pb.CountSSLCertResponse{ + Count: count, + }, nil +} + +// 列出单页匹配的Cert +func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCertsRequest) (*pb.ListSSLCertsResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + certIds, err := models.SharedSSLCertDAO.ListCertIds(req.IsCA, req.IsAvailable, req.IsExpired, int64(req.ExpiringDays), req.Keyword, req.Offset, req.Size) + if err != nil { + return nil, err + } + + certConfigs := []*sslconfigs.SSLCertConfig{} + for _, certId := range certIds { + certConfig, err := models.SharedSSLCertDAO.ComposeCertConfig(certId) + if err != nil { + return nil, err + } + + // 这里不需要数据内容 + certConfig.CertData = nil + certConfig.KeyData = nil + + certConfigs = append(certConfigs, certConfig) + } + certConfigsJSON, err := json.Marshal(certConfigs) + if err != nil { + return nil, err + } + return &pb.ListSSLCertsResponse{CertsJSON: certConfigsJSON}, nil +}