diff --git a/internal/rpc/services/service_http_firewall_policy.go b/internal/rpc/services/service_http_firewall_policy.go index 31720c72..d03351bd 100644 --- a/internal/rpc/services/service_http_firewall_policy.go +++ b/internal/rpc/services/service_http_firewall_policy.go @@ -375,3 +375,159 @@ func (this *HTTPFirewallPolicyService) FindEnabledFirewallPolicy(ctx context.Con OutboundJSON: []byte(policy.Outbound), }}, nil } + +// 导入策略数据 +func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Context, req *pb.ImportHTTPFirewallPolicyRequest) (*pb.RPCSuccess, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + // TODO 检查权限 + + oldConfig, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(req.FirewallPolicyId) + if err != nil { + return nil, err + } + if oldConfig == nil { + return nil, errors.New("can not find policy") + } + + // 解析数据 + newConfig := &firewallconfigs.HTTPFirewallPolicy{} + err = json.Unmarshal(req.FirewallPolicyJSON, newConfig) + if err != nil { + return nil, err + } + + // 入站分组 + if newConfig.Inbound != nil { + for _, g := range newConfig.Inbound.Groups { + if len(g.Code) > 0 { + // 对于有代号的,覆盖或者添加 + oldGroup := oldConfig.FindRuleGroupWithCode(g.Code) + if oldGroup == nil { + // 新创建分组 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Inbound.GroupRefs = append(oldConfig.Inbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } else { + setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} + for _, set := range g.Sets { + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(set) + if err != nil { + return nil, err + } + setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{ + IsOn: true, + SetId: setId, + }) + } + setsJSON, err := json.Marshal(setRefs) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(oldGroup.Id, true) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(oldGroup.Id, setsJSON) + if err != nil { + return nil, err + } + } + } else { + // 没有代号的直接创建 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Inbound.GroupRefs = append(oldConfig.Inbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } + } + } + + // 出站分组 + if newConfig.Outbound != nil { + for _, g := range newConfig.Outbound.Groups { + if len(g.Code) > 0 { + // 对于有代号的,覆盖或者添加 + oldGroup := oldConfig.FindRuleGroupWithCode(g.Code) + if oldGroup == nil { + // 新创建分组 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Outbound.GroupRefs = append(oldConfig.Outbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } else { + setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} + for _, set := range g.Sets { + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(set) + if err != nil { + return nil, err + } + setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{ + IsOn: true, + SetId: setId, + }) + } + setsJSON, err := json.Marshal(setRefs) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(oldGroup.Id, true) + if err != nil { + return nil, err + } + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(oldGroup.Id, setsJSON) + if err != nil { + return nil, err + } + } + } else { + // 没有代号的直接创建 + groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(g) + if err != nil { + return nil, err + } + oldConfig.Outbound.GroupRefs = append(oldConfig.Outbound.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{ + IsOn: true, + GroupId: groupId, + }) + } + } + } + + // 保存Inbound和Outbound + oldConfig.Inbound.Groups = nil + oldConfig.Outbound.Groups = nil + + inboundJSON, err := json.Marshal(oldConfig.Inbound) + if err != nil { + return nil, err + } + + outboundJSON, err := json.Marshal(oldConfig.Outbound) + if err != nil { + return nil, err + } + + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(req.FirewallPolicyId, inboundJSON, outboundJSON) + if err != nil { + return nil, err + } + + return this.Success() +}