From 8059ff4e650c297d22f9904ca38794962bd931c0 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Thu, 21 Apr 2022 15:09:18 +0800 Subject: [PATCH] =?UTF-8?q?IP=E5=88=97=E8=A1=A8=E5=A2=9E=E5=8A=A0=E5=90=8D?= =?UTF-8?q?=E5=8D=95=E7=B1=BB=E5=9E=8B=E7=AD=9B=E9=80=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/ip_item_dao.go | 19 +++++++++++++++---- internal/rpc/services/service_ip_item.go | 4 ++-- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index d956dd96..a4261a36 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -378,7 +378,7 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error) } // CountAllEnabledIPItems 计算数量 -func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, unread bool, eventLevel string) (int64, error) { +func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) { var query = this.Query(tx) if len(ip) > 0 { query.Attr("ipFrom", ip) @@ -386,7 +386,12 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int6 if listId > 0 { query.Attr("listId", listId) } else { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + if len(listType) > 0 { + query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") + query.Param("listType", listType) + } else { + query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + } } if unread { query.Attr("isRead", 0) @@ -394,6 +399,7 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int6 if len(eventLevel) > 0 { query.Attr("eventLevel", eventLevel) } + return query. State(IPItemStateEnabled). Where("(expiredAt=0 OR expiredAt>:expiredAt)"). @@ -402,7 +408,7 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int6 } // ListAllEnabledIPItems 搜索所有IP -func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, unread bool, eventLevel string, offset int64, size int64) (result []*IPItem, err error) { +func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) { var query = this.Query(tx) if len(ip) > 0 { query.Attr("ipFrom", ip) @@ -410,7 +416,12 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64 if listId > 0 { query.Attr("listId", listId) } else { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + if len(listType) > 0 { + query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") + query.Param("listType", listType) + } else { + query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + } } if unread { query.Attr("isRead", 0) diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 345c8e7c..a43e8657 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -484,7 +484,7 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C if req.GlobalOnly { listId = firewallconfigs.GlobalListId } - count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip, listId, req.Unread, req.EventLevel) + count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip, listId, req.Unread, req.EventLevel, req.ListType) if err != nil { return nil, err } @@ -504,7 +504,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li if req.GlobalOnly { listId = firewallconfigs.GlobalListId } - items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, listId, req.Unread, req.EventLevel, req.Offset, req.Size) + items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size) if err != nil { return nil, err }