From 747cdac7cf8d9909f01213c0f6428afbc5042b18 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Sun, 17 Jan 2021 16:48:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=8A=82=E7=82=B9=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E7=8A=B6=E6=80=81=E6=8F=90=E7=A4=BA=E5=92=8C=E4=BB=BB?= =?UTF-8?q?=E5=8A=A1=E5=88=97=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- go.mod | 5 +- go.sum | 5 - .../db/models/http_access_log_policy_dao.go | 11 +- internal/db/models/http_cache_policy_dao.go | 51 +++- .../db/models/http_firewall_policy_dao.go | 67 ++++- .../models/http_firewall_policy_dao_test.go | 33 +++ internal/db/models/http_firewall_rule_dao.go | 41 ++- .../db/models/http_firewall_rule_group_dao.go | 49 ++- .../http_firewall_rule_group_dao_test.go | 13 + .../db/models/http_firewall_rule_set_dao.go | 56 +++- .../models/http_firewall_rule_set_dao_test.go | 16 + internal/db/models/http_gzip_dao.go | 37 ++- internal/db/models/http_header_dao.go | 23 +- internal/db/models/http_header_policy_dao.go | 72 +++-- .../db/models/http_header_policy_dao_test.go | 12 + internal/db/models/http_location_dao.go | 47 ++- internal/db/models/http_page_dao.go | 38 ++- internal/db/models/http_rewrite_rule_dao.go | 37 ++- internal/db/models/http_web_dao.go | 178 ++++++++--- internal/db/models/http_web_dao_test.go | 58 +++- internal/db/models/ip_item_dao.go | 66 +++++ internal/db/models/ip_item_dao_test.go | 13 + internal/db/models/ip_list_dao.go | 56 ++++ internal/db/models/node_cluster_dao.go | 73 ++++- internal/db/models/node_dao.go | 113 ++++--- internal/db/models/node_dao_test.go | 2 +- internal/db/models/node_model_ext.go | 12 + internal/db/models/node_task_dao.go | 278 ++++++++++++++++++ internal/db/models/node_task_dao_test.go | 40 +++ internal/db/models/node_task_model.go | 32 ++ internal/db/models/node_task_model_ext.go | 1 + internal/db/models/origin_dao.go | 39 ++- internal/db/models/reverse_proxy_dao.go | 61 ++-- internal/db/models/reverse_proxy_dao_test.go | 11 + internal/db/models/server_dao.go | 201 ++++++++++--- internal/db/models/server_dao_test.go | 10 + internal/db/models/ssl_cert_dao.go | 51 +++- internal/db/models/ssl_policy_dao.go | 40 ++- internal/db/models/sys_event_types.go | 48 +-- internal/db/models/tcp_firewall_policy_dao.go | 11 +- internal/nodes/api_node.go | 1 + .../rpc/services/service_http_access_log.go | 4 +- internal/rpc/services/service_ip_item.go | 20 ++ internal/rpc/services/service_node.go | 52 ++-- internal/rpc/services/service_node_cluster.go | 68 +---- internal/rpc/services/service_node_task.go | 243 +++++++++++++++ internal/rpc/services/service_server.go | 71 ++--- internal/rpc/services/service_user.go | 22 +- internal/tasks/node_task_extractor.go | 51 ++++ 49 files changed, 1959 insertions(+), 580 deletions(-) create mode 100644 internal/db/models/node_task_dao.go create mode 100644 internal/db/models/node_task_dao_test.go create mode 100644 internal/db/models/node_task_model.go create mode 100644 internal/db/models/node_task_model_ext.go create mode 100644 internal/rpc/services/service_node_task.go create mode 100644 internal/tasks/node_task_extractor.go diff --git a/go.mod b/go.mod index 448a47e7..d09705f1 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.15 replace github.com/TeaOSLab/EdgeCommon => ../EdgeCommon +replace github.com/iwind/TeaGo => /Users/WorkSpace/TeaGo + require ( github.com/StackExchange/wmi v0.0.0-20190523213315-cbe66965904d // indirect github.com/TeaOSLab/EdgeCommon v0.0.0-00010101000000-000000000000 @@ -13,12 +15,13 @@ require ( github.com/go-sql-driver/mysql v1.5.0 github.com/go-yaml/yaml v2.1.0+incompatible github.com/golang/protobuf v1.4.2 - github.com/iwind/TeaGo v0.0.0-20210106152225-413a5aba30aa // indirect + github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e github.com/lionsoul2014/ip2region v2.2.0-release+incompatible github.com/mozillazg/go-pinyin v0.18.0 github.com/pkg/sftp v1.12.0 github.com/shirou/gopsutil v2.20.9+incompatible golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a + golang.org/x/sys v0.0.0-20200519105757-fe76b779f299 google.golang.org/grpc v1.32.0 google.golang.org/protobuf v1.25.0 gopkg.in/yaml.v2 v2.4.0 // indirect diff --git a/go.sum b/go.sum index 5fddb411..358d441c 100644 --- a/go.sum +++ b/go.sum @@ -173,11 +173,6 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/iij/doapi v0.0.0-20190504054126-0bbf12d6d7df/go.mod h1:QMZY7/J/KSQEhKWFeDesPjMj+wCHReeknARU3wqlyN4= -github.com/iwind/TeaGo v0.0.0-20200923021120-f5d76441fe9e/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= -github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea h1:ACgcrVeyHpKt8K6k92RrmPndxVq7Qx+zNXZBzRKv3+0= -github.com/iwind/TeaGo v0.0.0-20210103021650-62acfa30bcea/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= -github.com/iwind/TeaGo v0.0.0-20210106152225-413a5aba30aa h1:kGf94NRhZik+SPBbY+NgQhwIjJjqv1E1nqgLQ0BtOzw= -github.com/iwind/TeaGo v0.0.0-20210106152225-413a5aba30aa/go.mod h1:KU4mS7QNiZ7QWEuDBk1zw0/Q2LrAPZv3tycEFBsuUwc= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.3.0 h1:OS12ieG61fsCg5+qLJ+SsW9NicxNkg3b25OyT2yCeUc= github.com/jmespath/go-jmespath v0.3.0/go.mod h1:9QtRXoHjLGCJ5IBSaohpXITPlowMeeYCZ7fLUTSywik= diff --git a/internal/db/models/http_access_log_policy_dao.go b/internal/db/models/http_access_log_policy_dao.go index 4adde44b..f056c79e 100644 --- a/internal/db/models/http_access_log_policy_dao.go +++ b/internal/db/models/http_access_log_policy_dao.go @@ -37,16 +37,7 @@ func init() { // 初始化 func (this *HTTPAccessLogPolicyDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 diff --git a/internal/db/models/http_cache_policy_dao.go b/internal/db/models/http_cache_policy_dao.go index fedb1dfc..bf7b5775 100644 --- a/internal/db/models/http_cache_policy_dao.go +++ b/internal/db/models/http_cache_policy_dao.go @@ -39,16 +39,7 @@ func init() { // 初始化 func (this *HTTPCachePolicyDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -61,12 +52,15 @@ func (this *HTTPCachePolicyDAO) EnableHTTPCachePolicy(tx *dbs.Tx, id int64) erro } // 禁用条目 -func (this *HTTPCachePolicyDAO) DisableHTTPCachePolicy(tx *dbs.Tx, id int64) error { +func (this *HTTPCachePolicyDAO) DisableHTTPCachePolicy(tx *dbs.Tx, policyId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(policyId). Set("state", HTTPCachePolicyStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 查找启用中的条目 @@ -147,7 +141,10 @@ func (this *HTTPCachePolicyDAO) UpdateCachePolicy(tx *dbs.Tx, policyId int64, is op.Options = storageOptionsJSON } err := this.Save(tx, op) - return errors.Wrap(err) + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 组合配置 @@ -239,3 +236,29 @@ func (this *HTTPCachePolicyDAO) ListEnabledHTTPCachePolicies(tx *dbs.Tx, offset } return cachePolicies, nil } + +// 通知更新 +func (this *HTTPCachePolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { + webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithCachePolicyId(tx, policyId) + if err != nil { + return err + } + for _, webId := range webIds { + err := SharedHTTPWebDAO.NotifyUpdate(tx, webId) + if err != nil { + return err + } + } + + clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithCachePolicyId(tx, policyId) + if err != nil { + return err + } + for _, clusterId := range clusterIds { + err := SharedNodeClusterDAO.NotifyUpdate(tx, clusterId) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/db/models/http_firewall_policy_dao.go b/internal/db/models/http_firewall_policy_dao.go index 20287103..4031ccdb 100644 --- a/internal/db/models/http_firewall_policy_dao.go +++ b/internal/db/models/http_firewall_policy_dao.go @@ -7,6 +7,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" ) @@ -38,16 +39,7 @@ func init() { // 初始化 func (this *HTTPFirewallPolicyDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -300,3 +292,58 @@ func (this *HTTPFirewallPolicyDAO) CheckUserFirewallPolicy(tx *dbs.Tx, userId in } return nil } + +// 查找包含某个IPList的所有策略 +func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdsWithIPListId(tx *dbs.Tx, ipListId int64) ([]int64, error) { + ones, err := this.Query(tx). + ResultPk(). + State(HTTPFirewallPolicyStateEnabled). + Where("(JSON_CONTAINS(inbound, :listQuery, '$.whiteListRef') OR JSON_CONTAINS(inbound, :listQuery, '$.blackListRef') )"). + Param("listQuery", maps.Map{"isOn": true, "listId": ipListId}.AsJSON()). + FindAll() + if err != nil { + return nil, err + } + result := []int64{} + for _, one := range ones { + result = append(result, int64(one.(*HTTPFirewallPolicy).Id)) + } + return result, nil +} + +// 查找包含某个规则分组的策略ID +func (this *HTTPFirewallPolicyDAO) FindEnabledFirewallPolicyIdWithRuleGroupId(tx *dbs.Tx, ruleGroupId int64) (int64, error) { + return this.Query(tx). + ResultPk(). + State(HTTPFirewallPolicyStateEnabled). + Where("(JSON_CONTAINS(inbound, :jsonQuery, '$.groupRefs') OR JSON_CONTAINS(outbound, :jsonQuery, '$.groupRefs'))"). + Param("jsonQuery", maps.Map{"groupId": ruleGroupId}.AsJSON()). + FindInt64Col(0) +} + +// 通知更新 +func (this *HTTPFirewallPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { + webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId) + if err != nil { + return err + } + for _, webId := range webIds { + err := SharedHTTPWebDAO.NotifyUpdate(tx, webId) + if err != nil { + return err + } + } + + clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId) + if err != nil { + return err + } + for _, clusterId := range clusterIds { + err := SharedNodeClusterDAO.NotifyUpdate(tx, clusterId) + if err != nil { + return err + } + } + + return nil +} diff --git a/internal/db/models/http_firewall_policy_dao_test.go b/internal/db/models/http_firewall_policy_dao_test.go index 97c24b56..5d78bd2a 100644 --- a/internal/db/models/http_firewall_policy_dao_test.go +++ b/internal/db/models/http_firewall_policy_dao_test.go @@ -2,4 +2,37 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" + "testing" ) + +func TestHTTPFirewallPolicyDAO_FindFirewallPolicyIdsContainsIPList(t *testing.T) { + dbs.NotifyReady() + + { + policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(nil, 8) + if err != nil { + t.Fatal(err) + } + t.Log(policyIds) + } + + { + policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(nil, 18) + if err != nil { + t.Fatal(err) + } + t.Log(policyIds) + } +} + +func TestHTTPFirewallPolicyDAO_FindEnabledFirewallPolicyIdWithRuleGroupId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, 160) + if err != nil { + t.Fatal(err) + } + t.Log("policyIds:", policyIds) +} diff --git a/internal/db/models/http_firewall_rule_dao.go b/internal/db/models/http_firewall_rule_dao.go index b1b60bde..d3570b46 100644 --- a/internal/db/models/http_firewall_rule_dao.go +++ b/internal/db/models/http_firewall_rule_dao.go @@ -37,16 +37,7 @@ func init() { // 初始化 func (this *HTTPFirewallRuleDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -59,12 +50,15 @@ func (this *HTTPFirewallRuleDAO) EnableHTTPFirewallRule(tx *dbs.Tx, id int64) er } // 禁用条目 -func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(tx *dbs.Tx, id int64) error { +func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(tx *dbs.Tx, ruleId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(ruleId). Set("state", HTTPFirewallRuleStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, ruleId) } // 查找启用中的条目 @@ -154,5 +148,26 @@ func (this *HTTPFirewallRuleDAO) CreateOrUpdateRuleFromConfig(tx *dbs.Tx, ruleCo if err != nil { return 0, err } + + // 通知更新 + if ruleConfig.Id > 0 { + err := this.NotifyUpdate(tx, ruleConfig.Id) + if err != nil { + return 0, err + } + } + return types.Int64(op.Id), nil } + +// 通知更新 +func (this *HTTPFirewallRuleDAO) NotifyUpdate(tx *dbs.Tx, ruleId int64) error { + setId, err := SharedHTTPFirewallRuleSetDAO.FindEnabledRuleSetIdWithRuleId(tx, ruleId) + if err != nil { + return err + } + if setId > 0 { + return SharedHTTPFirewallRuleSetDAO.NotifyUpdate(tx, setId) + } + return nil +} diff --git a/internal/db/models/http_firewall_rule_group_dao.go b/internal/db/models/http_firewall_rule_group_dao.go index b2c9cdc0..b04afe9a 100644 --- a/internal/db/models/http_firewall_rule_group_dao.go +++ b/internal/db/models/http_firewall_rule_group_dao.go @@ -7,6 +7,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" ) @@ -38,16 +39,7 @@ func init() { // 初始化 func (this *HTTPFirewallRuleGroupDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -164,7 +156,10 @@ func (this *HTTPFirewallRuleGroupDAO) UpdateGroupIsOn(tx *dbs.Tx, groupId int64, Pk(groupId). Set("isOn", isOn). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, groupId) } // 创建分组 @@ -192,7 +187,10 @@ func (this *HTTPFirewallRuleGroupDAO) UpdateGroup(tx *dbs.Tx, groupId int64, isO op.Name = name op.Description = description err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, groupId) } // 修改分组中的规则集 @@ -204,5 +202,30 @@ func (this *HTTPFirewallRuleGroupDAO) UpdateGroupSets(tx *dbs.Tx, groupId int64, op.Id = groupId op.Sets = setsJSON err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, groupId) +} + +// 根据规则集查找规则分组 +func (this *HTTPFirewallRuleGroupDAO) FindRuleGroupIdWithRuleSetId(tx *dbs.Tx, setId int64) (int64, error) { + return this.Query(tx). + State(HTTPFirewallRuleStateEnabled). + Where("JSON_CONTAINS(sets, :jsonQuery)"). + Param("jsonQuery", maps.Map{"setId": setId}.AsJSON()). + ResultPk(). + FindInt64Col(0) +} + +// 通知更新 +func (this *HTTPFirewallRuleGroupDAO) NotifyUpdate(tx *dbs.Tx, groupId int64) error { + policyId, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdWithRuleGroupId(tx, groupId) + if err != nil { + return err + } + if policyId > 0 { + return SharedHTTPFirewallPolicyDAO.NotifyUpdate(tx, policyId) + } + return nil } diff --git a/internal/db/models/http_firewall_rule_group_dao_test.go b/internal/db/models/http_firewall_rule_group_dao_test.go index 97c24b56..16e8dfe7 100644 --- a/internal/db/models/http_firewall_rule_group_dao_test.go +++ b/internal/db/models/http_firewall_rule_group_dao_test.go @@ -2,4 +2,17 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" + "testing" ) + +func TestHTTPFirewallRuleGroupDAO_FindRuleGroupIdWithRuleSetId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, 22) + if err != nil { + t.Fatal(err) + } + t.Log("groupId:", groupId) +} \ No newline at end of file diff --git a/internal/db/models/http_firewall_rule_set_dao.go b/internal/db/models/http_firewall_rule_set_dao.go index d6589796..f00bf663 100644 --- a/internal/db/models/http_firewall_rule_set_dao.go +++ b/internal/db/models/http_firewall_rule_set_dao.go @@ -39,16 +39,7 @@ func init() { // 初始化 func (this *HTTPFirewallRuleSetDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -61,12 +52,15 @@ func (this *HTTPFirewallRuleSetDAO) EnableHTTPFirewallRuleSet(tx *dbs.Tx, id int } // 禁用条目 -func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(tx *dbs.Tx, id int64) error { +func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(tx *dbs.Tx, ruleSetId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(ruleSetId). Set("state", HTTPFirewallRuleSetStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, ruleSetId) } // 查找启用中的条目 @@ -180,6 +174,15 @@ func (this *HTTPFirewallRuleSetDAO) CreateOrUpdateSetFromConfig(tx *dbs.Tx, setC if err != nil { return 0, err } + + // 通知更新 + if setConfig.Id > 0 { + err := this.NotifyUpdate(tx, setConfig.Id) + if err != nil { + return 0, err + } + } + return types.Int64(op.Id), nil } @@ -192,5 +195,30 @@ func (this *HTTPFirewallRuleSetDAO) UpdateRuleSetIsOn(tx *dbs.Tx, ruleSetId int6 Pk(ruleSetId). Set("isOn", isOn). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, ruleSetId) +} + +// 根据规则查找规则集 +func (this *HTTPFirewallRuleSetDAO) FindEnabledRuleSetIdWithRuleId(tx *dbs.Tx, ruleId int64) (int64, error) { + return this.Query(tx). + State(HTTPFirewallRuleStateEnabled). + Where("JSON_CONTAINS(rules, :jsonQuery)"). + Param("jsonQuery", maps.Map{"ruleId": ruleId}.AsJSON()). + ResultPk(). + FindInt64Col(0) +} + +// 通知更新 +func (this *HTTPFirewallRuleSetDAO) NotifyUpdate(tx *dbs.Tx, setId int64) error { + groupId, err := SharedHTTPFirewallRuleGroupDAO.FindRuleGroupIdWithRuleSetId(tx, setId) + if err != nil { + return err + } + if groupId > 0 { + return SharedHTTPFirewallRuleGroupDAO.NotifyUpdate(tx, groupId) + } + return nil } diff --git a/internal/db/models/http_firewall_rule_set_dao_test.go b/internal/db/models/http_firewall_rule_set_dao_test.go index 97c24b56..f2a7dee7 100644 --- a/internal/db/models/http_firewall_rule_set_dao_test.go +++ b/internal/db/models/http_firewall_rule_set_dao_test.go @@ -2,4 +2,20 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" + "testing" + "time" ) + +func TestHTTPFirewallRuleSetDAO_FindRuleSetIdWithRuleId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + before := time.Now() + setId, err := SharedHTTPFirewallRuleSetDAO.FindEnabledRuleSetIdWithRuleId(tx, 20) + if err != nil { + t.Fatal(err) + } + t.Log("setId:", setId) + t.Log(time.Since(before).Seconds()*1000, "ms") +} diff --git a/internal/db/models/http_gzip_dao.go b/internal/db/models/http_gzip_dao.go index 07a9c1e8..f2a68c8a 100644 --- a/internal/db/models/http_gzip_dao.go +++ b/internal/db/models/http_gzip_dao.go @@ -39,16 +39,7 @@ func init() { // 初始化 func (this *HTTPGzipDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -61,12 +52,15 @@ func (this *HTTPGzipDAO) EnableHTTPGzip(tx *dbs.Tx, id int64) error { } // 禁用条目 -func (this *HTTPGzipDAO) DisableHTTPGzip(tx *dbs.Tx, id int64) error { +func (this *HTTPGzipDAO) DisableHTTPGzip(tx *dbs.Tx, gzipId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(gzipId). Set("state", HTTPGzipStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, gzipId) } // 查找启用中的条目 @@ -165,5 +159,20 @@ func (this *HTTPGzipDAO) UpdateGzip(tx *dbs.Tx, gzipId int64, level int, minLeng op.Conds = JSONBytes(condsJSON) } err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, gzipId) +} + +// 通知更新 +func (this *HTTPGzipDAO) NotifyUpdate(tx *dbs.Tx, gzipId int64) error { + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithGzipId(tx, gzipId) + if err != nil { + return err + } + if webId > 0 { + return SharedHTTPWebDAO.NotifyUpdate(tx, webId) + } + return nil } diff --git a/internal/db/models/http_header_dao.go b/internal/db/models/http_header_dao.go index c6deb92a..dd24ca45 100644 --- a/internal/db/models/http_header_dao.go +++ b/internal/db/models/http_header_dao.go @@ -38,16 +38,7 @@ func init() { // 初始化 func (this *HTTPHeaderDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -156,3 +147,15 @@ func (this *HTTPHeaderDAO) ComposeHeaderConfig(tx *dbs.Tx, headerId int64) (*sha return config, nil } + +// 通知更新 +func (this *HTTPHeaderDAO) NotifyUpdate(tx *dbs.Tx, headerId int64) error { + policyId, err := SharedHTTPHeaderPolicyDAO.FindHeaderPolicyIdWithHeaderId(tx, headerId) + if err != nil { + return err + } + if policyId > 0 { + return SharedHTTPHeaderPolicyDAO.NotifyUpdate(tx, policyId) + } + return nil +} diff --git a/internal/db/models/http_header_policy_dao.go b/internal/db/models/http_header_policy_dao.go index c4f95c6b..8b24e940 100644 --- a/internal/db/models/http_header_policy_dao.go +++ b/internal/db/models/http_header_policy_dao.go @@ -7,6 +7,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" ) @@ -38,16 +39,7 @@ func init() { // 初始化 func (this *HTTPHeaderPolicyDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -60,12 +52,15 @@ func (this *HTTPHeaderPolicyDAO) EnableHTTPHeaderPolicy(tx *dbs.Tx, id int64) er } // 禁用条目 -func (this *HTTPHeaderPolicyDAO) DisableHTTPHeaderPolicy(tx *dbs.Tx, id int64) error { +func (this *HTTPHeaderPolicyDAO) DisableHTTPHeaderPolicy(tx *dbs.Tx, policyId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(policyId). Set("state", HTTPHeaderPolicyStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 查找启用中的条目 @@ -102,8 +97,10 @@ func (this *HTTPHeaderPolicyDAO) UpdateAddingHeaders(tx *dbs.Tx, policyId int64, op.Id = policyId op.AddHeaders = headersJSON err := this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 修改SetHeaders @@ -116,8 +113,10 @@ func (this *HTTPHeaderPolicyDAO) UpdateSettingHeaders(tx *dbs.Tx, policyId int64 op.Id = policyId op.SetHeaders = headersJSON err := this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 修改ReplaceHeaders @@ -130,8 +129,10 @@ func (this *HTTPHeaderPolicyDAO) UpdateReplacingHeaders(tx *dbs.Tx, policyId int op.Id = policyId op.ReplaceHeaders = headersJSON err := this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 修改AddTrailers @@ -144,8 +145,10 @@ func (this *HTTPHeaderPolicyDAO) UpdateAddingTrailers(tx *dbs.Tx, policyId int64 op.Id = policyId op.AddTrailers = headersJSON err := this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 修改DeleteHeaders @@ -163,8 +166,10 @@ func (this *HTTPHeaderPolicyDAO) UpdateDeletingHeaders(tx *dbs.Tx, policyId int6 op.Id = policyId op.DeleteHeaders = string(namesJSON) err = this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 组合配置 @@ -286,3 +291,24 @@ func (this *HTTPHeaderPolicyDAO) ComposeHeaderPolicyConfig(tx *dbs.Tx, headerPol return config, nil } + +// 查找Header所在Policy +func (this *HTTPHeaderPolicyDAO) FindHeaderPolicyIdWithHeaderId(tx *dbs.Tx, headerId int64) (int64, error) { + return this.Query(tx). + Where("(JSON_CONTAINS(addHeaders, :jsonQuery) OR JSON_CONTAINS(addTrailers, :jsonQuery) OR JSON_CONTAINS(setHeaders, :jsonQuery) OR JSON_CONTAINS(replaceHeaders, :jsonQuery))"). + Param("jsonQuery", maps.Map{"id": headerId}.AsJSON()). + ResultPk(). + FindInt64Col(0) +} + +// 通知更新 +func (this *HTTPHeaderPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithHeaderPolicyId(tx, policyId) + if err != nil { + return err + } + if webId > 0 { + return SharedHTTPWebDAO.NotifyUpdate(tx, webId) + } + return nil +} diff --git a/internal/db/models/http_header_policy_dao_test.go b/internal/db/models/http_header_policy_dao_test.go index 97c24b56..ada06162 100644 --- a/internal/db/models/http_header_policy_dao_test.go +++ b/internal/db/models/http_header_policy_dao_test.go @@ -2,4 +2,16 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" + "testing" ) + +func TestHTTPHeaderPolicyDAO_FindHeaderPolicyIdWithHeaderId(t *testing.T) { + dbs.NotifyReady() + var tx *dbs.Tx + policyId, err := SharedHTTPHeaderPolicyDAO.FindHeaderPolicyIdWithHeaderId(tx, 15) + if err != nil { + t.Fatal(err) + } + t.Log("policyId:", policyId) +} diff --git a/internal/db/models/http_location_dao.go b/internal/db/models/http_location_dao.go index 71434e7d..69057640 100644 --- a/internal/db/models/http_location_dao.go +++ b/internal/db/models/http_location_dao.go @@ -38,16 +38,7 @@ func init() { // 初始化 func (this *HTTPLocationDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -60,12 +51,15 @@ func (this *HTTPLocationDAO) EnableHTTPLocation(tx *dbs.Tx, id int64) error { } // 禁用条目 -func (this *HTTPLocationDAO) DisableHTTPLocation(tx *dbs.Tx, id int64) error { +func (this *HTTPLocationDAO) DisableHTTPLocation(tx *dbs.Tx, locationId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(locationId). Set("state", HTTPLocationStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, locationId) } // 查找启用中的条目 @@ -118,7 +112,10 @@ func (this *HTTPLocationDAO) UpdateLocation(tx *dbs.Tx, locationId int64, name s op.IsOn = isOn op.IsBreak = isBreak err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, locationId) } // 组合配置 @@ -198,7 +195,10 @@ func (this *HTTPLocationDAO) UpdateLocationReverseProxy(tx *dbs.Tx, locationId i op.Id = locationId op.ReverseProxy = JSONBytes(reverseProxyJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, locationId) } // 查找WebId @@ -219,7 +219,10 @@ func (this *HTTPLocationDAO) UpdateLocationWeb(tx *dbs.Tx, locationId int64, web op.Id = locationId op.WebId = webId err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, locationId) } // 转换引用为配置 @@ -250,3 +253,15 @@ func (this *HTTPLocationDAO) FindEnabledLocationIdWithWebId(tx *dbs.Tx, webId in ResultPk(). FindInt64Col(0) } + +// 通知更新 +func (this *HTTPLocationDAO) NotifyUpdate(tx *dbs.Tx, locationId int64) error { + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithLocationId(tx, locationId) + if err != nil { + return err + } + if webId > 0 { + return SharedHTTPWebDAO.NotifyUpdate(tx, webId) + } + return nil +} diff --git a/internal/db/models/http_page_dao.go b/internal/db/models/http_page_dao.go index c1603400..bef3e242 100644 --- a/internal/db/models/http_page_dao.go +++ b/internal/db/models/http_page_dao.go @@ -38,25 +38,19 @@ func init() { // 初始化 func (this *HTTPPageDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 -func (this *HTTPPageDAO) EnableHTTPPage(tx *dbs.Tx, id int64) error { +func (this *HTTPPageDAO) EnableHTTPPage(tx *dbs.Tx, pageId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(pageId). Set("state", HTTPPageStateEnabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, pageId) } // 禁用条目 @@ -126,8 +120,10 @@ func (this *HTTPPageDAO) UpdatePage(tx *dbs.Tx, pageId int64, statusList []strin op.Url = url op.NewStatus = newStatus err = this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, pageId) } // 组合配置 @@ -160,3 +156,15 @@ func (this *HTTPPageDAO) ComposePageConfig(tx *dbs.Tx, pageId int64) (*servercon return config, nil } + +// 通知更新 +func (this *HTTPPageDAO) NotifyUpdate(tx *dbs.Tx, pageId int64) error { + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithPageId(tx, pageId) + if err != nil { + return err + } + if webId > 0 { + return SharedHTTPWebDAO.NotifyUpdate(tx, webId) + } + return nil +} diff --git a/internal/db/models/http_rewrite_rule_dao.go b/internal/db/models/http_rewrite_rule_dao.go index 31d39a0a..ced47f47 100644 --- a/internal/db/models/http_rewrite_rule_dao.go +++ b/internal/db/models/http_rewrite_rule_dao.go @@ -37,16 +37,7 @@ func init() { // 初始化 func (this *HTTPRewriteRuleDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -59,12 +50,15 @@ func (this *HTTPRewriteRuleDAO) EnableHTTPRewriteRule(tx *dbs.Tx, id int64) erro } // 禁用条目 -func (this *HTTPRewriteRuleDAO) DisableHTTPRewriteRule(tx *dbs.Tx, id int64) error { +func (this *HTTPRewriteRuleDAO) DisableHTTPRewriteRule(tx *dbs.Tx, rewriteRuleId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(rewriteRuleId). Set("state", HTTPRewriteRuleStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, rewriteRuleId) } // 查找启用中的条目 @@ -135,5 +129,20 @@ func (this *HTTPRewriteRuleDAO) UpdateRewriteRule(tx *dbs.Tx, rewriteRuleId int6 op.WithQuery = withQuery op.ProxyHost = proxyHost err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, rewriteRuleId) +} + +// 通知更新 +func (this *HTTPRewriteRuleDAO) NotifyUpdate(tx *dbs.Tx, rewriteRuleId int64) error { + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithRewriteRuleId(tx, rewriteRuleId) + if err != nil { + return err + } + if webId > 0 { + return SharedHTTPWebDAO.NotifyUpdate(tx, webId) + } + return nil } diff --git a/internal/db/models/http_web_dao.go b/internal/db/models/http_web_dao.go index 75809ee5..0c7cec6a 100644 --- a/internal/db/models/http_web_dao.go +++ b/internal/db/models/http_web_dao.go @@ -10,8 +10,8 @@ import ( "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/lists" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" - "strconv" ) const ( @@ -41,16 +41,7 @@ func init() { } func (this *HTTPWebDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -360,7 +351,11 @@ func (this *HTTPWebDAO) UpdateWeb(tx *dbs.Tx, webId int64, rootJSON []byte) erro op.Id = webId op.Root = JSONBytes(rootJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 修改Gzip配置 @@ -372,7 +367,11 @@ func (this *HTTPWebDAO) UpdateWebGzip(tx *dbs.Tx, webId int64, gzipJSON []byte) op.Id = webId op.Gzip = JSONBytes(gzipJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 修改字符编码 @@ -384,7 +383,11 @@ func (this *HTTPWebDAO) UpdateWebCharset(tx *dbs.Tx, webId int64, charsetJSON [] op.Id = webId op.Charset = JSONBytes(charsetJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改请求Header策略 @@ -396,7 +399,11 @@ func (this *HTTPWebDAO) UpdateWebRequestHeaderPolicy(tx *dbs.Tx, webId int64, he op.Id = webId op.RequestHeader = JSONBytes(headerPolicyJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改响应Header策略 @@ -408,7 +415,11 @@ func (this *HTTPWebDAO) UpdateWebResponseHeaderPolicy(tx *dbs.Tx, webId int64, h op.Id = webId op.ResponseHeader = JSONBytes(headerPolicyJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改特殊页面配置 @@ -420,7 +431,11 @@ func (this *HTTPWebDAO) UpdateWebPages(tx *dbs.Tx, webId int64, pagesJSON []byte op.Id = webId op.Pages = JSONBytes(pagesJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改Shutdown配置 @@ -432,7 +447,11 @@ func (this *HTTPWebDAO) UpdateWebShutdown(tx *dbs.Tx, webId int64, shutdownJSON op.Id = webId op.Shutdown = JSONBytes(shutdownJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改访问日志策略 @@ -444,7 +463,11 @@ func (this *HTTPWebDAO) UpdateWebAccessLogConfig(tx *dbs.Tx, webId int64, access op.Id = webId op.AccessLog = JSONBytes(accessLogJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改统计配置 @@ -456,7 +479,11 @@ func (this *HTTPWebDAO) UpdateWebStat(tx *dbs.Tx, webId int64, statJSON []byte) op.Id = webId op.Stat = JSONBytes(statJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改缓存配置 @@ -468,7 +495,11 @@ func (this *HTTPWebDAO) UpdateWebCache(tx *dbs.Tx, webId int64, cacheJSON []byte op.Id = webId op.Cache = JSONBytes(cacheJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改防火墙配置 @@ -480,7 +511,11 @@ func (this *HTTPWebDAO) UpdateWebFirewall(tx *dbs.Tx, webId int64, firewallJSON op.Id = webId op.Firewall = JSONBytes(firewallJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改路径规则配置 @@ -492,7 +527,11 @@ func (this *HTTPWebDAO) UpdateWebLocations(tx *dbs.Tx, webId int64, locationsJSO op.Id = webId op.Locations = JSONBytes(locationsJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 更改跳转到HTTPS设置 @@ -504,7 +543,11 @@ func (this *HTTPWebDAO) UpdateWebRedirectToHTTPS(tx *dbs.Tx, webId int64, redire op.Id = webId op.RedirectToHttps = JSONBytes(redirectToHTTPSJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 修改Websocket设置 @@ -516,7 +559,11 @@ func (this *HTTPWebDAO) UpdateWebsocket(tx *dbs.Tx, webId int64, websocketJSON [ op.Id = webId op.Websocket = JSONBytes(websocketJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 修改重写规则设置 @@ -528,7 +575,11 @@ func (this *HTTPWebDAO) UpdateWebRewriteRules(tx *dbs.Tx, webId int64, rewriteRu op.Id = webId op.RewriteRules = JSONBytes(rewriteRulesJSON) err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 根据缓存策略ID查找所有的WebId @@ -536,8 +587,8 @@ func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(tx *dbs.Tx, cachePolicyId ones, err := this.Query(tx). State(HTTPWebStateEnabled). ResultPk(). - Where(`JSON_CONTAINS(cache, '{"cachePolicyId": ` + strconv.FormatInt(cachePolicyId, 10) + ` }', '$.cacheRefs')`). - Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + Where(`JSON_CONTAINS(cache, :jsonQuery, '$.cacheRefs')`). + Param("jsonQuery", maps.Map{"cachePolicyId": cachePolicyId}.AsJSON()). FindAll() if err != nil { return nil, err @@ -580,8 +631,11 @@ func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(tx *dbs.Tx, firewa ones, err := this.Query(tx). State(HTTPWebStateEnabled). ResultPk(). - Where(`JSON_CONTAINS(firewall, '{"isOn": true, "firewallPolicyId": ` + strconv.FormatInt(firewallPolicyId, 10) + ` }')`). - Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + Where(`JSON_CONTAINS(firewall, :jsonQuery)`). + Param("jsonQuery", maps.Map{ + // 这里不加入isOn的判断,无论是否开启我们都同步 + "firewallPolicyId": firewallPolicyId, + }.AsJSON()). FindAll() if err != nil { return nil, err @@ -624,8 +678,48 @@ func (this *HTTPWebDAO) FindEnabledWebIdWithLocationId(tx *dbs.Tx, locationId in return this.Query(tx). State(HTTPWebStateEnabled). ResultPk(). - Where(`JSON_CONTAINS(locations, '{"locationId": ` + strconv.FormatInt(locationId, 10) + ` }')`). - Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + Where("JSON_CONTAINS(locations, :jsonQuery)"). + Param("jsonQuery", maps.Map{"locationId": locationId}.AsJSON()). + FindInt64Col(0) +} + +// 查找包含某个重写规则的Web +func (this *HTTPWebDAO) FindEnabledWebIdWithRewriteRuleId(tx *dbs.Tx, rewriteRuleId int64) (webId int64, err error) { + return this.Query(tx). + State(HTTPWebStateEnabled). + ResultPk(). + Where("JSON_CONTAINS(rewriteRules, :jsonQuery)"). + Param("jsonQuery", maps.Map{"rewriteRuleId": rewriteRuleId}.AsJSON()). + FindInt64Col(0) +} + +// 查找包含某个页面的Web +func (this *HTTPWebDAO) FindEnabledWebIdWithPageId(tx *dbs.Tx, pageId int64) (webId int64, err error) { + return this.Query(tx). + State(HTTPWebStateEnabled). + ResultPk(). + Where("JSON_CONTAINS(pages, :jsonQuery)"). + Param("jsonQuery", maps.Map{"id": pageId}.AsJSON()). + FindInt64Col(0) +} + +// 查找包含某个Header的Web +func (this *HTTPWebDAO) FindEnabledWebIdWithHeaderPolicyId(tx *dbs.Tx, headerPolicyId int64) (webId int64, err error) { + return this.Query(tx). + State(HTTPWebStateEnabled). + ResultPk(). + Where("(JSON_CONTAINS(requestHeader, :jsonQuery) OR JSON_CONTAINS(responseHeader, :jsonQuery))"). + Param("jsonQuery", maps.Map{"headerPolicyId": headerPolicyId}.AsJSON()). + FindInt64Col(0) +} + +// 查找包含某个Gzip配置的Web +func (this *HTTPWebDAO) FindEnabledWebIdWithGzipId(tx *dbs.Tx, gzipId int64) (webId int64, err error) { + return this.Query(tx). + State(HTTPWebStateEnabled). + ResultPk(). + Where("JSON_CONTAINS(gzip, :jsonQuery)"). + Param("jsonQuery", maps.Map{"gzipId": gzipId}.AsJSON()). FindInt64Col(0) } @@ -671,7 +765,7 @@ func (this *HTTPWebDAO) CheckUserWeb(tx *dbs.Tx, userId int64, webId int64) erro if serverId == 0 { return ErrNotFound } - return SharedServerDAO.CheckUserServer(tx, serverId, userId) + return SharedServerDAO.CheckUserServer(tx, userId, serverId) } // 设置主机跳转 @@ -690,7 +784,11 @@ func (this *HTTPWebDAO) UpdateWebHostRedirects(tx *dbs.Tx, webId int64, hostRedi Pk(webId). Set("hostRedirects", hostRedirectsJSON). Update() - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, webId) } // 查找主机跳转 @@ -704,3 +802,15 @@ func (this *HTTPWebDAO) FindWebHostRedirects(tx *dbs.Tx, webId int64) ([]byte, e } return []byte(col), nil } + +// 通知更新 +func (this *HTTPWebDAO) NotifyUpdate(tx *dbs.Tx, webId int64) error { + serverId, err := this.FindWebServerId(tx, webId) + if err != nil { + return err + } + if serverId == 0 { + return nil + } + return SharedServerDAO.NotifyUpdate(tx, serverId) +} diff --git a/internal/db/models/http_web_dao_test.go b/internal/db/models/http_web_dao_test.go index d6863e00..313769ef 100644 --- a/internal/db/models/http_web_dao_test.go +++ b/internal/db/models/http_web_dao_test.go @@ -43,7 +43,6 @@ func TestHTTPWebDAO_FindAllWebIdsWithHTTPFirewallPolicyId(t *testing.T) { t.Log("count:", count) } - func TestHTTPWebDAO_FindWebServerId(t *testing.T) { dbs.NotifyReady() @@ -74,4 +73,59 @@ func TestHTTPWebDAO_FindWebServerId(t *testing.T) { } t.Log("serverId:", serverId) } -} \ No newline at end of file +} + +func TestHTTPWebDAO_FindEnabledWebIdWithLocationId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithLocationId(tx, 17) + if err != nil { + t.Fatal(err) + } + t.Log("webId:", webId) +} + +func TestHTTPWebDAO_FindEnabledWebIdWithRewriteRuleId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithRewriteRuleId(tx, 13) + if err != nil { + t.Fatal(err) + } + t.Log("webId:", webId) +} + +func TestHTTPWebDAO_FindEnabledWebIdWithPageId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithPageId(tx, 15) + if err != nil { + t.Fatal(err) + } + t.Log("webId:", webId) +} + +func TestHTTPWebDAO_FindEnabledWebIdWithHeaderPolicyId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithHeaderPolicyId(tx, 52) + if err != nil { + t.Fatal(err) + } + t.Log("webId:", webId) +} + +func TestHTTPWebDAO_FindEnabledWebIdWithGzip(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + webId, err := SharedHTTPWebDAO.FindEnabledWebIdWithGzipId(tx, 9) + if err != nil { + t.Fatal(err) + } + t.Log("webId:", webId) +} diff --git a/internal/db/models/ip_item_dao.go b/internal/db/models/ip_item_dao.go index 84d6a431..5ade77e7 100644 --- a/internal/db/models/ip_item_dao.go +++ b/internal/db/models/ip_item_dao.go @@ -5,6 +5,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" "time" ) @@ -174,3 +175,68 @@ func (this *IPItemDAO) FindItemListId(tx *dbs.Tx, itemId int64) (int64, error) { Result("listId"). FindInt64Col(0) } + +// 通知更新 +func (this *IPItemDAO) NotifyClustersUpdate(tx *dbs.Tx, itemId int64, taskType NodeTaskType) error { + // 获取ListId + listId, err := this.FindItemListId(tx, itemId) + if err != nil { + return err + } + + if listId == 0 { + return nil + } + + httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) + if err != nil { + return err + } + resultClusterIds := []int64{} + for _, policyId := range httpFirewallPolicyIds { + // 集群 + clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId) + if err != nil { + return err + } + for _, clusterId := range clusterIds { + if !lists.ContainsInt64(resultClusterIds, clusterId) { + resultClusterIds = append(resultClusterIds, clusterId) + } + } + + // 服务 + webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId) + if err != nil { + return err + } + if len(webIds) > 0 { + for _, webId := range webIds { + serverId, err := SharedServerDAO.FindEnabledServerIdWithWebId(tx, webId) + if err != nil { + return err + } + if serverId > 0 { + clusterId, err := SharedServerDAO.FindServerClusterId(tx, serverId) + if err != nil { + return err + } + if !lists.ContainsInt64(resultClusterIds, clusterId) { + resultClusterIds = append(resultClusterIds, clusterId) + } + } + } + } + } + + if len(resultClusterIds) > 0 { + for _, clusterId := range resultClusterIds { + err = SharedNodeTaskDAO.CreateClusterTask(tx, clusterId, taskType) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/db/models/ip_item_dao_test.go b/internal/db/models/ip_item_dao_test.go index 97c24b56..4f23b690 100644 --- a/internal/db/models/ip_item_dao_test.go +++ b/internal/db/models/ip_item_dao_test.go @@ -2,4 +2,17 @@ package models import ( _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" + "testing" ) + +func TestIPItemDAO_NotifyClustersUpdate(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + err := SharedIPItemDAO.NotifyClustersUpdate(tx, 28, NodeTaskTypeIPItemChanged) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index daaf1758..9e12c78b 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -7,6 +7,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/types" ) @@ -144,3 +145,58 @@ func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) e } return ErrNotFound } + +// 通知更新 +func (this *IPListDAO) NotifyUpdate(tx *dbs.Tx, listId int64, taskType NodeTaskType) error { + httpFirewallPolicyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) + if err != nil { + return err + } + resultClusterIds := []int64{} + for _, policyId := range httpFirewallPolicyIds { + // 集群 + clusterIds, err := SharedNodeClusterDAO.FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx, policyId) + if err != nil { + return err + } + for _, clusterId := range clusterIds { + if !lists.ContainsInt64(resultClusterIds, clusterId) { + resultClusterIds = append(resultClusterIds, clusterId) + } + } + + // 服务 + webIds, err := SharedHTTPWebDAO.FindAllWebIdsWithHTTPFirewallPolicyId(tx, policyId) + if err != nil { + return err + } + if len(webIds) > 0 { + for _, webId := range webIds { + serverId, err := SharedServerDAO.FindEnabledServerIdWithWebId(tx, webId) + if err != nil { + return err + } + if serverId > 0 { + clusterId, err := SharedServerDAO.FindServerClusterId(tx, serverId) + if err != nil { + return err + } + if !lists.ContainsInt64(resultClusterIds, clusterId) { + resultClusterIds = append(resultClusterIds, clusterId) + } + } + } + } + } + + if len(resultClusterIds) > 0 { + for _, clusterId := range resultClusterIds { + err = SharedNodeTaskDAO.CreateClusterTask(tx, clusterId, taskType) + if err != nil { + return err + } + } + } + + return nil +} diff --git a/internal/db/models/node_cluster_dao.go b/internal/db/models/node_cluster_dao.go index 20bd3acc..ea9a5b22 100644 --- a/internal/db/models/node_cluster_dao.go +++ b/internal/db/models/node_cluster_dao.go @@ -102,6 +102,21 @@ func (this *NodeClusterDAO) FindAllEnableClusters(tx *dbs.Tx) (result []*NodeClu return } +// 查找所有可用的集群Ids +func (this *NodeClusterDAO) FindAllEnableClusterIds(tx *dbs.Tx) (result []int64, err error) { + ones, err := this.Query(tx). + State(NodeClusterStateEnabled). + ResultPk(). + FindAll() + if err != nil { + return nil, err + } + for _, one := range ones { + result = append(result, int64(one.(*NodeCluster).Id)) + } + return +} + // 创建集群 func (this *NodeClusterDAO) CreateCluster(tx *dbs.Tx, adminId int64, name string, grantId int64, installDir string, dnsDomainId int64, dnsName string, cachePolicyId int64, httpFirewallPolicyId int64, systemServices map[string]maps.Map) (clusterId int64, err error) { uniqueId, err := this.genUniqueId(tx) @@ -288,7 +303,10 @@ func (this *NodeClusterDAO) UpdateClusterHealthCheck(tx *dbs.Tx, clusterId int64 op.Id = clusterId op.HealthCheck = healthCheckJSON err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, clusterId) } // 计算使用某个认证的集群数量 @@ -432,7 +450,10 @@ func (this *NodeClusterDAO) UpdateClusterDNS(tx *dbs.Tx, clusterId int64, dnsNam op.Dns = dnsJSON err = this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, clusterId) } // 检查集群的DNS问题 @@ -585,7 +606,10 @@ func (this *NodeClusterDAO) UpdateClusterTOA(tx *dbs.Tx, clusterId int64, toaJSO op.Id = clusterId op.Toa = toaJSON err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, clusterId) } // 计算使用某个缓存策略的集群数量 @@ -626,6 +650,32 @@ func (this *NodeClusterDAO) FindAllEnabledNodeClustersWithHTTPFirewallPolicyId(t return } +// 查找使用WAF策略的所有集群Ids +func (this *NodeClusterDAO) FindAllEnabledNodeClusterIdsWithHTTPFirewallPolicyId(tx *dbs.Tx, httpFirewallPolicyId int64) (result []int64, err error) { + ones, err := this.Query(tx). + State(NodeClusterStateEnabled). + Attr("httpFirewallPolicyId", httpFirewallPolicyId). + ResultPk(). + FindAll() + for _, one := range ones { + result = append(result, int64(one.(*NodeCluster).Id)) + } + return +} + +// 查找使用缓存策略的所有集群Ids +func (this *NodeClusterDAO) FindAllEnabledNodeClusterIdsWithCachePolicyId(tx *dbs.Tx, cachePolicyId int64) (result []int64, err error) { + ones, err := this.Query(tx). + State(NodeClusterStateEnabled). + Attr("cachePolicyId", cachePolicyId). + ResultPk(). + FindAll() + for _, one := range ones { + result = append(result, int64(one.(*NodeCluster).Id)) + } + return +} + // 获取集群的WAF策略ID func (this *NodeClusterDAO) FindClusterHTTPFirewallPolicyId(tx *dbs.Tx, clusterId int64) (int64, error) { return this.Query(tx). @@ -640,7 +690,10 @@ func (this *NodeClusterDAO) UpdateNodeClusterHTTPCachePolicyId(tx *dbs.Tx, clust Pk(clusterId). Set("cachePolicyId", httpCachePolicyId). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, clusterId) } // 获取集群的缓存策略ID @@ -657,7 +710,10 @@ func (this *NodeClusterDAO) UpdateNodeClusterHTTPFirewallPolicyId(tx *dbs.Tx, cl Pk(clusterId). Set("httpFirewallPolicyId", httpFirewallPolicyId). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, clusterId) } // 修改集群的系统服务设置 @@ -696,7 +752,7 @@ func (this *NodeClusterDAO) UpdateNodeClusterSystemService(tx *dbs.Tx, clusterId if err != nil { return err } - return nil + return this.NotifyUpdate(tx, clusterId) } // 查找集群的系统服务设置 @@ -759,3 +815,8 @@ func (this *NodeClusterDAO) genUniqueId(tx *dbs.Tx) (string, error) { return uniqueId, nil } } + +// 通知更新 +func (this *NodeClusterDAO) NotifyUpdate(tx *dbs.Tx, clusterId int64) error { + return SharedNodeTaskDAO.CreateClusterTask(tx, clusterId, NodeTaskTypeConfigChanged) +} diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index 0bf84861..ef2e2f8b 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -3,6 +3,7 @@ package models import ( "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/errors" + "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" "github.com/TeaOSLab/EdgeCommon/pkg/configutils" "github.com/TeaOSLab/EdgeCommon/pkg/nodeconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" @@ -13,6 +14,7 @@ import ( "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" "strconv" + "strings" ) const ( @@ -50,12 +52,16 @@ func (this *NodeDAO) EnableNode(tx *dbs.Tx, id uint32) (rowsAffected int64, err } // 禁用条目 -func (this *NodeDAO) DisableNode(tx *dbs.Tx, id int64) (err error) { +func (this *NodeDAO) DisableNode(tx *dbs.Tx, nodeId int64) (err error) { _, err = this.Query(tx). - Pk(id). + Pk(nodeId). Set("state", NodeStateDisabled). Update() - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, nodeId) } // 查找启用中的条目 @@ -71,7 +77,7 @@ func (this *NodeDAO) FindEnabledNode(tx *dbs.Tx, id int64) (*Node, error) { } // 根据主键查找名称 -func (this *NodeDAO) FindNodeName(tx *dbs.Tx, id uint32) (string, error) { +func (this *NodeDAO) FindNodeName(tx *dbs.Tx, id int64) (string, error) { name, err := this.Query(tx). Pk(id). Result("name"). @@ -127,59 +133,11 @@ func (this *NodeDAO) UpdateNode(tx *dbs.Tx, nodeId int64, name string, clusterId op.MaxCPU = maxCPU op.IsOn = isOn err := this.Save(tx, op) - return err -} - -// 更新节点版本 -func (this *NodeDAO) UpdateNodeLatestVersion(tx *dbs.Tx, nodeId int64) error { - if nodeId <= 0 { - return errors.New("invalid nodeId") - } - op := NewNodeOperator() - op.Id = nodeId - op.LatestVersion = dbs.SQL("latestVersion+1") - err := this.Save(tx, op) - return err -} - -// 批量更新节点版本 -func (this *NodeDAO) IncreaseAllNodesLatestVersionMatch(tx *dbs.Tx, clusterId int64) error { - _, err := this.Query(tx). - Attr("clusterId", clusterId). - Set("latestVersion", dbs.SQL("latestVersion+1")). - Update() - - return err -} - -// 同步集群中的节点版本 -func (this *NodeDAO) SyncNodeVersionsWithCluster(tx *dbs.Tx, clusterId int64) error { - if clusterId <= 0 { - return errors.New("invalid cluster") - } - _, err := this.Query(tx). - Attr("clusterId", clusterId). - Set("version", dbs.SQL("latestVersion")). - Update() - return err -} - -// 取得有变更的集群 -func (this *NodeDAO) FindChangedClusterIds(tx *dbs.Tx) ([]int64, error) { - ones, _, err := this.Query(tx). - State(NodeStateEnabled). - Gt("latestVersion", 0). - Where("version!=latestVersion"). - Result("DISTINCT(clusterId) AS clusterId"). - FindOnes() if err != nil { - return nil, err + return err } - result := []int64{} - for _, one := range ones { - result = append(result, one.GetInt64("clusterId")) - } - return result, nil + + return this.NotifyUpdate(tx, nodeId) } // 计算所有节点数量 @@ -282,12 +240,17 @@ func (this *NodeDAO) FindNodeClusterId(tx *dbs.Tx, nodeId int64) (int64, error) } // 匹配节点并返回节点ID -func (this *NodeDAO) FindAllNodeIdsMatch(tx *dbs.Tx, clusterId int64) (result []int64, err error) { +func (this *NodeDAO) FindAllNodeIdsMatch(tx *dbs.Tx, clusterId int64, isOn configutils.BoolState) (result []int64, err error) { query := this.Query(tx) query.State(NodeStateEnabled) if clusterId > 0 { query.Attr("clusterId", clusterId) } + if isOn == configutils.BoolStateYes { + query.Attr("isOn", true) + } else if isOn == configutils.BoolStateNo { + query.Attr("isOn", false) + } query.Result("id") ones, _, err := query.FindOnes() if err != nil { @@ -737,7 +700,11 @@ func (this *NodeDAO) UpdateNodeDNS(tx *dbs.Tx, nodeId int64, routes map[int64][] op.Id = nodeId op.DnsRoutes = routesJSON err = this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, nodeId) } // 计算节点上线|下线状态 @@ -792,6 +759,7 @@ func (this *NodeDAO) UpdateNodeUp(tx *dbs.Tx, nodeId int64, isUp bool, maxUp int if err != nil { return false, err } + return } @@ -843,3 +811,34 @@ func (this *NodeDAO) genUniqueId(tx *dbs.Tx) (string, error) { return uniqueId, nil } } + +// 根据一组ID查找一组节点 +func (this *NodeDAO) FindEnabledNodesWithIds(tx *dbs.Tx, nodeIds []int64) (result []*Node, err error) { + if len(nodeIds) == 0 { + return nil, nil + } + idStrings := []string{} + for _, nodeId := range nodeIds { + idStrings = append(idStrings, numberutils.FormatInt64(nodeId)) + } + _, err = this.Query(tx). + State(NodeStateEnabled). + Where("id IN ("+strings.Join(idStrings, ", ")+")"). + Result("id", "connectedAPINodes", "isActive", "isOn"). + Slice(&result). + Reuse(false). + FindAll() + return +} + +// 通知更新 +func (this *NodeDAO) NotifyUpdate(tx *dbs.Tx, nodeId int64) error { + clusterId, err := this.FindNodeClusterId(tx, nodeId) + if err != nil { + return err + } + if clusterId > 0 { + return SharedNodeTaskDAO.CreateNodeTask(tx, clusterId, nodeId, NodeTaskTypeConfigChanged) + } + return nil +} diff --git a/internal/db/models/node_dao_test.go b/internal/db/models/node_dao_test.go index 20d23959..c777c793 100644 --- a/internal/db/models/node_dao_test.go +++ b/internal/db/models/node_dao_test.go @@ -8,7 +8,7 @@ import ( func TestNodeDAO_FindAllNodeIdsMatch(t *testing.T) { var tx *dbs.Tx - nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(tx, 1) + nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(tx, 1, 0) if err != nil { t.Fatal(err) } diff --git a/internal/db/models/node_model_ext.go b/internal/db/models/node_model_ext.go index b446abcf..eca59dc5 100644 --- a/internal/db/models/node_model_ext.go +++ b/internal/db/models/node_model_ext.go @@ -66,3 +66,15 @@ func (this *Node) DNSRouteCodesForDomainId(dnsDomainId int64) ([]string, error) domainRoutes, _ := routes[dnsDomainId] return domainRoutes, nil } + +// 连接的API +func (this *Node) DecodeConnectedAPINodeIds() ([]int64, error) { + apiNodeIds := []int64{} + if IsNotNull(this.ConnectedAPINodes) { + err := json.Unmarshal([]byte(this.ConnectedAPINodes), &apiNodeIds) + if err != nil { + return nil, err + } + } + return apiNodeIds, nil +} diff --git a/internal/db/models/node_task_dao.go b/internal/db/models/node_task_dao.go new file mode 100644 index 00000000..6d1c9a5d --- /dev/null +++ b/internal/db/models/node_task_dao.go @@ -0,0 +1,278 @@ +package models + +import ( + "github.com/TeaOSLab/EdgeAPI/internal/utils/numberutils" + "github.com/TeaOSLab/EdgeCommon/pkg/configutils" + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/Tea" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/maps" + "time" +) + +type NodeTaskType = string + +const ( + NodeTaskTypeConfigChanged NodeTaskType = "configChanged" + NodeTaskTypeIPItemChanged NodeTaskType = "ipItemChanged" +) + +type NodeTaskDAO dbs.DAO + +func NewNodeTaskDAO() *NodeTaskDAO { + return dbs.NewDAO(&NodeTaskDAO{ + DAOObject: dbs.DAOObject{ + DB: Tea.Env, + Table: "edgeNodeTasks", + Model: new(NodeTask), + PkName: "id", + }, + }).(*NodeTaskDAO) +} + +var SharedNodeTaskDAO *NodeTaskDAO + +func init() { + dbs.OnReady(func() { + SharedNodeTaskDAO = NewNodeTaskDAO() + }) +} + +// 创建单个节点任务 +func (this *NodeTaskDAO) CreateNodeTask(tx *dbs.Tx, clusterId int64, nodeId int64, taskType NodeTaskType) error { + if clusterId <= 0 || nodeId <= 0 { + return nil + } + uniqueId := numberutils.FormatInt64(nodeId) + "@node@" + taskType + updatedAt := time.Now().Unix() + _, _, err := this.Query(tx). + InsertOrUpdate(maps.Map{ + "clusterId": clusterId, + "nodeId": nodeId, + "type": taskType, + "uniqueId": uniqueId, + "updatedAt": updatedAt, + "isDone": 0, + "isOk": 0, + "error": "", + }, maps.Map{ + "clusterId": clusterId, + "updatedAt": updatedAt, + "isDone": 0, + "isOk": 0, + "error": "", + }) + return err +} + +// 创建集群任务 +func (this *NodeTaskDAO) CreateClusterTask(tx *dbs.Tx, clusterId int64, taskType NodeTaskType) error { + if clusterId <= 0 { + return nil + } + + uniqueId := numberutils.FormatInt64(clusterId) + "@cluster@" + taskType + updatedAt := time.Now().Unix() + _, _, err := this.Query(tx). + InsertOrUpdate(maps.Map{ + "clusterId": clusterId, + "nodeId": 0, + "type": taskType, + "uniqueId": uniqueId, + "updatedAt": updatedAt, + "isDone": 0, + "isOk": 0, + "isNotified": 0, + "error": "", + }, maps.Map{ + "updatedAt": updatedAt, + "isDone": 0, + "isOk": 0, + "isNotified": 0, + "error": "", + }) + return err +} + +// 分解集群任务 +func (this *NodeTaskDAO) ExtractClusterTask(tx *dbs.Tx, clusterId int64, taskType NodeTaskType) error { + nodeIds, err := SharedNodeDAO.FindAllNodeIdsMatch(tx, clusterId, configutils.BoolStateYes) + if err != nil { + return err + } + + _, err = this.Query(tx). + Attr("clusterId", clusterId). + Where("nodeId> 0"). + Attr("type", taskType). + Delete() + if err != nil { + return err + } + + for _, nodeId := range nodeIds { + err = this.CreateNodeTask(tx, clusterId, nodeId, taskType) + if err != nil { + return err + } + } + + _, err = this.Query(tx). + Attr("clusterId", clusterId). + Attr("nodeId", 0). + Attr("type", taskType). + Delete() + if err != nil { + return err + } + + return nil +} + +// 分解所有集群任务 +func (this *NodeTaskDAO) ExtractAllClusterTasks(tx *dbs.Tx) error { + ones, err := this.Query(tx). + Attr("nodeId", 0). + FindAll() + if err != nil { + return err + } + for _, one := range ones { + clusterId := int64(one.(*NodeTask).ClusterId) + err = this.ExtractClusterTask(tx, clusterId, one.(*NodeTask).Type) + if err != nil { + return err + } + } + return nil +} + +// 删除集群所有相关任务 +func (this *NodeTaskDAO) DeleteAllClusterTasks(tx *dbs.Tx, clusterId int64) error { + _, err := this.Query(tx). + Attr("clusterId", clusterId). + Delete() + return err +} + +// 删除节点相关任务 +func (this *NodeTaskDAO) DeleteNodeTasks(tx *dbs.Tx, nodeId int64) error { + _, err := this.Query(tx). + Attr("nodeId", nodeId). + Delete() + return err +} + +// 查询一个节点的所有任务 +func (this *NodeTaskDAO) FindDoingNodeTasks(tx *dbs.Tx, nodeId int64) (result []*NodeTask, err error) { + if nodeId <= 0 { + return + } + _, err = this.Query(tx). + Attr("nodeId", nodeId). + Where("(isDone=0 OR (isDone=1 AND isOk=0))"). + Slice(&result). + FindAll() + return +} + +// 修改节点任务的完成状态 +func (this *NodeTaskDAO) UpdateNodeTaskDone(tx *dbs.Tx, taskId int64, isOk bool, errorMessage string) error { + _, err := this.Query(tx). + Pk(taskId). + Set("isDone", 1). + Set("isOk", isOk). + Set("error", errorMessage). + Update() + return err +} + +// 查找正在更新的集群IDs +func (this *NodeTaskDAO) FindAllDoingTaskClusterIds(tx *dbs.Tx) ([]int64, error) { + ones, _, err := this.Query(tx). + Result("DISTINCT(clusterId) AS clusterId"). + Where("(nodeId=0 OR (isDone=0 OR (isDone=1 AND isOk=0)))"). + FindOnes() + if err != nil { + return nil, err + } + result := []int64{} + for _, one := range ones { + result = append(result, one.GetInt64("clusterId")) + } + return result, nil +} + +// 查询某个集群下所有的任务 +func (this *NodeTaskDAO) FindAllDoingNodeTasksWithClusterId(tx *dbs.Tx, clusterId int64) (result []*NodeTask, err error) { + _, err = this.Query(tx). + Attr("clusterId", clusterId). + Gt("nodeId", 0). + Where("(isDone=0 OR (isDone=1 AND isOk=0))"). + Desc("isDone"). + Asc(). + Asc("nodeId"). + Slice(&result). + FindAll() + return +} + +// 检查是否有正在执行的任务 +func (this *NodeTaskDAO) ExistsDoingNodeTasks(tx *dbs.Tx) (bool, error) { + return this.Query(tx). + Where("(isDone=0 OR (isDone=1 AND isOk=0))"). + Gt("nodeId", 0). + Exist() +} + +// 是否有错误的任务 +func (this *NodeTaskDAO) ExistsErrorNodeTasks(tx *dbs.Tx) (bool, error) { + return this.Query(tx). + Where("(isDone=1 AND isOk=0)"). + Exist() +} + +// 删除任务 +func (this *NodeTaskDAO) DeleteNodeTask(tx *dbs.Tx, taskId int64) error { + _, err := this.Query(tx). + Pk(taskId). + Delete() + return err +} + +// 计算正在执行的任务 +func (this *NodeTaskDAO) CountDoingNodeTasks(tx *dbs.Tx) (int64, error) { + return this.Query(tx). + Attr("isDone", 0). + Gt("nodeId", 0). + Count() +} + +// 查找需要通知的任务 +func (this *NodeTaskDAO) FindNotifyingNodeTasks(tx *dbs.Tx, size int64) (result []*NodeTask, err error) { + _, err = this.Query(tx). + Gt("nodeId", 0). + Attr("isNotified", 0). + Attr("isDone", 0). + Limit(size). + Slice(&result). + FindAll() + return +} + +// 设置任务已通知 +func (this *NodeTaskDAO) UpdateTasksNotified(tx *dbs.Tx, taskIds []int64) error { + if len(taskIds) == 0 { + return nil + } + for _, taskId := range taskIds { + _, err := this.Query(tx). + Pk(taskId). + Set("isNotified", 1). + Update() + if err != nil { + return err + } + } + return nil +} diff --git a/internal/db/models/node_task_dao_test.go b/internal/db/models/node_task_dao_test.go new file mode 100644 index 00000000..51603a52 --- /dev/null +++ b/internal/db/models/node_task_dao_test.go @@ -0,0 +1,40 @@ +package models + +import ( + _ "github.com/go-sql-driver/mysql" + "github.com/iwind/TeaGo/dbs" + "testing" +) + +func TestNodeTaskDAO_CreateNodeTask(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + err := SharedNodeTaskDAO.CreateNodeTask(tx, 1, 2, NodeTaskTypeConfigChanged) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func TestNodeTaskDAO_CreateClusterTask(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + err := SharedNodeTaskDAO.CreateClusterTask(tx, 1, NodeTaskTypeConfigChanged) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} + +func TestNodeTaskDAO_ExtractClusterTask(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + err := SharedNodeTaskDAO.ExtractClusterTask(tx, 1, NodeTaskTypeConfigChanged) + if err != nil { + t.Fatal(err) + } + t.Log("ok") +} diff --git a/internal/db/models/node_task_model.go b/internal/db/models/node_task_model.go new file mode 100644 index 00000000..8774f886 --- /dev/null +++ b/internal/db/models/node_task_model.go @@ -0,0 +1,32 @@ +package models + +// 节点同步任务 +type NodeTask struct { + Id uint64 `field:"id"` // ID + NodeId uint32 `field:"nodeId"` // 节点ID + ClusterId uint32 `field:"clusterId"` // 集群ID + Type string `field:"type"` // 任务类型 + UniqueId string `field:"uniqueId"` // 唯一ID:nodeId@type + UpdatedAt uint64 `field:"updatedAt"` // 修改时间 + IsDone uint8 `field:"isDone"` // 是否已完成 + IsOk uint8 `field:"isOk"` // 是否已完成 + Error string `field:"error"` // 错误信息 + IsNotified uint8 `field:"isNotified"` // 是否已通知更新 +} + +type NodeTaskOperator struct { + Id interface{} // ID + NodeId interface{} // 节点ID + ClusterId interface{} // 集群ID + Type interface{} // 任务类型 + UniqueId interface{} // 唯一ID:nodeId@type + UpdatedAt interface{} // 修改时间 + IsDone interface{} // 是否已完成 + IsOk interface{} // 是否已完成 + Error interface{} // 错误信息 + IsNotified interface{} // 是否已通知更新 +} + +func NewNodeTaskOperator() *NodeTaskOperator { + return &NodeTaskOperator{} +} diff --git a/internal/db/models/node_task_model_ext.go b/internal/db/models/node_task_model_ext.go new file mode 100644 index 00000000..2640e7f9 --- /dev/null +++ b/internal/db/models/node_task_model_ext.go @@ -0,0 +1 @@ +package models diff --git a/internal/db/models/origin_dao.go b/internal/db/models/origin_dao.go index e0ef62b3..a2a73066 100644 --- a/internal/db/models/origin_dao.go +++ b/internal/db/models/origin_dao.go @@ -40,16 +40,7 @@ func init() { // 初始化 func (this *OriginDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -62,12 +53,16 @@ func (this *OriginDAO) EnableOrigin(tx *dbs.Tx, id int64) error { } // 禁用条目 -func (this *OriginDAO) DisableOrigin(tx *dbs.Tx, id int64) error { +func (this *OriginDAO) DisableOrigin(tx *dbs.Tx, originId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(originId). Set("state", OriginStateDisabled). Update() - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, originId) } // 查找启用中的条目 @@ -128,7 +123,11 @@ func (this *OriginDAO) UpdateOrigin(tx *dbs.Tx, originId int64, name string, add op.IsOn = isOn op.Version = dbs.SQL("version+1") err := this.Save(tx, op) - return err + if err != nil { + return err + } + + return this.NotifyUpdate(tx, originId) } // 将源站信息转换为配置 @@ -262,3 +261,15 @@ func (this *OriginDAO) ComposeOriginConfig(tx *dbs.Tx, originId int64) (*serverc return config, nil } + +// 通知更新 +func (this *OriginDAO) NotifyUpdate(tx *dbs.Tx, originId int64) error { + reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, originId) + if err != nil { + return err + } + if reverseProxyId > 0 { + return SharedReverseProxyDAO.NotifyUpdate(tx, reverseProxyId) + } + return nil +} diff --git a/internal/db/models/reverse_proxy_dao.go b/internal/db/models/reverse_proxy_dao.go index e61b3b90..8e7ef6d1 100644 --- a/internal/db/models/reverse_proxy_dao.go +++ b/internal/db/models/reverse_proxy_dao.go @@ -7,6 +7,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" ) @@ -38,16 +39,7 @@ func init() { // 初始化 func (this *ReverseProxyDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -59,7 +51,7 @@ func (this *ReverseProxyDAO) EnableReverseProxy(tx *dbs.Tx, id int64) error { if err != nil { return err } - return this.CreateEvent() + return this.NotifyUpdate(tx, id) } // 禁用条目 @@ -71,7 +63,7 @@ func (this *ReverseProxyDAO) DisableReverseProxy(tx *dbs.Tx, id int64) error { if err != nil { return err } - return this.CreateEvent() + return this.NotifyUpdate(tx, id) } // 查找启用中的条目 @@ -188,8 +180,10 @@ func (this *ReverseProxyDAO) UpdateReverseProxyScheduling(tx *dbs.Tx, reversePro op.Scheduling = "null" } err := this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, reverseProxyId) } // 修改主要源站 @@ -205,8 +199,10 @@ func (this *ReverseProxyDAO) UpdateReverseProxyPrimaryOrigins(tx *dbs.Tx, revers op.PrimaryOrigins = "[]" } err := this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, reverseProxyId) } // 修改备用源站 @@ -222,8 +218,10 @@ func (this *ReverseProxyDAO) UpdateReverseProxyBackupOrigins(tx *dbs.Tx, reverse op.BackupOrigins = "[]" } err := this.Save(tx, op) - - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, reverseProxyId) } // 修改是否启用 @@ -245,10 +243,31 @@ func (this *ReverseProxyDAO) UpdateReverseProxy(tx *dbs.Tx, reverseProxyId int64 op.StripPrefix = stripPrefix op.AutoFlush = autoFlush err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, reverseProxyId) +} + +// 查找包含某个源站的反向代理ID +func (this *ReverseProxyDAO) FindReverseProxyContainsOriginId(tx *dbs.Tx, originId int64) (int64, error) { + return this.Query(tx). + ResultPk(). + Where("(JSON_CONTAINS(primaryOrigins, :jsonQuery) OR JSON_CONTAINS(backupOrigins, :jsonQuery))"). + Param("jsonQuery", maps.Map{ + "originId": originId, + }.AsJSON()). + FindInt64Col(0) } // 通知更新 -func (this *ReverseProxyDAO) CreateEvent() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) +func (this *ReverseProxyDAO) NotifyUpdate(tx *dbs.Tx, reverseProxyId int64) error { + serverId, err := SharedServerDAO.FindEnabledServerIdWithReverseProxyId(tx, reverseProxyId) + if err != nil { + return err + } + if serverId > 0 { + return SharedServerDAO.NotifyUpdate(tx, serverId) + } + return nil } diff --git a/internal/db/models/reverse_proxy_dao_test.go b/internal/db/models/reverse_proxy_dao_test.go index eb0d2444..1ee14ab4 100644 --- a/internal/db/models/reverse_proxy_dao_test.go +++ b/internal/db/models/reverse_proxy_dao_test.go @@ -14,3 +14,14 @@ func TestReverseProxyDAO_ComposeReverseProxyConfig(t *testing.T) { } t.Log(config) } + +func TestReverseProxyDAO_FindReverseProxyContainsOriginId(t *testing.T) { + dbs.NotifyReady() + + var tx *dbs.Tx + reverseProxyId, err := SharedReverseProxyDAO.FindReverseProxyContainsOriginId(tx, 68) + if err != nil { + t.Fatal(err) + } + t.Log("reverseProxyId:", reverseProxyId) +} diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index 483871e9..cc01109e 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -12,6 +12,7 @@ import ( _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/lists" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" @@ -67,7 +68,10 @@ func (this *ServerDAO) DisableServer(tx *dbs.Tx, id int64) (err error) { Pk(id). Set("state", ServerStateDisabled). Update() - return + if err != nil { + return err + } + return this.NotifyUpdate(tx, id) } // 查找启用中的服务 @@ -197,14 +201,9 @@ func (this *ServerDAO) CreateServer(tx *dbs.Tx, serverId = types.Int64(op.Id) - _, err = this.RenewServerConfig(tx, serverId, false) + err = this.NotifyUpdate(tx, serverId) if err != nil { - return serverId, err - } - - err = this.createEvent() - if err != nil { - return serverId, err + return 0, err } return serverId, nil @@ -237,12 +236,13 @@ func (this *ServerDAO) UpdateServerBasic(tx *dbs.Tx, serverId int64, name string return err } - _, err = this.RenewServerConfig(tx, serverId, false) + // 通知更新 + err = this.NotifyUpdate(tx, serverId) if err != nil { return err } - return this.createEvent() + return nil } // 设置用户相关的基本信息 @@ -259,12 +259,7 @@ func (this *ServerDAO) UpdateUserServerBasic(tx *dbs.Tx, serverId int64, name st return err } - _, err = this.RenewServerConfig(tx, serverId, false) - if err != nil { - return err - } - - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修复服务是否启用 @@ -273,7 +268,16 @@ func (this *ServerDAO) UpdateServerIsOn(tx *dbs.Tx, serverId int64, isOn bool) e Pk(serverId). Set("isOn", isOn). Update() - return err + if err != nil { + return err + } + + err = this.NotifyUpdate(tx, serverId) + if err != nil { + return err + } + + return nil } // 修改服务配置 @@ -335,12 +339,7 @@ func (this *ServerDAO) UpdateServerHTTP(tx *dbs.Tx, serverId int64, config []byt return err } - _, err = this.RenewServerConfig(tx, serverId, false) - if err != nil { - return err - } - - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改HTTPS配置 @@ -359,12 +358,7 @@ func (this *ServerDAO) UpdateServerHTTPS(tx *dbs.Tx, serverId int64, httpsJSON [ return err } - _, err = this.RenewServerConfig(tx, serverId, false) - if err != nil { - return err - } - - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改TCP配置 @@ -383,7 +377,7 @@ func (this *ServerDAO) UpdateServerTCP(tx *dbs.Tx, serverId int64, config []byte return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改TLS配置 @@ -402,7 +396,7 @@ func (this *ServerDAO) UpdateServerTLS(tx *dbs.Tx, serverId int64, config []byte return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改Unix配置 @@ -421,7 +415,7 @@ func (this *ServerDAO) UpdateServerUnix(tx *dbs.Tx, serverId int64, config []byt return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改UDP配置 @@ -440,7 +434,7 @@ func (this *ServerDAO) UpdateServerUDP(tx *dbs.Tx, serverId int64, config []byte return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改Web配置 @@ -455,7 +449,7 @@ func (this *ServerDAO) UpdateServerWeb(tx *dbs.Tx, serverId int64, webId int64) if err != nil { return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 初始化Web配置 @@ -482,7 +476,7 @@ func (this *ServerDAO) InitServerWeb(tx *dbs.Tx, serverId int64) (int64, error) return 0, err } - err = this.createEvent() + err = this.NotifyUpdate(tx, serverId) if err != nil { return webId, err } @@ -526,7 +520,7 @@ func (this *ServerDAO) UpdateServerNames(tx *dbs.Tx, serverId int64, serverNames if err != nil { return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改域名审核 @@ -548,7 +542,7 @@ func (this *ServerDAO) UpdateAuditingServerNames(tx *dbs.Tx, serverId int64, isA if err != nil { return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 修改域名审核结果 @@ -573,7 +567,12 @@ func (this *ServerDAO) UpdateServerAuditing(tx *dbs.Tx, serverId int64, result * if result.IsOk { op.ServerNames = dbs.SQL("auditingServerNames") } - return this.Save(tx, op) + err = this.Save(tx, op) + if err != nil { + return err + } + + return this.NotifyUpdate(tx, serverId) } // 修改反向代理配置 @@ -589,7 +588,7 @@ func (this *ServerDAO) UpdateServerReverseProxy(tx *dbs.Tx, serverId int64, conf return err } - return this.createEvent() + return this.NotifyUpdate(tx, serverId) } // 计算所有可用服务数量 @@ -690,6 +689,20 @@ func (this *ServerDAO) FindAllEnabledServerIds(tx *dbs.Tx) (serverIds []int64, e return } +// 获取某个用户的所有的服务ID +func (this *ServerDAO) FindAllEnabledServerIdsWithUserId(tx *dbs.Tx, userId int64) (serverIds []int64, err error) { + ones, err := this.Query(tx). + State(ServerStateEnabled). + Attr("userId", userId). + AscPk(). + ResultPk(). + FindAll() + for _, one := range ones { + serverIds = append(serverIds, int64(one.(*Server).Id)) + } + return +} + // 查找服务的搜索条件 func (this *ServerDAO) FindServerNodeFilters(tx *dbs.Tx, serverId int64) (isOk bool, clusterId int64, err error) { one, err := this.Query(tx). @@ -714,7 +727,7 @@ func (this *ServerDAO) ComposeServerConfig(tx *dbs.Tx, serverId int64) (*serverc return nil, err } if server == nil { - return nil, errors.New("server not found") + return nil, ErrNotFound } config := &serverconfigs.ServerConfig{} @@ -946,6 +959,32 @@ func (this *ServerDAO) FindAllEnabledServersWithSSLPolicyIds(tx *dbs.Tx, sslPoli return } +// 查找使用某个SSL策略的所有服务Id +func (this *ServerDAO) FindAllEnabledServerIdsWithSSLPolicyIds(tx *dbs.Tx, sslPolicyIds []int64) (result []int64, err error) { + if len(sslPolicyIds) == 0 { + return + } + + for _, policyId := range sslPolicyIds { + ones, err := this.Query(tx). + State(ServerStateEnabled). + ResultPk(). + Where("(JSON_CONTAINS(https, :jsonQuery) OR JSON_CONTAINS(tls, :jsonQuery))"). + Param("jsonQuery", maps.Map{"sslPolicyRef": maps.Map{"sslPolicyId": policyId}}.AsJSON()). + FindAll() + if err != nil { + return nil, err + } + for _, one := range ones { + serverId := int64(one.(*Server).Id) + if !lists.ContainsInt64(result, serverId) { + result = append(result, serverId) + } + } + } + return +} + // 计算使用某个缓存策略的所有服务数量 func (this *ServerDAO) CountEnabledServersWithWebIds(tx *dbs.Tx, webIds []int64) (count int64, err error) { if len(webIds) == 0 { @@ -1045,12 +1084,16 @@ func (this *ServerDAO) GenerateServerDNSName(tx *dbs.Tx, serverId int64) (string op.Id = serverId op.DnsName = dnsName err = this.Save(tx, op) - return dnsName, err -} + if err != nil { + return "", err + } -// 创建事件 -func (this *ServerDAO) createEvent() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) + err = this.NotifyUpdate(tx, serverId) + if err != nil { + return "", err + } + + return dnsName, nil } // 查询当前服务的集群ID @@ -1085,7 +1128,7 @@ func (this *ServerDAO) FindServerAdminIdAndUserId(tx *dbs.Tx, serverId int64) (a } // 检查用户服务 -func (this *ServerDAO) CheckUserServer(tx *dbs.Tx, serverId int64, userId int64) error { +func (this *ServerDAO) CheckUserServer(tx *dbs.Tx, userId int64, serverId int64) error { if serverId <= 0 || userId <= 0 { return ErrNotFound } @@ -1104,10 +1147,45 @@ func (this *ServerDAO) CheckUserServer(tx *dbs.Tx, serverId int64, userId int64) // 设置一个用户下的所有服务的所属集群 func (this *ServerDAO) UpdateUserServersClusterId(tx *dbs.Tx, userId int64, clusterId int64) error { - _, err := this.Query(tx). + // 之前的cluster + oldClusterId, err := SharedUserDAO.FindUserClusterId(tx, userId) + if err != nil { + return err + } + if oldClusterId == clusterId { + return nil + } + + _, err = this.Query(tx). Attr("userId", userId). Set("clusterId", clusterId). Update() + if err != nil { + return err + } + + if oldClusterId > 0 { + err = SharedNodeTaskDAO.CreateClusterTask(tx, oldClusterId, NodeTaskTypeConfigChanged) + if err != nil { + return err + } + err = SharedNodeTaskDAO.CreateClusterTask(tx, oldClusterId, NodeTaskTypeIPItemChanged) + if err != nil { + return err + } + } + + if clusterId > 0 { + err = SharedNodeTaskDAO.CreateClusterTask(tx, clusterId, NodeTaskTypeConfigChanged) + if err != nil { + return err + } + err = SharedNodeTaskDAO.CreateClusterTask(tx, clusterId, NodeTaskTypeIPItemChanged) + if err != nil { + return err + } + } + return err } @@ -1134,6 +1212,16 @@ func (this *ServerDAO) FindEnabledServerIdWithWebId(tx *dbs.Tx, webId int64) (se FindInt64Col(0) } +// 查找包含某个反向代理的Server +func (this *ServerDAO) FindEnabledServerIdWithReverseProxyId(tx *dbs.Tx, reverseProxyId int64) (serverId int64, err error) { + return this.Query(tx). + State(ServerStateEnabled). + Where("JSON_CONTAINS(reverseProxy, :jsonQuery)"). + Param("jsonQuery", maps.Map{"reverseProxyId": reverseProxyId}.AsJSON()). + ResultPk(). + FindInt64Col(0) +} + // 检查端口是否被使用 func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, port int) (bool, error) { listen := maps.Map{ @@ -1148,6 +1236,25 @@ func (this *ServerDAO) CheckPortIsUsing(tx *dbs.Tx, clusterId int64, port int) ( Exist() } +// 同步集群 +func (this *ServerDAO) NotifyUpdate(tx *dbs.Tx, serverId int64) error { + // 更新配置 + _, err := this.RenewServerConfig(tx, serverId, true) + if err != nil && err != ErrNotFound { + return err + } + + // 创建任务 + clusterId, err := this.FindServerClusterId(tx, serverId) + if err != nil { + return err + } + if clusterId == 0 { + return nil + } + return SharedNodeTaskDAO.CreateClusterTask(tx, clusterId, NodeTaskTypeConfigChanged) +} + // 生成DNS Name func (this *ServerDAO) genDNSName(tx *dbs.Tx) (string, error) { for { diff --git a/internal/db/models/server_dao_test.go b/internal/db/models/server_dao_test.go index 505d8421..d378cedb 100644 --- a/internal/db/models/server_dao_test.go +++ b/internal/db/models/server_dao_test.go @@ -77,3 +77,13 @@ func TestServerDAO_FindAllServerDNSNamesWithDNSDomainId(t *testing.T) { } t.Log("dnsNames:", dnsNames) } + +func TestServerDAO_FindAllEnabledServerIdsWithSSLPolicyIds(t *testing.T) { + dbs.NotifyReady() + var tx *dbs.Tx + serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithSSLPolicyIds(tx, []int64{14}) + if err != nil { + t.Fatal(err) + } + t.Log("serverIds:", serverIds) +} \ No newline at end of file diff --git a/internal/db/models/ssl_cert_dao.go b/internal/db/models/ssl_cert_dao.go index 7018d1da..05e3d4cd 100644 --- a/internal/db/models/ssl_cert_dao.go +++ b/internal/db/models/ssl_cert_dao.go @@ -40,16 +40,7 @@ func init() { // 初始化 func (this *SSLCertDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -62,12 +53,15 @@ func (this *SSLCertDAO) EnableSSLCert(tx *dbs.Tx, id int64) error { } // 禁用条目 -func (this *SSLCertDAO) DisableSSLCert(tx *dbs.Tx, id int64) error { +func (this *SSLCertDAO) DisableSSLCert(tx *dbs.Tx, certId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(certId). Set("state", SSLCertStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, certId) } // 查找启用中的条目 @@ -162,7 +156,10 @@ func (this *SSLCertDAO) UpdateCert(tx *dbs.Tx, certId int64, isOn bool, name str op.CommonNames = commonNamesJSON err = this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, certId) } // 组合配置 @@ -344,3 +341,29 @@ func (this *SSLCertDAO) CheckUserCert(tx *dbs.Tx, certId int64, userId int64) er } return nil } + +// 通知更新 +func (this *SSLCertDAO) NotifyUpdate(tx *dbs.Tx, certId int64) error { + policyIds, err := SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, certId) + if err != nil { + return err + } + if len(policyIds) == 0 { + return nil + } + + serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithSSLPolicyIds(tx, policyIds) + if err != nil { + return err + } + if len(serverIds) == 0 { + return nil + } + for _, serverId := range serverIds { + err := SharedServerDAO.NotifyUpdate(tx, serverId) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index 8ece89f9..b9d2f9e2 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -39,16 +39,7 @@ func init() { // 初始化 func (this *SSLPolicyDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } // 启用条目 @@ -61,12 +52,15 @@ func (this *SSLPolicyDAO) EnableSSLPolicy(tx *dbs.Tx, id int64) error { } // 禁用条目 -func (this *SSLPolicyDAO) DisableSSLPolicy(tx *dbs.Tx, id int64) error { +func (this *SSLPolicyDAO) DisableSSLPolicy(tx *dbs.Tx, policyId int64) error { _, err := this.Query(tx). - Pk(id). + Pk(policyId). Set("state", SSLPolicyStateDisabled). Update() - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 查找启用中的条目 @@ -259,7 +253,10 @@ func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled op.CipherSuites = "[]" } err := this.Save(tx, op) - return err + if err != nil { + return err + } + return this.NotifyUpdate(tx, policyId) } // 检查是否为用户所属策略 @@ -280,3 +277,18 @@ func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int } return nil } + +// 通知更新 +func (this *SSLPolicyDAO) NotifyUpdate(tx *dbs.Tx, policyId int64) error { + serverIds, err := SharedServerDAO.FindAllEnabledServerIdsWithSSLPolicyIds(tx, []int64{policyId}) + if err != nil { + return err + } + for _, serverId := range serverIds { + err := SharedServerDAO.NotifyUpdate(tx, serverId) + if err != nil { + return err + } + } + return nil +} diff --git a/internal/db/models/sys_event_types.go b/internal/db/models/sys_event_types.go index 10a21321..5c42c6b7 100644 --- a/internal/db/models/sys_event_types.go +++ b/internal/db/models/sys_event_types.go @@ -1,7 +1,6 @@ package models import ( - "github.com/iwind/TeaGo/dbs" "reflect" ) @@ -9,7 +8,7 @@ var eventTypeMapping = map[string]reflect.Type{} // eventType => reflect type func init() { for _, event := range []EventInterface{ - NewServerChangeEvent(), + // Event列表 } { eventTypeMapping[event.Type()] = reflect.ValueOf(event).Elem().Type() } @@ -20,48 +19,3 @@ type EventInterface interface { Type() string Run() error } - -// 服务变化 -type ServerChangeEvent struct { -} - -func NewServerChangeEvent() *ServerChangeEvent { - return &ServerChangeEvent{} -} - -func (this *ServerChangeEvent) Type() string { - return "serverChange" -} - -func (this *ServerChangeEvent) Run() error { - var tx *dbs.Tx - - serverIds, err := SharedServerDAO.FindAllEnabledServerIds(tx) - if err != nil { - return err - } - for _, serverId := range serverIds { - isChanged, err := SharedServerDAO.RenewServerConfig(tx, serverId, true) - if err != nil { - return err - } - if !isChanged { - continue - } - - // 检查节点是否需要更新 - isOk, clusterId, err := SharedServerDAO.FindServerNodeFilters(tx, serverId) - if err != nil { - return err - } - if !isOk { - continue - } - err = SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, clusterId) - if err != nil { - return err - } - } - - return nil -} diff --git a/internal/db/models/tcp_firewall_policy_dao.go b/internal/db/models/tcp_firewall_policy_dao.go index 0d6aa71f..eb6755b0 100644 --- a/internal/db/models/tcp_firewall_policy_dao.go +++ b/internal/db/models/tcp_firewall_policy_dao.go @@ -29,14 +29,5 @@ func init() { // 初始化 func (this *TCPFirewallPolicyDAO) Init() { - this.DAOObject.Init() - this.DAOObject.OnUpdate(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnInsert(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) - this.DAOObject.OnDelete(func() error { - return SharedSysEventDAO.CreateEvent(nil, NewServerChangeEvent()) - }) + _ = this.DAOObject.Init() } diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index 45c91c49..c7feb5e7 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -224,6 +224,7 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err pb.RegisterLoginServiceServer(rpcServer, &services.LoginService{}) pb.RegisterUserAccessKeyServiceServer(rpcServer, &services.UserAccessKeyService{}) pb.RegisterSysLockerServiceServer(rpcServer, &services.SysLockerService{}) + pb.RegisterNodeTaskServiceServer(rpcServer, &services.NodeTaskService{}) err := rpcServer.Serve(listener) if err != nil { return errors.New("[API_NODE]start rpc failed: " + err.Error()) diff --git a/internal/rpc/services/service_http_access_log.go b/internal/rpc/services/service_http_access_log.go index 95a9fc92..b10710ae 100644 --- a/internal/rpc/services/service_http_access_log.go +++ b/internal/rpc/services/service_http_access_log.go @@ -51,7 +51,7 @@ func (this *HTTPAccessLogService) ListHTTPAccessLogs(ctx context.Context, req *p return nil, errors.New("invalid serverId") } - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -98,7 +98,7 @@ func (this *HTTPAccessLogService) FindHTTPAccessLog(ctx context.Context, req *pb // 检查权限 if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, int64(accessLog.ServerId), userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, int64(accessLog.ServerId)) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ip_item.go b/internal/rpc/services/service_ip_item.go index 34a743df..fe912cc7 100644 --- a/internal/rpc/services/service_ip_item.go +++ b/internal/rpc/services/service_ip_item.go @@ -52,6 +52,12 @@ func (this *IPItemService) CreateIPItem(ctx context.Context, req *pb.CreateIPIte return nil, err } + // 通知更新 + err = models.SharedIPListDAO.NotifyUpdate(tx, req.IpListId, models.NodeTaskTypeIPItemChanged) + if err != nil { + return nil, err + } + return &pb.CreateIPItemResponse{IpItemId: itemId}, nil } @@ -81,6 +87,13 @@ func (this *IPItemService) UpdateIPItem(ctx context.Context, req *pb.UpdateIPIte if err != nil { return nil, err } + + // 通知更新 + err = models.SharedIPItemDAO.NotifyClustersUpdate(tx, req.IpItemId, models.NodeTaskTypeIPItemChanged) + if err != nil { + return nil, err + } + return this.Success() } @@ -110,6 +123,13 @@ func (this *IPItemService) DeleteIPItem(ctx context.Context, req *pb.DeleteIPIte if err != nil { return nil, err } + + // 通知更新 + err = models.SharedIPItemDAO.NotifyClustersUpdate(tx, req.IpItemId, models.NodeTaskTypeIPItemChanged) + if err != nil { + return nil, err + } + return this.Success() } diff --git a/internal/rpc/services/service_node.go b/internal/rpc/services/service_node.go index 84d8cdfd..d68f74e5 100644 --- a/internal/rpc/services/service_node.go +++ b/internal/rpc/services/service_node.go @@ -351,6 +351,12 @@ func (this *NodeService) DeleteNode(ctx context.Context, req *pb.DeleteNodeReque } }() + // 删除节点相关任务 + err = models.SharedNodeTaskDAO.DeleteNodeTasks(tx, req.NodeId) + if err != nil { + return nil, err + } + return this.Success() } @@ -577,23 +583,6 @@ func (this *NodeService) UpdateNodeStatus(ctx context.Context, req *pb.UpdateNod return this.Success() } -// 同步集群中的节点版本 -func (this *NodeService) SyncNodesVersionWithCluster(ctx context.Context, req *pb.SyncNodesVersionWithClusterRequest) (*pb.SyncNodesVersionWithClusterResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) - if err != nil { - return nil, err - } - - tx := this.NullTx() - - err = models.SharedNodeDAO.SyncNodeVersionsWithCluster(tx, req.NodeClusterId) - if err != nil { - return nil, err - } - - return &pb.SyncNodesVersionWithClusterResponse{}, nil -} - // 修改节点安装状态 func (this *NodeService) UpdateNodeIsInstalled(ctx context.Context, req *pb.UpdateNodeIsInstalledRequest) (*pb.RPCSuccess, error) { _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -1326,3 +1315,32 @@ func (this *NodeService) CountAllEnabledNodesWithNodeRegionId(ctx context.Contex } return this.SuccessCount(count) } + +// 根据一组ID获取节点信息 +func (this *NodeService) FindEnabledNodesWithIds(ctx context.Context, req *pb.FindEnabledNodesWithIdsRequest) (*pb.FindEnabledNodesWithIdsResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + tx := this.NullTx() + + nodes, err := models.SharedNodeDAO.FindEnabledNodesWithIds(tx, req.NodeIds) + if err != nil { + return nil, err + } + pbNodes := []*pb.Node{} + for _, node := range nodes { + connectedAPINodeIds, err := node.DecodeConnectedAPINodeIds() + if err != nil { + return nil, err + } + pbNodes = append(pbNodes, &pb.Node{ + Id: int64(node.Id), + IsOn: node.IsOn == 1, + IsActive: node.IsActive == 1, + ConnectedAPINodeIds: connectedAPINodeIds, + }) + } + return &pb.FindEnabledNodesWithIdsResponse{Nodes: pbNodes}, nil +} diff --git a/internal/rpc/services/service_node_cluster.go b/internal/rpc/services/service_node_cluster.go index 58c270bf..a856634f 100644 --- a/internal/rpc/services/service_node_cluster.go +++ b/internal/rpc/services/service_node_cluster.go @@ -83,6 +83,12 @@ func (this *NodeClusterService) DeleteNodeCluster(ctx context.Context, req *pb.D return nil, err } + // 删除相关任务 + err = models.SharedNodeTaskDAO.DeleteAllClusterTasks(tx, req.NodeClusterId) + if err != nil { + return nil, err + } + return this.Success() } @@ -208,44 +214,6 @@ func (this *NodeClusterService) FindAllEnabledNodeClusters(ctx context.Context, }, nil } -// 查找所有变更的集群 -func (this *NodeClusterService) FindAllChangedNodeClusters(ctx context.Context, req *pb.FindAllChangedNodeClustersRequest) (*pb.FindAllChangedNodeClustersResponse, error) { - _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) - if err != nil { - return nil, err - } - - tx := this.NullTx() - - clusterIds, err := models.SharedNodeDAO.FindChangedClusterIds(tx) - if err != nil { - return nil, err - } - if len(clusterIds) == 0 { - return &pb.FindAllChangedNodeClustersResponse{ - NodeClusters: []*pb.NodeCluster{}, - }, nil - } - result := []*pb.NodeCluster{} - for _, clusterId := range clusterIds { - cluster, err := models.SharedNodeClusterDAO.FindEnabledNodeCluster(tx, clusterId) - if err != nil { - return nil, err - } - if cluster == nil { - continue - } - result = append(result, &pb.NodeCluster{ - Id: int64(cluster.Id), - Name: cluster.Name, - CreatedAt: int64(cluster.CreatedAt), - UniqueId: cluster.UniqueId, - Secret: cluster.Secret, - }) - } - return &pb.FindAllChangedNodeClustersResponse{NodeClusters: result}, nil -} - // 计算所有集群数量 func (this *NodeClusterService) CountAllEnabledNodeClusters(ctx context.Context, req *pb.CountAllEnabledNodeClustersRequest) (*pb.RPCCountResponse, error) { _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) @@ -659,12 +627,6 @@ func (this *NodeClusterService) UpdateNodeClusterTOA(ctx context.Context, req *p return nil, err } - // 增加节点版本号 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) - if err != nil { - return nil, err - } - return this.Success() } @@ -764,12 +726,6 @@ func (this *NodeClusterService) UpdateNodeClusterHTTPCachePolicyId(ctx context.C return nil, err } - // 增加节点版本号 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) - if err != nil { - return nil, err - } - return this.Success() } @@ -787,12 +743,6 @@ func (this *NodeClusterService) UpdateNodeClusterHTTPFirewallPolicyId(ctx contex return nil, err } - // 增加节点版本号 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) - if err != nil { - return nil, err - } - return this.Success() } @@ -817,12 +767,6 @@ func (this *NodeClusterService) UpdateNodeClusterSystemService(ctx context.Conte return nil, err } - // 增加节点版本号 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) - if err != nil { - return nil, err - } - return this.Success() } diff --git a/internal/rpc/services/service_node_task.go b/internal/rpc/services/service_node_task.go new file mode 100644 index 00000000..57535e72 --- /dev/null +++ b/internal/rpc/services/service_node_task.go @@ -0,0 +1,243 @@ +package services + +import ( + "context" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/iwind/TeaGo/dbs" + "time" +) + +// 节点同步任务相关服务 +type NodeTaskService struct { + BaseService +} + +// 获取单节点同步任务 +func (this *NodeTaskService) FindNodeTasks(ctx context.Context, req *pb.FindNodeTasksRequest) (*pb.FindNodeTasksResponse, error) { + nodeId, err := this.ValidateNode(ctx) + if err != nil { + return nil, err + } + + _ = req + + var tx = this.NullTx() + tasks, err := models.SharedNodeTaskDAO.FindDoingNodeTasks(tx, nodeId) + if err != nil { + return nil, err + } + + pbTasks := []*pb.NodeTask{} + for _, task := range tasks { + pbTasks = append(pbTasks, &pb.NodeTask{ + Id: int64(task.Id), + Type: task.Type, + }) + } + + return &pb.FindNodeTasksResponse{NodeTasks: pbTasks}, nil +} + +// 报告同步任务结果 +func (this *NodeTaskService) ReportNodeTaskDone(ctx context.Context, req *pb.ReportNodeTaskDoneRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateNode(ctx) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + err = models.SharedNodeTaskDAO.UpdateNodeTaskDone(tx, req.NodeTaskId, req.IsOk, req.Error) + if err != nil { + return nil, err + } + + return this.Success() +} + +// 获取所有正在同步的集群信息 +func (this *NodeTaskService) FindNodeClusterTasks(ctx context.Context, req *pb.FindNodeClusterTasksRequest) (*pb.FindNodeClusterTasksResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + _ = req + + var tx = this.NullTx() + clusterIds, err := models.SharedNodeTaskDAO.FindAllDoingTaskClusterIds(tx) + if err != nil { + return nil, err + } + if len(clusterIds) == 0 { + return &pb.FindNodeClusterTasksResponse{ClusterTasks: []*pb.ClusterTask{}}, nil + } + + pbClusterTasks := []*pb.ClusterTask{} + for _, clusterId := range clusterIds { + pbClusterTask := &pb.ClusterTask{} + clusterName, err := models.SharedNodeClusterDAO.FindNodeClusterName(tx, clusterId) + if err != nil { + return nil, err + } + pbClusterTask.ClusterId = clusterId + pbClusterTask.ClusterName = clusterName + + // 错误的节点任务 + pbNodeTasks := []*pb.NodeTask{} + // TODO 考虑节点特别多的情形,比如只显示前100个 + tasks, err := models.SharedNodeTaskDAO.FindAllDoingNodeTasksWithClusterId(tx, clusterId) + if err != nil { + return nil, err + } + for _, task := range tasks { + // 节点 + nodeName, err := models.SharedNodeDAO.FindNodeName(tx, int64(task.NodeId)) + if err != nil { + return nil, err + } + + // 是否超时(N秒内没有更新) + if int64(task.UpdatedAt) < time.Now().Unix()-120 { + task.IsDone = 1 + task.IsOk = 0 + task.Error = "节点响应超时" + } + + pbNodeTasks = append(pbNodeTasks, &pb.NodeTask{ + Id: int64(task.Id), + Type: task.Type, + IsDone: task.IsDone == 1, + IsOk: task.IsOk == 1, + Error: task.Error, + UpdatedAt: int64(task.UpdatedAt), + Node: &pb.Node{ + Id: int64(task.NodeId), + Name: nodeName, + }, + }) + } + pbClusterTask.NodeTasks = pbNodeTasks + + pbClusterTasks = append(pbClusterTasks, pbClusterTask) + } + + return &pb.FindNodeClusterTasksResponse{ClusterTasks: pbClusterTasks}, nil +} + +// 检查是否有正在执行的任务 +func (this *NodeTaskService) ExistsNodeTasks(ctx context.Context, req *pb.ExistsNodeTasksRequest) (*pb.ExistsNodeTasksResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + _ = req + + var tx = this.NullTx() + + // 是否有任务 + existTask, err := models.SharedNodeTaskDAO.ExistsDoingNodeTasks(tx) + if err != nil { + return nil, err + } + + // 是否有错误 + existError, err := models.SharedNodeTaskDAO.ExistsErrorNodeTasks(tx) + if err != nil { + return nil, err + } + + return &pb.ExistsNodeTasksResponse{ + ExistTasks: existTask, + ExistError: existError, + }, nil +} + +// 删除任务 +func (this *NodeTaskService) DeleteNodeTask(ctx context.Context, req *pb.DeleteNodeTaskRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + err = models.SharedNodeTaskDAO.DeleteNodeTask(tx, req.NodeTaskId) + if err != nil { + return nil, err + } + + return this.Success() +} + +// 计算正在执行的任务数量 +func (this *NodeTaskService) CountDoingNodeTasks(ctx context.Context, req *pb.CountDoingNodeTasksRequest) (*pb.RPCCountResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + _ = req + + var tx = this.NullTx() + count, err := models.SharedNodeTaskDAO.CountDoingNodeTasks(tx) + if err != nil { + return nil, err + } + return this.SuccessCount(count) +} + +// 查找需要通知的任务 +func (this *NodeTaskService) FindNotifyingNodeTasks(ctx context.Context, req *pb.FindNotifyingNodeTasksRequest) (*pb.FindNotifyingNodeTasksResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + if req.Size <= 0 { + req.Size = 100 + } + if req.Size > 1000 { + req.Size = 1000 + } + + var tx = this.NullTx() + tasks, err := models.SharedNodeTaskDAO.FindNotifyingNodeTasks(tx, req.Size) + if err != nil { + return nil, err + } + + pbTasks := []*pb.NodeTask{} + for _, task := range tasks { + pbTasks = append(pbTasks, &pb.NodeTask{ + Id: int64(task.Id), + Type: task.Type, + IsDone: task.IsDone == 1, + IsOk: task.IsOk == 1, + Error: task.Error, + UpdatedAt: int64(task.UpdatedAt), + Node: &pb.Node{Id: int64(task.NodeId)}, + }) + } + + return &pb.FindNotifyingNodeTasksResponse{NodeTasks: pbTasks}, nil +} + +// 设置任务已通知 +func (this *NodeTaskService) UpdateNodeTasksNotified(ctx context.Context, req *pb.UpdateNodeTasksNotifiedRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + err = this.RunTx(func(tx *dbs.Tx) error { + err = models.SharedNodeTaskDAO.UpdateTasksNotified(tx, req.NodeTaskIds) + return err + }) + + if err != nil { + return nil, err + } + + return this.Success() +} diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index f4119469..26ce1015 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -84,12 +84,6 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return nil, err } - // 更新节点版本 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) - if err != nil { - return nil, err - } - return &pb.CreateServerResponse{ServerId: serverId}, nil } @@ -132,20 +126,6 @@ func (this *ServerService) UpdateServerBasic(ctx context.Context, req *pb.Update }() } - // 更新老的节点版本 - if req.NodeClusterId != int64(server.ClusterId) { - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, int64(server.ClusterId)) - if err != nil { - return nil, err - } - } - - // 更新新的节点版本 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) - if err != nil { - return nil, err - } - return this.Success() } @@ -159,7 +139,7 @@ func (this *ServerService) UpdateServerIsOn(ctx context.Context, req *pb.UpdateS tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -182,7 +162,7 @@ func (this *ServerService) UpdateServerHTTP(ctx context.Context, req *pb.UpdateS tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -208,7 +188,7 @@ func (this *ServerService) UpdateServerHTTPS(ctx context.Context, req *pb.Update tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -326,7 +306,7 @@ func (this *ServerService) UpdateServerWeb(ctx context.Context, req *pb.UpdateSe tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -352,7 +332,7 @@ func (this *ServerService) UpdateServerReverseProxy(ctx context.Context, req *pb tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -377,7 +357,7 @@ func (this *ServerService) FindServerNames(ctx context.Context, req *pb.FindServ tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -625,33 +605,18 @@ func (this *ServerService) DeleteServer(ctx context.Context, req *pb.DeleteServe tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } } - // 查找服务 - server, err := models.SharedServerDAO.FindEnabledServer(tx, req.ServerId) - if err != nil { - return nil, err - } - if server == nil { - return nil, errors.New("can not find the server") - } - // 禁用服务 err = models.SharedServerDAO.DisableServer(tx, req.ServerId) if err != nil { return nil, err } - // 更新节点版本 - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, int64(server.ClusterId)) - if err != nil { - return nil, err - } - return this.Success() } @@ -667,7 +632,7 @@ func (this *ServerService) FindEnabledServer(ctx context.Context, req *pb.FindEn // 检查权限 if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -769,7 +734,7 @@ func (this *ServerService) FindEnabledServerConfig(ctx context.Context, req *pb. // 检查权限 if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -802,7 +767,7 @@ func (this *ServerService) FindEnabledServerType(ctx context.Context, req *pb.Fi // 检查权限 if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -880,7 +845,7 @@ func (this *ServerService) FindAndInitServerWebConfig(ctx context.Context, req * tx := this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -1022,10 +987,16 @@ func (this *ServerService) NotifyServersChange(ctx context.Context, req *pb.Noti tx := this.NullTx() - err = models.SharedSysEventDAO.CreateEvent(tx, models.NewServerChangeEvent()) + clusterIds, err := models.SharedNodeClusterDAO.FindAllEnableClusterIds(tx) if err != nil { return nil, err } + for _, clusterId := range clusterIds { + err = models.SharedNodeClusterDAO.NotifyUpdate(tx, clusterId) + if err != nil { + return nil, err + } + } return &pb.NotifyServersChangeResponse{}, nil } @@ -1167,7 +1138,7 @@ func (this *ServerService) CheckUserServer(ctx context.Context, req *pb.CheckUse tx := this.NullTx() - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -1217,7 +1188,7 @@ func (this *ServerService) FindEnabledUserServerBasic(ctx context.Context, req * var tx = this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } @@ -1250,7 +1221,7 @@ func (this *ServerService) UpdateEnabledUserServerBasic(ctx context.Context, req var tx = this.NullTx() if userId > 0 { - err = models.SharedServerDAO.CheckUserServer(tx, req.ServerId, userId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_user.go b/internal/rpc/services/service_user.go index 94cc92ea..af646d38 100644 --- a/internal/rpc/services/service_user.go +++ b/internal/rpc/services/service_user.go @@ -57,16 +57,6 @@ func (this *UserService) UpdateUser(ctx context.Context, req *pb.UpdateUserReque if err != nil { return nil, err } - - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, oldClusterId) - if err != nil { - return nil, err - } - - err = models.SharedNodeDAO.IncreaseAllNodesLatestVersionMatch(tx, req.NodeClusterId) - if err != nil { - return nil, err - } } return this.Success() @@ -81,6 +71,18 @@ func (this *UserService) DeleteUser(ctx context.Context, req *pb.DeleteUserReque tx := this.NullTx() + // 删除其下的Server + serverIds, err := models.SharedServerDAO.FindAllEnabledServerIdsWithUserId(tx, req.UserId) + if err != nil { + return nil, err + } + for _, serverId := range serverIds { + err := models.SharedServerDAO.DisableServer(tx, serverId) + if err != nil { + return nil, err + } + } + _, err = models.SharedUserDAO.DisableUser(tx, req.UserId) if err != nil { return nil, err diff --git a/internal/tasks/node_task_extractor.go b/internal/tasks/node_task_extractor.go new file mode 100644 index 00000000..bf1a6b2e --- /dev/null +++ b/internal/tasks/node_task_extractor.go @@ -0,0 +1,51 @@ +package tasks + +import ( + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/iwind/TeaGo/dbs" + "github.com/iwind/TeaGo/logs" + "time" +) + +func init() { + dbs.OnReady(func() { + go NewNodeTaskExtractor().Start() + }) +} + +// 节点任务 +type NodeTaskExtractor struct { +} + +func NewNodeTaskExtractor() *NodeTaskExtractor { + return &NodeTaskExtractor{} +} + +func (this *NodeTaskExtractor) Start() { + ticker := time.NewTicker(10 * time.Second) + for range ticker.C { + err := this.Loop() + if err != nil { + logs.Println("[TASK][NODE_TASK_EXTRACTOR]" + err.Error()) + } + } +} + +func (this *NodeTaskExtractor) Loop() error { + ok, err := models.SharedSysLockerDAO.Lock(nil, "node_task_extractor", 10-1) // 假设执行时间为1秒 + if err != nil { + return err + } + if !ok { + return nil + } + + // 这里不解锁,是为了让任务N秒钟之内只运行一次 + + err = models.SharedNodeTaskDAO.ExtractAllClusterTasks(nil) + if err != nil { + return err + } + + return nil +}