mirror of
				https://github.com/TeaOSLab/EdgeAPI.git
				synced 2025-11-04 07:50:25 +08:00 
			
		
		
		
	优化删除IP名单时操作
This commit is contained in:
		@@ -92,7 +92,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
 | 
			
		||||
	if set == nil {
 | 
			
		||||
		return nil, nil
 | 
			
		||||
	}
 | 
			
		||||
	config := &firewallconfigs.HTTPFirewallRuleSet{}
 | 
			
		||||
	var config = &firewallconfigs.HTTPFirewallRuleSet{}
 | 
			
		||||
	config.Id = int64(set.Id)
 | 
			
		||||
	config.IsOn = set.IsOn
 | 
			
		||||
	config.Name = set.Name
 | 
			
		||||
@@ -102,7 +102,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
 | 
			
		||||
	config.IgnoreLocal = set.IgnoreLocal == 1
 | 
			
		||||
 | 
			
		||||
	if IsNotNull(set.Rules) {
 | 
			
		||||
		ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{}
 | 
			
		||||
		var ruleRefs = []*firewallconfigs.HTTPFirewallRuleRef{}
 | 
			
		||||
		err = json.Unmarshal(set.Rules, &ruleRefs)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
@@ -128,6 +128,22 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int
 | 
			
		||||
		config.Actions = actionConfigs
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查各个选项
 | 
			
		||||
	for _, actionConfig := range actionConfigs {
 | 
			
		||||
		if actionConfig.Code == firewallconfigs.HTTPFirewallActionRecordIP { // 记录IP动作
 | 
			
		||||
			if actionConfig.Options != nil {
 | 
			
		||||
				var ipListId = actionConfig.Options.GetInt64("ipListId")
 | 
			
		||||
				exists, err := SharedIPListDAO.ExistsEnabledIPList(tx, ipListId)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					return nil, err
 | 
			
		||||
				}
 | 
			
		||||
				if !exists {
 | 
			
		||||
					actionConfig.Options["ipListIsDeleted"] = true
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return config, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -212,6 +228,28 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledRuleSetIdWithRuleId(tx *dbs.Tx, r
 | 
			
		||||
		FindInt64Col(0)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FindAllEnabledRuleSetIdsWithIPListId 根据IP名单ID查找对应动作的WAF规则集
 | 
			
		||||
func (this *HTTPFirewallRuleSetDAO) FindAllEnabledRuleSetIdsWithIPListId(tx *dbs.Tx, ipListId int64) (setIds []int64, err error) {
 | 
			
		||||
	ones, err := this.Query(tx).
 | 
			
		||||
		State(HTTPFirewallRuleStateEnabled).
 | 
			
		||||
		Where("JSON_CONTAINS(actions, :jsonQuery)").
 | 
			
		||||
		Param("jsonQuery", maps.Map{
 | 
			
		||||
			"code": firewallconfigs.HTTPFirewallActionRecordIP,
 | 
			
		||||
			"options": maps.Map{
 | 
			
		||||
				"ipListId": ipListId,
 | 
			
		||||
			},
 | 
			
		||||
		}.AsJSON()).
 | 
			
		||||
		ResultPk().
 | 
			
		||||
		FindAll()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	for _, one := range ones {
 | 
			
		||||
		setIds = append(setIds, int64(one.(*HTTPFirewallRuleSet).Id))
 | 
			
		||||
	}
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckUserRuleSet 检查用户
 | 
			
		||||
func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error {
 | 
			
		||||
	groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId)
 | 
			
		||||
 
 | 
			
		||||
@@ -11,6 +11,7 @@ import (
 | 
			
		||||
	"github.com/iwind/TeaGo/Tea"
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"github.com/iwind/TeaGo/lists"
 | 
			
		||||
	"github.com/iwind/TeaGo/maps"
 | 
			
		||||
	"github.com/iwind/TeaGo/types"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -61,12 +62,16 @@ func (this *IPListDAO) EnableIPList(tx *dbs.Tx, id int64) error {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DisableIPList 禁用条目
 | 
			
		||||
func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error {
 | 
			
		||||
func (this *IPListDAO) DisableIPList(tx *dbs.Tx, listId int64) error {
 | 
			
		||||
	_, err := this.Query(tx).
 | 
			
		||||
		Pk(id).
 | 
			
		||||
		Pk(listId).
 | 
			
		||||
		Set("state", IPListStateDisabled).
 | 
			
		||||
		Update()
 | 
			
		||||
	return err
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return this.NotifyUpdate(tx, listId, NodeTaskTypeIPListDeleted+"@"+string(maps.Map{"listId": listId}.AsJSON()))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// FindEnabledIPList 查找启用中的条目
 | 
			
		||||
@@ -258,11 +263,35 @@ func (this *IPListDAO) ExistsEnabledIPList(tx *dbs.Tx, listId int64) (bool, erro
 | 
			
		||||
 | 
			
		||||
// NotifyUpdate 通知更新
 | 
			
		||||
func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error {
 | 
			
		||||
	// WAF策略中的
 | 
			
		||||
	httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	resultClusterIds := []int64{}
 | 
			
		||||
 | 
			
		||||
	// 规则集动作中使用此名单的策略
 | 
			
		||||
	ruleSetIds, err := SharedHTTPFirewallRuleSetDAO.FindAllEnabledRuleSetIdsWithIPListId(tx, listId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	for _, ruleSetId := range ruleSetIds {
 | 
			
		||||
		ruleGroupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, ruleSetId)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
		if ruleGroupId > 0 {
 | 
			
		||||
			policyId, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, ruleGroupId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
			if policyId > 0 && !lists.ContainsInt64(httpFirewallPolicyIds, policyId) {
 | 
			
		||||
				httpFirewallPolicyIds = append(httpFirewallPolicyIds, policyId)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 查找集群
 | 
			
		||||
	var resultClusterIds = []int64{}
 | 
			
		||||
	for _, policyId := range httpFirewallPolicyIds {
 | 
			
		||||
		// 集群
 | 
			
		||||
		clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,7 @@
 | 
			
		||||
package models
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	_ "github.com/go-sql-driver/mysql"
 | 
			
		||||
	"github.com/iwind/TeaGo/dbs"
 | 
			
		||||
	"runtime"
 | 
			
		||||
@@ -27,7 +28,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		err := NewIPListDAO().CheckUserIPList(tx, 1, 100)
 | 
			
		||||
		if err == ErrNotFound {
 | 
			
		||||
		if err != nil && errors.Is(err, ErrNotFound) {
 | 
			
		||||
			t.Log("not found")
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Log(err)
 | 
			
		||||
@@ -36,7 +37,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		err := NewIPListDAO().CheckUserIPList(tx, 1, 85)
 | 
			
		||||
		if err == ErrNotFound {
 | 
			
		||||
		if err != nil && errors.Is(err, ErrNotFound) {
 | 
			
		||||
			t.Log("not found")
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Log(err)
 | 
			
		||||
@@ -45,7 +46,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		err := NewIPListDAO().CheckUserIPList(tx, 1, 17)
 | 
			
		||||
		if err == ErrNotFound {
 | 
			
		||||
		if err != nil && errors.Is(err, ErrNotFound) {
 | 
			
		||||
			t.Log("not found")
 | 
			
		||||
		} else {
 | 
			
		||||
			t.Log(err)
 | 
			
		||||
@@ -53,6 +54,17 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestIPListDAO_NotifyUpdate(t *testing.T) {
 | 
			
		||||
	dbs.NotifyReady()
 | 
			
		||||
 | 
			
		||||
	var dao = NewIPListDAO()
 | 
			
		||||
	var tx *dbs.Tx
 | 
			
		||||
	err := dao.NotifyUpdate(tx, 104, NodeTaskTypeIPListDeleted)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
 | 
			
		||||
	runtime.GOMAXPROCS(1)
 | 
			
		||||
 | 
			
		||||
@@ -65,4 +77,3 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) {
 | 
			
		||||
		_, _ = dao.IncreaseVersion(tx)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -19,17 +19,18 @@ const (
 | 
			
		||||
	NodeTaskTypeConfigChanged             NodeTaskType = "configChanged"             // 节点整体配置变化
 | 
			
		||||
	NodeTaskTypeDDosProtectionChanged     NodeTaskType = "ddosProtectionChanged"     // 节点DDoS配置变更
 | 
			
		||||
	NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化
 | 
			
		||||
	NodeTaskTypeIPItemChanged             NodeTaskType = "ipItemChanged"             // IP条目变更
 | 
			
		||||
	NodeTaskTypeNodeVersionChanged        NodeTaskType = "nodeVersionChanged"        // 节点版本变化
 | 
			
		||||
	NodeTaskTypeScriptsChanged            NodeTaskType = "scriptsChanged"            // 脚本配置变化
 | 
			
		||||
	NodeTaskTypeNodeLevelChanged          NodeTaskType = "nodeLevelChanged"          // 节点级别变化
 | 
			
		||||
	NodeTaskTypeUserServersStateChanged   NodeTaskType = "userServersStateChanged"   // 用户服务状态变化
 | 
			
		||||
	NodeTaskTypeUAMPolicyChanged          NodeTaskType = "uamPolicyChanged"          // UAM策略变化
 | 
			
		||||
	NodeTaskTypeHTTPPagesPolicyChanged    NodeTaskType = "httpPagesPolicyChanged"    // 自定义页面变化
 | 
			
		||||
	NodeTaskTypeHTTPCCPolicyChanged       NodeTaskType = "httpCCPolicyChanged"       // CC策略变化
 | 
			
		||||
	NodeTaskTypeHTTP3PolicyChanged        NodeTaskType = "http3PolicyChanged"        // HTTP3策略变化
 | 
			
		||||
	NodeTaskTypeUpdatingServers           NodeTaskType = "updatingServers"           // 更新一组服务
 | 
			
		||||
	NodeTaskTypeTOAChanged                NodeTaskType = "toaChanged"                // TOA配置变化
 | 
			
		||||
	NodeTaskTypeIPListDeleted             NodeTaskType = "ipListDeleted"
 | 
			
		||||
	NodeTaskTypeIPItemChanged             NodeTaskType = "ipItemChanged"           // IP条目变更
 | 
			
		||||
	NodeTaskTypeNodeVersionChanged        NodeTaskType = "nodeVersionChanged"      // 节点版本变化
 | 
			
		||||
	NodeTaskTypeScriptsChanged            NodeTaskType = "scriptsChanged"          // 脚本配置变化
 | 
			
		||||
	NodeTaskTypeNodeLevelChanged          NodeTaskType = "nodeLevelChanged"        // 节点级别变化
 | 
			
		||||
	NodeTaskTypeUserServersStateChanged   NodeTaskType = "userServersStateChanged" // 用户服务状态变化
 | 
			
		||||
	NodeTaskTypeUAMPolicyChanged          NodeTaskType = "uamPolicyChanged"        // UAM策略变化
 | 
			
		||||
	NodeTaskTypeHTTPPagesPolicyChanged    NodeTaskType = "httpPagesPolicyChanged"  // 自定义页面变化
 | 
			
		||||
	NodeTaskTypeHTTPCCPolicyChanged       NodeTaskType = "httpCCPolicyChanged"     // CC策略变化
 | 
			
		||||
	NodeTaskTypeHTTP3PolicyChanged        NodeTaskType = "http3PolicyChanged"      // HTTP3策略变化
 | 
			
		||||
	NodeTaskTypeUpdatingServers           NodeTaskType = "updatingServers"         // 更新一组服务
 | 
			
		||||
	NodeTaskTypeTOAChanged                NodeTaskType = "toaChanged"              // TOA配置变化
 | 
			
		||||
 | 
			
		||||
	// NS相关
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -268,7 +268,7 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) {
 | 
			
		||||
 | 
			
		||||
	var dao = models.NewServerDAO()
 | 
			
		||||
	var tx *dbs.Tx
 | 
			
		||||
	err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day")
 | 
			
		||||
	err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day", "traffic")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user