diff --git a/internal/db/models/http_firewall_rule_group_dao.go b/internal/db/models/http_firewall_rule_group_dao.go index 549c510c..7f2d4dba 100644 --- a/internal/db/models/http_firewall_rule_group_dao.go +++ b/internal/db/models/http_firewall_rule_group_dao.go @@ -194,13 +194,13 @@ func (this *HTTPFirewallRuleGroupDAO) UpdateGroup(tx *dbs.Tx, groupId int64, isO } // UpdateGroupSets 修改分组中的规则集 -func (this *HTTPFirewallRuleGroupDAO) UpdateGroupSets(tx *dbs.Tx, groupId int64, setsJSON []byte) error { +func (this *HTTPFirewallRuleGroupDAO) UpdateGroupSets(tx *dbs.Tx, groupId int64, setRefsJSON []byte) error { if groupId <= 0 { return errors.New("invalid groupId") } op := NewHTTPFirewallRuleGroupOperator() op.Id = groupId - op.Sets = setsJSON + op.Sets = setRefsJSON err := this.Save(tx, op) if err != nil { return err diff --git a/internal/rpc/services/service_http_firewall_rule_group.go b/internal/rpc/services/service_http_firewall_rule_group.go index 7b4aad22..134b11ed 100644 --- a/internal/rpc/services/service_http_firewall_rule_group.go +++ b/internal/rpc/services/service_http_firewall_rule_group.go @@ -4,7 +4,9 @@ import ( "context" "encoding/json" "github.com/TeaOSLab/EdgeAPI/internal/db/models" + "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" ) // HTTPFirewallRuleGroupService WAF规则分组相关服务 @@ -167,7 +169,7 @@ func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroupSets(ctx co return nil, err } } - + tx := this.NullTx() err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, req.GetFirewallRuleGroupId(), req.FirewallRuleSetsJSON) @@ -176,3 +178,65 @@ func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroupSets(ctx co } return this.Success() } + +// AddHTTPFirewallRuleGroupSet 添加规则集 +func (this *HTTPFirewallRuleGroupService) AddHTTPFirewallRuleGroupSet(ctx context.Context, req *pb.AddHTTPFirewallRuleGroupSetRequest) (*pb.RPCSuccess, error) { + // 校验请求 + _, userId, err := this.ValidateAdminAndUser(ctx, 0, 0) + if err != nil { + return nil, err + } + + if userId > 0 { + // 校验权限 + err = models.SharedHTTPFirewallRuleGroupDAO.CheckUserRuleGroup(nil, userId, req.FirewallRuleGroupId) + if err != nil { + return nil, err + } + } + + tx := this.NullTx() + + // 已经有的规则 + config, err := models.SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(tx, req.FirewallRuleGroupId) + if err != nil { + return nil, err + } + if config == nil { + return nil, errors.New("can not find group") + } + var setRefs = config.SetRefs + + var set = &firewallconfigs.HTTPFirewallRuleSet{} + err = json.Unmarshal(req.FirewallRuleSetConfigJSON, set) + if err != nil { + return nil, err + } + + if set.Id > 0 { + setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{ + IsOn: true, + SetId: set.Id, + }) + } else { + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(tx, set) + if err != nil { + return nil, err + } + setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{ + IsOn: true, + SetId: setId, + }) + } + + setRefsJSON, err := json.Marshal(setRefs) + if err != nil { + return nil, err + } + + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(tx, req.FirewallRuleGroupId, setRefsJSON) + if err != nil { + return nil, err + } + return this.Success() +} diff --git a/internal/setup/sql_upgrade.go b/internal/setup/sql_upgrade.go index 53e2b2f0..14aabf32 100644 --- a/internal/setup/sql_upgrade.go +++ b/internal/setup/sql_upgrade.go @@ -56,6 +56,9 @@ var upgradeFuncs = []*upgradeVersion{ { "0.3.2", upgradeV0_3_2, }, + { + "0.3.3", upgradeV0_3_3, + }, } // UpgradeSQLData 升级SQL数据 @@ -511,3 +514,14 @@ func upgradeV0_3_2(db *dbs.DB) error { return nil } + +// v0.3.3 +func upgradeV0_3_3(db *dbs.DB) error { + // 升级CC请求数Code + _, err := db.Exec("UPDATE edgeHTTPFirewallRuleSets SET code='8002' WHERE name='CC请求数' AND code='8001'") + if err != nil { + return err + } + + return nil +} diff --git a/internal/setup/sql_upgrade_test.go b/internal/setup/sql_upgrade_test.go index f1aa278c..56e0d620 100644 --- a/internal/setup/sql_upgrade_test.go +++ b/internal/setup/sql_upgrade_test.go @@ -22,7 +22,7 @@ func TestUpgradeSQLData(t *testing.T) { } -func TestUpgradeSQLData_v1_3_1(t *testing.T) { +func TestUpgradeSQLData_v0_3_1(t *testing.T) { db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{ Driver: "mysql", Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge_new?charset=utf8mb4&timeout=30s", @@ -38,7 +38,7 @@ func TestUpgradeSQLData_v1_3_1(t *testing.T) { t.Log("ok") } -func TestUpgradeSQLData_v1_3_2(t *testing.T) { +func TestUpgradeSQLData_v0_3_2(t *testing.T) { db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{ Driver: "mysql", Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s", @@ -52,4 +52,20 @@ func TestUpgradeSQLData_v1_3_2(t *testing.T) { t.Fatal(err) } t.Log("ok") +} + +func TestUpgradeSQLData_v0_3_3(t *testing.T) { + db, err := dbs.NewInstanceFromConfig(&dbs.DBConfig{ + Driver: "mysql", + Dsn: "root:123456@tcp(127.0.0.1:3306)/db_edge?charset=utf8mb4&timeout=30s", + Prefix: "edge", + }) + if err != nil { + t.Fatal(err) + } + err = upgradeV0_3_3(db) + if err != nil { + t.Fatal(err) + } + t.Log("ok") } \ No newline at end of file