From 175a29aab8c40a59987c2efd3c19c1ba43f7b2c6 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Tue, 20 Oct 2020 20:18:06 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=B6=88=E6=81=AF=E7=AE=A1?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/db_node_initializer.go | 12 +- internal/db/models/http_web_dao.go | 15 +-- internal/db/models/message_dao.go | 53 +++++++++ internal/nodes/api_node.go | 1 + internal/rpc/services/service_message.go | 131 ++++++++++++++++++++++ 5 files changed, 189 insertions(+), 23 deletions(-) create mode 100644 internal/rpc/services/service_message.go diff --git a/internal/db/models/db_node_initializer.go b/internal/db/models/db_node_initializer.go index e10f815e..638ecb88 100644 --- a/internal/db/models/db_node_initializer.go +++ b/internal/db/models/db_node_initializer.go @@ -156,7 +156,7 @@ func (this *DBNodeInitializer) loop() error { accessLogLocker.Lock() closingDbs := []*dbs.DB{} for nodeId, db := range accessLogDBMapping { - if !this.containsInt64(nodeIds, nodeId) { + if !lists.ContainsInt64(nodeIds, nodeId) { closingDbs = append(closingDbs, db) delete(accessLogDBMapping, nodeId) delete(accessLogDAOMapping, nodeId) @@ -250,13 +250,3 @@ func (this *DBNodeInitializer) loop() error { return nil } - -// 判断是否包含某数字 -func (this *DBNodeInitializer) containsInt64(values []int64, value int64) bool { - for _, v := range values { - if v == value { - return true - } - } - return false -} diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index e1e5077c..b3812a9b 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -9,6 +9,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" "strconv" ) @@ -548,7 +549,7 @@ func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(cachePolicyId int64) ([]i // 如果非Location if locationId == 0 { - if !this.containsInt64(result, webId) { + if !lists.ContainsInt64(result, webId) { result = append(result, webId) } break @@ -592,7 +593,7 @@ func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(firewallPolicyId i // 如果非Location if locationId == 0 { - if !this.containsInt64(result, webId) { + if !lists.ContainsInt64(result, webId) { result = append(result, webId) } break @@ -621,13 +622,3 @@ func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(locationId int64) (webId Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 FindInt64Col(0) } - -// 判断slice是否包含某个int64值 -func (this *HTTPWebDAO) containsInt64(values []int64, value int64) bool { - for _, v := range values { - if v == value { - return true - } - } - return false -} diff --git a/internal/db/models/message_dao.go b/internal/db/models/message_dao.go index f5c05c65..bfff7c42 100644 --- a/internal/db/models/message_dao.go +++ b/internal/db/models/message_dao.go @@ -3,6 +3,7 @@ package models import ( "crypto/md5" "fmt" + "github.com/TeaOSLab/EdgeAPI/internal/errors" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" @@ -93,6 +94,58 @@ func (this *MessageDAO) DeleteMessagesBeforeDay(dayTime time.Time) error { return err } +// 计算未读消息数量 +func (this *MessageDAO) CountUnreadMessages() (int64, error) { + return this.Query(). + Attr("isRead", false). + Count() +} + +// 列出单页未读消息 +func (this *MessageDAO) ListUnreadMessages(offset int64, size int64) (result []*Message, err error) { + _, err = this.Query(). + Attr("isRead", false). + Offset(offset). + Limit(size). + DescPk(). + Slice(&result). + FindAll() + return +} + +// 设置消息已读状态 +func (this *MessageDAO) UpdateMessageRead(messageId int64, b bool) error { + if messageId <= 0 { + return errors.New("invalid messageId") + } + op := NewMessageOperator() + op.Id = messageId + op.IsRead = b + _, err := this.Save(op) + return err +} + +// 设置一组消息为已读状态 +func (this *MessageDAO) UpdateMessagesRead(messageIds []int64, b bool) error { + // 这里我们一个一个更改,因为In语句不容易Prepare,且效率不高 + for _, messageId := range messageIds { + err := this.UpdateMessageRead(messageId, b) + if err != nil { + return err + } + } + return nil +} + +// 设置所有消息为已读 +func (this *MessageDAO) UpdateAllMessagesRead() error { + _, err := this.Query(). + Attr("isRead", false). + Set("isRead", true). + Update() + return err +} + // 创建消息 func (this *MessageDAO) createMessage(clusterId int64, nodeId int64, messageType MessageType, level string, body string, paramsJSON []byte) (int64, error) { h := md5.New() diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index e8f89c5d..d05af269 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -170,6 +170,7 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err pb.RegisterDBNodeServiceServer(rpcServer, &services.DBNodeService{}) pb.RegisterNodeLogServiceServer(rpcServer, &services.NodeLogService{}) pb.RegisterHTTPAccessLogServiceServer(rpcServer, &services.HTTPAccessLogService{}) + pb.RegisterMessageServiceServer(rpcServer, &services.MessageService{}) err := rpcServer.Serve(listener) if err != nil { return errors.New("[API]start rpc failed: " + err.Error()) diff --git a/internal/rpc/services/service_message.go b/internal/rpc/services/service_message.go new file mode 100644 index 00000000..f762a38a --- /dev/null +++ b/internal/rpc/services/service_message.go @@ -0,0 +1,131 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" +) + +// 消息相关服务 +type MessageService struct { +} + +// 计算未读消息数 +func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.CountUnreadMessagesRequest) (*pb.CountUnreadMessagesResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + count, err := models.SharedMessageDAO.CountUnreadMessages() + if err != nil { + return nil, err + } + return &pb.CountUnreadMessagesResponse{Count: count}, nil +} + +// 列出单页未读消息 +func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.ListUnreadMessagesRequest) (*pb.ListUnreadMessagesResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + messages, err := models.SharedMessageDAO.ListUnreadMessages(req.Offset, req.Size) + if err != nil { + return nil, err + } + result := []*pb.Message{} + for _, message := range messages { + var pbCluster *pb.NodeCluster = nil + var pbNode *pb.Node = nil + + if message.ClusterId > 0 { + cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(int64(message.ClusterId)) + if err != nil { + return nil, err + } + if cluster != nil { + pbCluster = &pb.NodeCluster{ + Id: int64(cluster.Id), + Name: cluster.Name, + } + } + } + + if message.NodeId > 0 { + node, err := models.SharedNodeDAO.FindEnabledNode(int64(message.NodeId)) + if err != nil { + return nil, err + } + if node != nil { + pbNode = &pb.Node{ + Id: int64(node.Id), + Name: node.Name, + } + } + } + + result = append(result, &pb.Message{ + Id: int64(message.Id), + Type: message.Type, + Body: message.Body, + Level: message.Level, + ParamsJSON: []byte(message.Params), + IsRead: message.IsRead == 1, + CreatedAt: int64(message.CreatedAt), + Cluster: pbCluster, + Node: pbNode, + }) + } + + return &pb.ListUnreadMessagesResponse{Messages: result}, nil +} + +// 设置消息已读状态 +func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.UpdateMessageReadRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedMessageDAO.UpdateMessageRead(req.MessageId, req.IsRead) + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +} + +// 设置一组消息已读状态 +func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.UpdateMessagesReadRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedMessageDAO.UpdateMessagesRead(req.MessageIds, req.IsRead) + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +} + +// 设置所有消息为已读 +func (this *MessageService) UpdateAllMessagesRead(ctx context.Context, req *pb.UpdateAllMessagesReadRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedMessageDAO.UpdateAllMessagesRead() + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +}