diff --git a/internal/db/models/http_firewall_policy_dao.go b/internal/db/models/http_firewall_policy_dao.go index 149db490..74e41026 100644 --- a/internal/db/models/http_firewall_policy_dao.go +++ b/internal/db/models/http_firewall_policy_dao.go @@ -645,7 +645,7 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx * ones, err := this.Query(tx). ResultPk(). State(HTTPFirewallPolicyStateEnabled). - Where("(JSON_CONTAINS(inbound, :listQuery, '$.whiteListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.blackListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.publicWhiteListRefs') OR JSON_CONTAINS(inbound, :listQuery, '$.publicBlackListRefs'))"). + Where("(JSON_CONTAINS(inbound, :listQuery, '$.whiteListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.blackListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.publicWhiteListRefs') OR JSON_CONTAINS(inbound, :listQuery, '$.publicBlackListRefs') OR JSON_CONTAINS(inbound, :listQuery, '$.publicGreyListRefs'))"). Param("listQuery", maps.Map{"isOn": true, "listId": ipListId}.AsJSON()). FindAll() if err != nil { @@ -663,7 +663,7 @@ func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx * func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyWithIPListId(tx *dbs.Tx, ipListId int64) (*HTTPFirewallPolicy, error) { one, err := this.Query(tx). State(HTTPFirewallPolicyStateEnabled). - Where("(JSON_CONTAINS(inbound, :listQuery, '$.whiteListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.blackListRef'))"). + Where("(JSON_CONTAINS(inbound, :listQuery, '$.whiteListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.blackListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.greyListRef'))"). Param("listQuery", maps.Map{"isOn": true, "listId": ipListId}.AsJSON()). Find() if err != nil || one == nil { diff --git a/internal/db/models/http_firewall_rule_set_dao.go b/internal/db/models/http_firewall_rule_set_dao.go index 99170014..08636e6d 100644 --- a/internal/db/models/http_firewall_rule_set_dao.go +++ b/internal/db/models/http_firewall_rule_set_dao.go @@ -135,7 +135,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int var ipListId = actionConfig.Options.GetInt64("ipListId") if ipListId <= 0 { // default list id if forNode { - actionConfig.Options["ipListId"] = firewallconfigs.GlobalListId + actionConfig.Options["ipListId"] = firewallconfigs.FindGlobalListIdWithType(actionConfig.Options.GetString("type")) } actionConfig.Options["ipListIsDeleted"] = false } else { diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 648d4e46..b534d963 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -76,7 +76,7 @@ func (this *IPItemDAO) EnableIPItem(tx *dbs.Tx, id int64) error { } // DisableIPItem 禁用条目 -func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64, sourceUserId int64) error { +func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, itemId int64, sourceUserId int64) error { version, err := SharedIPListDAO.IncreaseVersion(tx) if err != nil { return err @@ -91,7 +91,7 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64, sourceUserId int64) e } _, err = query. - Pk(id). + Pk(itemId). Set("state", IPItemStateDisabled). Set("version", version). Update() @@ -99,7 +99,7 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64, sourceUserId int64) e if err != nil { return err } - return this.NotifyUpdate(tx, id) + return this.NotifyUpdate(tx, itemId) } // DisableIPItemsWithIP 禁用某个IP相关条目 @@ -390,7 +390,7 @@ func (this *IPItemDAO) CreateIPItem(tx *dbs.Tx, op.SourceUserId = userId } - var autoAdded = listId == firewallconfigs.GlobalListId || sourceNodeId > 0 || sourceServerId > 0 || sourceHTTPFirewallPolicyId > 0 + var autoAdded = firewallconfigs.IsGlobalListId(listId) || sourceNodeId > 0 || sourceServerId > 0 || sourceHTTPFirewallPolicyId > 0 if autoAdded { op.IsRead = 0 } @@ -477,7 +477,7 @@ func (this *IPItemDAO) CountIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUs State(IPItemStateEnabled). Attr("listId", listId) if sourceUserId > 0 { - if listId <= 0 || listId == firewallconfigs.GlobalListId { + if listId <= 0 || firewallconfigs.IsGlobalListId(listId) { query.Attr("sourceUserId", sourceUserId) } } @@ -503,7 +503,7 @@ func (this *IPItemDAO) ListIPItemsWithListId(tx *dbs.Tx, listId int64, sourceUse State(IPItemStateEnabled). Attr("listId", listId) if sourceUserId > 0 { - if listId <= 0 || listId == firewallconfigs.GlobalListId { + if listId <= 0 || firewallconfigs.IsGlobalListId(listId) { query.Attr("sourceUserId", sourceUserId) } } @@ -600,13 +600,25 @@ func (this *IPItemDAO) ExistsEnabledItem(tx *dbs.Tx, itemId int64) (bool, error) } // CountAllEnabledIPItems 计算数量 -func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string) (int64, error) { +func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, isGlobal bool) (int64, error) { var query = this.Query(tx) + var globalListIdStrings = strings.Join(firewallconfigs.FindGlobalListIdStrings(), ",") + if len(listType) > 0 { + var globalListId = firewallconfigs.FindGlobalListIdWithType(listType) + if globalListId > 0 { + globalListIdStrings = types.String(globalListId) + } + } + if sourceUserId > 0 { if listId <= 0 { - query.Where("((listId=" + types.String(firewallconfigs.GlobalListId) + " AND sourceUserId=:sourceUserId) OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE userId=:sourceUserId AND state=1))") + if isGlobal { + query.Where("(listId IN (" + globalListIdStrings + ") AND sourceUserId=:sourceUserId)") + } else { + query.Where("((listId IN (" + globalListIdStrings + ") AND sourceUserId=:sourceUserId) OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE userId=:sourceUserId AND state=1))") + } query.Param("sourceUserId", sourceUserId) - } else if listId == firewallconfigs.GlobalListId { + } else if firewallconfigs.IsGlobalListId(listId) { query.Attr("sourceUserId", sourceUserId) query.UseIndex("sourceUserId") } @@ -631,10 +643,18 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, ke query.Attr("listId", listId) } else { if len(listType) > 0 { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") + if isGlobal { + query.Where("(listId IN (" + globalListIdStrings + "))") + } else { + query.Where("(listId IN (" + globalListIdStrings + ") OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") + } query.Param("listType", listType) } else { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + if isGlobal { + query.Where("(listId IN (" + globalListIdStrings + "))") + } else { + query.Where("(listId IN (" + globalListIdStrings + ") OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + } } } if unread { @@ -652,13 +672,25 @@ func (this *IPItemDAO) CountAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, ke } // ListAllEnabledIPItems 搜索所有IP -func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (result []*IPItem, err error) { +func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, isGlobal bool, offset int64, size int64) (result []*IPItem, err error) { + var globalListIdStrings = strings.Join(firewallconfigs.FindGlobalListIdStrings(), ",") + if len(listType) > 0 { + var globalListId = firewallconfigs.FindGlobalListIdWithType(listType) + if globalListId > 0 { + globalListIdStrings = types.String(globalListId) + } + } + var query = this.Query(tx) if sourceUserId > 0 { if listId <= 0 { - query.Where("((listId=" + types.String(firewallconfigs.GlobalListId) + " AND sourceUserId=:sourceUserId) OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE userId=:sourceUserId AND state=1))") + if isGlobal { + query.Where("(listId IN (" + globalListIdStrings + ") AND sourceUserId=:sourceUserId)") + } else { + query.Where("((listId IN (" + globalListIdStrings + ") AND sourceUserId=:sourceUserId) OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE userId=:sourceUserId AND state=1))") + } query.Param("sourceUserId", sourceUserId) - } else if listId == firewallconfigs.GlobalListId { + } else if firewallconfigs.IsGlobalListId(listId) { query.Attr("sourceUserId", sourceUserId) query.UseIndex("sourceUserId") } @@ -683,10 +715,18 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, key query.Attr("listId", listId) } else { if len(listType) > 0 { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") + if isGlobal { + query.Where("(listId IN (" + globalListIdStrings + "))") + } else { + query.Where("(listId IN (" + globalListIdStrings + ") OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") + } query.Param("listType", listType) } else { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + if isGlobal { + query.Where("(listId IN (" + globalListIdStrings + "))") + } else { + query.Where("(listId IN (" + globalListIdStrings + ") OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + } } } if unread { @@ -709,12 +749,20 @@ func (this *IPItemDAO) ListAllEnabledIPItems(tx *dbs.Tx, sourceUserId int64, key // ListAllIPItemIds 搜索所有IP Id列表 func (this *IPItemDAO) ListAllIPItemIds(tx *dbs.Tx, sourceUserId int64, keyword string, ip string, listId int64, unread bool, eventLevel string, listType string, offset int64, size int64) (itemIds []int64, err error) { + var globalListIdStrings = strings.Join(firewallconfigs.FindGlobalListIdStrings(), ",") + if len(listType) > 0 { + var globalListId = firewallconfigs.FindGlobalListIdWithType(listType) + if globalListId > 0 { + globalListIdStrings = types.String(globalListId) + } + } + var query = this.Query(tx) if sourceUserId > 0 { if listId <= 0 { - query.Where("((listId=" + types.String(firewallconfigs.GlobalListId) + " AND sourceUserId=:sourceUserId) OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE userId=:sourceUserId AND state=1))") + query.Where("((listId IN (" + globalListIdStrings + ") AND sourceUserId=:sourceUserId) OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE userId=:sourceUserId AND state=1))") query.Param("sourceUserId", sourceUserId) - } else if listId == firewallconfigs.GlobalListId { + } else if firewallconfigs.IsGlobalListId(listId) { query.Attr("sourceUserId", sourceUserId) query.UseIndex("sourceUserId") } @@ -733,10 +781,10 @@ func (this *IPItemDAO) ListAllIPItemIds(tx *dbs.Tx, sourceUserId int64, keyword query.Attr("listId", listId) } else { if len(listType) > 0 { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") + query.Where("(listId IN (" + globalListIdStrings + ") OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1 AND type=:listType))") query.Param("listType", listType) } else { - query.Where("(listId=" + types.String(firewallconfigs.GlobalListId) + " OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") + query.Where("(listId IN (" + globalListIdStrings + ") OR listId IN (SELECT id FROM " + SharedIPListDAO.Table + " WHERE state=1))") } } if unread { @@ -888,7 +936,7 @@ func (this *IPItemDAO) NotifyUpdate(tx *dbs.Tx, itemId int64) error { return nil } - if listId == firewallconfigs.GlobalListId { + if firewallconfigs.IsGlobalListId(listId) { sourceNodeId, err := this.Query(tx). Pk(itemId). Result("sourceNodeId"). diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index 990b8546..c97e7207 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -22,8 +22,8 @@ const ( ) var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList -var DefaultGlobalIPList = &IPList{ - Id: uint32(firewallconfigs.GlobalListId), +var DefaultGlobalBlackIPList = &IPList{ + Id: uint32(firewallconfigs.GlobalBlackListId), Name: "系统黑名单", IsPublic: true, IsGlobal: true, @@ -32,6 +32,26 @@ var DefaultGlobalIPList = &IPList{ IsOn: true, } +var DefaultGlobalWhiteIPList = &IPList{ + Id: uint32(firewallconfigs.GlobalWhiteListId), + Name: "系统白名单", + IsPublic: true, + IsGlobal: true, + Type: "white", + State: IPListStateEnabled, + IsOn: true, +} + +var DefaultGlobalGreyIPList = &IPList{ + Id: uint32(firewallconfigs.GlobalGreyListId), + Name: "系统灰名单", + IsPublic: true, + IsGlobal: true, + Type: "grey", + State: IPListStateEnabled, + IsOn: true, +} + var ipListCodeRegexp = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) type IPListDAO dbs.DAO @@ -79,8 +99,9 @@ func (this *IPListDAO) DisableIPList(tx *dbs.Tx, listId int64) error { // FindEnabledIPList 查找启用中的条目 func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64, cacheMap *utils.CacheMap) (*IPList, error) { - if id == firewallconfigs.GlobalListId { - return DefaultGlobalIPList, nil + globalList, ok := this.findGlobalList(id) + if ok { + return globalList, nil } var cacheKey = this.Table + ":FindEnabledIPList:" + types.String(id) @@ -116,9 +137,9 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) { // FindIPListCacheable 获取名单 func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, error) { - // 全局黑名单 - if listId == firewallconfigs.GlobalListId { - return DefaultGlobalIPList, nil + globalList, ok := this.findGlobalList(listId) + if ok { + return globalList, nil } // 检查缓存 @@ -165,7 +186,21 @@ func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, serverId int64, li if err != nil { return 0, err } - return types.Int64(op.Id), nil + var newListId = types.Int64(op.Id) + + // 防止和全局名单ID冲突 + if lists.ContainsInt64(firewallconfigs.FindGlobalListIds(), newListId) { + // 先删除 + err = this.Query(tx).Pk(newListId).DeleteQuickly() + if err != nil { + return 0, err + } + + // 自动创建下一个 + return this.CreateIPList(tx, userId, serverId, listType, name, code, timeoutJSON, description, isPublic, isGlobal) + } + + return newListId, nil } // UpdateIPList 修改名单 @@ -372,3 +407,17 @@ func (this *IPListDAO) FindIPListIdWithCode(tx *dbs.Tx, listCode string) (int64, func (this *IPListDAO) ValidateIPListCode(code string) bool { return ipListCodeRegexp.MatchString(code) } + +// 查找ID对应的全局名单 +func (this *IPListDAO) findGlobalList(id int64) (list *IPList, ok bool) { + switch id { + case firewallconfigs.GlobalBlackListId: + return DefaultGlobalBlackIPList, true + case firewallconfigs.GlobalWhiteListId: + return DefaultGlobalWhiteIPList, true + case firewallconfigs.GlobalGreyListId: + return DefaultGlobalGreyIPList, true + } + + return +} diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 0c2ee956..53536462 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -9,6 +9,7 @@ import ( "github.com/TeaOSLab/EdgeCommon/pkg/iputils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs" "net" "time" ) @@ -255,7 +256,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte // 使用value删除 if len(req.Value) > 0 { // 检查IP列表 - if req.IpListId > 0 && userId > 0 && req.IpListId != firewallconfigs.GlobalListId { + if req.IpListId > 0 && userId > 0 && !firewallconfigs.IsGlobalListId(req.IpListId) { err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) if err != nil { return nil, err @@ -272,7 +273,7 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte // 如果是使用ipFrom+ipTo删除 if len(req.IpFrom) > 0 { // 检查IP列表 - if req.IpListId > 0 && userId > 0 && req.IpListId != firewallconfigs.GlobalListId { + if req.IpListId > 0 && userId > 0 && !firewallconfigs.IsGlobalListId(req.IpListId) { err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) if err != nil { return nil, err @@ -318,7 +319,7 @@ func (this *IPItemService) CountIPItemsWithListId(ctx context.Context, req *pb.C if userId > 0 { // 检查用户所属名单 - if req.IpListId != firewallconfigs.GlobalListId { + if !firewallconfigs.IsGlobalListId(req.IpListId) { err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) if err != nil { return nil, err @@ -345,7 +346,7 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li if userId > 0 { // 检查用户所属名单 - if req.IpListId != firewallconfigs.GlobalListId { + if !firewallconfigs.IsGlobalListId(req.IpListId) { err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) if err != nil { return nil, err @@ -357,7 +358,7 @@ func (this *IPItemService) ListIPItemsWithListId(ctx context.Context, req *pb.Li if err != nil { return nil, err } - result := []*pb.IPItem{} + var result = []*pb.IPItem{} for _, item := range items { if len(item.Type) == 0 { item.Type = models.IPItemTypeIPv4 @@ -502,12 +503,17 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb. var tx = this.NullTx() - result := []*pb.IPItem{} + var result = []*pb.IPItem{} items, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(tx, req.Version, req.Size) if err != nil { return nil, err } + + var latestVersion = req.Version + for _, item := range items { + latestVersion = int64(item.Version) + // 是否已过期 if item.ExpiredAt > 0 && int64(item.ExpiredAt) <= time.Now().Unix() { item.State = models.IPItemStateDisabled @@ -526,6 +532,11 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb. continue } + // 跳过灰名单 + if list.Type == ipconfigs.IPListTypeGrey { + continue + } + // 如果已经删除 if list.State != models.IPListStateEnabled { item.State = models.IPItemStateDisabled @@ -551,7 +562,10 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb. }) } - return &pb.ListIPItemsAfterVersionResponse{IpItems: result}, nil + return &pb.ListIPItemsAfterVersionResponse{ + IpItems: result, + Version: latestVersion, + }, nil } // CheckIPItemStatus 检查IP状态 @@ -646,11 +660,7 @@ func (this *IPItemService) CountAllEnabledIPItems(ctx context.Context, req *pb.C } var tx = this.NullTx() - var listId int64 = 0 - if req.GlobalOnly { - listId = firewallconfigs.GlobalListId - } - count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType) + count, err := models.SharedIPItemDAO.CountAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, 0, req.Unread, req.EventLevel, req.ListType, req.GlobalOnly) if err != nil { return nil, err } @@ -670,11 +680,7 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li var results = []*pb.ListAllEnabledIPItemsResponse_Result{} var tx = this.NullTx() - var listId int64 = 0 - if req.GlobalOnly { - listId = firewallconfigs.GlobalListId - } - items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size) + items, err := models.SharedIPItemDAO.ListAllEnabledIPItems(tx, userId, req.Keyword, req.Ip, 0, req.Unread, req.EventLevel, req.ListType, req.GlobalOnly, req.Offset, req.Size) if err != nil { return nil, err } @@ -798,10 +804,10 @@ func (this *IPItemService) ListAllEnabledIPItems(ctx context.Context, req *pb.Li IsGlobal: list.IsGlobal, } - // 所属服务(注意同SourceServer不同) + // 所属服务(注意与SourceServer不同) var pbFirewallServer *pb.Server - // 所属策略(注意同SourceHTTPFirewallPolicy不同) + // 所属策略(注意与SourceHTTPFirewallPolicy不同) var pbFirewallPolicy *pb.HTTPFirewallPolicy if !list.IsPublic { policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyWithIPListId(tx, int64(list.Id)) @@ -859,11 +865,7 @@ func (this *IPItemService) ListAllIPItemIds(ctx context.Context, req *pb.ListAll } var tx = this.NullTx() - var listId int64 = 0 - if req.GlobalOnly { - listId = firewallconfigs.GlobalListId - } - itemIds, err := models.SharedIPItemDAO.ListAllIPItemIds(tx, userId, req.Keyword, req.Ip, listId, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size) + itemIds, err := models.SharedIPItemDAO.ListAllIPItemIds(tx, userId, req.Keyword, req.Ip, 0, req.Unread, req.EventLevel, req.ListType, req.Offset, req.Size) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go index 6f7fda8a..d9881510 100644 --- a/internal/rpc/services/service_ip_list.go +++ b/internal/rpc/services/service_ip_list.go @@ -27,11 +27,12 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis var tx = this.NullTx() // 修正默认的代号 - if req.Code == "white" || req.Code == "black" { + if req.Code == "white" || req.Code == "black" || req.Code == "grey" { req.Code = req.Code + "-" + rands.HexString(8) } // 检查用户相关信息 + var sourceUserId = userId if userId > 0 { // 检查网站ID if req.ServerId > 0 { @@ -40,6 +41,11 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis return nil, err } } + } else if req.ServerId > 0 { + sourceUserId, err = models.SharedServerDAO.FindServerUserId(tx, req.ServerId) + if err != nil { + return nil, err + } } // 检查代号 @@ -57,7 +63,7 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis } } - listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.ServerId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal) + listId, err := models.SharedIPListDAO.CreateIPList(tx, sourceUserId, req.ServerId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal) if err != nil { return nil, err } @@ -107,7 +113,7 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn var tx = this.NullTx() if userId > 0 { // 检查用户所属名单 - if req.IpListId != firewallconfigs.GlobalListId { + if !firewallconfigs.IsGlobalListId(req.IpListId) { err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) if err != nil { return nil, err