优化删除IP名单时操作

This commit is contained in:
GoEdgeLab
2023-09-13 17:16:00 +08:00
parent bd8d88bf18
commit d79d54e1e3
5 changed files with 101 additions and 22 deletions

View File

@@ -92,7 +92,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
if set == nil { if set == nil {
return nil, nil return nil, nil
} }
config := &firewallconfigs.HTTPFirewallRuleSet{} var config = &firewallconfigs.HTTPFirewallRuleSet{}
config.Id = int64(set.Id) config.Id = int64(set.Id)
config.IsOn = set.IsOn config.IsOn = set.IsOn
config.Name = set.Name config.Name = set.Name
@@ -102,7 +102,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
config.IgnoreLocal = set.IgnoreLocal == 1 config.IgnoreLocal = set.IgnoreLocal == 1
if IsNotNull(set.Rules) { if IsNotNull(set.Rules) {
ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{} var ruleRefs = []*firewallconfigs.HTTPFirewallRuleRef{}
err = json.Unmarshal(set.Rules, &ruleRefs) err = json.Unmarshal(set.Rules, &ruleRefs)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -128,6 +128,22 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
config.Actions = actionConfigs config.Actions = actionConfigs
} }
// 检查各个选项
for _, actionConfig := range actionConfigs {
if actionConfig.Code == firewallconfigs.HTTPFirewallActionRecordIP { // 记录IP动作
if actionConfig.Options != nil {
var ipListId = actionConfig.Options.GetInt64("ipListId")
exists, err := SharedIPListDAO.ExistsEnabledIPList(tx, ipListId)
if err != nil {
return nil, err
}
if !exists {
actionConfig.Options["ipListIsDeleted"] = true
}
}
}
}
return config, nil return config, nil
} }
@@ -212,6 +228,28 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledRuleSetIdWithRuleId(tx *dbs.Tx, r
FindInt64Col(0) FindInt64Col(0)
} }
// FindAllEnabledRuleSetIdsWithIPListId 根据IP名单ID查找对应动作的WAF规则集
func (this *HTTPFirewallRuleSetDAO) FindAllEnabledRuleSetIdsWithIPListId(tx *dbs.Tx, ipListId int64) (setIds []int64, err error) {
ones, err := this.Query(tx).
State(HTTPFirewallRuleStateEnabled).
Where("JSON_CONTAINS(actions, :jsonQuery)").
Param("jsonQuery", maps.Map{
"code": firewallconfigs.HTTPFirewallActionRecordIP,
"options": maps.Map{
"ipListId": ipListId,
},
}.AsJSON()).
ResultPk().
FindAll()
if err != nil {
return nil, err
}
for _, one := range ones {
setIds = append(setIds, int64(one.(*HTTPFirewallRuleSet).Id))
}
return
}
// CheckUserRuleSet 检查用户 // CheckUserRuleSet 检查用户
func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error { func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error {
groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId) groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId)

View File

@@ -11,6 +11,7 @@ import (
"github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/lists"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types" "github.com/iwind/TeaGo/types"
) )
@@ -61,12 +62,16 @@ func (this *IPListDAO) EnableIPList(tx *dbs.Tx, id int64) error {
} }
// DisableIPList 禁用条目 // DisableIPList 禁用条目
func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error { func (this *IPListDAO) DisableIPList(tx *dbs.Tx, listId int64) error {
_, err := this.Query(tx). _, err := this.Query(tx).
Pk(id). Pk(listId).
Set("state", IPListStateDisabled). Set("state", IPListStateDisabled).
Update() Update()
if err != nil {
return err return err
}
return this.NotifyUpdate(tx, listId, NodeTaskTypeIPListDeleted+"@"+string(maps.Map{"listId": listId}.AsJSON()))
} }
// FindEnabledIPList 查找启用中的条目 // FindEnabledIPList 查找启用中的条目
@@ -258,11 +263,35 @@ func (this *IPListDAO) ExistsEnabledIPList(tx *dbs.Tx, listId int64) (bool, erro
// NotifyUpdate 通知更新 // NotifyUpdate 通知更新
func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error { func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error {
// WAF策略中的
httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
if err != nil { if err != nil {
return err return err
} }
resultClusterIds := []int64{}
// 规则集动作中使用此名单的策略
ruleSetIds, err := SharedHTTPFirewallRuleSetDAO.FindAllEnabledRuleSetIdsWithIPListId(tx, listId)
if err != nil {
return err
}
for _, ruleSetId := range ruleSetIds {
ruleGroupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, ruleSetId)
if err != nil {
return err
}
if ruleGroupId > 0 {
policyId, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, ruleGroupId)
if err != nil {
return err
}
if policyId > 0 && !lists.ContainsInt64(httpFirewallPolicyIds, policyId) {
httpFirewallPolicyIds = append(httpFirewallPolicyIds, policyId)
}
}
}
// 查找集群
var resultClusterIds = []int64{}
for _, policyId := range httpFirewallPolicyIds { for _, policyId := range httpFirewallPolicyIds {
// 集群 // 集群
clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId) clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId)

View File

@@ -1,6 +1,7 @@
package models package models
import ( import (
"errors"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/dbs"
"runtime" "runtime"
@@ -27,7 +28,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
{ {
err := NewIPListDAO().CheckUserIPList(tx, 1, 100) err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
if err == ErrNotFound { if err != nil && errors.Is(err, ErrNotFound) {
t.Log("not found") t.Log("not found")
} else { } else {
t.Log(err) t.Log(err)
@@ -36,7 +37,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
{ {
err := NewIPListDAO().CheckUserIPList(tx, 1, 85) err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
if err == ErrNotFound { if err != nil && errors.Is(err, ErrNotFound) {
t.Log("not found") t.Log("not found")
} else { } else {
t.Log(err) t.Log(err)
@@ -45,7 +46,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
{ {
err := NewIPListDAO().CheckUserIPList(tx, 1, 17) err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
if err == ErrNotFound { if err != nil && errors.Is(err, ErrNotFound) {
t.Log("not found") t.Log("not found")
} else { } else {
t.Log(err) t.Log(err)
@@ -53,6 +54,17 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
} }
} }
func TestIPListDAO_NotifyUpdate(t *testing.T) {
dbs.NotifyReady()
var dao = NewIPListDAO()
var tx *dbs.Tx
err := dao.NotifyUpdate(tx, 104, NodeTaskTypeIPListDeleted)
if err != nil {
t.Fatal(err)
}
}
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
runtime.GOMAXPROCS(1) runtime.GOMAXPROCS(1)
@@ -65,4 +77,3 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
_, _ = dao.IncreaseVersion(tx) _, _ = dao.IncreaseVersion(tx)
} }
} }

View File

@@ -19,6 +19,7 @@ const (
NodeTaskTypeConfigChanged NodeTaskType = "configChanged" // 节点整体配置变化 NodeTaskTypeConfigChanged NodeTaskType = "configChanged" // 节点整体配置变化
NodeTaskTypeDDosProtectionChanged NodeTaskType = "ddosProtectionChanged" // 节点DDoS配置变更 NodeTaskTypeDDosProtectionChanged NodeTaskType = "ddosProtectionChanged" // 节点DDoS配置变更
NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化 NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化
NodeTaskTypeIPListDeleted NodeTaskType = "ipListDeleted"
NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" // IP条目变更 NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" // IP条目变更
NodeTaskTypeNodeVersionChanged NodeTaskType = "nodeVersionChanged" // 节点版本变化 NodeTaskTypeNodeVersionChanged NodeTaskType = "nodeVersionChanged" // 节点版本变化
NodeTaskTypeScriptsChanged NodeTaskType = "scriptsChanged" // 脚本配置变化 NodeTaskTypeScriptsChanged NodeTaskType = "scriptsChanged" // 脚本配置变化

View File

@@ -268,7 +268,7 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
var dao = models.NewServerDAO() var dao = models.NewServerDAO()
var tx *dbs.Tx var tx *dbs.Tx
err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day") err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day", "traffic")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }