diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 1b2913d0..fc64be5f 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -344,25 +344,33 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error) } // CountAllEnabledIPItems 计算数量 -func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string) (int64, error) { +func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64) (int64, error) { var query = this.Query(tx) if len(ip) > 0 { query.Attr("ipFrom", ip) } + if listId > 0 { + query.Attr("listId", listId) + } else { + query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + } return query. - Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))"). State(IPItemStateEnabled). Count() } // ListAllEnabledIPItems 搜索所有IP -func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, offset int64, size int64) (result []*IPItem, err error) { +func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, offset int64, size int64) (result []*IPItem, err error) { var query = this.Query(tx) if len(ip) > 0 { query.Attr("ipFrom", ip) } + if listId > 0 { + query.Attr("listId", listId) + } else { + query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + } _, err = query. - Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))"). State(IPItemStateEnabled). DescPk(). Offset(offset). diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 520f3c89..3a9300cd 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -7,6 +7,7 @@ import ( rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "net" ) @@ -455,7 +456,11 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C } var tx = this.NullTx() - count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip) + var listId int64 = 0 + if req.GlobalOnly { + listId = firewallconfigs.GlobalListId + } + count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip, listId) if err != nil { return nil, err } @@ -471,7 +476,11 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li var results = []*pb.ListAllEnabledIPItemsResponse_Result{} var tx = this.NullTx() - items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, req.Offset, req.Size) + var listId int64 = 0 + if req.GlobalOnly { + listId = firewallconfigs.GlobalListId + } + items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, listId, req.Offset, req.Size) if err != nil { return nil, err }