diff --git a/internal/db/models/acme_task_dao.go b/internal/db/models/acme_task_dao.go new file mode 100644 index 00000000..9d848799 --- /dev/null +++ b/internal/db/models/acme_task_dao.go @@ -0,0 +1,175 @@ +package models + +import ( + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/errors" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" +) + +const ( + ACMETaskStateEnabled = 1 // 已启用 + ACMETaskStateDisabled = 0 // 已禁用 +) + +type ACMETaskDAO dbs.DAO + +func NewACMETaskDAO() *ACMETaskDAO { + return dbs.NewDAO(&ACMETaskDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeACMETasks", + Model: new(ACMETask), + PkName: "id", + }, + }).(*ACMETaskDAO) +} + +var SharedACMETaskDAO *ACMETaskDAO + +func init() { + dbs.OnReady(func() { + SharedACMETaskDAO = NewACMETaskDAO() + }) +} + +// 启用条目 +func (this *ACMETaskDAO) EnableACMETask(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", ACMETaskStateEnabled). + Update() + return err +} + +// 禁用条目 +func (this *ACMETaskDAO) DisableACMETask(id int64) error { + _, err := this.Query(). + Pk(id). + Set("state", ACMETaskStateDisabled). + Update() + return err +} + +// 查找启用中的条目 +func (this *ACMETaskDAO) FindEnabledACMETask(id int64) (*ACMETask, error) { + result, err := this.Query(). + Pk(id). + Attr("state", ACMETaskStateEnabled). + Find() + if result == nil { + return nil, err + } + return result.(*ACMETask), err +} + +// 计算某个ACME用户相关的任务数量 +func (this *ACMETaskDAO) CountACMETasksWithACMEUserId(acmeUserId int64) (int64, error) { + return this.Query(). + State(ACMETaskStateEnabled). + Attr("acmeUserId", acmeUserId). + Count() +} + +// 计算某个DNS服务商相关的任务数量 +func (this *ACMETaskDAO) CountACMETasksWithDNSProviderId(dnsProviderId int64) (int64, error) { + return this.Query(). + State(ACMETaskStateEnabled). + Attr("dnsProviderId", dnsProviderId). + Count() +} + +// 停止某个证书相关任务 +func (this *ACMETaskDAO) DisableAllTasksWithCertId(certId int64) error { + _, err := this.Query(). + Attr("certId", certId). + Set("state", ACMETaskStateDisabled). + Update() + return err +} + +// 计算所有任务数量 +func (this *ACMETaskDAO) CountAllEnabledACMETasks(adminId int64, userId int64) (int64, error) { + return NewQuery(this, adminId, userId). + State(ACMETaskStateEnabled). + Count() +} + +// 列出单页任务 +func (this *ACMETaskDAO) ListEnabledACMETasks(adminId int64, userId int64, offset int64, size int64) (result []*ACMETask, err error) { + _, err = NewQuery(this, adminId, userId). + State(ACMETaskStateEnabled). + DescPk(). + Offset(offset). + Limit(size). + Slice(&result). + FindAll() + return +} + +// 创建任务 +func (this *ACMETaskDAO) CreateACMETask(adminId int64, userId int64, acmeUserId int64, dnsProviderId int64, dnsDomain string, domains []string, autoRenew bool) (int64, error) { + op := NewACMETaskOperator() + op.AdminId = adminId + op.UserId = userId + op.AcmeUserId = acmeUserId + op.DnsProviderId = dnsProviderId + op.DnsDomain = dnsDomain + + if len(domains) > 0 { + domainsJSON, err := json.Marshal(domains) + if err != nil { + return 0, err + } + op.Domains = domainsJSON + } else { + op.Domains = "[]" + } + + op.AutoRenew = autoRenew + op.IsOn = true + op.State = ACMETaskStateEnabled + op.IsOk = false + _, err := this.Save(op) + if err != nil { + return 0, err + } + return types.Int64(op.Id), nil +} + +// 修改任务 +func (this *ACMETaskDAO) UpdateACMETask(acmeTaskId int64, acmeUserId int64, dnsProviderId int64, dnsDomain string, domains []string, autoRenew bool) error { + if acmeTaskId <= 0 { + return errors.New("invalid acmeTaskId") + } + + op := NewACMETaskOperator() + op.Id = acmeTaskId + op.AcmeUserId = acmeUserId + op.DnsProviderId = dnsProviderId + op.DnsDomain = dnsDomain + + if len(domains) > 0 { + domainsJSON, err := json.Marshal(domains) + if err != nil { + return err + } + op.Domains = domainsJSON + } else { + op.Domains = "[]" + } + + op.AutoRenew = autoRenew + _, err := this.Save(op) + return err +} + +// 检查权限 +func (this *ACMETaskDAO) CheckACMETask(adminId int64, userId int64, acmeTaskId int64) (bool, error) { + return NewQuery(this, adminId, userId). + State(ACMETaskStateEnabled). + Pk(acmeTaskId). + Exist() +} diff --git a/internal/db/models/acme_task_dao_test.go b/internal/db/models/acme_task_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/acme_task_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/acme_task_log_dao.go b/internal/db/models/acme_task_log_dao.go new file mode 100644 index 00000000..fd149e3c --- /dev/null +++ b/internal/db/models/acme_task_log_dao.go @@ -0,0 +1,28 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" +) + +type ACMETaskLogDAO dbs.DAO + +func NewACMETaskLogDAO() *ACMETaskLogDAO { + return dbs.NewDAO(&ACMETaskLogDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeACMETaskLogs", + Model: new(ACMETaskLog), + PkName: "id", + }, + }).(*ACMETaskLogDAO) +} + +var SharedACMETaskLogDAO *ACMETaskLogDAO + +func init() { + dbs.OnReady(func() { + SharedACMETaskLogDAO = NewACMETaskLogDAO() + }) +} diff --git a/internal/db/models/acme_task_log_dao_test.go b/internal/db/models/acme_task_log_dao_test.go new file mode 100644 index 00000000..97c24b56 --- /dev/null +++ b/internal/db/models/acme_task_log_dao_test.go @@ -0,0 +1,5 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" +) diff --git a/internal/db/models/acme_task_log_model.go b/internal/db/models/acme_task_log_model.go new file mode 100644 index 00000000..070c0ba3 --- /dev/null +++ b/internal/db/models/acme_task_log_model.go @@ -0,0 +1,22 @@ +package models + +// ACME任务运行日志 +type ACMETaskLog struct { + Id uint64 `field:"id"` // ID + TaskId uint64 `field:"taskId"` // 任务ID + IsOk uint8 `field:"isOk"` // 是否成功 + Error string `field:"error"` // 错误信息 + CreatedAt uint64 `field:"createdAt"` // 运行时间 +} + +type ACMETaskLogOperator struct { + Id interface{} // ID + TaskId interface{} // 任务ID + IsOk interface{} // 是否成功 + Error interface{} // 错误信息 + CreatedAt interface{} // 运行时间 +} + +func NewACMETaskLogOperator() *ACMETaskLogOperator { + return &ACMETaskLogOperator{} +} diff --git a/internal/db/models/acme_task_log_model_ext.go b/internal/db/models/acme_task_log_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/acme_task_log_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/db/models/acme_task_model.go b/internal/db/models/acme_task_model.go new file mode 100644 index 00000000..8eab2b8a --- /dev/null +++ b/internal/db/models/acme_task_model.go @@ -0,0 +1,38 @@ +package models + +// ACME任务 +type ACMETask struct { + Id uint64 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + IsOn uint8 `field:"isOn"` // 是否启用 + AcmeUserId uint32 `field:"acmeUserId"` // ACME用户ID + DnsDomain string `field:"dnsDomain"` // DNS主域名 + DnsProviderId uint64 `field:"dnsProviderId"` // DNS服务商 + Domains string `field:"domains"` // 证书域名 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + State uint8 `field:"state"` // 状态 + IsOk uint8 `field:"isOk"` // 最后运行是否正常 + CertId uint64 `field:"certId"` // 生成的证书ID + AutoRenew uint8 `field:"autoRenew"` // 是否自动更新 +} + +type ACMETaskOperator struct { + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + IsOn interface{} // 是否启用 + AcmeUserId interface{} // ACME用户ID + DnsDomain interface{} // DNS主域名 + DnsProviderId interface{} // DNS服务商 + Domains interface{} // 证书域名 + CreatedAt interface{} // 创建时间 + State interface{} // 状态 + IsOk interface{} // 最后运行是否正常 + CertId interface{} // 生成的证书ID + AutoRenew interface{} // 是否自动更新 +} + +func NewACMETaskOperator() *ACMETaskOperator { + return &ACMETaskOperator{} +} diff --git a/internal/db/models/acme_task_model_ext.go b/internal/db/models/acme_task_model_ext.go new file mode 100644 index 00000000..42c8c0e1 --- /dev/null +++ b/internal/db/models/acme_task_model_ext.go @@ -0,0 +1,20 @@ +package models + +import ( + "encoding/json" + "github.com/iwind/TeaGo/logs" +) + +// 将域名解析成字符串数组 +func (this *ACMETask) DecodeDomains() []string { + if len(this.Domains) == 0 || this.Domains == "null" { + return nil + } + result := []string{} + err := json.Unmarshal([]byte(this.Domains), &result) + if err != nil { + logs.Error(err) + return nil + } + return result +} diff --git a/internal/db/models/acme_user_dao.go b/internal/db/models/acme_user_dao.go index 57927c7a..c74f1907 100644 --- a/internal/db/models/acme_user_dao.go +++ b/internal/db/models/acme_user_dao.go @@ -144,6 +144,28 @@ func (this *ACMEUserDAO) ListACMEUsers(adminId int64, userId int64, offset int64 return } +// 查找所有用户 +func (this *ACMEUserDAO) FindAllACMEUsers(adminId int64, userId int64) (result []*ACMEUser, err error) { + // 防止没有传入条件导致返回的数据过多 + if adminId <= 0 && userId <= 0 { + return nil, errors.New("'adminId' or 'userId' should not be empty") + } + + query := this.Query() + if adminId > 0 { + query.Attr("adminId", adminId) + } + if userId > 0 { + query.Attr("userId", userId) + } + _, err = query. + State(ACMEUserStateEnabled). + Slice(&result). + DescPk(). + FindAll() + return +} + // 检查用户权限 func (this *ACMEUserDAO) CheckACMEUser(acmeUserId int64, adminId int64, userId int64) (bool, error) { if acmeUserId <= 0 { diff --git a/internal/db/models/dns_provider_dao.go b/internal/db/models/dns_provider_dao.go index a8187199..ae04fc8b 100644 --- a/internal/db/models/dns_provider_dao.go +++ b/internal/db/models/dns_provider_dao.go @@ -66,8 +66,10 @@ func (this *DNSProviderDAO) FindEnabledDNSProvider(id int64) (*DNSProvider, erro } // 创建服务商 -func (this *DNSProviderDAO) CreateDNSProvider(providerType string, name string, apiParamsJSON []byte) (int64, error) { +func (this *DNSProviderDAO) CreateDNSProvider(adminId int64, userId int64, providerType string, name string, apiParamsJSON []byte) (int64, error) { op := NewDNSProviderOperator() + op.AdminId = adminId + op.UserId = userId op.Type = providerType op.Name = name if len(apiParamsJSON) > 0 { @@ -104,15 +106,15 @@ func (this *DNSProviderDAO) UpdateDNSProvider(dnsProviderId int64, name string, } // 计算服务商数量 -func (this *DNSProviderDAO) CountAllEnabledDNSProviders() (int64, error) { - return this.Query(). +func (this *DNSProviderDAO) CountAllEnabledDNSProviders(adminId int64, userId int64) (int64, error) { + return NewQuery(this, adminId, userId). State(DNSProviderStateEnabled). Count() } // 列出单页服务商 -func (this *DNSProviderDAO) ListEnabledDNSProviders(offset int64, size int64) (result []*DNSProvider, err error) { - _, err = this.Query(). +func (this *DNSProviderDAO) ListEnabledDNSProviders(adminId int64, userId int64, offset int64, size int64) (result []*DNSProvider, err error) { + _, err = NewQuery(this, adminId, userId). State(DNSProviderStateEnabled). Offset(offset). Limit(size). @@ -122,6 +124,16 @@ func (this *DNSProviderDAO) ListEnabledDNSProviders(offset int64, size int64) (r return } +// 列出所有服务商 +func (this *DNSProviderDAO) FindAllEnabledDNSProviders(adminId int64, userId int64) (result []*DNSProvider, err error) { + _, err = NewQuery(this, adminId, userId). + State(DNSProviderStateEnabled). + DescPk(). + Slice(&result). + FindAll() + return +} + // 查询某个类型下的所有服务商 func (this *DNSProviderDAO) FindAllEnabledDNSProvidersWithType(providerType string) (result []*DNSProvider, err error) { _, err = this.Query(). diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index 5d26b34b..cb3a10ac 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -269,49 +269,3 @@ func (this *SSLCertDAO) ListCertIds(isCA bool, isAvailable bool, isExpired bool, } return result, nil } - -// 计算所有某个管理员/用户下所有的ACME用户生成的证书数量 -func (this *SSLCertDAO) CountAllSSLCertsWithACME(adminId int64, userId int64) (int64, error) { - query := this.Query() - if adminId > 0 { - query.Attr("adminId", adminId) - } - if userId > 0 { - query.Attr("userId", userId) - } - return query. - State(SSLCertStateEnabled). - Where("acmeUserId>0"). - Count() -} - -// 列出某个管理员/用户下所有的ACME用户生成的证书Ids -func (this *SSLCertDAO) ListSSLCertIdsWithACME(adminId int64, userId int64, offset int64, size int64) (certIds []int64, err error) { - query := this.Query() - if adminId > 0 { - query.Attr("adminId", adminId) - } - if userId > 0 { - query.Attr("userId", userId) - } - ones, err := query. - ResultPk(). - State(SSLCertStateEnabled). - Where("acmeUserId>0"). - Offset(offset). - Limit(size). - DescPk(). - FindAll() - for _, one := range ones { - certIds = append(certIds, int64(one.(*SSLCert).Id)) - } - return -} - -// 计算某个ACME用户生成的证书数量 -func (this *SSLCertDAO) CountSSLCertsWithACMEUserId(acmeUserId int64) (int64, error) { - return this.Query(). - State(SSLCertStateEnabled). - Attr("acmeUserId", acmeUserId). - Count() -} diff --git a/internal/db/models/ssl_cert_model.go b/internal/db/models/ssl_cert_model.go index a6c7c51d..d073035e 100644 --- a/internal/db/models/ssl_cert_model.go +++ b/internal/db/models/ssl_cert_model.go @@ -2,51 +2,49 @@ package models // SSL证书 type SSLCert struct { - Id uint32 `field:"id"` // ID - AdminId uint32 `field:"adminId"` // 管理员ID - UserId uint32 `field:"userId"` // 用户ID - State uint8 `field:"state"` // 状态 - CreatedAt uint64 `field:"createdAt"` // 创建时间 - UpdatedAt uint64 `field:"updatedAt"` // 修改时间 - IsOn uint8 `field:"isOn"` // 是否启用 - Name string `field:"name"` // 证书名 - Description string `field:"description"` // 描述 - CertData string `field:"certData"` // 证书内容 - KeyData string `field:"keyData"` // 密钥内容 - ServerName string `field:"serverName"` // 证书使用的主机名 - IsCA uint8 `field:"isCA"` // 是否为CA证书 - GroupIds string `field:"groupIds"` // 证书分组 - TimeBeginAt uint64 `field:"timeBeginAt"` // 开始时间 - TimeEndAt uint64 `field:"timeEndAt"` // 结束时间 - DnsNames string `field:"dnsNames"` // DNS名称列表 - CommonNames string `field:"commonNames"` // 发行单位列表 - IsACME uint8 `field:"isACME"` // 是否为ACME自动生成的 - AcmeUserId uint64 `field:"acmeUserId"` // ACME用户ID - AcmeAutoRenew uint8 `field:"acmeAutoRenew"` // ACME生成后是否自动更新 + Id uint32 `field:"id"` // ID + AdminId uint32 `field:"adminId"` // 管理员ID + UserId uint32 `field:"userId"` // 用户ID + State uint8 `field:"state"` // 状态 + CreatedAt uint64 `field:"createdAt"` // 创建时间 + UpdatedAt uint64 `field:"updatedAt"` // 修改时间 + IsOn uint8 `field:"isOn"` // 是否启用 + Name string `field:"name"` // 证书名 + Description string `field:"description"` // 描述 + CertData string `field:"certData"` // 证书内容 + KeyData string `field:"keyData"` // 密钥内容 + ServerName string `field:"serverName"` // 证书使用的主机名 + IsCA uint8 `field:"isCA"` // 是否为CA证书 + GroupIds string `field:"groupIds"` // 证书分组 + TimeBeginAt uint64 `field:"timeBeginAt"` // 开始时间 + TimeEndAt uint64 `field:"timeEndAt"` // 结束时间 + DnsNames string `field:"dnsNames"` // DNS名称列表 + CommonNames string `field:"commonNames"` // 发行单位列表 + IsACME uint8 `field:"isACME"` // 是否为ACME自动生成的 + AcmeTaskId uint64 `field:"acmeTaskId"` // ACME任务ID } type SSLCertOperator struct { - Id interface{} // ID - AdminId interface{} // 管理员ID - UserId interface{} // 用户ID - State interface{} // 状态 - CreatedAt interface{} // 创建时间 - UpdatedAt interface{} // 修改时间 - IsOn interface{} // 是否启用 - Name interface{} // 证书名 - Description interface{} // 描述 - CertData interface{} // 证书内容 - KeyData interface{} // 密钥内容 - ServerName interface{} // 证书使用的主机名 - IsCA interface{} // 是否为CA证书 - GroupIds interface{} // 证书分组 - TimeBeginAt interface{} // 开始时间 - TimeEndAt interface{} // 结束时间 - DnsNames interface{} // DNS名称列表 - CommonNames interface{} // 发行单位列表 - IsACME interface{} // 是否为ACME自动生成的 - AcmeUserId interface{} // ACME用户ID - AcmeAutoRenew interface{} // ACME生成后是否自动更新 + Id interface{} // ID + AdminId interface{} // 管理员ID + UserId interface{} // 用户ID + State interface{} // 状态 + CreatedAt interface{} // 创建时间 + UpdatedAt interface{} // 修改时间 + IsOn interface{} // 是否启用 + Name interface{} // 证书名 + Description interface{} // 描述 + CertData interface{} // 证书内容 + KeyData interface{} // 密钥内容 + ServerName interface{} // 证书使用的主机名 + IsCA interface{} // 是否为CA证书 + GroupIds interface{} // 证书分组 + TimeBeginAt interface{} // 开始时间 + TimeEndAt interface{} // 结束时间 + DnsNames interface{} // DNS名称列表 + CommonNames interface{} // 发行单位列表 + IsACME interface{} // 是否为ACME自动生成的 + AcmeTaskId interface{} // ACME任务ID } func NewSSLCertOperator() *SSLCertOperator { diff --git a/internal/db/models/utils.go b/internal/db/models/utils.go index f95d8c24..c13bf506 100644 --- a/internal/db/models/utils.go +++ b/internal/db/models/utils.go @@ -1,5 +1,7 @@ package models +import "github.com/iwind/TeaGo/dbs" + // 处理JSON字节Slice func JSONBytes(data []byte) []byte { if len(data) == 0 { @@ -12,3 +14,15 @@ func JSONBytes(data []byte) []byte { func IsNotNull(data string) bool { return len(data) > 0 && data != "null" } + +// 构造Query +func NewQuery(dao dbs.DAOWrapper, adminId int64, userId int64) *dbs.Query { + query := dao.Object().Query() + if adminId > 0 { + query.Attr("adminId", adminId) + } + if userId > 0 { + query.Attr("userId", userId) + } + return query +} diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index cc131adc..c5b15818 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -198,6 +198,7 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err pb.RegisterDNSDomainServiceServer(rpcServer, &services.DNSDomainService{}) pb.RegisterDNSServiceServer(rpcServer, &services.DNSService{}) pb.RegisterACMEUserServiceServer(rpcServer, &services.ACMEUserService{}) + pb.RegisterACMETaskServiceServer(rpcServer, &services.ACMETaskService{}) err := rpcServer.Serve(listener) if err != nil { return errors.New("[API_NODE]start rpc failed: " + err.Error()) diff --git a/internal/rpc/services/service_acme_task.go b/internal/rpc/services/service_acme_task.go new file mode 100644 index 00000000..fe49e167 --- /dev/null +++ b/internal/rpc/services/service_acme_task.go @@ -0,0 +1,281 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/dnsclients" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +// ACME任务相关服务 +type ACMETaskService struct { + BaseService +} + +// 计算某个ACME用户相关的任务数量 +func (this *ACMETaskService) CountAllEnabledACMETasksWithACMEUserId(ctx context.Context, req *pb.CountAllEnabledACMETasksWithACMEUserIdRequest) (*pb.RPCCountResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + + // TODO 校验权限 + + count, err := models.SharedACMETaskDAO.CountACMETasksWithACMEUserId(req.AcmeUserId) + if err != nil { + return nil, err + } + return this.SuccessCount(count) +} + +// 计算跟某个DNS服务商相关的任务数量 +func (this *ACMETaskService) CountEnabledACMETasksWithDNSProviderId(ctx context.Context, req *pb.CountEnabledACMETasksWithDNSProviderIdRequest) (*pb.RPCCountResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + + // TODO 校验权限 + + count, err := models.SharedACMETaskDAO.CountACMETasksWithDNSProviderId(req.DnsProviderId) + if err != nil { + return nil, err + } + return this.SuccessCount(count) +} + +// 计算所有任务数量 +func (this *ACMETaskService) CountAllEnabledACMETasks(ctx context.Context, req *pb.CountAllEnabledACMETasksRequest) (*pb.RPCCountResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) + if err != nil { + return nil, err + } + + count, err := models.SharedACMETaskDAO.CountAllEnabledACMETasks(req.AdminId, req.UserId) + if err != nil { + return nil, err + } + return this.SuccessCount(count) +} + +// 列出单页任务 +func (this *ACMETaskService) ListEnabledACMETasks(ctx context.Context, req *pb.ListEnabledACMETasksRequest) (*pb.ListEnabledACMETasksResponse, error) { + _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) + if err != nil { + return nil, err + } + + tasks, err := models.SharedACMETaskDAO.ListEnabledACMETasks(req.AdminId, req.UserId, req.Offset, req.Size) + if err != nil { + return nil, err + } + + result := []*pb.ACMETask{} + for _, task := range tasks { + // ACME用户 + acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(int64(task.AcmeUserId)) + if err != nil { + return nil, err + } + if acmeUser == nil { + continue + } + pbACMEUser := &pb.ACMEUser{ + Id: int64(acmeUser.Id), + Email: acmeUser.Email, + Description: acmeUser.Description, + CreatedAt: int64(acmeUser.CreatedAt), + } + + // DNS + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(int64(task.DnsProviderId)) + if err != nil { + return nil, err + } + if provider == nil { + continue + } + pbProvider := &pb.DNSProvider{ + Id: int64(provider.Id), + Name: provider.Name, + Type: provider.Type, + TypeName: dnsclients.FindProviderTypeName(provider.Type), + } + + // 证书 + var pbCert *pb.SSLCert = nil + if task.CertId > 0 { + cert, err := models.SharedSSLCertDAO.FindEnabledSSLCert(int64(task.CertId)) + if err != nil { + return nil, err + } + if cert == nil { + continue + } + pbCert = &pb.SSLCert{ + Id: int64(cert.Id), + IsOn: cert.IsOn == 1, + Name: cert.Name, + } + } + + result = append(result, &pb.ACMETask{ + Id: int64(task.Id), + IsOn: task.IsOn == 1, + DnsDomain: task.DnsDomain, + Domains: task.DecodeDomains(), + CreatedAt: int64(task.CreatedAt), + IsOk: task.IsOk == 1, + AutoRenew: task.AutoRenew == 1, + AcmeUser: pbACMEUser, + DnsProvider: pbProvider, + SslCert: pbCert, + }) + } + + return &pb.ListEnabledACMETasksResponse{AcmeTasks: result}, nil +} + +// 创建任务 +func (this *ACMETaskService) CreateACMETask(ctx context.Context, req *pb.CreateACMETaskRequest) (*pb.CreateACMETaskResponse, error) { + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + taskId, err := models.SharedACMETaskDAO.CreateACMETask(adminId, userId, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew) + if err != nil { + return nil, err + } + return &pb.CreateACMETaskResponse{AcmeTaskId: taskId}, nil +} + +// 修改任务 +func (this *ACMETaskService) UpdateACMETask(ctx context.Context, req *pb.UpdateACMETaskRequest) (*pb.RPCSuccess, error) { + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + if err != nil { + return nil, err + } + if !canAccess { + return nil, this.PermissionError() + } + + err = models.SharedACMETaskDAO.UpdateACMETask(req.AcmeTaskId, req.AcmeUserId, req.DnsProviderId, req.DnsDomain, req.Domains, req.AutoRenew) + if err != nil { + return nil, err + } + return this.Success() +} + +// 删除任务 +func (this *ACMETaskService) DeleteACMETask(ctx context.Context, req *pb.DeleteACMETaskRequest) (*pb.RPCSuccess, error) { + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + if err != nil { + return nil, err + } + if !canAccess { + return nil, this.PermissionError() + } + + err = models.SharedACMETaskDAO.DisableACMETask(req.AcmeTaskId) + if err != nil { + return nil, err + } + return this.Success() +} + +// 运行某个任务 +func (this *ACMETaskService) RunACMETask(ctx context.Context, req *pb.RunACMETaskRequest) (*pb.RunACMETaskResponse, error) { + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + if err != nil { + return nil, err + } + if !canAccess { + return nil, this.PermissionError() + } + + // TODO + + return &pb.RunACMETaskResponse{}, nil +} + +// 查找单个任务信息 +func (this *ACMETaskService) FindEnabledACMETask(ctx context.Context, req *pb.FindEnabledACMETaskRequest) (*pb.FindEnabledACMETaskResponse, error) { + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + + canAccess, err := models.SharedACMETaskDAO.CheckACMETask(adminId, userId, req.AcmeTaskId) + if err != nil { + return nil, err + } + if !canAccess { + return nil, this.PermissionError() + } + + task, err := models.SharedACMETaskDAO.FindEnabledACMETask(req.AcmeTaskId) + if err != nil { + return nil, err + } + if task == nil { + return &pb.FindEnabledACMETaskResponse{AcmeTask: nil}, nil + } + + // 用户 + var pbACMEUser *pb.ACMEUser = nil + if task.AcmeUserId > 0 { + acmeUser, err := models.SharedACMEUserDAO.FindEnabledACMEUser(int64(task.AcmeUserId)) + if err != nil { + return nil, err + } + if acmeUser != nil { + pbACMEUser = &pb.ACMEUser{ + Id: int64(acmeUser.Id), + Email: acmeUser.Email, + Description: acmeUser.Description, + CreatedAt: int64(acmeUser.CreatedAt), + } + } + } + + // DNS + var pbProvider *pb.DNSProvider + provider, err := models.SharedDNSProviderDAO.FindEnabledDNSProvider(int64(task.DnsProviderId)) + if err != nil { + return nil, err + } + if provider != nil { + pbProvider = &pb.DNSProvider{ + Id: int64(provider.Id), + Name: provider.Name, + Type: provider.Type, + TypeName: dnsclients.FindProviderTypeName(provider.Type), + } + } + + return &pb.FindEnabledACMETaskResponse{AcmeTask: &pb.ACMETask{ + Id: int64(task.Id), + IsOn: task.IsOn == 1, + DnsDomain: task.DnsDomain, + Domains: task.DecodeDomains(), + CreatedAt: int64(task.CreatedAt), + AutoRenew: task.AutoRenew == 1, + DnsProvider: pbProvider, + AcmeUser: pbACMEUser, + }}, nil +} diff --git a/internal/rpc/services/service_acme_user.go b/internal/rpc/services/service_acme_user.go index 7215e4f9..f49d9ab2 100644 --- a/internal/rpc/services/service_acme_user.go +++ b/internal/rpc/services/service_acme_user.go @@ -14,7 +14,7 @@ type ACMEUserService struct { // 创建用户 func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateACMEUserRequest) (*pb.CreateACMEUserResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } @@ -29,7 +29,7 @@ func (this *ACMEUserService) CreateACMEUser(ctx context.Context, req *pb.CreateA // 修改用户 func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateACMEUserRequest) (*pb.RPCSuccess, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } @@ -53,7 +53,7 @@ func (this *ACMEUserService) UpdateACMEUser(ctx context.Context, req *pb.UpdateA // 删除用户 func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteACMEUserRequest) (*pb.RPCSuccess, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } @@ -77,7 +77,7 @@ func (this *ACMEUserService) DeleteACMEUser(ctx context.Context, req *pb.DeleteA // 计算用户数量 func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAcmeUsersRequest) (*pb.RPCCountResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId) if err != nil { return nil, err } @@ -92,7 +92,7 @@ func (this *ACMEUserService) CountACMEUsers(ctx context.Context, req *pb.CountAc // 列出单页用户 func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACMEUsersRequest) (*pb.ListACMEUsersResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId) if err != nil { return nil, err } @@ -116,7 +116,7 @@ func (this *ACMEUserService) ListACMEUsers(ctx context.Context, req *pb.ListACME // 查找单个用户 func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.FindEnabledACMEUserRequest) (*pb.FindEnabledACMEUserResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } @@ -144,3 +144,27 @@ func (this *ACMEUserService) FindEnabledACMEUser(ctx context.Context, req *pb.Fi CreatedAt: int64(acmeUser.CreatedAt), }}, nil } + +// 查找所有用户 +func (this *ACMEUserService) FindAllACMEUsers(ctx context.Context, req *pb.FindAllACMEUsersRequest) (*pb.FindAllACMEUsersResponse, error) { + // 校验请求 + adminId, userId, err := this.ValidateAdminAndUser(ctx, req.UserId) + if err != nil { + return nil, err + } + + acmeUsers, err := models.SharedACMEUserDAO.FindAllACMEUsers(adminId, userId) + if err != nil { + return nil, err + } + result := []*pb.ACMEUser{} + for _, user := range acmeUsers { + result = append(result, &pb.ACMEUser{ + Id: int64(user.Id), + Email: user.Email, + Description: user.Description, + CreatedAt: int64(user.CreatedAt), + }) + } + return &pb.FindAllACMEUsersResponse{AcmeUsers: result}, nil +} diff --git a/internal/rpc/services/service_base.go b/internal/rpc/services/service_base.go index c7860883..5e128483 100644 --- a/internal/rpc/services/service_base.go +++ b/internal/rpc/services/service_base.go @@ -11,7 +11,7 @@ type BaseService struct { } // 校验管理员和用户 -func (this *BaseService) ValidateAdminAndUser(ctx context.Context) (adminId int64, userId int64, err error) { +func (this *BaseService) ValidateAdminAndUser(ctx context.Context, reqUserId int64) (adminId int64, userId int64, err error) { reqUserType, reqUserId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser) if err != nil { return @@ -22,9 +22,26 @@ func (this *BaseService) ValidateAdminAndUser(ctx context.Context) (adminId int6 switch reqUserType { case rpcutils.UserTypeAdmin: adminId = reqUserId + if adminId <= 0 { + err = errors.New("invalid 'adminId'") + return + } case rpcutils.UserTypeUser: userId = reqUserId + if userId <= 0 { + err = errors.New("invalid 'userId'") + return + } + + // 校验权限 + if reqUserId > 0 && reqUserId != userId { + err = this.PermissionError() + return + } + default: + err = errors.New("invalid user type") } + return } diff --git a/internal/rpc/services/service_dns_provider.go b/internal/rpc/services/service_dns_provider.go index 6d40398c..62bcd04b 100644 --- a/internal/rpc/services/service_dns_provider.go +++ b/internal/rpc/services/service_dns_provider.go @@ -16,12 +16,12 @@ type DNSProviderService struct { // 创建服务商 func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.CreateDNSProviderRequest) (*pb.CreateDNSProviderResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } - providerId, err := models.SharedDNSProviderDAO.CreateDNSProvider(req.Type, req.Name, req.ApiParamsJSON) + providerId, err := models.SharedDNSProviderDAO.CreateDNSProvider(adminId, userId, req.Type, req.Name, req.ApiParamsJSON) if err != nil { return nil, err } @@ -32,11 +32,13 @@ func (this *DNSProviderService) CreateDNSProvider(ctx context.Context, req *pb.C // 修改服务商 func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.UpdateDNSProviderRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } + // TODO 校验权限 + err = models.SharedDNSProviderDAO.UpdateDNSProvider(req.DnsProviderId, req.Name, req.ApiParamsJSON) if err != nil { return nil, err @@ -47,12 +49,12 @@ func (this *DNSProviderService) UpdateDNSProvider(ctx context.Context, req *pb.U // 计算服务商数量 func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context, req *pb.CountAllEnabledDNSProvidersRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) if err != nil { return nil, err } - count, err := models.SharedDNSProviderDAO.CountAllEnabledDNSProviders() + count, err := models.SharedDNSProviderDAO.CountAllEnabledDNSProviders(req.AdminId, req.UserId) if err != nil { return nil, err } @@ -62,12 +64,14 @@ func (this *DNSProviderService) CountAllEnabledDNSProviders(ctx context.Context, // 列出单页服务商信息 func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req *pb.ListEnabledDNSProvidersRequest) (*pb.ListEnabledDNSProvidersResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) if err != nil { return nil, err } - providers, err := models.SharedDNSProviderDAO.ListEnabledDNSProviders(req.Offset, req.Size) + // TODO 校验权限 + + providers, err := models.SharedDNSProviderDAO.ListEnabledDNSProviders(req.AdminId, req.UserId, req.Offset, req.Size) if err != nil { return nil, err } @@ -85,14 +89,44 @@ func (this *DNSProviderService) ListEnabledDNSProviders(ctx context.Context, req return &pb.ListEnabledDNSProvidersResponse{DnsProviders: result}, nil } -// 删除服务商 -func (this *DNSProviderService) DeleteDNSProvider(ctx context.Context, req *pb.DeleteDNSProviderRequest) (*pb.RPCSuccess, error) { +// 查找所有的DNS服务商 +func (this *DNSProviderService) FindAllEnabledDNSProviders(ctx context.Context, req *pb.FindAllEnabledDNSProvidersRequest) (*pb.FindAllEnabledDNSProvidersResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + _, _, err := this.ValidateAdminAndUser(ctx, req.UserId) if err != nil { return nil, err } + // TODO 校验权限 + + providers, err := models.SharedDNSProviderDAO.FindAllEnabledDNSProviders(req.AdminId, req.UserId) + if err != nil { + return nil, err + } + result := []*pb.DNSProvider{} + for _, provider := range providers { + result = append(result, &pb.DNSProvider{ + Id: int64(provider.Id), + Name: provider.Name, + Type: provider.Type, + TypeName: dnsclients.FindProviderTypeName(provider.Type), + ApiParamsJSON: []byte(provider.ApiParams), + DataUpdatedAt: int64(provider.DataUpdatedAt), + }) + } + return &pb.FindAllEnabledDNSProvidersResponse{DnsProviders: result}, nil +} + +// 删除服务商 +func (this *DNSProviderService) DeleteDNSProvider(ctx context.Context, req *pb.DeleteDNSProviderRequest) (*pb.RPCSuccess, error) { + // 校验请求 + _, _, err := this.ValidateAdminAndUser(ctx, 0) + if err != nil { + return nil, err + } + + // TODO 校验权限 + err = models.SharedDNSProviderDAO.DisableDNSProvider(req.DnsProviderId) if err != nil { return nil, err diff --git a/internal/rpc/services/service_ssl_cert.go b/internal/rpc/services/service_ssl_cert.go index c1deb798..e77ab001 100644 --- a/internal/rpc/services/service_ssl_cert.go +++ b/internal/rpc/services/service_ssl_cert.go @@ -17,11 +17,13 @@ type SSLCertService struct { // 创建Cert func (this *SSLCertService) CreateSSLCert(ctx context.Context, req *pb.CreateSSLCertRequest) (*pb.CreateSSLCertResponse, error) { // 校验请求 - adminId, userId, err := this.ValidateAdminAndUser(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } + // TODO 校验权限 + certId, err := models.SharedSSLCertDAO.CreateCert(adminId, userId, req.IsOn, req.Name, req.Description, req.ServerName, req.IsCA, req.CertData, req.KeyData, req.TimeBeginAt, req.TimeEndAt, req.DnsNames, req.CommonNames) if err != nil { return nil, err @@ -79,6 +81,12 @@ func (this *SSLCertService) DeleteSSLCert(ctx context.Context, req *pb.DeleteSSL return nil, err } + // 停止相关ACME任务 + err = models.SharedACMETaskDAO.DisableAllTasksWithCertId(req.CertId) + if err != nil { + return nil, err + } + return this.Success() } @@ -130,72 +138,3 @@ func (this *SSLCertService) ListSSLCerts(ctx context.Context, req *pb.ListSSLCer } return &pb.ListSSLCertsResponse{CertsJSON: certConfigsJSON}, nil } - -// 计算某个ACME用户生成的证书数量 -func (this *SSLCertService) CountSSLCertsWithACMEUserId(ctx context.Context, req *pb.CountSSLCertsWithACMEUserIdRequest) (*pb.RPCCountResponse, error) { - // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) - if err != nil { - return nil, err - } - - // TODO 检查用户权限 - - count, err := models.SharedSSLCertDAO.CountSSLCertsWithACMEUserId(req.AcmeUserId) - if err != nil { - return nil, err - } - return this.SuccessCount(count) -} - -// 计算所有某个管理员/用户下所有的ACME用户生成的证书 -func (this *SSLCertService) CountAllSSLCertsWithACME(ctx context.Context, req *pb.CountAllSSLCertsWithACMERequest) (*pb.RPCCountResponse, error) { - // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx) - if err != nil { - return nil, err - } - - // TODO 校验用户 - - count, err := models.SharedSSLCertDAO.CountAllSSLCertsWithACME(req.AdminId, req.UserId) - if err != nil { - return nil, err - } - return this.SuccessCount(count) -} - -// 列出单个管理员/用户下所有的ACME用户生成的证书 -func (this *SSLCertService) ListSSLCertsWithACME(ctx context.Context, req *pb.ListSSLCertsWithACMERequest) (*pb.ListSSLCertsWithACMEResponse, error) { - // 校验请求 - _, _, err := this.ValidateAdminAndUser(ctx) - if err != nil { - return nil, err - } - - // TODO 校验用户 - - certIds, err := models.SharedSSLCertDAO.ListSSLCertIdsWithACME(req.AdminId, req.UserId, req.Offset, req.Size) - if err != nil { - return nil, err - } - - certConfigs := []*sslconfigs.SSLCertConfig{} - for _, certId := range certIds { - certConfig, err := models.SharedSSLCertDAO.ComposeCertConfig(certId) - if err != nil { - return nil, err - } - - // 这里不需要数据内容 - certConfig.CertData = nil - certConfig.KeyData = nil - - certConfigs = append(certConfigs, certConfig) - } - certConfigsJSON, err := json.Marshal(certConfigs) - if err != nil { - return nil, err - } - return &pb.ListSSLCertsWithACMEResponse{CertsJSON: certConfigsJSON}, nil -} diff --git a/internal/setup/sql_data.go b/internal/setup/sql_data.go new file mode 100644 index 00000000..65642c13 --- /dev/null +++ b/internal/setup/sql_data.go @@ -0,0 +1,67 @@ +package setup + +import ( + "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/types" + stringutil "github.com/iwind/TeaGo/utils/string" +) + +type upgradeVersion struct { + version string + f func(db *dbs.DB) error +} + +var upgradeFuncs = []*upgradeVersion{ + { + "0.0.3", upgradeV0_0_3, + }, +} + +// 升级SQL数据 +func UpgradeSQLData(db *dbs.DB) error { + version, err := db.FindCol(0, "SELECT version FROM edgeVersions") + if err != nil { + return err + } + versionString := types.String(version) + if len(versionString) > 0 { + for _, f := range upgradeFuncs { + if stringutil.VersionCompare(versionString, f.version) >= 0 { + continue + } + err = f.f(db) + if err != nil { + return err + } + } + } + return nil +} + +// v0.0.3 +func upgradeV0_0_3(db *dbs.DB) error { + // 获取第一个管理员 + adminIdCol, err := db.FindCol(0, "SELECT id FROM edgeAdmins ORDER BY id ASC LIMIT 1") + if err != nil { + return err + } + adminId := types.Int64(adminIdCol) + if adminId <= 0 { + return errors.New("'edgeAdmins' table should not be empty") + } + + // 升级edgeDNSProviders + _, err = db.Exec("UPDATE edgeDNSProviders SET adminId=? WHERE adminId=0 AND userId=0", adminId) + if err != nil { + return err + } + + // 升级edgeSSLCerts + _, err = db.Exec("UPDATE edgeSSLCerts SET adminId=? WHERE adminId=0 AND userId=0", adminId) + if err != nil { + return err + } + + return nil +} diff --git a/internal/setup/sql_data_test.go b/internal/setup/sql_data_test.go new file mode 100644 index 00000000..c4379d87 --- /dev/null +++ b/internal/setup/sql_data_test.go @@ -0,0 +1,18 @@ +package setup + +import ( + "github.com/iwind/TeaGo/dbs" + "testing" +) + +func TestUpgradeSQLData(t *testing.T) { + db, err := dbs.Default() + if err != nil { + t.Fatal(err) + } + err = UpgradeSQLData(db) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/setup/sql_dump.go b/internal/setup/sql_dump.go index 7fb81d35..cdef83b3 100644 --- a/internal/setup/sql_dump.go +++ b/internal/setup/sql_dump.go @@ -1,6 +1,7 @@ package setup import ( + "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/types" "regexp" @@ -240,6 +241,12 @@ func (this *SQLDump) Apply(db *dbs.DB, newResult *SQLDumpResult) (ops []string, // 减少表格 // 由于我们不删除任何表格,所以这里什么都不做 + // 升级数据 + err = UpgradeSQLData(db) + if err != nil { + return nil, errors.New("upgrade data failed: " + err.Error()) + } + return }