diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index a4261a36..cdb36954 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -93,6 +93,71 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error { return this.NotifyUpdate(tx, id) } +// DisableIPItemsWithIP 禁用某个IP相关条目 +func (this *IPItemDAO) DisableIPItemsWithIP(tx *dbs.Tx, ipFrom string, ipTo string, userId int64, listId int64) error { + if len(ipFrom) == 0 { + return errors.New("invalid 'ipFrom'") + } + + var query = this.Query(tx). + Result("id", "listId"). + Attr("ipFrom", ipFrom). + Attr("ipTo", ipTo). + State(IPItemStateEnabled) + + if listId > 0 { + if userId > 0 { + err := SharedIPListDAO.CheckUserIPList(tx, userId, listId) + if err != nil { + return err + } + } + + query.Attr("listId", listId) + } + + ones, err := query.FindAll() + if err != nil { + return err + } + + var itemIds = []int64{} + for _, one := range ones { + var item = one.(*IPItem) + var itemId = int64(item.Id) + var itemListId = int64(item.ListId) + if itemListId != listId && userId > 0 { + err = SharedIPListDAO.CheckUserIPList(tx, userId, itemListId) + if err != nil { + // ignore error + continue + } + } + itemIds = append(itemIds, itemId) + } + + for _, itemId := range itemIds { + version, err := SharedIPListDAO.IncreaseVersion(tx) + if err != nil { + return err + } + + _, err = this.Query(tx). + Pk(itemId). + Set("state", IPItemStateDisabled). + Set("version", version). + Update() + if err != nil { + return err + } + } + + if len(itemIds) > 0 { + return this.NotifyUpdate(tx, itemIds[len(itemIds)-1]) + } + return nil +} + // DisableIPItemsWithListId 禁用某个IP名单内的所有IP func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error { for { diff --git a/internal/db/models/ip_item_dao_test.go b/internal/db/models/ip_item_dao_test.go index e1a92ca2..1ce17501 100644 --- a/internal/db/models/ip_item_dao_test.go +++ b/internal/db/models/ip_item_dao_test.go @@ -1,6 +1,7 @@ -package models +package models_test import ( + "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/dbs" @@ -14,7 +15,7 @@ func TestIPItemDAO_NotifyClustersUpdate(t *testing.T) { dbs.NotifyReady() var tx *dbs.Tx - err := SharedIPItemDAO.NotifyUpdate(tx, 28) + err := models.SharedIPItemDAO.NotifyUpdate(tx, 28) if err != nil { t.Fatal(err) } @@ -25,7 +26,7 @@ func TestIPItemDAO_DisableIPItemsWithListId(t *testing.T) { dbs.NotifyReady() var tx *dbs.Tx - err := SharedIPItemDAO.DisableIPItemsWithListId(tx, 67) + err := models.SharedIPItemDAO.DisableIPItemsWithListId(tx, 67) if err != nil { t.Fatal(err) } @@ -36,7 +37,7 @@ func TestIPItemDAO_ListIPItemsAfterVersion(t *testing.T) { dbs.NotifyReady() var tx *dbs.Tx - _, err := SharedIPItemDAO.ListIPItemsAfterVersion(tx, 0, 100) + _, err := models.SharedIPItemDAO.ListIPItemsAfterVersion(tx, 0, 100) if err != nil { t.Fatal(err) } @@ -47,10 +48,10 @@ func TestIPItemDAO_CreateManyIPs(t *testing.T) { dbs.NotifyReady() var tx *dbs.Tx - var dao = NewIPItemDAO() + var dao = models.NewIPItemDAO() var n = 10 for i := 0; i < n; i++ { - itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, "192."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)), "", time.Now().Unix()+86400, "test", IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0) + itemId, err := dao.CreateIPItem(tx, firewallconfigs.GlobalListId, "192."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255))+"."+types.String(rands.Int(0, 255)), "", time.Now().Unix()+86400, "test", models.IPItemTypeIPv4, "warning", 0, 0, 0, 0, 0, 0, 0) if err != nil { t.Fatal(err) } @@ -62,3 +63,14 @@ func TestIPItemDAO_CreateManyIPs(t *testing.T) { } t.Log("ok") } + +func TestIPItemDAO_DisableIPItemsWithIP(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + err := models.SharedIPItemDAO.DisableIPItemsWithIP(tx, "192.168.1.100", "", 0, 0) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index a43e8657..e9f6e982 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -114,23 +114,42 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte return nil, err } - tx := this.NullTx() + var tx = this.NullTx() - if userId > 0 { - listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId) - if err != nil { - return nil, err + // 如果是使用IPItemId删除 + if req.IpItemId > 0 { + if userId > 0 { + listId, err := models.SharedIPItemDAO.FindItemListId(tx, req.IpItemId) + if err != nil { + return nil, err + } + + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId) + if err != nil { + return nil, err + } } - err = models.SharedIPListDAO.CheckUserIPList(tx, userId, listId) + err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId) if err != nil { return nil, err } } - err = models.SharedIPItemDAO.DisableIPItem(tx, req.IpItemId) - if err != nil { - return nil, err + // 如果是使用ipFrom+ipTo删除 + if len(req.IpFrom) > 0 { + // 检查IP列表 + if req.IpListId > 0 && userId > 0 { + err = models.SharedIPListDAO.CheckUserIPList(tx, userId, req.IpListId) + if err != nil { + return nil, err + } + } + + err = models.SharedIPItemDAO.DisableIPItemsWithIP(tx, req.IpFrom, req.IpTo, userId, req.IpListId) + if err != nil { + return nil, err + } } return this.Success()