增加通过IP删除IP名单中IP参数

This commit is contained in:
GoEdgeLab
2022-05-10 15:11:48 +08:00
parent 3d0ddeb129
commit 7c32617d08
3 changed files with 111 additions and 15 deletions

View File

@@ -93,6 +93,71 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
return this.NotifyUpdate(tx, id)
}
// DisableIPItemsWithIP 禁用某个IP相关条目
func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, userId int64, listId int64) error {
if len(ipFrom) == 0 {
return errors.New("invalid 'ipFrom'")
}
var query = this.Query(tx).
Result("id", "listId").
Attr("ipFrom", ipFrom).
Attr("ipTo", ipTo).
State(IPItemStateEnabled)
if listId > 0 {
if userId > 0 {
err := SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return err
}
}
query.Attr("listId", listId)
}
ones, err := query.FindAll()
if err != nil {
return err
}
var itemIds = []int64{}
for _, one := range ones {
var item = one.(*IPItem)
var itemId = int64(item.Id)
var itemListId = int64(item.ListId)
if itemListId != listId && userId > 0 {
err = SharedIPListDAO.CheckUserIPList(tx, userId, itemListId)
if err != nil {
// ignore error
continue
}
}
itemIds = append(itemIds, itemId)
}
for _, itemId := range itemIds {
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
}
_, err = this.Query(tx).
Pk(itemId).
Set("state", IPItemStateDisabled).
Set("version", version).
Update()
if err != nil {
return err
}
}
if len(itemIds) > 0 {
return this.NotifyUpdate(tx, itemIds[len(itemIds)-1])
}
return nil
}
// DisableIPItemsWithListId 禁用某个IP名单内的所有IP
func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error {
for {

View File

@@ -1,6 +1,7 @@
package models
package models_test
import (
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs"
@@ -14,7 +15,7 @@ func TestIPItemDAO_NotifyClustersUpdate(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
err := SharedIPItemDAO.NotifyUpdate(tx, 28)
err := models.SharedIPItemDAO.NotifyUpdate(tx, 28)
if err != nil {
t.Fatal(err)
}
@@ -25,7 +26,7 @@ func TestIPItemDAO_DisableIPItemsWithListId(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
err := SharedIPItemDAO.DisableIPItemsWithListId(tx, 67)
err := models.SharedIPItemDAO.DisableIPItemsWithListId(tx, 67)
if err != nil {
t.Fatal(err)
}
@@ -36,7 +37,7 @@ func TestIPItemDAO_ListIPItemsAfterVersion(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
_, err := SharedIPItemDAO.ListIPItemsAfterVersion(tx, 0, 100)
_, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(tx, 0, 100)
if err != nil {
t.Fatal(err)
}
@@ -47,10 +48,10 @@ func TestIPItemDAO_CreateManyIPs(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
var dao = NewIPItemDAO()
var dao = models.NewIPItemDAO()
var n = 10
for i := 0; i < n; i++ {
itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, "192."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)), "", time.Now().Unix()+86400, "test", IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0)
itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, "192."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)), "", time.Now().Unix()+86400, "test", models.IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0)
if err != nil {
t.Fatal(err)
}
@@ -62,3 +63,14 @@ func TestIPItemDAO_CreateManyIPs(t *testing.T) {
}
t.Log("ok")
}
func TestIPItemDAO_DisableIPItemsWithIP(t *testing.T) {
dbs.NotifyReady()
var tx *dbs.Tx
err := models.SharedIPItemDAO.DisableIPItemsWithIP(tx, "192.168.1.100", "", 0, 0)
if err != nil {
t.Fatal(err)
}
t.Log("ok")
}

View File

@@ -114,23 +114,42 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte
return nil, err
}
tx := this.NullTx()
var tx = this.NullTx()
if userId > 0 {
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
// 如果是使用IPItemId删除
if req.IpItemId > 0 {
if userId > 0 {
listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId)
if err != nil {
return nil, err
}
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
if err != nil {
return nil, err
}
}
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId)
err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId)
if err != nil {
return nil, err
// 如果是使用ipFrom+ipTo删除
if len(req.IpFrom) > 0 {
// 检查IP列表
if req.IpListId > 0 && userId > 0 {
err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId)
if err != nil {
return nil, err
}
}
err = models.SharedIPItemDAO.DisableIPItemsWithIP(tx, req.IpFrom, req.IpTo, userId, req.IpListId)
if err != nil {
return nil, err
}
}
return this.Success()