优化操作IP条目时检查用户ID的相关代码

This commit is contained in:
GoEdgeLab
2023-04-03 10:02:17 +08:00
parent 61cb729305
commit 3a3854d443
3 changed files with 81 additions and 61 deletions

View File

@@ -189,19 +189,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
// 如果是使用IPItemId删除
if req.IpItemId > 0 {
if userId > 0 {
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
}
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId)
err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId, userId)
if err != nil {
return nil, err
}
@@ -210,7 +198,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
// 如果是使用ipFrom+ipTo删除
if len(req.IpFrom) > 0 {
// 检查IP列表
if req.IpListId > 0 && userId > 0 {
if req.IpListId > 0 && userId > 0 && req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
@@ -228,14 +216,14 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
// DeleteIPItems 批量删除IP
func (this *IPItemService) DeleteIPItems(ctx context.Context, req *pb.DeleteIPItemsRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
for _, itemId := range req.IpItemIds {
err = models.SharedIPItemDAO.DisableIPItem(tx, itemId)
err = models.SharedIPItemDAO.DisableIPItem(tx, itemId, userId)
if err != nil {
return nil, err
}
@@ -254,13 +242,16 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C
var tx = this.NullTx()
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
// 检查用户所属名单
if req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
}
count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel)
count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel)
if err != nil {
return nil, err
}
@@ -278,13 +269,16 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li
var tx = this.NullTx()
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
// 检查用户所属名单
if req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
}
items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size)
items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size)
if err != nil {
return nil, err
}
@@ -564,17 +558,21 @@ func (this *IPItemService) ExistsEnabledIPItem(ctx context.Context, req *pb.Exis
// CountAllEnabledIPItems 计算所有IP数量
func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx)
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if adminId > 0 {
userId = req.UserId
}
var tx = this.NullTx()
var listId int64 = 0
if req.GlobalOnly {
listId = firewallconfigs.GlobalListId
}
count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType)
count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType)
if err != nil {
return nil, err
}
@@ -583,18 +581,22 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C
// ListAllEnabledIPItems 搜索IP
func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) {
_, err := this.ValidateAdmin(ctx)
adminId, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
if adminId > 0 {
userId = req.UserId
}
var results = []*pb.ListAllEnabledIPItemsResponse_Result{}
var tx = this.NullTx()
var listId int64 = 0
if req.GlobalOnly {
listId = firewallconfigs.GlobalListId
}
items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size)
items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size)
if err != nil {
return nil, err
}
@@ -703,7 +705,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
return nil, err
}
if list == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id))
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
if err != nil {
return nil, err
}
@@ -728,7 +730,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
return nil, err
}
if policy == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id))
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0)
if err != nil {
return nil, err
}
@@ -768,13 +770,13 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li
// UpdateIPItemsRead 设置所有为已读
func (this *IPItemService) UpdateIPItemsRead(ctx context.Context, req *pb.UpdateIPItemsReadRequest) (*pb.RPCSuccess, error) {
_, err := this.ValidateAdmin(ctx)
_, userId, err := this.ValidateAdminAndUser(ctx, true)
if err != nil {
return nil, err
}
var tx = this.NullTx()
err = models.SharedIPItemDAO.UpdateItemsRead(tx)
err = models.SharedIPItemDAO.UpdateItemsRead(tx, userId)
if err != nil {
return nil, err
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/lists"
)
@@ -68,9 +69,12 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn
var tx = this.NullTx()
if userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
// 检查用户所属名单
if req.IpListId != firewallconfigs.GlobalListId {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
}