diff --git a/internal/db/models/admin_dao.go b/internal/db/models/admin_dao.go index fcee164e..20e725b0 100644 --- a/internal/db/models/admin_dao.go +++ b/internal/db/models/admin_dao.go @@ -198,7 +198,9 @@ func (this *AdminDAO) UpdateAdminLogin(adminId int64, username string, password op := NewAdminOperator() op.Id = adminId op.Username = username - op.Password = stringutil.Md5(password) + if len(password) > 0 { + op.Password = stringutil.Md5(password) + } err := this.Save(op) return err } diff --git a/internal/db/models/message_dao.go b/internal/db/models/message_dao.go index 93cfea75..63009ebb 100644 --- a/internal/db/models/message_dao.go +++ b/internal/db/models/message_dao.go @@ -135,16 +135,29 @@ func (this *MessageDAO) DeleteMessagesBeforeDay(dayTime time.Time) error { } // 计算未读消息数量 -func (this *MessageDAO) CountUnreadMessages() (int64, error) { - return this.Query(). - Attr("isRead", false). - Count() +func (this *MessageDAO) CountUnreadMessages(adminId int64, userId int64) (int64, error) { + query := this.Query(). + Attr("isRead", false) + if adminId > 0 { + query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). + Param("adminId", adminId) + } else if userId > 0 { + query.Attr("userId", userId) + } + return query.Count() } // 列出单页未读消息 -func (this *MessageDAO) ListUnreadMessages(offset int64, size int64) (result []*Message, err error) { - _, err = this.Query(). - Attr("isRead", false). +func (this *MessageDAO) ListUnreadMessages(adminId int64, userId int64, offset int64, size int64) (result []*Message, err error) { + query := this.Query(). + Attr("isRead", false) + if adminId > 0 { + query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). + Param("adminId", adminId) + } else if userId > 0 { + query.Attr("userId", userId) + } + _, err = query. Offset(offset). Limit(size). DescPk(). @@ -178,14 +191,37 @@ func (this *MessageDAO) UpdateMessagesRead(messageIds []int64, b bool) error { } // 设置所有消息为已读 -func (this *MessageDAO) UpdateAllMessagesRead() error { - _, err := this.Query(). - Attr("isRead", false). +func (this *MessageDAO) UpdateAllMessagesRead(adminId int64, userId int64) error { + query := this.Query(). + Attr("isRead", false) + if adminId > 0 { + query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). + Param("adminId", adminId) + } else if userId > 0 { + query.Attr("userId", userId) + } + _, err := query. Set("isRead", true). Update() return err } +// 检查消息权限 +func (this *MessageDAO) CheckMessageUser(messageId int64, adminId int64, userId int64) (bool, error) { + if messageId <= 0 || (adminId <= 0 && userId <= 0) { + return false, nil + } + query := this.Query(). + Pk(messageId) + if adminId > 0 { + query.Where("(adminId=:adminId OR (adminId=0 AND userId=0))"). + Param("adminId", adminId) + } else if userId > 0 { + query.Attr("userId", userId) + } + return query.Exist() +} + // 创建消息 func (this *MessageDAO) createMessage(clusterId int64, nodeId int64, messageType MessageType, level string, body string, paramsJSON []byte) (int64, error) { h := md5.New() diff --git a/internal/db/models/user_dao.go b/internal/db/models/user_dao.go index 5807edbd..f592f9d4 100644 --- a/internal/db/models/user_dao.go +++ b/internal/db/models/user_dao.go @@ -126,6 +126,32 @@ func (this *UserDAO) UpdateUser(userId int64, username string, password string, return err } +// 修改用户基本信息 +func (this *UserDAO) UpdateUserInfo(userId int64, fullname string) error { + if userId <= 0 { + return errors.New("invalid userId") + } + op := NewUserOperator() + op.Id = userId + op.Fullname = fullname + return this.Save(op) +} + +// 修改用户登录信息 +func (this *UserDAO) UpdateUserLogin(userId int64, username string, password string) error { + if userId <= 0 { + return errors.New("invalid userId") + } + op := NewUserOperator() + op.Id = userId + op.Username = username + if len(password) > 0 { + op.Password = stringutil.Md5(password) + } + err := this.Save(op) + return err +} + // 计算用户数量 func (this *UserDAO) CountAllEnabledUsers(keyword string) (int64, error) { query := this.Query() diff --git a/internal/rpc/services/service_message.go b/internal/rpc/services/service_message.go index 9ce0e7df..5ac997fe 100644 --- a/internal/rpc/services/service_message.go +++ b/internal/rpc/services/service_message.go @@ -3,7 +3,6 @@ package services import ( "context" "github.com/TeaOSLab/EdgeAPI/internal/db/models" - rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" ) @@ -15,12 +14,12 @@ type MessageService struct { // 计算未读消息数 func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.CountUnreadMessagesRequest) (*pb.RPCCountResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } - count, err := models.SharedMessageDAO.CountUnreadMessages() + count, err := models.SharedMessageDAO.CountUnreadMessages(adminId, userId) if err != nil { return nil, err } @@ -30,12 +29,12 @@ func (this *MessageService) CountUnreadMessages(ctx context.Context, req *pb.Cou // 列出单页未读消息 func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.ListUnreadMessagesRequest) (*pb.ListUnreadMessagesResponse, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } - messages, err := models.SharedMessageDAO.ListUnreadMessages(req.Offset, req.Size) + messages, err := models.SharedMessageDAO.ListUnreadMessages(adminId, userId, req.Offset, req.Size) if err != nil { return nil, err } @@ -89,11 +88,20 @@ func (this *MessageService) ListUnreadMessages(ctx context.Context, req *pb.List // 设置消息已读状态 func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.UpdateMessageReadRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } + // 校验权限 + exists, err := models.SharedMessageDAO.CheckMessageUser(req.MessageId, adminId, userId) + if err != nil { + return nil, err + } + if !exists { + return nil, this.PermissionError() + } + err = models.SharedMessageDAO.UpdateMessageRead(req.MessageId, req.IsRead) if err != nil { return nil, err @@ -104,14 +112,25 @@ func (this *MessageService) UpdateMessageRead(ctx context.Context, req *pb.Updat // 设置一组消息已读状态 func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.UpdateMessagesReadRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } - err = models.SharedMessageDAO.UpdateMessagesRead(req.MessageIds, req.IsRead) - if err != nil { - return nil, err + // 校验权限 + for _, messageId := range req.MessageIds { + exists, err := models.SharedMessageDAO.CheckMessageUser(messageId, adminId, userId) + if err != nil { + return nil, err + } + if !exists { + return nil, this.PermissionError() + } + + err = models.SharedMessageDAO.UpdateMessageRead(messageId, req.IsRead) + if err != nil { + return nil, err + } } return this.Success() } @@ -119,12 +138,13 @@ func (this *MessageService) UpdateMessagesRead(ctx context.Context, req *pb.Upda // 设置所有消息为已读 func (this *MessageService) UpdateAllMessagesRead(ctx context.Context, req *pb.UpdateAllMessagesReadRequest) (*pb.RPCSuccess, error) { // 校验请求 - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + // 校验请求 + adminId, userId, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } - err = models.SharedMessageDAO.UpdateAllMessagesRead() + err = models.SharedMessageDAO.UpdateAllMessagesRead(adminId, userId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index 487ac08d..85376a6d 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -101,7 +101,7 @@ func (this *UserService) ListEnabledUsers(ctx context.Context, req *pb.ListEnabl // 查询单个用户信息 func (this *UserService) FindEnabledUser(ctx context.Context, req *pb.FindEnabledUserRequest) (*pb.FindEnabledUserResponse, error) { - _, err := this.ValidateAdmin(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } @@ -127,17 +127,22 @@ func (this *UserService) FindEnabledUser(ctx context.Context, req *pb.FindEnable } // 检查用户名是否存在 -func (this *UserService) CheckUsername(ctx context.Context, req *pb.CheckUsernameRequest) (*pb.CheckUsernameResponse, error) { - _, err := this.ValidateAdmin(ctx, 0) +func (this *UserService) CheckUserUsername(ctx context.Context, req *pb.CheckUserUsernameRequest) (*pb.CheckUserUsernameResponse, error) { + userType, userId, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin, rpcutils.UserTypeUser) if err != nil { return nil, err } + // 校验权限 + if userType == rpcutils.UserTypeUser && userId != req.UserId { + return nil, this.PermissionError() + } + b, err := models.SharedUserDAO.ExistUser(req.UserId, req.Username) if err != nil { return nil, err } - return &pb.CheckUsernameResponse{Exists: b}, nil + return &pb.CheckUserUsernameResponse{Exists: b}, nil } // 登录 @@ -174,3 +179,39 @@ func (this *UserService) LoginUser(ctx context.Context, req *pb.LoginUserRequest IsOk: true, }, nil } + +// 修改用户基本信息 +func (this *UserService) UpdateUserInfo(ctx context.Context, req *pb.UpdateUserInfoRequest) (*pb.RPCSuccess, error) { + userId, err := this.ValidateUser(ctx) + if err != nil { + return nil, err + } + + if userId != req.UserId { + return nil, this.PermissionError() + } + + err = models.SharedUserDAO.UpdateUserInfo(req.UserId, req.Fullname) + if err != nil { + return nil, err + } + return this.Success() +} + +// 修改用户登录信息 +func (this *UserService) UpdateUserLogin(ctx context.Context, req *pb.UpdateUserLoginRequest) (*pb.RPCSuccess, error) { + userId, err := this.ValidateUser(ctx) + if err != nil { + return nil, err + } + + if userId != req.UserId { + return nil, this.PermissionError() + } + + err = models.SharedUserDAO.UpdateUserLogin(req.UserId, req.Username, req.Password) + if err != nil { + return nil, err + } + return this.Success() +} diff --git a/internal/rpc/services/service_user_bill.go b/internal/rpc/services/service_user_bill.go index 78637a1a..fa7d445d 100644 --- a/internal/rpc/services/service_user_bill.go +++ b/internal/rpc/services/service_user_bill.go @@ -39,7 +39,7 @@ func (this *UserBillService) GenerateAllUserBills(ctx context.Context, req *pb.G // 计算所有账单数量 func (this *UserBillService) CountAllUserBills(ctx context.Context, req *pb.CountAllUserBillsRequest) (*pb.RPCCountResponse, error) { - _, err := this.ValidateAdmin(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } @@ -53,7 +53,7 @@ func (this *UserBillService) CountAllUserBills(ctx context.Context, req *pb.Coun // 列出单页账单 func (this *UserBillService) ListUserBills(ctx context.Context, req *pb.ListUserBillsRequest) (*pb.ListUserBillsResponse, error) { - _, err := this.ValidateAdmin(ctx, 0) + _, _, err := this.ValidateAdminAndUser(ctx, 0) if err != nil { return nil, err } diff --git a/internal/rpc/utils/utils.go b/internal/rpc/utils/utils.go index be2b58ce..2a895dc7 100644 --- a/internal/rpc/utils/utils.go +++ b/internal/rpc/utils/utils.go @@ -93,7 +93,7 @@ func ValidateRequest(ctx context.Context, userTypes ...UserType) (userType UserT t := m.GetString("type") if len(userTypes) > 0 && !lists.ContainsString(userTypes, t) { - return UserTypeNone, 0, errors.New("not supported user type: '" + userType + "'") + return UserTypeNone, 0, errors.New("not supported node type: '" + t + "'") } switch apiToken.Role {