diff --git a/internal/db/models/http_firewall_rule_set_dao.go b/internal/db/models/http_firewall_rule_set_dao.go index 44429141..38719365 100644 --- a/internal/db/models/http_firewall_rule_set_dao.go +++ b/internal/db/models/http_firewall_rule_set_dao.go @@ -92,7 +92,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int if set == nil { return nil, nil } - config := &firewallconfigs.HTTPFirewallRuleSet{} + var config = &firewallconfigs.HTTPFirewallRuleSet{} config.Id = int64(set.Id) config.IsOn = set.IsOn config.Name = set.Name @@ -102,7 +102,7 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int config.IgnoreLocal = set.IgnoreLocal == 1 if IsNotNull(set.Rules) { - ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{} + var ruleRefs = []*firewallconfigs.HTTPFirewallRuleRef{} err = json.Unmarshal(set.Rules, &ruleRefs) if err != nil { return nil, err @@ -128,6 +128,22 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(tx *dbs.Tx, setId int config.Actions = actionConfigs } + // 检查各个选项 + for _, actionConfig := range actionConfigs { + if actionConfig.Code == firewallconfigs.HTTPFirewallActionRecordIP { // 记录IP动作 + if actionConfig.Options != nil { + var ipListId = actionConfig.Options.GetInt64("ipListId") + exists, err := SharedIPListDAO.ExistsEnabledIPList(tx, ipListId) + if err != nil { + return nil, err + } + if !exists { + actionConfig.Options["ipListIsDeleted"] = true + } + } + } + } + return config, nil } @@ -212,6 +228,28 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledRuleSetIdWithRuleId(tx *dbs.Tx, r FindInt64Col(0) } +// FindAllEnabledRuleSetIdsWithIPListId 根据IP名单ID查找对应动作的WAF规则集 +func (this *HTTPFirewallRuleSetDAO) FindAllEnabledRuleSetIdsWithIPListId(tx *dbs.Tx, ipListId int64) (setIds []int64, err error) { + ones, err := this.Query(tx). + State(HTTPFirewallRuleStateEnabled). + Where("JSON_CONTAINS(actions, :jsonQuery)"). + Param("jsonQuery", maps.Map{ + "code": firewallconfigs.HTTPFirewallActionRecordIP, + "options": maps.Map{ + "ipListId": ipListId, + }, + }.AsJSON()). + ResultPk(). + FindAll() + if err != nil { + return nil, err + } + for _, one := range ones { + setIds = append(setIds, int64(one.(*HTTPFirewallRuleSet).Id)) + } + return +} + // CheckUserRuleSet 检查用户 func (this *HTTPFirewallRuleSetDAO) CheckUserRuleSet(tx *dbs.Tx, userId int64, setId int64) error { groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId) diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index d2093204..41e19638 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -11,6 +11,7 @@ import ( "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" ) @@ -61,12 +62,16 @@ func (this *IPListDAO) EnableIPList(tx *dbs.Tx, id int64) error { } // DisableIPList 禁用条目 -func (this *IPListDAO) DisableIPList(tx *dbs.Tx, id int64) error { +func (this *IPListDAO) DisableIPList(tx *dbs.Tx, listId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(listId). Set("state", IPListStateDisabled). Update() - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, listId, NodeTaskTypeIPListDeleted+"@"+string(maps.Map{"listId": listId}.AsJSON())) } // FindEnabledIPList 查找启用中的条目 @@ -258,11 +263,35 @@ func (this *IPListDAO) ExistsEnabledIPList(tx *dbs.Tx, listId int64) (bool, erro // NotifyUpdate 通知更新 func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error { + // WAF策略中的 httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) if err != nil { return err } - resultClusterIds := []int64{} + + // 规则集动作中使用此名单的策略 + ruleSetIds, err := SharedHTTPFirewallRuleSetDAO.FindAllEnabledRuleSetIdsWithIPListId(tx, listId) + if err != nil { + return err + } + for _, ruleSetId := range ruleSetIds { + ruleGroupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, ruleSetId) + if err != nil { + return err + } + if ruleGroupId > 0 { + policyId, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, ruleGroupId) + if err != nil { + return err + } + if policyId > 0 && !lists.ContainsInt64(httpFirewallPolicyIds, policyId) { + httpFirewallPolicyIds = append(httpFirewallPolicyIds, policyId) + } + } + } + + // 查找集群 + var resultClusterIds = []int64{} for _, policyId := range httpFirewallPolicyIds { // 集群 clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId) diff --git a/internal/db/models/ip_list_dao_test.go b/internal/db/models/ip_list_dao_test.go index 68abbc1a..159ba72c 100644 --- a/internal/db/models/ip_list_dao_test.go +++ b/internal/db/models/ip_list_dao_test.go @@ -1,6 +1,7 @@ package models import ( + "errors" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/dbs" "runtime" @@ -27,7 +28,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) { { err := NewIPListDAO().CheckUserIPList(tx, 1, 100) - if err == ErrNotFound { + if err != nil && errors.Is(err, ErrNotFound) { t.Log("not found") } else { t.Log(err) @@ -36,7 +37,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) { { err := NewIPListDAO().CheckUserIPList(tx, 1, 85) - if err == ErrNotFound { + if err != nil && errors.Is(err, ErrNotFound) { t.Log("not found") } else { t.Log(err) @@ -45,7 +46,7 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) { { err := NewIPListDAO().CheckUserIPList(tx, 1, 17) - if err == ErrNotFound { + if err != nil && errors.Is(err, ErrNotFound) { t.Log("not found") } else { t.Log(err) @@ -53,6 +54,17 @@ func TestIPListDAO_CheckUserIPList(t *testing.T) { } } +func TestIPListDAO_NotifyUpdate(t *testing.T) { + dbs.NotifyReady() + + var dao = NewIPListDAO() + var tx *dbs.Tx + err := dao.NotifyUpdate(tx, 104, NodeTaskTypeIPListDeleted) + if err != nil { + t.Fatal(err) + } +} + func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { runtime.GOMAXPROCS(1) @@ -65,4 +77,3 @@ func BenchmarkIPListDAO_IncreaseVersion(b *testing.B) { _, _ = dao.IncreaseVersion(tx) } } - diff --git a/internal/db/models/node_task_dao.go b/internal/db/models/node_task_dao.go index 43c5036d..863f8a08 100644 --- a/internal/db/models/node_task_dao.go +++ b/internal/db/models/node_task_dao.go @@ -19,17 +19,18 @@ const ( NodeTaskTypeConfigChanged NodeTaskType = "configChanged" // 节点整体配置变化 NodeTaskTypeDDosProtectionChanged NodeTaskType = "ddosProtectionChanged" // 节点DDoS配置变更 NodeTaskTypeGlobalServerConfigChanged NodeTaskType = "globalServerConfigChanged" // 全局服务设置变化 - NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" // IP条目变更 - NodeTaskTypeNodeVersionChanged NodeTaskType = "nodeVersionChanged" // 节点版本变化 - NodeTaskTypeScriptsChanged NodeTaskType = "scriptsChanged" // 脚本配置变化 - NodeTaskTypeNodeLevelChanged NodeTaskType = "nodeLevelChanged" // 节点级别变化 - NodeTaskTypeUserServersStateChanged NodeTaskType = "userServersStateChanged" // 用户服务状态变化 - NodeTaskTypeUAMPolicyChanged NodeTaskType = "uamPolicyChanged" // UAM策略变化 - NodeTaskTypeHTTPPagesPolicyChanged NodeTaskType = "httpPagesPolicyChanged" // 自定义页面变化 - NodeTaskTypeHTTPCCPolicyChanged NodeTaskType = "httpCCPolicyChanged" // CC策略变化 - NodeTaskTypeHTTP3PolicyChanged NodeTaskType = "http3PolicyChanged" // HTTP3策略变化 - NodeTaskTypeUpdatingServers NodeTaskType = "updatingServers" // 更新一组服务 - NodeTaskTypeTOAChanged NodeTaskType = "toaChanged" // TOA配置变化 + NodeTaskTypeIPListDeleted NodeTaskType = "ipListDeleted" + NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" // IP条目变更 + NodeTaskTypeNodeVersionChanged NodeTaskType = "nodeVersionChanged" // 节点版本变化 + NodeTaskTypeScriptsChanged NodeTaskType = "scriptsChanged" // 脚本配置变化 + NodeTaskTypeNodeLevelChanged NodeTaskType = "nodeLevelChanged" // 节点级别变化 + NodeTaskTypeUserServersStateChanged NodeTaskType = "userServersStateChanged" // 用户服务状态变化 + NodeTaskTypeUAMPolicyChanged NodeTaskType = "uamPolicyChanged" // UAM策略变化 + NodeTaskTypeHTTPPagesPolicyChanged NodeTaskType = "httpPagesPolicyChanged" // 自定义页面变化 + NodeTaskTypeHTTPCCPolicyChanged NodeTaskType = "httpCCPolicyChanged" // CC策略变化 + NodeTaskTypeHTTP3PolicyChanged NodeTaskType = "http3PolicyChanged" // HTTP3策略变化 + NodeTaskTypeUpdatingServers NodeTaskType = "updatingServers" // 更新一组服务 + NodeTaskTypeTOAChanged NodeTaskType = "toaChanged" // TOA配置变化 // NS相关 diff --git a/internal/db/models/server_dao_test.go b/internal/db/models/server_dao_test.go index 889f451f..52368620 100644 --- a/internal/db/models/server_dao_test.go +++ b/internal/db/models/server_dao_test.go @@ -268,7 +268,7 @@ func TestServerDAO_UpdateServerTrafficLimitStatus(t *testing.T) { var dao = models.NewServerDAO() var tx *dbs.Tx - err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day") + err := dao.UpdateServerTrafficLimitStatus(tx, 23, timeutil.Format("Ymd", time.Now().AddDate(0, 0, 20)), 14, "day", "traffic") if err != nil { t.Fatal(err) }