diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 3c912e5d..d3fb6871 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -77,17 +77,37 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error { // DisableIPItemsWithListId 禁用某个IP名单内的所有IP func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error { - version, err := SharedIPListDAO.IncreaseVersion(tx) - if err != nil { - return err + for { + ones, err := this.Query(tx). + ResultPk(). + Attr("listId", listId). + State(IPItemStateEnabled). + Limit(1000). + FindAll() + if err != nil { + return err + } + if len(ones) == 0 { + break + } + for _, one := range ones { + var itemId = one.(*IPItem).Id + version, err := SharedIPListDAO.IncreaseVersion(tx) + if err != nil { + return err + } + err = this.Query(tx). + Pk(itemId). + State(IPItemStateEnabled). + Set("version", version). + Set("state", IPItemStateDisabled). + UpdateQuickly() + if err != nil { + return err + } + } } - - return this.Query(tx). - Attr("listId", listId). - State(IPItemStateEnabled). - Set("version", version). - Set("state", IPItemStateDisabled). - UpdateQuickly() + return nil } // FindEnabledIPItem 查找启用中的条目 diff --git a/internal/db/models/ip_item_dao_test.go b/internal/db/models/ip_item_dao_test.go index 5c0a59f2..05cdcc57 100644 --- a/internal/db/models/ip_item_dao_test.go +++ b/internal/db/models/ip_item_dao_test.go @@ -16,3 +16,14 @@ func TestIPItemDAO_NotifyClustersUpdate(t *testing.T) { } t.Log("ok") } + +func TestIPItemDAO_DisableIPItemsWithListId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + err := SharedIPItemDAO.DisableIPItemsWithListId(tx, 67) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index 541ecdc7..e5c00e6c 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -3,6 +3,7 @@ package models import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" @@ -16,7 +17,7 @@ const ( IPListStateDisabled = 0 // 已禁用 ) -var listTypeCacheMap = map[int64]string{} // listId => type +var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList type IPListDAO dbs.DAO @@ -77,38 +78,46 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) { FindStringCol("") } -// FindIPListTypeCacheable 获取名单类型 -func (this *IPListDAO) FindIPListTypeCacheable(tx *dbs.Tx, listId int64) (string, error) { +// FindIPListCacheable 获取名单 +func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, error) { + // 全局黑名单 + if listId == firewallconfigs.GlobalListId { + return &IPList{ + Id: uint32(listId), + IsPublic: 1, + IsGlobal: 1, + Type: "black", + State: IPListStateEnabled, + IsOn: 1, + }, nil + } + // 检查缓存 SharedCacheLocker.RLock() - listType, ok := listTypeCacheMap[listId] + list, ok := listTypeCacheMap[listId] SharedCacheLocker.RUnlock() if ok { - return listType, nil + return list, nil } - listType, err := this.Query(tx). + one, err := this.Query(tx). Pk(listId). - Result("type"). - FindStringCol("") - if err != nil { - return "", err - } - - if len(listType) == 0 { - return "", nil + Result("isGlobal", "type", "state", "id", "isPublic", "isGlobal"). + Find() + if err != nil || one == nil { + return nil, err } // 保存缓存 SharedCacheLocker.Lock() - listTypeCacheMap[listId] = listType + listTypeCacheMap[listId] = one.(*IPList) SharedCacheLocker.Unlock() - return listType, nil + return one.(*IPList), nil } // CreateIPList 创建名单 -func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool) (int64, error) { +func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) { op := NewIPListOperator() op.IsOn = true op.UserId = userId @@ -121,6 +130,7 @@ func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs } op.Description = description op.IsPublic = isPublic + op.IsGlobal = isGlobal err := this.Save(tx, op) if err != nil { return 0, err diff --git a/internal/db/models/ip_list_model.go b/internal/db/models/ip_list_model.go index c1115789..7d39a986 100644 --- a/internal/db/models/ip_list_model.go +++ b/internal/db/models/ip_list_model.go @@ -15,6 +15,7 @@ type IPList struct { Actions string `field:"actions"` // IP触发的动作 Description string `field:"description"` // 描述 IsPublic uint8 `field:"isPublic"` // 是否公用 + IsGlobal uint8 `field:"isGlobal"` // 是否全局 } type IPListOperator struct { @@ -31,6 +32,7 @@ type IPListOperator struct { Actions interface{} // IP触发的动作 Description interface{} // 描述 IsPublic interface{} // 是否公用 + IsGlobal interface{} // 是否全局 } func NewIPListOperator() *IPListOperator { diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 016b2cb6..305d4b1e 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -332,10 +332,18 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb. } // List类型 - listType, err := models.SharedIPListDAO.FindIPListTypeCacheable(tx, int64(item.ListId)) + list, err := models.SharedIPListDAO.FindIPListCacheable(tx, int64(item.ListId)) if err != nil { return nil, err } + if list == nil { + continue + } + + // 如果已经删除 + if list.State != models.IPListStateEnabled { + item.State = models.IPItemStateDisabled + } result = append(result, &pb.IPItem{ Id: int64(item.Id), @@ -349,7 +357,8 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb. IsDeleted: item.State == 0, Type: item.Type, EventLevel: item.EventLevel, - ListType: listType, + ListType: list.Type, + IsGlobal: list.IsPublic == 1 && list.IsGlobal == 1, NodeId: int64(item.NodeId), ServerId: int64(item.ServerId), }) diff --git a/internal/rpc/services/service_ip_list.go b/internal/rpc/services/service_ip_list.go index 475f238e..deb11e75 100644 --- a/internal/rpc/services/service_ip_list.go +++ b/internal/rpc/services/service_ip_list.go @@ -22,7 +22,7 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis tx := this.NullTx() - listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic) + listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal) if err != nil { return nil, err } @@ -71,6 +71,7 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn Code: list.Code, TimeoutJSON: []byte(list.Timeout), Description: list.Description, + IsGlobal: list.IsGlobal == 1, }}, nil } @@ -112,6 +113,7 @@ func (this *IPListService) ListEnabledIPLists(ctx context.Context, req *pb.ListE TimeoutJSON: []byte(list.Timeout), IsPublic: list.IsPublic == 1, Description: list.Description, + IsGlobal: list.IsGlobal == 1, }) } return &pb.ListEnabledIPListsResponse{IpLists: pbLists}, nil @@ -191,6 +193,7 @@ func (this *IPListService) FindEnabledIPListContainsIP(ctx context.Context, req Name: list.Name, Code: list.Code, IsPublic: list.IsPublic == 1, + IsGlobal: list.IsGlobal == 1, Description: "", })