mirror of
https://github.com/TeaOSLab/EdgeAPI.git
synced 2025-11-03 15:00:27 +08:00
优化删除IP名单时操作
This commit is contained in:
@@ -92,7 +92,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
|
||||
if set == nil {
|
||||
return nil, nil
|
||||
}
|
||||
config := &firewallconfigs.HTTPFirewallRuleSet{}
|
||||
var config = &firewallconfigs.HTTPFirewallRuleSet{}
|
||||
config.Id = int64(set.Id)
|
||||
config.IsOn = set.IsOn
|
||||
config.Name = set.Name
|
||||
@@ -102,7 +102,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
|
||||
config.IgnoreLocal = set.IgnoreLocal == 1
|
||||
|
||||
if IsNotNull(set.Rules) {
|
||||
ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{}
|
||||
var ruleRefs = []*firewallconfigs.HTTPFirewallRuleRef{}
|
||||
err = json.Unmarshal(set.Rules, &ruleRefs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -128,6 +128,22 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
|
||||
config.Actions = actionConfigs
|
||||
}
|
||||
|
||||
// 检查各个选项
|
||||
for _, actionConfig := range actionConfigs {
|
||||
if actionConfig.Code == firewallconfigs.HTTPFirewallActionRecordIP { // 记录IP动作
|
||||
if actionConfig.Options != nil {
|
||||
var ipListId = actionConfig.Options.GetInt64("ipListId")
|
||||
exists, err := SharedIPListDAO.ExistsEnabledIPList(tx, ipListId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !exists {
|
||||
actionConfig.Options["ipListIsDeleted"] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -212,6 +228,28 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledRuleSetIdWithRuleId(tx *dbs.Tx, r
|
||||
FindInt64Col(0)
|
||||
}
|
||||
|
||||
// FindAllEnabledRuleSetIdsWithIPListId 根据IP名单ID查找对应动作的WAF规则集
|
||||
func (this *HTTPFirewallRuleSetDAO) FindAllEnabledRuleSetIdsWithIPListId(tx *dbs.Tx, ipListId int64) (setIds []int64, err error) {
|
||||
ones, err := this.Query(tx).
|
||||
State(HTTPFirewallRuleStateEnabled).
|
||||
Where("JSON_CONTAINS(actions, :jsonQuery)").
|
||||
Param("jsonQuery", maps.Map{
|
||||
"code": firewallconfigs.HTTPFirewallActionRecordIP,
|
||||
"options": maps.Map{
|
||||
"ipListId": ipListId,
|
||||
},
|
||||
}.AsJSON()).
|
||||
ResultPk().
|
||||
FindAll()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, one := range ones {
|
||||
setIds = append(setIds, int64(one.(*HTTPFirewallRuleSet).Id))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CheckUserRuleSet 检查用户
|
||||
func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error {
|
||||
groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/iwind/TeaGo/Tea"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"github.com/iwind/TeaGo/lists"
|
||||
"github.com/iwind/TeaGo/maps"
|
||||
"github.com/iwind/TeaGo/types"
|
||||
)
|
||||
|
||||
@@ -61,14 +62,18 @@ func (this *IPListDAO) EnableIPList(tx *dbs.Tx, id int64) error {
|
||||
}
|
||||
|
||||
// DisableIPList 禁用条目
|
||||
func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error {
|
||||
func (this *IPListDAO) DisableIPList(tx *dbs.Tx, listId int64) error {
|
||||
_, err := this.Query(tx).
|
||||
Pk(id).
|
||||
Pk(listId).
|
||||
Set("state", IPListStateDisabled).
|
||||
Update()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return this.NotifyUpdate(tx, listId, NodeTaskTypeIPListDeleted+"@"+string(maps.Map{"listId": listId}.AsJSON()))
|
||||
}
|
||||
|
||||
// FindEnabledIPList 查找启用中的条目
|
||||
func (this *IPListDAO) FindEnabledIPList(tx *dbs.Tx, id int64, cacheMap *utils.CacheMap) (*IPList, error) {
|
||||
if id == firewallconfigs.GlobalListId {
|
||||
@@ -258,11 +263,35 @@ func (this *IPListDAO) ExistsEnabledIPList(tx *dbs.Tx, listId int64) (bool, erro
|
||||
|
||||
// NotifyUpdate 通知更新
|
||||
func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error {
|
||||
// WAF策略中的
|
||||
httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
resultClusterIds := []int64{}
|
||||
|
||||
// 规则集动作中使用此名单的策略
|
||||
ruleSetIds, err := SharedHTTPFirewallRuleSetDAO.FindAllEnabledRuleSetIdsWithIPListId(tx, listId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ruleSetId := range ruleSetIds {
|
||||
ruleGroupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, ruleSetId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if ruleGroupId > 0 {
|
||||
policyId, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, ruleGroupId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if policyId > 0 && !lists.ContainsInt64(httpFirewallPolicyIds, policyId) {
|
||||
httpFirewallPolicyIds = append(httpFirewallPolicyIds, policyId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 查找集群
|
||||
var resultClusterIds = []int64{}
|
||||
for _, policyId := range httpFirewallPolicyIds {
|
||||
// 集群
|
||||
clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"errors"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/iwind/TeaGo/dbs"
|
||||
"runtime"
|
||||
@@ -27,7 +28,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
|
||||
|
||||
{
|
||||
err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
|
||||
if err == ErrNotFound {
|
||||
if err != nil && errors.Is(err, ErrNotFound) {
|
||||
t.Log("not found")
|
||||
} else {
|
||||
t.Log(err)
|
||||
@@ -36,7 +37,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
|
||||
|
||||
{
|
||||
err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
|
||||
if err == ErrNotFound {
|
||||
if err != nil && errors.Is(err, ErrNotFound) {
|
||||
t.Log("not found")
|
||||
} else {
|
||||
t.Log(err)
|
||||
@@ -45,7 +46,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
|
||||
|
||||
{
|
||||
err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
|
||||
if err == ErrNotFound {
|
||||
if err != nil && errors.Is(err, ErrNotFound) {
|
||||
t.Log("not found")
|
||||
} else {
|
||||
t.Log(err)
|
||||
@@ -53,6 +54,17 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPListDAO_NotifyUpdate(t *testing.T) {
|
||||
dbs.NotifyReady()
|
||||
|
||||
var dao = NewIPListDAO()
|
||||
var tx *dbs.Tx
|
||||
err := dao.NotifyUpdate(tx, 104, NodeTaskTypeIPListDeleted)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||
runtime.GOMAXPROCS(1)
|
||||
|
||||
@@ -65,4 +77,3 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
|
||||
_, _ = dao.IncreaseVersion(tx)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ const (
|
||||
NodeTaskTypeConfigChanged NodeTaskType = "configChanged" // 节点整体配置变化
|
||||
NodeTaskTypeDDosProtectionChanged NodeTaskType = "ddosProtectionChanged" // 节点DDoS配置变更
|
||||
NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化
|
||||
NodeTaskTypeIPListDeleted NodeTaskType = "ipListDeleted"
|
||||
NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" // IP条目变更
|
||||
NodeTaskTypeNodeVersionChanged NodeTaskType = "nodeVersionChanged" // 节点版本变化
|
||||
NodeTaskTypeScriptsChanged NodeTaskType = "scriptsChanged" // 脚本配置变化
|
||||
|
||||
@@ -268,7 +268,7 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
|
||||
|
||||
var dao = models.NewServerDAO()
|
||||
var tx *dbs.Tx
|
||||
err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day")
|
||||
err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day", "traffic")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user