增加全局查看、检索IP功能

This commit is contained in:
刘祥超
2021-11-17 19:51:00 +08:00
parent 4d7e82d0a2
commit f7cbf051bd
5 changed files with 255 additions and 12 deletions

View File

@@ -468,6 +468,19 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *
return result, nil
}
// FindEnabledFirewallPolicyWithIPListId 查找使用某个IPList的策略
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) {
one, err := this.Query(tx).
State(HTTPFirewallPolicyStateEnabled).
Where("(JSON_CONTAINS(inbound, :listQuery, '$.whiteListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.blackListRef'))").
Param("listQuery", maps.Map{"isOn": true, "listId": ipListId}.AsJSON()).
Find()
if err != nil || one == nil {
return nil, err
}
return one.(*HTTPFirewallPolicy), err
}
// FindEnabledFirewallPolicyIdWithRuleGroupId 查找包含某个规则分组的策略ID
func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdWithRuleGroupId(tx *dbs.Tx, ruleGroupId int64) (int64, error) {
return this.Query(tx).

View File

@@ -4,6 +4,7 @@ import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
@@ -342,6 +343,35 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error)
Exist()
}
// CountAllEnabledIPItems 计算数量
func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, ip string) (int64, error) {
var query = this.Query(tx)
if len(ip) > 0 {
query.Attr("ipFrom", ip)
}
return query.
Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))").
State(IPItemStateEnabled).
Count()
}
// ListAllEnabledIPItems 搜索所有IP
func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, ip string, offset int64, size int64) (result []*IPItem, err error) {
var query = this.Query(tx)
if len(ip) > 0 {
query.Attr("ipFrom", ip)
}
_, err = query.
Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))").
State(IPItemStateEnabled).
DescPk().
Offset(offset).
Size(size).
Slice(&result).
FindAll()
return
}
// NotifyUpdate 通知更新
func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error {
// 获取ListId

View File

@@ -2,6 +2,7 @@ package models
import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
@@ -18,6 +19,15 @@ const (
)
var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList
var DefaultGlobalIPList = &IPList{
Id: uint32(firewallconfigs.GlobalListId),
Name: "全局封锁名单",
IsPublic: 1,
IsGlobal: 1,
Type: "black",
State: IPListStateEnabled,
IsOn: 1,
}
type IPListDAO dbs.DAO
@@ -59,7 +69,19 @@ func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error {
}
// FindEnabledIPList 查找启用中的条目
func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64) (*IPList, error) {
func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64, cacheMap *utils.CacheMap) (*IPList, error) {
if id == firewallconfigs.GlobalListId {
return DefaultGlobalIPList, nil
}
var cacheKey = this.Table + ":FindEnabledIPList:" + types.String(id)
if cacheMap != nil {
cache, ok := cacheMap.Get(cacheKey)
if ok {
return cache.(*IPList), nil
}
}
result, err := this.Query(tx).
Pk(id).
Attr("state", IPListStateEnabled).
@@ -67,6 +89,11 @@ func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64) (*IPList, error)
if result == nil {
return nil, err
}
if cacheMap != nil {
cacheMap.Put(cacheKey, result)
}
return result.(*IPList), err
}
@@ -82,14 +109,7 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) {
func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, error) {
// 全局黑名单
if listId == firewallconfigs.GlobalListId {
return &IPList{
Id: uint32(listId),
IsPublic: 1,
IsGlobal: 1,
Type: "black",
State: IPListStateEnabled,
IsOn: 1,
}, nil
return DefaultGlobalIPList, nil
}
// 检查缓存

View File

@@ -387,7 +387,7 @@ func (this *IPItemService) CheckIPItemStatus(ctx context.Context, req *pb.CheckI
tx := this.NullTx()
// 名单类型
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId)
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
if err != nil {
return nil, err
}
@@ -446,3 +446,181 @@ func (this *IPItemService) ExistsEnabledIPItem(ctx context.Context, req *pb.Exis
}
return &pb.ExistsEnabledIPItemResponse{Exists: b}, nil
}
// CountAllEnabledIPItems 计算所有IP数量
func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.CountAllEnabledIPItemsRequest) (*pb.RPCCountResponse, error) {
_, err := this.ValidateAdmin(ctx, 0)
if err != nil {
return nil, err
}
var tx = this.NullTx()
count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, req.Ip)
if err != nil {
return nil, err
}
return this.SuccessCount(count)
}
// ListAllEnabledIPItems 搜索IP
func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.ListAllEnabledIPItemsRequest) (*pb.ListAllEnabledIPItemsResponse, error) {
_, err := this.ValidateAdmin(ctx, 0)
if err != nil {
return nil, err
}
var results = []*pb.ListAllEnabledIPItemsResponse_Result{}
var tx = this.NullTx()
items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, req.Ip, req.Offset, req.Size)
if err != nil {
return nil, err
}
var cacheMap = utils.NewCacheMap()
for _, item := range items {
// server
var pbSourceServer *pb.Server
if item.SourceServerId > 0 {
serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(item.SourceServerId))
if err != nil {
return nil, err
}
pbSourceServer = &pb.Server{
Id: int64(item.SourceServerId),
Name: serverName,
}
}
// WAF策略
var pbSourcePolicy *pb.HTTPFirewallPolicy
if item.SourceHTTPFirewallPolicyId > 0 {
policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicyBasic(tx, int64(item.SourceHTTPFirewallPolicyId))
if err != nil {
return nil, err
}
if policy != nil {
pbSourcePolicy = &pb.HTTPFirewallPolicy{
Id: int64(item.SourceHTTPFirewallPolicyId),
Name: policy.Name,
ServerId: int64(policy.ServerId),
}
}
}
// WAF分组
var pbSourceGroup *pb.HTTPFirewallRuleGroup
if item.SourceHTTPFirewallRuleGroupId > 0 {
groupName, err := models.SharedHTTPFirewallRuleGroupDAO.FindHTTPFirewallRuleGroupName(tx, int64(item.SourceHTTPFirewallRuleGroupId))
if err != nil {
return nil, err
}
pbSourceGroup = &pb.HTTPFirewallRuleGroup{
Id: int64(item.SourceHTTPFirewallRuleGroupId),
Name: groupName,
}
}
// WAF规则集
var pbSourceSet *pb.HTTPFirewallRuleSet
if item.SourceHTTPFirewallRuleSetId > 0 {
setName, err := models.SharedHTTPFirewallRuleSetDAO.FindHTTPFirewallRuleSetName(tx, int64(item.SourceHTTPFirewallRuleSetId))
if err != nil {
return nil, err
}
pbSourceSet = &pb.HTTPFirewallRuleSet{
Id: int64(item.SourceHTTPFirewallRuleSetId),
Name: setName,
}
}
var pbItem = &pb.IPItem{
Id: int64(item.Id),
IpFrom: item.IpFrom,
IpTo: item.IpTo,
Version: int64(item.Version),
CreatedAt: int64(item.CreatedAt),
ExpiredAt: int64(item.ExpiredAt),
Reason: item.Reason,
Type: item.Type,
EventLevel: item.EventLevel,
NodeId: int64(item.NodeId),
ServerId: int64(item.ServerId),
SourceNodeId: int64(item.SourceNodeId),
SourceServerId: int64(item.SourceServerId),
SourceHTTPFirewallPolicyId: int64(item.SourceHTTPFirewallPolicyId),
SourceHTTPFirewallRuleGroupId: int64(item.SourceHTTPFirewallRuleGroupId),
SourceHTTPFirewallRuleSetId: int64(item.SourceHTTPFirewallRuleSetId),
SourceServer: pbSourceServer,
SourceHTTPFirewallPolicy: pbSourcePolicy,
SourceHTTPFirewallRuleGroup: pbSourceGroup,
SourceHTTPFirewallRuleSet: pbSourceSet,
}
// 所属名单
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId), cacheMap)
if err != nil {
return nil, err
}
if list == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id))
if err != nil {
return nil, err
}
continue
}
var pbList = &pb.IPList{
Id: int64(list.Id),
Name: list.Name,
Type: list.Type,
IsPublic: list.IsPublic == 1,
IsGlobal: list.IsGlobal == 1,
}
// 所属服务注意同SourceServer不同
var pbFirewallServer *pb.Server
// 所属策略注意同SourceHTTPFirewallPolicy不同
var pbFirewallPolicy *pb.HTTPFirewallPolicy
if list.IsPublic == 0 {
policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyWithIPListId(tx, int64(list.Id))
if err != nil {
return nil, err
}
if policy == nil {
err = models.SharedIPItemDAO.DisableIPItem(tx, int64(item.Id))
if err != nil {
return nil, err
}
continue
}
pbFirewallPolicy = &pb.HTTPFirewallPolicy{
Id: int64(policy.Id),
Name: policy.Name,
}
if policy.ServerId > 0 {
serverName, err := models.SharedServerDAO.FindEnabledServerName(tx, int64(policy.ServerId))
if err != nil {
return nil, err
}
if len(serverName) == 0 {
serverName = "[已删除]"
}
pbFirewallServer = &pb.Server{
Id: int64(policy.ServerId),
Name: serverName,
}
}
}
results = append(results, &pb.ListAllEnabledIPItemsResponse_Result{
IpList: pbList,
IpItem: pbItem,
Server: pbFirewallServer,
HttpFirewallPolicy: pbFirewallPolicy,
})
}
return &pb.ListAllEnabledIPItemsResponse{Results: results}, nil
}

View File

@@ -3,6 +3,7 @@ package services
import (
"context"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/iwind/TeaGo/lists"
)
@@ -56,7 +57,7 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn
tx := this.NullTx()
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId)
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, req.IpListId, nil)
if err != nil {
return nil, err
}
@@ -171,12 +172,13 @@ func (this *IPListService) FindEnabledIPListContainsIP(ctx context.Context, req
var pbLists = []*pb.IPList{}
var listIds = []int64{}
var cacheMap = utils.NewCacheMap()
for _, item := range items {
if lists.ContainsInt64(listIds, int64(item.ListId)) {
continue
}
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId))
list, err := models.SharedIPListDAO.FindEnabledIPList(tx, int64(item.ListId), cacheMap)
if err != nil {
return nil, err
}