diff --git a/internal/db/models/http_firewall_policy_dao.go b/internal/db/models/http_firewall_policy_dao.go index 4e1a64eb..320c8b3c 100644 --- a/internal/db/models/http_firewall_policy_dao.go +++ b/internal/db/models/http_firewall_policy_dao.go @@ -468,6 +468,19 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx * return result, nil } +// FindEnabledFirewallPolicyWithIPListId 查找使用某个IPList的策略 +func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) { + one, err := this.Query(tx). + State(HTTPFirewallPolicyStateEnabled). + Where("(JSON_CONTAINS(inbound, :listQuery, '$.whiteListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.blackListRef'))"). + Param("listQuery", maps.Map{"isOn": true, "listId": ipListId}.AsJSON()). + Find() + if err != nil || one == nil { + return nil, err + } + return one.(*HTTPFirewallPolicy), err +} + // FindEnabledFirewallPolicyIdWithRuleGroupId 查找包含某个规则分组的策略ID func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdWithRuleGroupId(tx *dbs.Tx, ruleGroupId int64) (int64, error) { return this.Query(tx). diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index d3fb6871..1b2913d0 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -4,6 +4,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" @@ -342,6 +343,35 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error) Exist() } +// CountAllEnabledIPItems 计算数量 +func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string) (int64, error) { + var query = this.Query(tx) + if len(ip) > 0 { + query.Attr("ipFrom", ip) + } + return query. + Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))"). + State(IPItemStateEnabled). + Count() +} + +// ListAllEnabledIPItems 搜索所有IP +func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, offset int64, size int64) (result []*IPItem, err error) { + var query = this.Query(tx) + if len(ip) > 0 { + query.Attr("ipFrom", ip) + } + _, err = query. + Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))"). + State(IPItemStateEnabled). + DescPk(). + Offset(offset). + Size(size). + Slice(&result). + FindAll() + return +} + // NotifyUpdate 通知更新 func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error { // 获取ListId diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index e5c00e6c..fb1c33bd 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -2,6 +2,7 @@ package models import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs" @@ -18,6 +19,15 @@ const ( ) var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList +var DefaultGlobalIPList = &IPList{ + Id: uint32(firewallconfigs.GlobalListId), + Name: "全局封锁名单", + IsPublic: 1, + IsGlobal: 1, + Type: "black", + State: IPListStateEnabled, + IsOn: 1, +} type IPListDAO dbs.DAO @@ -59,7 +69,19 @@ func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error { } // FindEnabledIPList 查找启用中的条目 -func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64) (*IPList, error) { +func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64, cacheMap *utils.CacheMap) (*IPList, error) { + if id == firewallconfigs.GlobalListId { + return DefaultGlobalIPList, nil + } + + var cacheKey = this.Table + ":FindEnabledIPList:" + types.String(id) + if cacheMap != nil { + cache, ok := cacheMap.Get(cacheKey) + if ok { + return cache.(*IPList), nil + } + } + result, err := this.Query(tx). Pk(id). Attr("state", IPListStateEnabled). @@ -67,6 +89,11 @@ func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64) (*IPList, error) if result == nil { return nil, err } + + if cacheMap != nil { + cacheMap.Put(cacheKey, result) + } + return result.(*IPList), err } @@ -82,14 +109,7 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) { func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, error) { // 全局黑名单 if listId == firewallconfigs.GlobalListId { - return &IPList{ - Id: uint32(listId), - IsPublic: 1, - IsGlobal: 1, - Type: "black", - State: IPListStateEnabled, - IsOn: 1, - }, nil + return DefaultGlobalIPList, nil } // 检查缓存 diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 305d4b1e..520f3c89 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -387,7 +387,7 @@ func (this *IPItemService) CheckIPItemStatus(ctx context.Context, req *pb.CheckI tx := this.NullTx() // 名单类型 - list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId) + list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil) if err != nil { return nil, err } @@ -446,3 +446,181 @@ func (this *IPItemService) ExistsEnabledIPItem(ctx context.Context, req *pb.Exis } return &pb.ExistsEnabledIPItemResponse{Exists: b}, nil } + +// CountAllEnabledIPItems 计算所有IP数量 +func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip) + if err != nil { + return nil, err + } + return this.SuccessCount(count) +} + +// ListAllEnabledIPItems 搜索IP +func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + var results = []*pb.ListAllEnabledIPItemsResponse_Result{} + var tx = this.NullTx() + items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, req.Offset, req.Size) + if err != nil { + return nil, err + } + + var cacheMap = utils.NewCacheMap() + for _, item := range items { + // server + var pbSourceServer *pb.Server + if item.SourceServerId > 0 { + serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(item.SourceServerId)) + if err != nil { + return nil, err + } + pbSourceServer = &pb.Server{ + Id: int64(item.SourceServerId), + Name: serverName, + } + } + + // WAF策略 + var pbSourcePolicy *pb.HTTPFirewallPolicy + if item.SourceHTTPFirewallPolicyId > 0 { + policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicyBasic(tx, int64(item.SourceHTTPFirewallPolicyId)) + if err != nil { + return nil, err + } + if policy != nil { + pbSourcePolicy = &pb.HTTPFirewallPolicy{ + Id: int64(item.SourceHTTPFirewallPolicyId), + Name: policy.Name, + ServerId: int64(policy.ServerId), + } + } + } + + // WAF分组 + var pbSourceGroup *pb.HTTPFirewallRuleGroup + if item.SourceHTTPFirewallRuleGroupId > 0 { + groupName, err := models.SharedHTTPFirewallRuleGroupDAO.FindHTTPFirewallRuleGroupName(tx, int64(item.SourceHTTPFirewallRuleGroupId)) + if err != nil { + return nil, err + } + pbSourceGroup = &pb.HTTPFirewallRuleGroup{ + Id: int64(item.SourceHTTPFirewallRuleGroupId), + Name: groupName, + } + } + + // WAF规则集 + var pbSourceSet *pb.HTTPFirewallRuleSet + if item.SourceHTTPFirewallRuleSetId > 0 { + setName, err := models.SharedHTTPFirewallRuleSetDAO.FindHTTPFirewallRuleSetName(tx, int64(item.SourceHTTPFirewallRuleSetId)) + if err != nil { + return nil, err + } + pbSourceSet = &pb.HTTPFirewallRuleSet{ + Id: int64(item.SourceHTTPFirewallRuleSetId), + Name: setName, + } + } + + var pbItem = &pb.IPItem{ + Id: int64(item.Id), + IpFrom: item.IpFrom, + IpTo: item.IpTo, + Version: int64(item.Version), + CreatedAt: int64(item.CreatedAt), + ExpiredAt: int64(item.ExpiredAt), + Reason: item.Reason, + Type: item.Type, + EventLevel: item.EventLevel, + NodeId: int64(item.NodeId), + ServerId: int64(item.ServerId), + SourceNodeId: int64(item.SourceNodeId), + SourceServerId: int64(item.SourceServerId), + SourceHTTPFirewallPolicyId: int64(item.SourceHTTPFirewallPolicyId), + SourceHTTPFirewallRuleGroupId: int64(item.SourceHTTPFirewallRuleGroupId), + SourceHTTPFirewallRuleSetId: int64(item.SourceHTTPFirewallRuleSetId), + SourceServer: pbSourceServer, + SourceHTTPFirewallPolicy: pbSourcePolicy, + SourceHTTPFirewallRuleGroup: pbSourceGroup, + SourceHTTPFirewallRuleSet: pbSourceSet, + } + + // 所属名单 + list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId), cacheMap) + if err != nil { + return nil, err + } + if list == nil { + err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id)) + if err != nil { + return nil, err + } + continue + } + var pbList = &pb.IPList{ + Id: int64(list.Id), + Name: list.Name, + Type: list.Type, + IsPublic: list.IsPublic == 1, + IsGlobal: list.IsGlobal == 1, + } + + // 所属服务(注意同SourceServer不同) + var pbFirewallServer *pb.Server + + // 所属策略(注意同SourceHTTPFirewallPolicy不同) + var pbFirewallPolicy *pb.HTTPFirewallPolicy + if list.IsPublic == 0 { + policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyWithIPListId(tx, int64(list.Id)) + if err != nil { + return nil, err + } + if policy == nil { + err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id)) + if err != nil { + return nil, err + } + continue + } + + pbFirewallPolicy = &pb.HTTPFirewallPolicy{ + Id: int64(policy.Id), + Name: policy.Name, + } + + if policy.ServerId > 0 { + serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(policy.ServerId)) + if err != nil { + return nil, err + } + if len(serverName) == 0 { + serverName = "[已删除]" + } + pbFirewallServer = &pb.Server{ + Id: int64(policy.ServerId), + Name: serverName, + } + } + } + + results = append(results, &pb.ListAllEnabledIPItemsResponse_Result{ + IpList: pbList, + IpItem: pbItem, + Server: pbFirewallServer, + HttpFirewallPolicy: pbFirewallPolicy, + }) + } + + return &pb.ListAllEnabledIPItemsResponse{Results: results}, nil +} diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go index deb11e75..9cbbd735 100644 --- a/internal/rpc/services/service_ip_list.go +++ b/internal/rpc/services/service_ip_list.go @@ -3,6 +3,7 @@ package services import ( "context" "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/iwind/TeaGo/lists" ) @@ -56,7 +57,7 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn tx := this.NullTx() - list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId) + list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil) if err != nil { return nil, err } @@ -171,12 +172,13 @@ func (this *IPListService) FindEnabledIPListContainsIP(ctx context.Context, req var pbLists = []*pb.IPList{} var listIds = []int64{} + var cacheMap = utils.NewCacheMap() for _, item := range items { if lists.ContainsInt64(listIds, int64(item.ListId)) { continue } - list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId)) + list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId), cacheMap) if err != nil { return nil, err }