diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 36cecec4..4706ea68 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -75,13 +75,18 @@ func (this *IPItemDAO) EnableIPItem(tx *dbs.Tx, id int64) error { } // DisableIPItem 禁用条目 -func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error { +func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64, sourceUserId int64) error { version, err := SharedIPListDAO.IncreaseVersion(tx) if err != nil { return err } - _, err = this.Query(tx). + var query = this.Query(tx) + if sourceUserId > 0 { + query.Attr("sourceUserId", sourceUserId) + } + + _, err = query. Pk(id). Set("state", IPItemStateDisabled). Set("version", version). @@ -94,7 +99,7 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error { } // DisableIPItemsWithIP 禁用某个IP相关条目 -func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, userId int64, listId int64) error { +func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, sourceUserId int64, listId int64) error { if len(ipFrom) == 0 { return errors.New("invalid 'ipFrom'") } @@ -106,16 +111,13 @@ func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo stri State(IPItemStateEnabled) if listId > 0 { - if userId > 0 { - err := SharedIPListDAO.CheckUserIPList(tx, userId, listId) - if err != nil { - return err - } - } - query.Attr("listId", listId) } + if sourceUserId > 0 { + query.Attr("sourceUserId", sourceUserId) + } + ones, err := query.FindAll() if err != nil { return err @@ -125,14 +127,6 @@ func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo stri for _, one := range ones { var item = one.(*IPItem) var itemId = int64(item.Id) - var itemListId = int64(item.ListId) - if itemListId != listId && userId > 0 { - err = SharedIPListDAO.CheckUserIPList(tx, userId, itemListId) - if err != nil { - // ignore error - continue - } - } itemIds = append(itemIds, itemId) } @@ -366,10 +360,13 @@ func (this *IPItemDAO) UpdateIPItem(tx *dbs.Tx, itemId int64, ipFrom string, ipT } // CountIPItemsWithListId 计算IP数量 -func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, keyword string, ipFrom string, ipTo string, eventLevel string) (int64, error) { +func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUserId int64, keyword string, ipFrom string, ipTo string, eventLevel string) (int64, error) { var query = this.Query(tx). State(IPItemStateEnabled). Attr("listId", listId) + if sourceUserId > 0 { + query.Attr("sourceUserId", sourceUserId) + } if len(keyword) > 0 { query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)"). Param("keyword", dbutils.QuoteLike(keyword)) @@ -387,10 +384,13 @@ func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, keyword } // ListIPItemsWithListId 查找IP列表 -func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, keyword string, ipFrom string, ipTo string, eventLevel string, offset int64, size int64) (result []*IPItem, err error) { +func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUserId int64, keyword string, ipFrom string, ipTo string, eventLevel string, offset int64, size int64) (result []*IPItem, err error) { var query = this.Query(tx). State(IPItemStateEnabled). Attr("listId", listId) + if sourceUserId > 0 { + query.Attr("sourceUserId", sourceUserId) + } if len(keyword) > 0 { query.Where("(ipFrom LIKE :keyword OR ipTo LIKE :keyword)"). Param("keyword", dbutils.QuoteLike(keyword)) @@ -479,8 +479,12 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error) } // CountAllEnabledIPItems 计算数量 -func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) { +func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) { var query = this.Query(tx) + if sourceUserId > 0 { + query.Attr("sourceUserId", sourceUserId) + query.UseIndex("sourceUserId") + } if len(keyword) > 0 { query.Like("ipFrom", dbutils.QuoteLike(keyword)) } @@ -512,8 +516,12 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, keyword string, ip str } // ListAllEnabledIPItems 搜索所有IP -func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) { +func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) { var query = this.Query(tx) + if sourceUserId > 0 { + query.Attr("sourceUserId", sourceUserId) + query.UseIndex("sourceUserId") + } if len(keyword) > 0 { query.Like("ipFrom", dbutils.QuoteLike(keyword)) } @@ -549,11 +557,17 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, keyword string, ip stri } // UpdateItemsRead 设置所有未已读 -func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx) error { - return this.Query(tx). +func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx, sourceUserId int64) error { + var query = this.Query(tx). Attr("isRead", 0). - Set("isRead", 1). - UpdateQuickly() + Set("isRead", 1) + + if sourceUserId > 0 { + query.Attr("sourceUserId", sourceUserId) + query.UseIndex("sourceUserId") + } + + return query.UpdateQuickly() } // CleanExpiredIPItems 清除过期数据 diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index afa4ec52..e4a7a6bd 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -189,19 +189,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte // 如果是使用IPItemId删除 if req.IpItemId > 0 { - if userId > 0 { - listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId) - if err != nil { - return nil, err - } - - err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId) - if err != nil { - return nil, err - } - } - - err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId) + err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId, userId) if err != nil { return nil, err } @@ -210,7 +198,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte // 如果是使用ipFrom+ipTo删除 if len(req.IpFrom) > 0 { // 检查IP列表 - if req.IpListId > 0 && userId > 0 { + if req.IpListId > 0 && userId > 0 && req.IpListId != firewallconfigs.GlobalListId { err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) if err != nil { return nil, err @@ -228,14 +216,14 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte // DeleteIPItems 批量删除IP func (this *IPItemService) DeleteIPItems(ctx context.Context, req *pb.DeleteIPItemsRequest) (*pb.RPCSuccess, error) { - _, err := this.ValidateAdmin(ctx) + _, userId, err := this.ValidateAdminAndUser(ctx, true) if err != nil { return nil, err } var tx = this.NullTx() for _, itemId := range req.IpItemIds { - err = models.SharedIPItemDAO.DisableIPItem(tx, itemId) + err = models.SharedIPItemDAO.DisableIPItem(tx, itemId, userId) if err != nil { return nil, err } @@ -254,13 +242,16 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C var tx = this.NullTx() if userId > 0 { - err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) - if err != nil { - return nil, err + // 检查用户所属名单 + if req.IpListId != firewallconfigs.GlobalListId { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) + if err != nil { + return nil, err + } } } - count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel) + count, err := models.SharedIPItemDAO.CountIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel) if err != nil { return nil, err } @@ -278,13 +269,16 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li var tx = this.NullTx() if userId > 0 { - err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) - if err != nil { - return nil, err + // 检查用户所属名单 + if req.IpListId != firewallconfigs.GlobalListId { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) + if err != nil { + return nil, err + } } } - items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size) + items, err := models.SharedIPItemDAO.ListIPItemsWithListId(tx, req.IpListId, userId, req.Keyword, req.IpFrom, req.IpTo, req.EventLevel, req.Offset, req.Size) if err != nil { return nil, err } @@ -564,17 +558,21 @@ func (this *IPItemService) ExistsEnabledIPItem(ctx context.Context, req *pb.Exis // CountAllEnabledIPItems 计算所有IP数量 func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) { - _, err := this.ValidateAdmin(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, true) if err != nil { return nil, err } + if adminId > 0 { + userId = req.UserId + } + var tx = this.NullTx() var listId int64 = 0 if req.GlobalOnly { listId = firewallconfigs.GlobalListId } - count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType) + count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType) if err != nil { return nil, err } @@ -583,18 +581,22 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C // ListAllEnabledIPItems 搜索IP func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) { - _, err := this.ValidateAdmin(ctx) + adminId, userId, err := this.ValidateAdminAndUser(ctx, true) if err != nil { return nil, err } + if adminId > 0 { + userId = req.UserId + } + var results = []*pb.ListAllEnabledIPItemsResponse_Result{} var tx = this.NullTx() var listId int64 = 0 if req.GlobalOnly { listId = firewallconfigs.GlobalListId } - items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size) + items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size) if err != nil { return nil, err } @@ -703,7 +705,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li return nil, err } if list == nil { - err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id)) + err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0) if err != nil { return nil, err } @@ -728,7 +730,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li return nil, err } if policy == nil { - err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id)) + err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id), 0) if err != nil { return nil, err } @@ -768,13 +770,13 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li // UpdateIPItemsRead 设置所有为已读 func (this *IPItemService) UpdateIPItemsRead(ctx context.Context, req *pb.UpdateIPItemsReadRequest) (*pb.RPCSuccess, error) { - _, err := this.ValidateAdmin(ctx) + _, userId, err := this.ValidateAdminAndUser(ctx, true) if err != nil { return nil, err } var tx = this.NullTx() - err = models.SharedIPItemDAO.UpdateItemsRead(tx) + err = models.SharedIPItemDAO.UpdateItemsRead(tx, userId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go index 83f30944..cf5c0eab 100644 --- a/internal/rpc/services/service_ip_list.go +++ b/internal/rpc/services/service_ip_list.go @@ -5,6 +5,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/iwind/TeaGo/lists" ) @@ -68,9 +69,12 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn var tx = this.NullTx() if userId > 0 { - err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) - if err != nil { - return nil, err + // 检查用户所属名单 + if req.IpListId != firewallconfigs.GlobalListId { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) + if err != nil { + return nil, err + } } }