diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index bd3cd296..82fa8537 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -156,7 +156,7 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, return 0, err } - op := NewIPItemOperator() + var op = NewIPItemOperator() op.ListId = listId op.IpFrom = ipFrom op.IpTo = ipTo @@ -179,6 +179,11 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, op.SourceHTTPFirewallRuleGroupId = sourceHTTPFirewallRuleGroupId op.SourceHTTPFirewallRuleSetId = sourceHTTPFirewallRuleSetId + var autoAdded = listId == firewallconfigs.GlobalListId || sourceNodeId > 0 || sourceServerId > 0 || sourceHTTPFirewallPolicyId > 0 + if autoAdded { + op.IsRead = 0 + } + op.State = IPItemStateEnabled err = this.Save(tx, op) if err != nil { @@ -187,7 +192,7 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, itemId := types.Int64(op.Id) // 自动加入名单不需要即时更新,防止数量过多而导致性能问题 - if listId == firewallconfigs.GlobalListId || sourceNodeId > 0 || sourceServerId > 0 || sourceHTTPFirewallPolicyId > 0 { + if autoAdded { return itemId, nil } @@ -379,7 +384,7 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error) } // CountAllEnabledIPItems 计算数量 -func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64) (int64, error) { +func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, unread bool) (int64, error) { var query = this.Query(tx) if len(ip) > 0 { query.Attr("ipFrom", ip) @@ -389,6 +394,9 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int6 } else { query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") } + if unread { + query.Attr("isRead", 0) + } return query. State(IPItemStateEnabled). Where("(expiredAt=0 OR expiredAt>:expiredAt)"). @@ -397,7 +405,7 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string, listId int6 } // ListAllEnabledIPItems 搜索所有IP -func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, offset int64, size int64) (result []*IPItem, err error) { +func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64, unread bool, offset int64, size int64) (result []*IPItem, err error) { var query = this.Query(tx) if len(ip) > 0 { query.Attr("ipFrom", ip) @@ -407,6 +415,9 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64 } else { query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") } + if unread { + query.Attr("isRead", 0) + } _, err = query. State(IPItemStateEnabled). Where("(expiredAt=0 OR expiredAt>:expiredAt)"). @@ -419,6 +430,14 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, listId int64 return } +// UpdateItemsRead 设置所有未已读 +func (this *IPItemDAO) UpdateItemsRead(tx *dbs.Tx) error { + return this.Query(tx). + Attr("isRead", 0). + Set("isRead", 1). + UpdateQuickly() +} + // NotifyUpdate 通知更新 func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error { // 获取ListId diff --git a/internal/db/models/ip_item_model.go b/internal/db/models/ip_item_model.go index a2072498..4295a2e3 100644 --- a/internal/db/models/ip_item_model.go +++ b/internal/db/models/ip_item_model.go @@ -23,6 +23,7 @@ type IPItem struct { SourceHTTPFirewallPolicyId uint32 `field:"sourceHTTPFirewallPolicyId"` // 来源策略ID SourceHTTPFirewallRuleGroupId uint32 `field:"sourceHTTPFirewallRuleGroupId"` // 来源规则集分组ID SourceHTTPFirewallRuleSetId uint32 `field:"sourceHTTPFirewallRuleSetId"` // 来源规则集ID + IsRead uint8 `field:"isRead"` // 是否已读 } type IPItemOperator struct { @@ -47,6 +48,7 @@ type IPItemOperator struct { SourceHTTPFirewallPolicyId interface{} // 来源策略ID SourceHTTPFirewallRuleGroupId interface{} // 来源规则集分组ID SourceHTTPFirewallRuleSetId interface{} // 来源规则集ID + IsRead interface{} // 是否已读 } func NewIPItemOperator() *IPItemOperator { diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 2b661f66..c10f9784 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -280,6 +280,7 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li SourceHTTPFirewallPolicy: pbSourcePolicy, SourceHTTPFirewallRuleGroup: pbSourceGroup, SourceHTTPFirewallRuleSet: pbSourceSet, + IsRead: item.IsRead == 1, }) } @@ -483,7 +484,7 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C if req.GlobalOnly { listId = firewallconfigs.GlobalListId } - count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip, listId) + count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip, listId, req.Unread) if err != nil { return nil, err } @@ -503,7 +504,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li if req.GlobalOnly { listId = firewallconfigs.GlobalListId } - items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, listId, req.Offset, req.Size) + items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, listId, req.Unread, req.Offset, req.Size) if err != nil { return nil, err } @@ -586,6 +587,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li SourceHTTPFirewallPolicy: pbSourcePolicy, SourceHTTPFirewallRuleGroup: pbSourceGroup, SourceHTTPFirewallRuleSet: pbSourceSet, + IsRead: item.IsRead == 1, } // 所属名单 @@ -656,3 +658,18 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li return &pb.ListAllEnabledIPItemsResponse{Results: results}, nil } + +// UpdateIPItemsRead 设置所有为已读 +func (this *IPItemService) UpdateIPItemsRead(ctx context.Context, req *pb.UpdateIPItemsReadRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + err = models.SharedIPItemDAO.UpdateItemsRead(tx) + if err != nil { + return nil, err + } + return this.Success() +}