IP名单增加是否全局

This commit is contained in:
GoEdgeLab
2021-11-17 16:14:55 +08:00
parent dfb4e6a155
commit 09cfc13c7e
6 changed files with 85 additions and 30 deletions

View File

@@ -77,17 +77,37 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
// DisableIPItemsWithListId 禁用某个IP名单内的所有IP // DisableIPItemsWithListId 禁用某个IP名单内的所有IP
func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error { func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error {
version, err := SharedIPListDAO.IncreaseVersion(tx) for {
if err != nil { ones, err := this.Query(tx).
return err ResultPk().
Attr("listId", listId).
State(IPItemStateEnabled).
Limit(1000).
FindAll()
if err != nil {
return err
}
if len(ones) == 0 {
break
}
for _, one := range ones {
var itemId = one.(*IPItem).Id
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
}
err = this.Query(tx).
Pk(itemId).
State(IPItemStateEnabled).
Set("version", version).
Set("state", IPItemStateDisabled).
UpdateQuickly()
if err != nil {
return err
}
}
} }
return nil
return this.Query(tx).
Attr("listId", listId).
State(IPItemStateEnabled).
Set("version", version).
Set("state", IPItemStateDisabled).
UpdateQuickly()
} }
// FindEnabledIPItem 查找启用中的条目 // FindEnabledIPItem 查找启用中的条目

View File

@@ -16,3 +16,14 @@ func TestIPItemDAO_NotifyClustersUpdate(t *testing.T) {
} }
t.Log("ok") t.Log("ok")
} }
func TestIPItemDAO_DisableIPItemsWithListId(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
err := SharedIPItemDAO.DisableIPItemsWithListId(tx, 67)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -3,6 +3,7 @@ package models
import ( import (
"github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
@@ -16,7 +17,7 @@ const (
IPListStateDisabled = 0 // 已禁用 IPListStateDisabled = 0 // 已禁用
) )
var listTypeCacheMap = map[int64]string{} // listId => type var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList
type IPListDAO dbs.DAO type IPListDAO dbs.DAO
@@ -77,38 +78,46 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) {
FindStringCol("") FindStringCol("")
} }
// FindIPListTypeCacheable 获取名单类型 // FindIPListCacheable 获取名单
func (this *IPListDAO) FindIPListTypeCacheable(tx *dbs.Tx, listId int64) (string, error) { func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, error) {
// 全局黑名单
if listId == firewallconfigs.GlobalListId {
return &IPList{
Id: uint32(listId),
IsPublic: 1,
IsGlobal: 1,
Type: "black",
State: IPListStateEnabled,
IsOn: 1,
}, nil
}
// 检查缓存 // 检查缓存
SharedCacheLocker.RLock() SharedCacheLocker.RLock()
listType, ok := listTypeCacheMap[listId] list, ok := listTypeCacheMap[listId]
SharedCacheLocker.RUnlock() SharedCacheLocker.RUnlock()
if ok { if ok {
return listType, nil return list, nil
} }
listType, err := this.Query(tx). one, err := this.Query(tx).
Pk(listId). Pk(listId).
Result("type"). Result("isGlobal", "type", "state", "id", "isPublic", "isGlobal").
FindStringCol("") Find()
if err != nil { if err != nil || one == nil {
return "", err return nil, err
}
if len(listType) == 0 {
return "", nil
} }
// 保存缓存 // 保存缓存
SharedCacheLocker.Lock() SharedCacheLocker.Lock()
listTypeCacheMap[listId] = listType listTypeCacheMap[listId] = one.(*IPList)
SharedCacheLocker.Unlock() SharedCacheLocker.Unlock()
return listType, nil return one.(*IPList), nil
} }
// CreateIPList 创建名单 // CreateIPList 创建名单
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool) (int64, error) { func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) {
op := NewIPListOperator() op := NewIPListOperator()
op.IsOn = true op.IsOn = true
op.UserId = userId op.UserId = userId
@@ -121,6 +130,7 @@ func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs
} }
op.Description = description op.Description = description
op.IsPublic = isPublic op.IsPublic = isPublic
op.IsGlobal = isGlobal
err := this.Save(tx, op) err := this.Save(tx, op)
if err != nil { if err != nil {
return 0, err return 0, err

View File

@@ -15,6 +15,7 @@ type IPList struct {
Actions string `field:"actions"` // IP触发的动作 Actions string `field:"actions"` // IP触发的动作
Description string `field:"description"` // 描述 Description string `field:"description"` // 描述
IsPublic uint8 `field:"isPublic"` // 是否公用 IsPublic uint8 `field:"isPublic"` // 是否公用
IsGlobal uint8 `field:"isGlobal"` // 是否全局
} }
type IPListOperator struct { type IPListOperator struct {
@@ -31,6 +32,7 @@ type IPListOperator struct {
Actions interface{} // IP触发的动作 Actions interface{} // IP触发的动作
Description interface{} // 描述 Description interface{} // 描述
IsPublic interface{} // 是否公用 IsPublic interface{} // 是否公用
IsGlobal interface{} // 是否全局
} }
func NewIPListOperator() *IPListOperator { func NewIPListOperator() *IPListOperator {

View File

@@ -332,10 +332,18 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.
} }
// List类型 // List类型
listType, err := models.SharedIPListDAO.FindIPListTypeCacheable(tx, int64(item.ListId)) list, err := models.SharedIPListDAO.FindIPListCacheable(tx, int64(item.ListId))
if err != nil { if err != nil {
return nil, err return nil, err
} }
if list == nil {
continue
}
// 如果已经删除
if list.State != models.IPListStateEnabled {
item.State = models.IPItemStateDisabled
}
result = append(result, &pb.IPItem{ result = append(result, &pb.IPItem{
Id: int64(item.Id), Id: int64(item.Id),
@@ -349,7 +357,8 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.
IsDeleted: item.State == 0, IsDeleted: item.State == 0,
Type: item.Type, Type: item.Type,
EventLevel: item.EventLevel, EventLevel: item.EventLevel,
ListType: listType, ListType: list.Type,
IsGlobal: list.IsPublic == 1 && list.IsGlobal == 1,
NodeId: int64(item.NodeId), NodeId: int64(item.NodeId),
ServerId: int64(item.ServerId), ServerId: int64(item.ServerId),
}) })

View File

@@ -22,7 +22,7 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis
tx := this.NullTx() tx := this.NullTx()
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic) listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -71,6 +71,7 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn
Code: list.Code, Code: list.Code,
TimeoutJSON: []byte(list.Timeout), TimeoutJSON: []byte(list.Timeout),
Description: list.Description, Description: list.Description,
IsGlobal: list.IsGlobal == 1,
}}, nil }}, nil
} }
@@ -112,6 +113,7 @@ func (this *IPListService) ListEnabledIPLists(ctx context.Context, req *pb.ListE
TimeoutJSON: []byte(list.Timeout), TimeoutJSON: []byte(list.Timeout),
IsPublic: list.IsPublic == 1, IsPublic: list.IsPublic == 1,
Description: list.Description, Description: list.Description,
IsGlobal: list.IsGlobal == 1,
}) })
} }
return &pb.ListEnabledIPListsResponse{IpLists: pbLists}, nil return &pb.ListEnabledIPListsResponse{IpLists: pbLists}, nil
@@ -191,6 +193,7 @@ func (this *IPListService) FindEnabledIPListContainsIP(ctx context.Context, req
Name: list.Name, Name: list.Name,
Code: list.Code, Code: list.Code,
IsPublic: list.IsPublic == 1, IsPublic: list.IsPublic == 1,
IsGlobal: list.IsGlobal == 1,
Description: "", Description: "",
}) })