Files
EdgeCommon/pkg/rpc/dao/http_firewall_policy_dao.go
2021-01-03 20:18:21 +08:00

183 lines
5.5 KiB
Go

package dao
import (
"context"
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs"
)
var SharedHTTPFirewallPolicyDAO = new(HTTPFirewallPolicyDAO)
// WAF策略相关
type HTTPFirewallPolicyDAO struct {
BaseDAO
}
// 查找WAF策略基本信息
func (this *HTTPFirewallPolicyDAO) FindEnabledHTTPFirewallPolicy(ctx context.Context, policyId int64) (*pb.HTTPFirewallPolicy, error) {
resp, err := this.RPC().HTTPFirewallPolicyRPC().FindEnabledHTTPFirewallPolicy(ctx, &pb.FindEnabledHTTPFirewallPolicyRequest{HttpFirewallPolicyId: policyId})
if err != nil {
return nil, err
}
return resp.HttpFirewallPolicy, nil
}
// 查找WAF策略配置
func (this *HTTPFirewallPolicyDAO) FindEnabledHTTPFirewallPolicyConfig(ctx context.Context, policyId int64) (*firewallconfigs.HTTPFirewallPolicy, error) {
resp, err := this.RPC().HTTPFirewallPolicyRPC().FindEnabledHTTPFirewallPolicyConfig(ctx, &pb.FindEnabledHTTPFirewallPolicyConfigRequest{HttpFirewallPolicyId: policyId})
if err != nil {
return nil, err
}
if len(resp.HttpFirewallPolicyJSON) == 0 {
return nil, nil
}
firewallPolicy := &firewallconfigs.HTTPFirewallPolicy{}
err = json.Unmarshal(resp.HttpFirewallPolicyJSON, firewallPolicy)
if err != nil {
return nil, err
}
return firewallPolicy, nil
}
// 查找WAF的Inbound
func (this *HTTPFirewallPolicyDAO) FindEnabledHTTPFirewallPolicyInboundConfig(ctx context.Context, policyId int64) (*firewallconfigs.HTTPFirewallInboundConfig, error) {
config, err := this.FindEnabledHTTPFirewallPolicyConfig(ctx, policyId)
if err != nil {
return nil, err
}
if config == nil {
return nil, errors.New("not found")
}
return config.Inbound, nil
}
// 根据类型查找WAF的IP名单
func (this *HTTPFirewallPolicyDAO) FindEnabledPolicyIPListIdWithType(ctx context.Context, policyId int64, listType ipconfigs.IPListType) (int64, error) {
switch listType {
case ipconfigs.IPListTypeWhite:
return this.FindEnabledPolicyWhiteIPListId(ctx, policyId)
case ipconfigs.IPListTypeBlack:
return this.FindEnabledPolicyBlackIPListId(ctx, policyId)
default:
return 0, errors.New("invalid ip list type '" + listType + "'")
}
}
// 查找WAF的白名单
func (this *HTTPFirewallPolicyDAO) FindEnabledPolicyWhiteIPListId(ctx context.Context, policyId int64) (int64, error) {
config, err := this.FindEnabledHTTPFirewallPolicyConfig(ctx, policyId)
if err != nil {
return 0, err
}
if config == nil {
return 0, errors.New("not found")
}
if config.Inbound == nil {
config.Inbound = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
}
if config.Inbound.AllowListRef == nil || config.Inbound.AllowListRef.ListId == 0 {
createResp, err := this.RPC().IPListRPC().CreateIPList(ctx, &pb.CreateIPListRequest{
Type: "white",
Name: "白名单",
Code: "white",
TimeoutJSON: nil,
})
if err != nil {
return 0, err
}
listId := createResp.IpListId
config.Inbound.AllowListRef = &ipconfigs.IPListRef{
IsOn: true,
ListId: listId,
}
inboundJSON, err := json.Marshal(config.Inbound)
if err != nil {
return 0, err
}
_, err = this.RPC().HTTPFirewallPolicyRPC().UpdateHTTPFirewallInboundConfig(ctx, &pb.UpdateHTTPFirewallInboundConfigRequest{
HttpFirewallPolicyId: policyId,
InboundJSON: inboundJSON,
})
if err != nil {
return 0, err
}
return listId, nil
}
return config.Inbound.AllowListRef.ListId, nil
}
// 查找WAF的黑名单
func (this *HTTPFirewallPolicyDAO) FindEnabledPolicyBlackIPListId(ctx context.Context, policyId int64) (int64, error) {
config, err := this.FindEnabledHTTPFirewallPolicyConfig(ctx, policyId)
if err != nil {
return 0, err
}
if config == nil {
return 0, errors.New("not found")
}
if config.Inbound == nil {
config.Inbound = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
}
if config.Inbound.DenyListRef == nil || config.Inbound.DenyListRef.ListId == 0 {
createResp, err := this.RPC().IPListRPC().CreateIPList(ctx, &pb.CreateIPListRequest{
Type: "black",
Name: "黑名单",
Code: "black",
TimeoutJSON: nil,
})
if err != nil {
return 0, err
}
listId := createResp.IpListId
config.Inbound.DenyListRef = &ipconfigs.IPListRef{
IsOn: true,
ListId: listId,
}
inboundJSON, err := json.Marshal(config.Inbound)
if err != nil {
return 0, err
}
_, err = this.RPC().HTTPFirewallPolicyRPC().UpdateHTTPFirewallInboundConfig(ctx, &pb.UpdateHTTPFirewallInboundConfigRequest{
HttpFirewallPolicyId: policyId,
InboundJSON: inboundJSON,
})
if err != nil {
return 0, err
}
return listId, nil
}
return config.Inbound.DenyListRef.ListId, nil
}
// 根据服务Id查找WAF策略
func (this *HTTPFirewallPolicyDAO) FindEnabledHTTPFirewallPolicyWithServerId(ctx context.Context, serverId int64) (*pb.HTTPFirewallPolicy, error) {
serverResp, err := this.RPC().ServerRPC().FindEnabledServer(ctx, &pb.FindEnabledServerRequest{ServerId: serverId})
if err != nil {
return nil, err
}
server := serverResp.Server
if server == nil {
return nil, nil
}
if server.NodeCluster == nil {
return nil, nil
}
clusterId := server.NodeCluster.Id
cluster, err := SharedNodeClusterDAO.FindEnabledNodeCluster(ctx, clusterId)
if err != nil {
return nil, err
}
if cluster == nil {
return nil, nil
}
if cluster.HttpFirewallPolicyId == 0 {
return nil, nil
}
return SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicy(ctx, cluster.HttpFirewallPolicyId)
}