mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 16:00:24 +08:00 
			
		
		
		
	优化删除IP名单时操作
This commit is contained in:
		@@ -92,7 +92,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
 | 
				
			|||||||
	if set == nil {
 | 
						if set == nil {
 | 
				
			||||||
		return nil, nil
 | 
							return nil, nil
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	config := &firewallconfigs.HTTPFirewallRuleSet{}
 | 
						var config = &firewallconfigs.HTTPFirewallRuleSet{}
 | 
				
			||||||
	config.Id = int64(set.Id)
 | 
						config.Id = int64(set.Id)
 | 
				
			||||||
	config.IsOn = set.IsOn
 | 
						config.IsOn = set.IsOn
 | 
				
			||||||
	config.Name = set.Name
 | 
						config.Name = set.Name
 | 
				
			||||||
@@ -102,7 +102,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
 | 
				
			|||||||
	config.IgnoreLocal = set.IgnoreLocal == 1
 | 
						config.IgnoreLocal = set.IgnoreLocal == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if IsNotNull(set.Rules) {
 | 
						if IsNotNull(set.Rules) {
 | 
				
			||||||
		ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{}
 | 
							var ruleRefs = []*firewallconfigs.HTTPFirewallRuleRef{}
 | 
				
			||||||
		err = json.Unmarshal(set.Rules, &ruleRefs)
 | 
							err = json.Unmarshal(set.Rules, &ruleRefs)
 | 
				
			||||||
		if err != nil {
 | 
							if err != nil {
 | 
				
			||||||
			return nil, err
 | 
								return nil, err
 | 
				
			||||||
@@ -128,6 +128,22 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
 | 
				
			|||||||
		config.Actions = actionConfigs
 | 
							config.Actions = actionConfigs
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 检查各个选项
 | 
				
			||||||
 | 
						for _, actionConfig := range actionConfigs {
 | 
				
			||||||
 | 
							if actionConfig.Code == firewallconfigs.HTTPFirewallActionRecordIP { // 记录IP动作
 | 
				
			||||||
 | 
								if actionConfig.Options != nil {
 | 
				
			||||||
 | 
									var ipListId = actionConfig.Options.GetInt64("ipListId")
 | 
				
			||||||
 | 
									exists, err := SharedIPListDAO.ExistsEnabledIPList(tx, ipListId)
 | 
				
			||||||
 | 
									if err != nil {
 | 
				
			||||||
 | 
										return nil, err
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
									if !exists {
 | 
				
			||||||
 | 
										actionConfig.Options["ipListIsDeleted"] = true
 | 
				
			||||||
 | 
									}
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	return config, nil
 | 
						return config, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -212,6 +228,28 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledRuleSetIdWithRuleId(tx *dbs.Tx, r
 | 
				
			|||||||
		FindInt64Col(0)
 | 
							FindInt64Col(0)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// FindAllEnabledRuleSetIdsWithIPListId 根据IP名单ID查找对应动作的WAF规则集
 | 
				
			||||||
 | 
					func (this *HTTPFirewallRuleSetDAO) FindAllEnabledRuleSetIdsWithIPListId(tx *dbs.Tx, ipListId int64) (setIds []int64, err error) {
 | 
				
			||||||
 | 
						ones, err := this.Query(tx).
 | 
				
			||||||
 | 
							State(HTTPFirewallRuleStateEnabled).
 | 
				
			||||||
 | 
							Where("JSON_CONTAINS(actions, :jsonQuery)").
 | 
				
			||||||
 | 
							Param("jsonQuery", maps.Map{
 | 
				
			||||||
 | 
								"code": firewallconfigs.HTTPFirewallActionRecordIP,
 | 
				
			||||||
 | 
								"options": maps.Map{
 | 
				
			||||||
 | 
									"ipListId": ipListId,
 | 
				
			||||||
 | 
								},
 | 
				
			||||||
 | 
							}.AsJSON()).
 | 
				
			||||||
 | 
							ResultPk().
 | 
				
			||||||
 | 
							FindAll()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return nil, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, one := range ones {
 | 
				
			||||||
 | 
							setIds = append(setIds, int64(one.(*HTTPFirewallRuleSet).Id))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// CheckUserRuleSet 检查用户
 | 
					// CheckUserRuleSet 检查用户
 | 
				
			||||||
func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error {
 | 
					func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error {
 | 
				
			||||||
	groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId)
 | 
						groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -11,6 +11,7 @@ import (
 | 
				
			|||||||
	"github.com/iwind/TeaGo/Tea"
 | 
						"github.com/iwind/TeaGo/Tea"
 | 
				
			||||||
	"github.com/iwind/TeaGo/dbs"
 | 
						"github.com/iwind/TeaGo/dbs"
 | 
				
			||||||
	"github.com/iwind/TeaGo/lists"
 | 
						"github.com/iwind/TeaGo/lists"
 | 
				
			||||||
 | 
						"github.com/iwind/TeaGo/maps"
 | 
				
			||||||
	"github.com/iwind/TeaGo/types"
 | 
						"github.com/iwind/TeaGo/types"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -61,12 +62,16 @@ func (this *IPListDAO) EnableIPList(tx *dbs.Tx, id int64) error {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// DisableIPList 禁用条目
 | 
					// DisableIPList 禁用条目
 | 
				
			||||||
func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error {
 | 
					func (this *IPListDAO) DisableIPList(tx *dbs.Tx, listId int64) error {
 | 
				
			||||||
	_, err := this.Query(tx).
 | 
						_, err := this.Query(tx).
 | 
				
			||||||
		Pk(id).
 | 
							Pk(listId).
 | 
				
			||||||
		Set("state", IPListStateDisabled).
 | 
							Set("state", IPListStateDisabled).
 | 
				
			||||||
		Update()
 | 
							Update()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						return this.NotifyUpdate(tx, listId, NodeTaskTypeIPListDeleted+"@"+string(maps.Map{"listId": listId}.AsJSON()))
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// FindEnabledIPList 查找启用中的条目
 | 
					// FindEnabledIPList 查找启用中的条目
 | 
				
			||||||
@@ -258,11 +263,35 @@ func (this *IPListDAO) ExistsEnabledIPList(tx *dbs.Tx, listId int64) (bool, erro
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
// NotifyUpdate 通知更新
 | 
					// NotifyUpdate 通知更新
 | 
				
			||||||
func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error {
 | 
					func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error {
 | 
				
			||||||
 | 
						// WAF策略中的
 | 
				
			||||||
	httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
 | 
						httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	resultClusterIds := []int64{}
 | 
					
 | 
				
			||||||
 | 
						// 规则集动作中使用此名单的策略
 | 
				
			||||||
 | 
						ruleSetIds, err := SharedHTTPFirewallRuleSetDAO.FindAllEnabledRuleSetIdsWithIPListId(tx, listId)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						for _, ruleSetId := range ruleSetIds {
 | 
				
			||||||
 | 
							ruleGroupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, ruleSetId)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							if ruleGroupId > 0 {
 | 
				
			||||||
 | 
								policyId, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, ruleGroupId)
 | 
				
			||||||
 | 
								if err != nil {
 | 
				
			||||||
 | 
									return err
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if policyId > 0 && !lists.ContainsInt64(httpFirewallPolicyIds, policyId) {
 | 
				
			||||||
 | 
									httpFirewallPolicyIds = append(httpFirewallPolicyIds, policyId)
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// 查找集群
 | 
				
			||||||
 | 
						var resultClusterIds = []int64{}
 | 
				
			||||||
	for _, policyId := range httpFirewallPolicyIds {
 | 
						for _, policyId := range httpFirewallPolicyIds {
 | 
				
			||||||
		// 集群
 | 
							// 集群
 | 
				
			||||||
		clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId)
 | 
							clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,6 +1,7 @@
 | 
				
			|||||||
package models
 | 
					package models
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
 | 
						"errors"
 | 
				
			||||||
	_ "github.com/go-sql-driver/mysql"
 | 
						_ "github.com/go-sql-driver/mysql"
 | 
				
			||||||
	"github.com/iwind/TeaGo/dbs"
 | 
						"github.com/iwind/TeaGo/dbs"
 | 
				
			||||||
	"runtime"
 | 
						"runtime"
 | 
				
			||||||
@@ -27,7 +28,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
 | 
							err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
 | 
				
			||||||
		if err == ErrNotFound {
 | 
							if err != nil && errors.Is(err, ErrNotFound) {
 | 
				
			||||||
			t.Log("not found")
 | 
								t.Log("not found")
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			t.Log(err)
 | 
								t.Log(err)
 | 
				
			||||||
@@ -36,7 +37,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
 | 
							err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
 | 
				
			||||||
		if err == ErrNotFound {
 | 
							if err != nil && errors.Is(err, ErrNotFound) {
 | 
				
			||||||
			t.Log("not found")
 | 
								t.Log("not found")
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			t.Log(err)
 | 
								t.Log(err)
 | 
				
			||||||
@@ -45,7 +46,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	{
 | 
						{
 | 
				
			||||||
		err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
 | 
							err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
 | 
				
			||||||
		if err == ErrNotFound {
 | 
							if err != nil && errors.Is(err, ErrNotFound) {
 | 
				
			||||||
			t.Log("not found")
 | 
								t.Log("not found")
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			t.Log(err)
 | 
								t.Log(err)
 | 
				
			||||||
@@ -53,6 +54,17 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestIPListDAO_NotifyUpdate(t *testing.T) {
 | 
				
			||||||
 | 
						dbs.NotifyReady()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var dao = NewIPListDAO()
 | 
				
			||||||
 | 
						var tx *dbs.Tx
 | 
				
			||||||
 | 
						err := dao.NotifyUpdate(tx, 104, NodeTaskTypeIPListDeleted)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							t.Fatal(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
 | 
					func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
 | 
				
			||||||
	runtime.GOMAXPROCS(1)
 | 
						runtime.GOMAXPROCS(1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -65,4 +77,3 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
 | 
				
			|||||||
		_, _ = dao.IncreaseVersion(tx)
 | 
							_, _ = dao.IncreaseVersion(tx)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,6 +19,7 @@ const (
 | 
				
			|||||||
	NodeTaskTypeConfigChanged             NodeTaskType = "configChanged"             // 节点整体配置变化
 | 
						NodeTaskTypeConfigChanged             NodeTaskType = "configChanged"             // 节点整体配置变化
 | 
				
			||||||
	NodeTaskTypeDDosProtectionChanged     NodeTaskType = "ddosProtectionChanged"     // 节点DDoS配置变更
 | 
						NodeTaskTypeDDosProtectionChanged     NodeTaskType = "ddosProtectionChanged"     // 节点DDoS配置变更
 | 
				
			||||||
	NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化
 | 
						NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化
 | 
				
			||||||
 | 
						NodeTaskTypeIPListDeleted             NodeTaskType = "ipListDeleted"
 | 
				
			||||||
	NodeTaskTypeIPItemChanged             NodeTaskType = "ipItemChanged"           // IP条目变更
 | 
						NodeTaskTypeIPItemChanged             NodeTaskType = "ipItemChanged"           // IP条目变更
 | 
				
			||||||
	NodeTaskTypeNodeVersionChanged        NodeTaskType = "nodeVersionChanged"      // 节点版本变化
 | 
						NodeTaskTypeNodeVersionChanged        NodeTaskType = "nodeVersionChanged"      // 节点版本变化
 | 
				
			||||||
	NodeTaskTypeScriptsChanged            NodeTaskType = "scriptsChanged"          // 脚本配置变化
 | 
						NodeTaskTypeScriptsChanged            NodeTaskType = "scriptsChanged"          // 脚本配置变化
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -268,7 +268,7 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
	var dao = models.NewServerDAO()
 | 
						var dao = models.NewServerDAO()
 | 
				
			||||||
	var tx *dbs.Tx
 | 
						var tx *dbs.Tx
 | 
				
			||||||
	err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day")
 | 
						err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day", "traffic")
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatal(err)
 | 
							t.Fatal(err)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user