From f51ad97b9cf6e9cbb519c20b8f563c2e8e86fa2e Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Wed, 2 Dec 2020 16:09:15 +0800 Subject: [PATCH] =?UTF-8?q?[WAF]=E8=A7=84=E5=88=99=E6=94=AF=E6=8C=81?= =?UTF-8?q?=E5=AF=BC=E5=85=A5=E5=AF=BC=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../services/service_http_firewall_policy.go | 156 ++++++++++++++++++ 1 file changed, 156 insertions(+) diff --git a/internal/rpc/services/service_http_firewall_policy.go b/internal/rpc/services/service_http_firewall_policy.go index 31720c72..d03351bd 100644 --- a/internal/rpc/services/service_http_firewall_policy.go +++ b/internal/rpc/services/service_http_firewall_policy.go @@ -375,3 +375,159 @@ func (this *HTTPFirewallPolicyService) FindEnabledFirewallPolicy(ctx context.Con OutboundJSON: []byte(policy.Outbound), }}, nil } + +// 导入策略数据 +func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Context, req *pb.ImportHTTPFirewallPolicyRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + // TODO 检查权限 + + oldConfig, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(req.FirewallPolicyId) + if err != nil { + return nil, err + } + if oldConfig == nil { + return nil, errors.New("can not find policy") + } + + // 解析数据 + newConfig := &firewallconfigs.HTTPFirewallPolicy{} + err = json.Unmarshal(req.FirewallPolicyJSON, newConfig) + if err != nil { + return nil, err + } + + // 入站分组 + if newConfig.Inbound != nil { + for _, g := range newConfig.Inbound.Groups { + if len(g.Code) > 0 { + // 对于有代号的,覆盖或者添加 + oldGroup := oldConfig.FindRuleGroupWithCode(g.Code) + if oldGroup == nil { + // 新创建分组 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Inbound.GroupRefs = append(oldConfig.Inbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } else { + setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} + for _, set := range g.Sets { + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(set) + if err != nil { + return nil, err + } + setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{ + IsOn: true, + SetId: setId, + }) + } + setsJSON, err := json.Marshal(setRefs) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(oldGroup.Id, true) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(oldGroup.Id, setsJSON) + if err != nil { + return nil, err + } + } + } else { + // 没有代号的直接创建 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Inbound.GroupRefs = append(oldConfig.Inbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } + } + } + + // 出站分组 + if newConfig.Outbound != nil { + for _, g := range newConfig.Outbound.Groups { + if len(g.Code) > 0 { + // 对于有代号的,覆盖或者添加 + oldGroup := oldConfig.FindRuleGroupWithCode(g.Code) + if oldGroup == nil { + // 新创建分组 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Outbound.GroupRefs = append(oldConfig.Outbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } else { + setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} + for _, set := range g.Sets { + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(set) + if err != nil { + return nil, err + } + setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{ + IsOn: true, + SetId: setId, + }) + } + setsJSON, err := json.Marshal(setRefs) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(oldGroup.Id, true) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(oldGroup.Id, setsJSON) + if err != nil { + return nil, err + } + } + } else { + // 没有代号的直接创建 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Outbound.GroupRefs = append(oldConfig.Outbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } + } + } + + // 保存Inbound和Outbound + oldConfig.Inbound.Groups = nil + oldConfig.Outbound.Groups = nil + + inboundJSON, err := json.Marshal(oldConfig.Inbound) + if err != nil { + return nil, err + } + + outboundJSON, err := json.Marshal(oldConfig.Outbound) + if err != nil { + return nil, err + } + + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(req.FirewallPolicyId, inboundJSON, outboundJSON) + if err != nil { + return nil, err + } + + return this.Success() +}