diff --git a/internal/db/models/dns_domain_dao.go b/internal/db/models/dns_domain_dao.go index 7ca935b3..85a4dfef 100644 --- a/internal/db/models/dns_domain_dao.go +++ b/internal/db/models/dns_domain_dao.go @@ -76,9 +76,11 @@ func (this *DNSDomainDAO) FindDNSDomainName(id int64) (string, error) { } // 创建域名 -func (this *DNSDomainDAO) CreateDomain(providerId int64, name string) (int64, error) { +func (this *DNSDomainDAO) CreateDomain(adminId int64, userId int64, providerId int64, name string) (int64, error) { op := NewDNSDomainOperator() op.ProviderId = providerId + op.AdminId = adminId + op.UserId = userId op.Name = name op.State = DNSDomainStateEnabled op.IsOn = true diff --git a/internal/db/models/message_dao.go b/internal/db/models/message_dao.go index 901e3900..b9950fdf 100644 --- a/internal/db/models/message_dao.go +++ b/internal/db/models/message_dao.go @@ -25,12 +25,15 @@ const ( type MessageType = string const ( - MessageTypeHealthCheckFailed MessageType = "HealthCheckFailed" - MessageTypeHealthCheckNodeUp MessageType = "HealthCheckNodeUp" - MessageTypeHealthCheckNodeDown MessageType = "HealthCheckNodeDown" - MessageTypeNodeInactive MessageType = "NodeInactive" - MessageTypeNodeActive MessageType = "NodeActive" - MessageTypeClusterDNSSyncFailed MessageType = "ClusterDNSSyncFailed" + MessageTypeHealthCheckFailed MessageType = "HealthCheckFailed" + MessageTypeHealthCheckNodeUp MessageType = "HealthCheckNodeUp" + MessageTypeHealthCheckNodeDown MessageType = "HealthCheckNodeDown" + MessageTypeNodeInactive MessageType = "NodeInactive" + MessageTypeNodeActive MessageType = "NodeActive" + MessageTypeClusterDNSSyncFailed MessageType = "ClusterDNSSyncFailed" + MessageTypeSSLCertExpiring MessageType = "SSLCertExpiring" // SSL证书即将过期 + MessageTypeSSLCertACMETaskFailed MessageType = "SSLCertACMETaskFailed" // SSL证书任务执行失败 + MessageTypeSSLCertACMETaskSuccess MessageType = "SSLCertACMETaskSuccess" // SSL证书任务执行成功 ) type MessageDAO dbs.DAO @@ -96,6 +99,30 @@ func (this *MessageDAO) CreateNodeMessage(clusterId int64, nodeId int64, message return err } +// 创建普通消息 +func (this *MessageDAO) CreateMessage(adminId int64, userId int64, messageType MessageType, level string, body string, paramsJSON []byte) error { + h := md5.New() + h.Write([]byte(body)) + h.Write(paramsJSON) + hash := fmt.Sprintf("%x", h.Sum(nil)) + + op := NewMessageOperator() + op.AdminId = adminId + op.UserId = userId + op.Type = messageType + op.Level = level + op.Body = body + if len(paramsJSON) > 0 { + op.Params = paramsJSON + } + op.State = MessageStateEnabled + op.IsRead = false + op.Day = timeutil.Format("Ymd") + op.Hash = hash + _, err := this.Save(op) + return err +} + // 删除某天之前的消息 func (this *MessageDAO) DeleteMessagesBeforeDay(dayTime time.Time) error { day := timeutil.Format("Ymd", dayTime) diff --git a/internal/db/models/node_cluster_dao.go b/internal/db/models/node_cluster_dao.go index 4aec9e7c..7f531503 100644 --- a/internal/db/models/node_cluster_dao.go +++ b/internal/db/models/node_cluster_dao.go @@ -101,7 +101,7 @@ func (this *NodeClusterDAO) FindAllEnableClusters() (result []*NodeCluster, err } // 创建集群 -func (this *NodeClusterDAO) CreateCluster(name string, grantId int64, installDir string, dnsDomainId int64, dnsName string) (clusterId int64, err error) { +func (this *NodeClusterDAO) CreateCluster(adminId int64, name string, grantId int64, installDir string, dnsDomainId int64, dnsName string) (clusterId int64, err error) { uniqueId, err := this.genUniqueId() if err != nil { return 0, err @@ -114,6 +114,7 @@ func (this *NodeClusterDAO) CreateCluster(name string, grantId int64, installDir } op := NewNodeClusterOperator() + op.AdminId = adminId op.Name = name op.GrantId = grantId op.InstallDir = installDir @@ -522,6 +523,14 @@ func (this *NodeClusterDAO) CheckClusterDNS(cluster *NodeCluster) (issues []*pb. return } +// 查找集群所属管理员 +func (this *NodeClusterDAO) FindClusterAdminId(clusterId int64) (int64, error) { + return this.Query(). + Pk(clusterId). + Result("adminId"). + FindInt64Col(0) +} + // 生成唯一ID func (this *NodeClusterDAO) genUniqueId() (string, error) { for { diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index 734b22cc..9d459bcc 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -80,7 +80,7 @@ func (this *NodeDAO) FindNodeName(id uint32) (string, error) { } // 创建节点 -func (this *NodeDAO) CreateNode(name string, clusterId int64, groupId int64) (nodeId int64, err error) { +func (this *NodeDAO) CreateNode(adminId int64, name string, clusterId int64, groupId int64) (nodeId int64, err error) { uniqueId, err := this.genUniqueId() if err != nil { return 0, err @@ -95,6 +95,7 @@ func (this *NodeDAO) CreateNode(name string, clusterId int64, groupId int64) (no } op := NewNodeOperator() + op.AdminId = adminId op.Name = name op.UniqueId = uniqueId op.Secret = secret diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index f19df485..c3f2ad05 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -8,6 +8,7 @@ import ( "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/types" + timeutil "github.com/iwind/TeaGo/utils/time" "time" ) @@ -283,3 +284,32 @@ func (this *SSLCertDAO) UpdateCertACME(certId int64, acmeTaskId int64) error { _, err := this.Save(op) return err } + +// 查找需要自动更新的任务 +// 这里我们只返回有限的字段以节省内存 +func (this *SSLCertDAO) FindAllExpiringCerts(days int) (result []*SSLCert, err error) { + if days < 0 { + days = 0 + } + + deltaSeconds := int64(days * 86400) + _, err = this.Query(). + State(SSLCertStateEnabled). + Where("FROM_UNIXTIME(timeEndAt, '%Y-%m-%d')=:day AND FROM_UNIXTIME(notifiedAt, '%Y-%m-%d')!=:today"). + Param("day", timeutil.FormatTime("Y-m-d", time.Now().Unix()+deltaSeconds)). + Param("today", timeutil.Format("Y-m-d")). + Result("id", "adminId", "userId", "timeEndAt", "name", "dnsNames", "notifiedAt", "acmeTaskId"). + Slice(&result). + AscPk(). + FindAll() + return +} + +// 设置当前证书事件通知时间 +func (this *SSLCertDAO) UpdateCertNotifiedAt(certId int64) error { + _, err := this.Query(). + Pk(certId). + Set("notifiedAt", time.Now().Unix()). + Update() + return err +} diff --git a/internal/db/models/ssl_cert_model.go b/internal/db/models/ssl_cert_model.go index d073035e..8fc6622c 100644 --- a/internal/db/models/ssl_cert_model.go +++ b/internal/db/models/ssl_cert_model.go @@ -22,6 +22,7 @@ type SSLCert struct { CommonNames string `field:"commonNames"` // 发行单位列表 IsACME uint8 `field:"isACME"` // 是否为ACME自动生成的 AcmeTaskId uint64 `field:"acmeTaskId"` // ACME任务ID + NotifiedAt uint64 `field:"notifiedAt"` // 最后通知时间 } type SSLCertOperator struct { @@ -45,6 +46,7 @@ type SSLCertOperator struct { CommonNames interface{} // 发行单位列表 IsACME interface{} // 是否为ACME自动生成的 AcmeTaskId interface{} // ACME任务ID + NotifiedAt interface{} // 最后通知时间 } func NewSSLCertOperator() *SSLCertOperator { diff --git a/internal/rpc/services/service_base.go b/internal/rpc/services/service_base.go index 5e128483..f5c4a0af 100644 --- a/internal/rpc/services/service_base.go +++ b/internal/rpc/services/service_base.go @@ -10,6 +10,18 @@ import ( type BaseService struct { } +// 校验管理员 +func (this *BaseService) ValidateAdmin(ctx context.Context, reqAdminId int64) (adminId int64, err error) { + _, reqUserId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return + } + if reqAdminId > 0 && reqUserId != reqAdminId { + return 0, this.PermissionError() + } + return reqUserId, nil +} + // 校验管理员和用户 func (this *BaseService) ValidateAdminAndUser(ctx context.Context, reqUserId int64) (adminId int64, userId int64, err error) { reqUserType, reqUserId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser) diff --git a/internal/rpc/services/service_dns_domain.go b/internal/rpc/services/service_dns_domain.go index d8bb660a..dc53415f 100644 --- a/internal/rpc/services/service_dns_domain.go +++ b/internal/rpc/services/service_dns_domain.go @@ -21,7 +21,7 @@ type DNSDomainService struct { // 创建域名 func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.CreateDNSDomainRequest) (*pb.CreateDNSDomainResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } @@ -39,7 +39,7 @@ func (this *DNSDomainService) CreateDNSDomain(ctx context.Context, req *pb.Creat return nil, err } - domainId, err := models.SharedDNSDomainDAO.CreateDomain(req.DnsProviderId, req.Name) + domainId, err := models.SharedDNSDomainDAO.CreateDomain(adminId, userId, req.DnsProviderId, req.Name) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node.go b/internal/rpc/services/service_node.go index 72809b13..c66f7f94 100644 --- a/internal/rpc/services/service_node.go +++ b/internal/rpc/services/service_node.go @@ -38,12 +38,12 @@ type NodeService struct { // 创建节点 func (this *NodeService) CreateNode(ctx context.Context, req *pb.CreateNodeRequest) (*pb.CreateNodeResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } - nodeId, err := models.SharedNodeDAO.CreateNode(req.Name, req.ClusterId, req.GroupId) + nodeId, err := models.SharedNodeDAO.CreateNode(adminId, req.Name, req.ClusterId, req.GroupId) if err != nil { return nil, err } @@ -87,7 +87,12 @@ func (this *NodeService) RegisterClusterNode(ctx context.Context, req *pb.Regist return nil, err } - nodeId, err := models.SharedNodeDAO.CreateNode(req.Name, clusterId, 0) + adminId, err := models.SharedNodeClusterDAO.FindClusterAdminId(clusterId) + if err != nil { + return nil, err + } + + nodeId, err := models.SharedNodeDAO.CreateNode(adminId, req.Name, clusterId, 0) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index 7e450d9a..71a0ad9f 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -19,12 +19,12 @@ type NodeClusterService struct { // 创建集群 func (this *NodeClusterService) CreateNodeCluster(ctx context.Context, req *pb.CreateNodeClusterRequest) (*pb.CreateNodeClusterResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, err := this.ValidateAdmin(ctx, 0) if err != nil { return nil, err } - clusterId, err := models.SharedNodeClusterDAO.CreateCluster(req.Name, req.GrantId, req.InstallDir, req.DnsDomainId, req.DnsName) + clusterId, err := models.SharedNodeClusterDAO.CreateCluster(adminId, req.Name, req.GrantId, req.InstallDir, req.DnsDomainId, req.DnsName) if err != nil { return nil, err } diff --git a/internal/setup/sql_data.go b/internal/setup/sql_data.go index 65642c13..50a1a430 100644 --- a/internal/setup/sql_data.go +++ b/internal/setup/sql_data.go @@ -57,11 +57,29 @@ func upgradeV0_0_3(db *dbs.DB) error { return err } + // 升级edgeDNSDomains + _, err = db.Exec("UPDATE edgeDNSDomains SET adminId=? WHERE adminId=0 AND userId=0", adminId) + if err != nil { + return err + } + // 升级edgeSSLCerts _, err = db.Exec("UPDATE edgeSSLCerts SET adminId=? WHERE adminId=0 AND userId=0", adminId) if err != nil { return err } + // 升级edgeNodeClusters + _, err = db.Exec("UPDATE edgeNodeClusters SET adminId=? WHERE adminId=0 AND userId=0", adminId) + if err != nil { + return err + } + + // 升级edgeNodes + _, err = db.Exec("UPDATE edgeNodes SET adminId=? WHERE adminId=0 AND userId=0", adminId) + if err != nil { + return err + } + return nil } diff --git a/internal/tasks/ssl_cert_expire_check_executor.go b/internal/tasks/ssl_cert_expire_check_executor.go new file mode 100644 index 00000000..23559993 --- /dev/null +++ b/internal/tasks/ssl_cert_expire_check_executor.go @@ -0,0 +1,204 @@ +package tasks + +import ( + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/logs" + "github.com/iwind/TeaGo/maps" + timeutil "github.com/iwind/TeaGo/utils/time" + "strconv" + "time" +) + +func init() { + dbs.OnReady(func() { + go NewSSLCertExpireCheckExecutor().Start() + }) +} + +// 证书检查任务 +type SSLCertExpireCheckExecutor struct { +} + +func NewSSLCertExpireCheckExecutor() *SSLCertExpireCheckExecutor { + return &SSLCertExpireCheckExecutor{} +} + +// 启动任务 +func (this *SSLCertExpireCheckExecutor) Start() { + seconds := int64(3600) + ticker := time.NewTicker(time.Duration(seconds) * time.Second) + for range ticker.C { + err := this.loop(seconds) + if err != nil { + logs.Println("[ERROR][SSLCertExpireCheckExecutor]" + err.Error()) + } + } +} + +// 单次执行 +func (this *SSLCertExpireCheckExecutor) loop(seconds int64) error { + // 检查上次运行时间,防止重复运行 + settingKey := "sslCertExpiringCheck" + timestamp := time.Now().Unix() + c, err := models.SharedSysSettingDAO.CompareInt64Setting(settingKey, timestamp-seconds) + if err != nil { + return err + } + if c > 0 { + return nil + } + + // 记录时间 + err = models.SharedSysSettingDAO.UpdateSetting(settingKey, []byte(numberutils.FormatInt64(timestamp))) + if err != nil { + return err + } + + // 查找需要自动更新的证书 + // 30, 14 ... 是到期的天数 + for _, days := range []int{30, 14, 7} { + certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(days) + if err != nil { + return err + } + for _, cert := range certs { + // 发送消息 + msg := "SSL证书\"" + cert.Name + "\"(" + cert.DnsNames + ")在" + strconv.Itoa(days) + "天后将到期," + + // 是否有自动更新任务 + if cert.AcmeTaskId > 0 { + task, err := models.SharedACMETaskDAO.FindEnabledACMETask(int64(cert.AcmeTaskId)) + if err != nil { + return err + } + if task != nil { + if task.AutoRenew == 1 { + msg += "此证书是免费申请的证书,且已设置了自动续期,将会在到期前三天自动尝试续期。" + } else { + msg += "此证书是免费申请的证书,没有设置自动续期,请在到期前手动执行续期任务。" + } + } + } else { + msg += "请及时更新证书。" + } + + err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ + "certId": cert.Id, + "acmeTaskId": cert.AcmeTaskId, + }.AsJSON()) + if err != nil { + return err + } + + // 设置最后通知时间 + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + if err != nil { + return err + } + } + } + + // 自动续期 + for _, days := range []int{3, 2, 1} { + certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(days) + if err != nil { + return err + } + for _, cert := range certs { + // 发送消息 + msg := "SSL证书\"" + cert.Name + "\"(" + cert.DnsNames + ")在" + strconv.Itoa(days) + "天后将到期," + + // 是否有自动更新任务 + if cert.AcmeTaskId > 0 { + task, err := models.SharedACMETaskDAO.FindEnabledACMETask(int64(cert.AcmeTaskId)) + if err != nil { + return err + } + if task != nil { + if task.AutoRenew == 1 { + isOk, errMsg, _ := models.SharedACMETaskDAO.RunTask(int64(cert.AcmeTaskId)) + if isOk { + // 发送成功通知 + msg = "系统已成功为你自动更新了证书\"" + cert.Name + "\"(" + cert.DnsNames + ")。" + err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertACMETaskSuccess, models.MessageLevelSuccess, msg, maps.Map{ + "certId": cert.Id, + "acmeTaskId": cert.AcmeTaskId, + }.AsJSON()) + + // 更新通知时间 + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + if err != nil { + return err + } + } else { + // 发送失败通知 + msg = "系统在尝试自动更新证书\"" + cert.Name + "\"(" + cert.DnsNames + ")时发生错误:" + errMsg + "。请检查系统设置并修复错误。" + err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertACMETaskFailed, models.MessageLevelError, msg, maps.Map{ + "certId": cert.Id, + "acmeTaskId": cert.AcmeTaskId, + }.AsJSON()) + + // 更新通知时间 + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + if err != nil { + return err + } + } + + // 中止不发送消息 + continue + + } else { + msg += "此证书是免费申请的证书,没有设置自动续期,请在到期前手动执行续期任务。" + } + } + } else { + msg += "请及时更新证书。" + } + + err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ + "certId": cert.Id, + "acmeTaskId": cert.AcmeTaskId, + }.AsJSON()) + if err != nil { + return err + } + + // 设置最后通知时间 + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + if err != nil { + return err + } + } + } + + // 当天过期 + for _, days := range []int{0} { + certs, err := models.SharedSSLCertDAO.FindAllExpiringCerts(days) + if err != nil { + return err + } + for _, cert := range certs { + // 发送消息 + today := timeutil.Format("Y-m-d") + msg := "SSL证书\"" + cert.Name + "\"(" + cert.DnsNames + ")在今天(" + today + ")过期,请及时更新证书,之后将不再重复提醒。" + err = models.SharedMessageDAO.CreateMessage(int64(cert.AdminId), int64(cert.UserId), models.MessageTypeSSLCertExpiring, models.MessageLevelWarning, msg, maps.Map{ + "certId": cert.Id, + "acmeTaskId": cert.AcmeTaskId, + }.AsJSON()) + if err != nil { + return err + } + + // 设置最后通知时间 + err = models.SharedSSLCertDAO.UpdateCertNotifiedAt(int64(cert.Id)) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/tasks/ssl_cert_expire_check_executor_test.go b/internal/tasks/ssl_cert_expire_check_executor_test.go new file mode 100644 index 00000000..58078e81 --- /dev/null +++ b/internal/tasks/ssl_cert_expire_check_executor_test.go @@ -0,0 +1,25 @@ +package tasks + +import ( + "github.com/iwind/TeaGo/dbs" + timeutil "github.com/iwind/TeaGo/utils/time" + "testing" + "time" +) + +func TestSSLCertExpireCheckExecutor_loop(t *testing.T) { + dbs.NotifyReady() + + t.Log("30 days later: ", timeutil.FormatTime("Y-m-d", time.Now().Unix()+30*86400), time.Now().Unix()+30*86400) + t.Log("14 days later: ", timeutil.FormatTime("Y-m-d", time.Now().Unix()+14*86400), time.Now().Unix()+14*86400) + t.Log("7 days later: ", timeutil.FormatTime("Y-m-d", time.Now().Unix()+7*86400), time.Now().Unix()+7*86400) + t.Log("3 days later: ", timeutil.FormatTime("Y-m-d", time.Now().Unix()+3*86400), time.Now().Unix()+3*86400) + t.Log("today: ", timeutil.FormatTime("Y-m-d", time.Now().Unix()), time.Now().Unix()) + + executor := NewSSLCertExpireCheckExecutor() + err := executor.loop(0) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +}