IP名单增加是否全局

This commit is contained in:
GoEdgeLab
2021-11-17 16:14:55 +08:00
parent dfb4e6a155
commit 09cfc13c7e
6 changed files with 85 additions and 30 deletions

View File

@@ -77,17 +77,37 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
// DisableIPItemsWithListId 禁用某个IP名单内的所有IP
func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error {
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
for {
ones, err := this.Query(tx).
ResultPk().
Attr("listId", listId).
State(IPItemStateEnabled).
Limit(1000).
FindAll()
if err != nil {
return err
}
if len(ones) == 0 {
break
}
for _, one := range ones {
var itemId = one.(*IPItem).Id
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
}
err = this.Query(tx).
Pk(itemId).
State(IPItemStateEnabled).
Set("version", version).
Set("state", IPItemStateDisabled).
UpdateQuickly()
if err != nil {
return err
}
}
}
return this.Query(tx).
Attr("listId", listId).
State(IPItemStateEnabled).
Set("version", version).
Set("state", IPItemStateDisabled).
UpdateQuickly()
return nil
}
// FindEnabledIPItem 查找启用中的条目

View File

@@ -16,3 +16,14 @@ func TestIPItemDAO_NotifyClustersUpdate(t *testing.T) {
}
t.Log("ok")
}
func TestIPItemDAO_DisableIPItemsWithListId(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
err := SharedIPItemDAO.DisableIPItemsWithListId(tx, 67)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -3,6 +3,7 @@ package models
import (
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
@@ -16,7 +17,7 @@ const (
IPListStateDisabled = 0 // 已禁用
)
var listTypeCacheMap = map[int64]string{} // listId => type
var listTypeCacheMap = map[int64]*IPList{} // listId => *IPList
type IPListDAO dbs.DAO
@@ -77,38 +78,46 @@ func (this *IPListDAO) FindIPListName(tx *dbs.Tx, id int64) (string, error) {
FindStringCol("")
}
// FindIPListTypeCacheable 获取名单类型
func (this *IPListDAO) FindIPListTypeCacheable(tx *dbs.Tx, listId int64) (string, error) {
// FindIPListCacheable 获取名单
func (this *IPListDAO) FindIPListCacheable(tx *dbs.Tx, listId int64) (*IPList, error) {
// 全局黑名单
if listId == firewallconfigs.GlobalListId {
return &IPList{
Id: uint32(listId),
IsPublic: 1,
IsGlobal: 1,
Type: "black",
State: IPListStateEnabled,
IsOn: 1,
}, nil
}
// 检查缓存
SharedCacheLocker.RLock()
listType, ok := listTypeCacheMap[listId]
list, ok := listTypeCacheMap[listId]
SharedCacheLocker.RUnlock()
if ok {
return listType, nil
return list, nil
}
listType, err := this.Query(tx).
one, err := this.Query(tx).
Pk(listId).
Result("type").
FindStringCol("")
if err != nil {
return "", err
}
if len(listType) == 0 {
return "", nil
Result("isGlobal", "type", "state", "id", "isPublic", "isGlobal").
Find()
if err != nil || one == nil {
return nil, err
}
// 保存缓存
SharedCacheLocker.Lock()
listTypeCacheMap[listId] = listType
listTypeCacheMap[listId] = one.(*IPList)
SharedCacheLocker.Unlock()
return listType, nil
return one.(*IPList), nil
}
// CreateIPList 创建名单
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool) (int64, error) {
func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs.IPListType, name string, code string, timeoutJSON []byte, description string, isPublic bool, isGlobal bool) (int64, error) {
op := NewIPListOperator()
op.IsOn = true
op.UserId = userId
@@ -121,6 +130,7 @@ func (this *IPListDAO) CreateIPList(tx *dbs.Tx, userId int64, listType ipconfigs
}
op.Description = description
op.IsPublic = isPublic
op.IsGlobal = isGlobal
err := this.Save(tx, op)
if err != nil {
return 0, err

View File

@@ -15,6 +15,7 @@ type IPList struct {
Actions string `field:"actions"` // IP触发的动作
Description string `field:"description"` // 描述
IsPublic uint8 `field:"isPublic"` // 是否公用
IsGlobal uint8 `field:"isGlobal"` // 是否全局
}
type IPListOperator struct {
@@ -31,6 +32,7 @@ type IPListOperator struct {
Actions interface{} // IP触发的动作
Description interface{} // 描述
IsPublic interface{} // 是否公用
IsGlobal interface{} // 是否全局
}
func NewIPListOperator() *IPListOperator {

View File

@@ -332,10 +332,18 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.
}
// List类型
listType, err := models.SharedIPListDAO.FindIPListTypeCacheable(tx, int64(item.ListId))
list, err := models.SharedIPListDAO.FindIPListCacheable(tx, int64(item.ListId))
if err != nil {
return nil, err
}
if list == nil {
continue
}
// 如果已经删除
if list.State != models.IPListStateEnabled {
item.State = models.IPItemStateDisabled
}
result = append(result, &pb.IPItem{
Id: int64(item.Id),
@@ -349,7 +357,8 @@ func (this *IPItemService) ListIPItemsAfterVersion(ctx context.Context, req *pb.
IsDeleted: item.State == 0,
Type: item.Type,
EventLevel: item.EventLevel,
ListType: listType,
ListType: list.Type,
IsGlobal: list.IsPublic == 1 && list.IsGlobal == 1,
NodeId: int64(item.NodeId),
ServerId: int64(item.ServerId),
})

View File

@@ -22,7 +22,7 @@ func (this *IPListService) CreateIPList(ctx context.Context, req *pb.CreateIPLis
tx := this.NullTx()
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic)
listId, err := models.SharedIPListDAO.CreateIPList(tx, userId, req.Type, req.Name, req.Code, req.TimeoutJSON, req.Description, req.IsPublic, req.IsGlobal)
if err != nil {
return nil, err
}
@@ -71,6 +71,7 @@ func (this *IPListService) FindEnabledIPList(ctx context.Context, req *pb.FindEn
Code: list.Code,
TimeoutJSON: []byte(list.Timeout),
Description: list.Description,
IsGlobal: list.IsGlobal == 1,
}}, nil
}
@@ -112,6 +113,7 @@ func (this *IPListService) ListEnabledIPLists(ctx context.Context, req *pb.ListE
TimeoutJSON: []byte(list.Timeout),
IsPublic: list.IsPublic == 1,
Description: list.Description,
IsGlobal: list.IsGlobal == 1,
})
}
return &pb.ListEnabledIPListsResponse{IpLists: pbLists}, nil
@@ -191,6 +193,7 @@ func (this *IPListService) FindEnabledIPListContainsIP(ctx context.Context, req
Name: list.Name,
Code: list.Code,
IsPublic: list.IsPublic == 1,
IsGlobal: list.IsGlobal == 1,
Description: "",
})