删除WAF策略和删除服务时同时也删除关联的IP名单

This commit is contained in:
GoEdgeLab
2021-11-16 17:50:52 +08:00
parent 770884eaca
commit dfb4e6a155
5 changed files with 140 additions and 3 deletions

View File

@@ -63,6 +63,11 @@ func (this *HTTPFirewallPolicyDAO) DisableHTTPFirewallPolicy(tx *dbs.Tx, policyI
return err return err
} }
err = this.NotifyDisable(tx, policyId)
if err != nil {
return err
}
return this.NotifyUpdate(tx, policyId) return this.NotifyUpdate(tx, policyId)
} }
@@ -482,6 +487,23 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyServerId(tx *dbs.Tx, poli
return err return err
} }
// FindFirewallPolicyIdsWithServerId 查找服务独立关联的策略IDs
func (this *HTTPFirewallPolicyDAO) FindFirewallPolicyIdsWithServerId(tx *dbs.Tx, serverId int64) ([]int64, error) {
var result = []int64{}
ones, err := this.Query(tx).
Attr("serverId", serverId).
State(HTTPFirewallPolicyStateEnabled).
Result("id").
FindAll()
if err != nil {
return nil, err
}
for _, one := range ones {
result = append(result, int64(one.(*HTTPFirewallPolicy).Id))
}
return result, nil
}
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error {
webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId) webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId)
@@ -508,3 +530,65 @@ func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) erro
return nil return nil
} }
// NotifyDisable 通知禁用
func (this *HTTPFirewallPolicyDAO) NotifyDisable(tx *dbs.Tx, policyId int64) error {
if policyId <= 0 {
return nil
}
// 禁用IP名单
inboundString, err := this.Query(tx).
Pk(policyId).
Result("inbound").
FindStringCol("")
if err != nil {
return err
}
if len(inboundString) > 0 {
var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{}
err = json.Unmarshal([]byte(inboundString), inboundConfig)
if err != nil {
// 不处理错误
return nil
}
if inboundConfig.AllowListRef != nil && inboundConfig.AllowListRef.ListId > 0 {
err = SharedIPListDAO.DisableIPList(tx, inboundConfig.AllowListRef.ListId)
if err != nil {
return err
}
err = SharedIPItemDAO.DisableIPItemsWithListId(tx, inboundConfig.AllowListRef.ListId)
if err != nil {
return err
}
}
if inboundConfig.DenyListRef != nil && inboundConfig.DenyListRef.ListId > 0 {
err = SharedIPListDAO.DisableIPList(tx, inboundConfig.DenyListRef.ListId)
if err != nil {
return err
}
err = SharedIPItemDAO.DisableIPItemsWithListId(tx, inboundConfig.DenyListRef.ListId)
if err != nil {
return err
}
}
if inboundConfig.GreyListRef != nil && inboundConfig.GreyListRef.ListId > 0 {
err = SharedIPListDAO.DisableIPList(tx, inboundConfig.GreyListRef.ListId)
if err != nil {
return err
}
err = SharedIPItemDAO.DisableIPItemsWithListId(tx, inboundConfig.GreyListRef.ListId)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -75,6 +75,21 @@ func (this *IPItemDAO) DisableIPItem(tx *dbs.Tx, id int64) error {
return this.NotifyUpdate(tx, id) return this.NotifyUpdate(tx, id)
} }
// DisableIPItemsWithListId 禁用某个IP名单内的所有IP
func (this *IPItemDAO) DisableIPItemsWithListId(tx *dbs.Tx, listId int64) error {
version, err := SharedIPListDAO.IncreaseVersion(tx)
if err != nil {
return err
}
return this.Query(tx).
Attr("listId", listId).
State(IPItemStateEnabled).
Set("version", version).
Set("state", IPItemStateDisabled).
UpdateQuickly()
}
// FindEnabledIPItem 查找启用中的条目 // FindEnabledIPItem 查找启用中的条目
func (this *IPItemDAO) FindEnabledIPItem(tx *dbs.Tx, id int64) (*IPItem, error) { func (this *IPItemDAO) FindEnabledIPItem(tx *dbs.Tx, id int64) (*IPItem, error) {
result, err := this.Query(tx). result, err := this.Query(tx).
@@ -249,6 +264,7 @@ func (this *IPItemDAO) ListIPItemsAfterVersion(tx *dbs.Tx, version int64, size i
Where("(expiredAt=0 OR expiredAt>:expiredAt)"). Where("(expiredAt=0 OR expiredAt>:expiredAt)").
Param("expiredAt", time.Now().Unix()). Param("expiredAt", time.Now().Unix()).
Asc("version"). Asc("version").
Asc("id").
Limit(size). Limit(size).
Slice(&result). Slice(&result).
FindAll() FindAll()

View File

@@ -74,6 +74,13 @@ func (this *ServerDAO) DisableServer(tx *dbs.Tx, serverId int64) (err error) {
if err != nil { if err != nil {
return err return err
} }
// 删除对应的操作
err = this.NotifyDisable(tx, serverId)
if err != nil {
return err
}
err = this.NotifyUpdate(tx, serverId) err = this.NotifyUpdate(tx, serverId)
if err != nil { if err != nil {
return err return err
@@ -2150,3 +2157,25 @@ func (this *ServerDAO) NotifyDNSUpdate(tx *dbs.Tx, serverId int64) error {
} }
return dns.SharedDNSTaskDAO.CreateServerTask(tx, serverId, dns.DNSTaskTypeServerChange) return dns.SharedDNSTaskDAO.CreateServerTask(tx, serverId, dns.DNSTaskTypeServerChange)
} }
// NotifyDisable 通知禁用
func (this *ServerDAO) NotifyDisable(tx *dbs.Tx, serverId int64) error {
// 禁用缓存策略相关的内容
policyIds, err := SharedHTTPFirewallPolicyDAO.FindFirewallPolicyIdsWithServerId(tx, serverId)
if err != nil {
return err
}
for _, policyId := range policyIds {
err = SharedHTTPFirewallPolicyDAO.DisableHTTPFirewallPolicy(tx, policyId)
if err != nil {
return err
}
err = SharedHTTPFirewallPolicyDAO.NotifyDisable(tx, policyId)
if err != nil {
return err
}
}
return nil
}

View File

@@ -34,7 +34,8 @@ func TestSysLocker_Increase(t *testing.T) {
defer wg.Done() defer wg.Done()
v, err := NewSysLockerDAO().Increase(nil, "hello", 0) v, err := NewSysLockerDAO().Increase(nil, "hello", 0)
if err != nil { if err != nil {
t.Fatal(err) t.Log("err:", err)
return
} }
t.Log("v:", v) t.Log("v:", v)
}() }()

View File

@@ -97,12 +97,12 @@ func (this *IPListService) ListEnabledIPLists(ctx context.Context, req *pb.ListE
} }
var tx = this.NullTx() var tx = this.NullTx()
lists, err := models.SharedIPListDAO.ListEnabledIPLists(tx, req.Type, req.IsPublic, req.Keyword, req.Offset, req.Size) ipLists, err := models.SharedIPListDAO.ListEnabledIPLists(tx, req.Type, req.IsPublic, req.Keyword, req.Offset, req.Size)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var pbLists []*pb.IPList var pbLists []*pb.IPList
for _, list := range lists { for _, list := range ipLists {
pbLists = append(pbLists, &pb.IPList{ pbLists = append(pbLists, &pb.IPList{
Id: int64(list.Id), Id: int64(list.Id),
IsOn: list.IsOn == 1, IsOn: list.IsOn == 1,
@@ -129,6 +129,13 @@ func (this *IPListService) DeleteIPList(ctx context.Context, req *pb.DeleteIPLis
if err != nil { if err != nil {
return nil, err return nil, err
} }
// 删除所有IP
err = models.SharedIPItemDAO.DisableIPItemsWithListId(tx, req.IpListId)
if err != nil {
return nil, err
}
return this.Success() return this.Success()
} }