实现WAF策略部分功能

This commit is contained in:
GoEdgeLab
2020-10-06 21:02:15 +08:00
parent ae788f2e9f
commit 7e5869d5d5
10 changed files with 839 additions and 60 deletions

View File

@@ -1,9 +1,13 @@
package models
import (
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
)
const (
@@ -87,3 +91,159 @@ func (this *HTTPFirewallPolicyDAO) FindAllEnabledFirewallPolicies() (result []*H
FindAll()
return
}
// 创建策略
func (this *HTTPFirewallPolicyDAO) CreateFirewallPolicy(isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) (int64, error) {
op := NewHTTPFirewallPolicyOperator()
op.State = HTTPFirewallPolicyStateEnabled
op.IsOn = isOn
op.Name = name
op.Description = description
if len(inboundJSON) > 0 {
op.Inbound = inboundJSON
}
if len(outboundJSON) > 0 {
op.Outbound = outboundJSON
}
_, err := this.Save(op)
return types.Int64(op.Id), err
}
// 修改策略的Inbound和Outbound
func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(policyId int64, inboundJSON []byte, outboundJSON []byte) error {
if policyId <= 0 {
return errors.New("invalid policyId")
}
op := NewHTTPFirewallPolicyOperator()
op.Id = policyId
if len(inboundJSON) > 0 {
op.Inbound = inboundJSON
} else {
op.Inbound = "null"
}
if len(outboundJSON) > 0 {
op.Outbound = outboundJSON
} else {
op.Outbound = "null"
}
_, err := this.Save(op)
return err
}
// 修改策略
func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicy(policyId int64, isOn bool, name string, description string, inboundJSON []byte, outboundJSON []byte) error {
if policyId <= 0 {
return errors.New("invalid policyId")
}
op := NewHTTPFirewallPolicyOperator()
op.Id = policyId
op.IsOn = isOn
op.Name = name
op.Description = description
if len(inboundJSON) > 0 {
op.Inbound = inboundJSON
} else {
op.Inbound = "null"
}
if len(outboundJSON) > 0 {
op.Outbound = outboundJSON
} else {
op.Outbound = "null"
}
_, err := this.Save(op)
return err
}
// 计算所有可用的策略数量
func (this *HTTPFirewallPolicyDAO) CountAllEnabledFirewallPolicies() (int64, error) {
return this.Query().
State(HTTPFirewallPolicyStateEnabled).
Count()
}
// 列出单页的策略
func (this *HTTPFirewallPolicyDAO) ListEnabledFirewallPolicies(offset int64, size int64) (result []*HTTPFirewallPolicy, err error) {
_, err = this.Query().
State(HTTPFirewallPolicyStateEnabled).
Offset(offset).
Limit(size).
DescPk().
Slice(&result).
FindAll()
return
}
// 组合策略配置
func (this *HTTPFirewallPolicyDAO) ComposeFirewallPolicy(policyId int64) (*firewallconfigs.HTTPFirewallPolicy, error) {
policy, err := this.FindEnabledHTTPFirewallPolicy(policyId)
if err != nil {
return nil, err
}
if policy == nil {
return nil, nil
}
config := &firewallconfigs.HTTPFirewallPolicy{}
config.Id = int64(policy.Id)
config.IsOn = policy.IsOn == 1
config.Name = policy.Name
config.Description = policy.Description
// Inbound
if IsNotNull(policy.Inbound) {
inbound := &firewallconfigs.HTTPFirewallInboundConfig{}
err = json.Unmarshal([]byte(policy.Inbound), inbound)
if err != nil {
return nil, err
}
if len(inbound.GroupRefs) > 0 {
resultGroupRefs := []*firewallconfigs.HTTPFirewallRuleGroupRef{}
resultGroups := []*firewallconfigs.HTTPFirewallRuleGroup{}
for _, groupRef := range inbound.GroupRefs {
groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(groupRef.GroupId)
if err != nil {
return nil, err
}
if groupConfig != nil {
resultGroupRefs = append(resultGroupRefs, groupRef)
resultGroups = append(resultGroups, groupConfig)
}
}
inbound.GroupRefs = resultGroupRefs
inbound.Groups = resultGroups
}
config.Inbound = inbound
}
// Outbound
if IsNotNull(policy.Outbound) {
outbound := &firewallconfigs.HTTPFirewallOutboundConfig{}
err = json.Unmarshal([]byte(policy.Outbound), outbound)
if err != nil {
return nil, err
}
if len(outbound.GroupRefs) > 0 {
resultGroupRefs := []*firewallconfigs.HTTPFirewallRuleGroupRef{}
resultGroups := []*firewallconfigs.HTTPFirewallRuleGroup{}
for _, groupRef := range outbound.GroupRefs {
groupConfig, err := SharedHTTPFirewallRuleGroupDAO.ComposeFirewallRuleGroup(groupRef.GroupId)
if err != nil {
return nil, err
}
if groupConfig != nil {
resultGroupRefs = append(resultGroupRefs, groupRef)
resultGroups = append(resultGroups, groupConfig)
}
}
outbound.GroupRefs = resultGroupRefs
outbound.Groups = resultGroups
}
config.Outbound = outbound
}
return config, nil
}

View File

@@ -2,31 +2,31 @@ package models
// HTTP防火墙
type HTTPFirewallPolicy struct {
Id uint32 `field:"id"` // ID
TemplateId uint32 `field:"templateId"` // 模版ID
AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
IsOn uint8 `field:"isOn"` // 是否启用
Name string `field:"name"` // 名称
Inbound string `field:"inbound"` // 入站规则
Outbound string `field:"outbound"` // 站规则
Conds string `field:"conds"` // 条件
Id uint32 `field:"id"` // ID
TemplateId uint32 `field:"templateId"` // 模版ID
AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID
State uint8 `field:"state"` // 状态
CreatedAt uint64 `field:"createdAt"` // 创建时间
IsOn uint8 `field:"isOn"` // 是否启用
Name string `field:"name"` // 名称
Description string `field:"description"` // 描述
Inbound string `field:"inbound"` // 站规则
Outbound string `field:"outbound"` // 出站规则
}
type HTTPFirewallPolicyOperator struct {
Id interface{} // ID
TemplateId interface{} // 模版ID
AdminId interface{} // 管理员ID
UserId interface{} // 用户ID
State interface{} // 状态
CreatedAt interface{} // 创建时间
IsOn interface{} // 是否启用
Name interface{} // 名称
Inbound interface{} // 入站规则
Outbound interface{} // 站规则
Conds interface{} // 条件
Id interface{} // ID
TemplateId interface{} // 模版ID
AdminId interface{} // 管理员ID
UserId interface{} // 用户ID
State interface{} // 状态
CreatedAt interface{} // 创建时间
IsOn interface{} // 是否启用
Name interface{} // 名称
Description interface{} // 描述
Inbound interface{} // 站规则
Outbound interface{} // 出站规则
}
func NewHTTPFirewallPolicyOperator() *HTTPFirewallPolicyOperator {

View File

@@ -1,9 +1,12 @@
package models
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
)
const (
@@ -41,7 +44,7 @@ func (this *HTTPFirewallRuleDAO) Init() {
}
// 启用条目
func (this *HTTPFirewallRuleDAO) EnableHTTPFirewallRule(id uint32) error {
func (this *HTTPFirewallRuleDAO) EnableHTTPFirewallRule(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPFirewallRuleStateEnabled).
@@ -50,7 +53,7 @@ func (this *HTTPFirewallRuleDAO) EnableHTTPFirewallRule(id uint32) error {
}
// 禁用条目
func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(id uint32) error {
func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPFirewallRuleStateDisabled).
@@ -59,7 +62,7 @@ func (this *HTTPFirewallRuleDAO) DisableHTTPFirewallRule(id uint32) error {
}
// 查找启用中的条目
func (this *HTTPFirewallRuleDAO) FindEnabledHTTPFirewallRule(id uint32) (*HTTPFirewallRule, error) {
func (this *HTTPFirewallRuleDAO) FindEnabledHTTPFirewallRule(id int64) (*HTTPFirewallRule, error) {
result, err := this.Query().
Pk(id).
Attr("state", HTTPFirewallRuleStateEnabled).
@@ -69,3 +72,59 @@ func (this *HTTPFirewallRuleDAO) FindEnabledHTTPFirewallRule(id uint32) (*HTTPFi
}
return result.(*HTTPFirewallRule), err
}
// 组合配置
func (this *HTTPFirewallRuleDAO) ComposeFirewallRule(ruleId int64) (*firewallconfigs.HTTPFirewallRule, error) {
rule, err := this.FindEnabledHTTPFirewallRule(ruleId)
if err != nil {
return nil, err
}
if rule == nil {
return nil, nil
}
config := &firewallconfigs.HTTPFirewallRule{}
config.Id = int64(rule.Id)
config.IsOn = rule.IsOn == 1
config.Param = rule.Param
config.Operator = rule.Operator
config.Value = rule.Value
config.IsCaseInsensitive = rule.IsCaseInsensitive == 1
if IsNotNull(rule.CheckpointOptions) {
checkpointOptions := map[string]interface{}{}
err = json.Unmarshal([]byte(rule.CheckpointOptions), &checkpointOptions)
if err != nil {
return nil, err
}
config.CheckpointOptions = checkpointOptions
}
config.Description = rule.Description
return config, nil
}
// 从配置中配置规则
func (this *HTTPFirewallRuleDAO) CreateRuleFromConfig(ruleConfig *firewallconfigs.HTTPFirewallRule) (int64, error) {
op := NewHTTPFirewallRuleOperator()
op.State = HTTPFirewallRuleStateEnabled
op.IsOn = ruleConfig.IsOn
op.Description = ruleConfig.Description
op.Param = ruleConfig.Param
op.Value = ruleConfig.Value
op.IsCaseInsensitive = ruleConfig.IsCaseInsensitive
op.Operator = ruleConfig.Operator
if ruleConfig.CheckpointOptions != nil {
checkpointOptionsJSON, err := json.Marshal(ruleConfig.CheckpointOptions)
if err != nil {
return 0, err
}
op.CheckpointOptions = checkpointOptionsJSON
}
_, err := this.Save(op)
if err != nil {
return 0, err
}
return types.Int64(op.Id), nil
}

View File

@@ -1,9 +1,12 @@
package models
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/types"
)
const (
@@ -41,7 +44,7 @@ func (this *HTTPFirewallRuleGroupDAO) Init() {
}
// 启用条目
func (this *HTTPFirewallRuleGroupDAO) EnableHTTPFirewallRuleGroup(id uint32) error {
func (this *HTTPFirewallRuleGroupDAO) EnableHTTPFirewallRuleGroup(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPFirewallRuleGroupStateEnabled).
@@ -50,7 +53,7 @@ func (this *HTTPFirewallRuleGroupDAO) EnableHTTPFirewallRuleGroup(id uint32) err
}
// 禁用条目
func (this *HTTPFirewallRuleGroupDAO) DisableHTTPFirewallRuleGroup(id uint32) error {
func (this *HTTPFirewallRuleGroupDAO) DisableHTTPFirewallRuleGroup(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPFirewallRuleGroupStateDisabled).
@@ -59,7 +62,7 @@ func (this *HTTPFirewallRuleGroupDAO) DisableHTTPFirewallRuleGroup(id uint32) er
}
// 查找启用中的条目
func (this *HTTPFirewallRuleGroupDAO) FindEnabledHTTPFirewallRuleGroup(id uint32) (*HTTPFirewallRuleGroup, error) {
func (this *HTTPFirewallRuleGroupDAO) FindEnabledHTTPFirewallRuleGroup(id int64) (*HTTPFirewallRuleGroup, error) {
result, err := this.Query().
Pk(id).
Attr("state", HTTPFirewallRuleGroupStateEnabled).
@@ -71,9 +74,88 @@ func (this *HTTPFirewallRuleGroupDAO) FindEnabledHTTPFirewallRuleGroup(id uint32
}
// 根据主键查找名称
func (this *HTTPFirewallRuleGroupDAO) FindHTTPFirewallRuleGroupName(id uint32) (string, error) {
func (this *HTTPFirewallRuleGroupDAO) FindHTTPFirewallRuleGroupName(id int64) (string, error) {
return this.Query().
Pk(id).
Result("name").
FindStringCol("")
}
// 组合配置
func (this *HTTPFirewallRuleGroupDAO) ComposeFirewallRuleGroup(groupId int64) (*firewallconfigs.HTTPFirewallRuleGroup, error) {
group, err := this.FindEnabledHTTPFirewallRuleGroup(groupId)
if err != nil {
return nil, err
}
if group == nil {
return nil, nil
}
config := &firewallconfigs.HTTPFirewallRuleGroup{}
config.Id = int64(group.Id)
config.IsOn = group.IsOn == 1
config.Name = group.Name
config.Description = group.Description
config.Code = group.Code
if IsNotNull(group.Sets) {
setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{}
err = json.Unmarshal([]byte(group.Sets), &setRefs)
if err != nil {
return nil, err
}
for _, setRef := range setRefs {
setConfig, err := SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(setRef.SetId)
if err != nil {
return nil, err
}
if setConfig != nil {
config.SetRefs = append(config.SetRefs, setRef)
config.Sets = append(config.Sets, setConfig)
}
}
}
return config, nil
}
// 从配置中创建分组
func (this *HTTPFirewallRuleGroupDAO) CreateGroupFromConfig(groupConfig *firewallconfigs.HTTPFirewallRuleGroup) (int64, error) {
op := NewHTTPFirewallRuleGroupOperator()
op.IsOn = groupConfig.IsOn
op.Name = groupConfig.Name
op.Description = groupConfig.Description
op.State = HTTPFirewallRuleGroupStateEnabled
op.Code = groupConfig.Code
// sets
setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{}
for _, setConfig := range groupConfig.Sets {
setId, err := SharedHTTPFirewallRuleSetDAO.CreateSetFromConfig(setConfig)
if err != nil {
return 0, err
}
setRefs = append(setRefs, &firewallconfigs.HTTPFirewallRuleSetRef{
IsOn: true,
SetId: setId,
})
}
setRefsJSON, err := json.Marshal(setRefs)
if err != nil {
return 0, err
}
op.Sets = setRefsJSON
_, err = this.Save(op)
if err != nil {
return 0, err
}
return types.Int64(op.Id), nil
}
// 修改开启状态
func (this *HTTPFirewallRuleGroupDAO) UpdateGroupIsOn(groupId int64, isOn bool) error {
_, err := this.Query().
Pk(groupId).
Set("isOn", isOn).
Update()
return err
}

View File

@@ -4,6 +4,7 @@ package models
type HTTPFirewallRule struct {
Id uint32 `field:"id"` // ID
IsOn uint8 `field:"isOn"` // 是否启用
Description string `field:"description"` // 说明
Param string `field:"param"` // 参数
Operator string `field:"operator"` // 操作符
Value string `field:"value"` // 对比值
@@ -18,6 +19,7 @@ type HTTPFirewallRule struct {
type HTTPFirewallRuleOperator struct {
Id interface{} // ID
IsOn interface{} // 是否启用
Description interface{} // 说明
Param interface{} // 参数
Operator interface{} // 操作符
Value interface{} // 对比值

View File

@@ -1,9 +1,13 @@
package models
import (
"encoding/json"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
"github.com/iwind/TeaGo/dbs"
"github.com/iwind/TeaGo/maps"
"github.com/iwind/TeaGo/types"
)
const (
@@ -41,7 +45,7 @@ func (this *HTTPFirewallRuleSetDAO) Init() {
}
// 启用条目
func (this *HTTPFirewallRuleSetDAO) EnableHTTPFirewallRuleSet(id uint32) error {
func (this *HTTPFirewallRuleSetDAO) EnableHTTPFirewallRuleSet(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPFirewallRuleSetStateEnabled).
@@ -50,7 +54,7 @@ func (this *HTTPFirewallRuleSetDAO) EnableHTTPFirewallRuleSet(id uint32) error {
}
// 禁用条目
func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(id uint32) error {
func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(id int64) error {
_, err := this.Query().
Pk(id).
Set("state", HTTPFirewallRuleSetStateDisabled).
@@ -59,7 +63,7 @@ func (this *HTTPFirewallRuleSetDAO) DisableHTTPFirewallRuleSet(id uint32) error
}
// 查找启用中的条目
func (this *HTTPFirewallRuleSetDAO) FindEnabledHTTPFirewallRuleSet(id uint32) (*HTTPFirewallRuleSet, error) {
func (this *HTTPFirewallRuleSetDAO) FindEnabledHTTPFirewallRuleSet(id int64) (*HTTPFirewallRuleSet, error) {
result, err := this.Query().
Pk(id).
Attr("state", HTTPFirewallRuleSetStateEnabled).
@@ -71,9 +75,100 @@ func (this *HTTPFirewallRuleSetDAO) FindEnabledHTTPFirewallRuleSet(id uint32) (*
}
// 根据主键查找名称
func (this *HTTPFirewallRuleSetDAO) FindHTTPFirewallRuleSetName(id uint32) (string, error) {
func (this *HTTPFirewallRuleSetDAO) FindHTTPFirewallRuleSetName(id int64) (string, error) {
return this.Query().
Pk(id).
Result("name").
FindStringCol("")
}
// 组合配置
func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(setId int64) (*firewallconfigs.HTTPFirewallRuleSet, error) {
set, err := this.FindEnabledHTTPFirewallRuleSet(setId)
if err != nil {
return nil, err
}
if set == nil {
return nil, nil
}
config := &firewallconfigs.HTTPFirewallRuleSet{}
config.Id = int64(set.Id)
config.IsOn = set.IsOn == 1
config.Name = set.Name
config.Description = set.Description
config.Code = set.Code
config.Connector = set.Connector
if IsNotNull(set.Rules) {
ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{}
err = json.Unmarshal([]byte(set.Rules), &ruleRefs)
if err != nil {
return nil, err
}
for _, ruleRef := range ruleRefs {
ruleConfig, err := SharedHTTPFirewallRuleDAO.ComposeFirewallRule(ruleRef.RuleId)
if err != nil {
return nil, err
}
if ruleConfig != nil {
config.RuleRefs = append(config.RuleRefs, ruleRef)
config.Rules = append(config.Rules, ruleConfig)
}
}
}
config.Action = set.Action
if IsNotNull(set.ActionOptions) {
options := maps.Map{}
err = json.Unmarshal([]byte(set.ActionOptions), &options)
if err != nil {
return nil, err
}
config.ActionOptions = options
}
return config, nil
}
// 从配置中创建规则集
func (this *HTTPFirewallRuleSetDAO) CreateSetFromConfig(setConfig *firewallconfigs.HTTPFirewallRuleSet) (int64, error) {
op := NewHTTPFirewallRuleSetOperator()
op.State = HTTPFirewallRuleSetStateEnabled
op.IsOn = setConfig.IsOn
op.Name = setConfig.Name
op.Description = setConfig.Description
op.Connector = setConfig.Connector
op.Action = setConfig.Action
op.Code = setConfig.Code
if setConfig.ActionOptions != nil {
actionOptionsJSON, err := json.Marshal(setConfig.ActionOptions)
if err != nil {
return 0, err
}
op.ActionOptions = actionOptionsJSON
}
// rules
ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{}
for _, ruleConfig := range setConfig.Rules {
ruleId, err := SharedHTTPFirewallRuleDAO.CreateRuleFromConfig(ruleConfig)
if err != nil {
return 0, err
}
ruleRefs = append(ruleRefs, &firewallconfigs.HTTPFirewallRuleRef{
IsOn: true,
RuleId: ruleId,
})
}
ruleRefsJSON, err := json.Marshal(ruleRefs)
if err != nil {
return 0, err
}
op.Rules = ruleRefsJSON
_, err = this.Save(op)
if err != nil {
return 0, err
}
return types.Int64(op.Id), nil
}

View File

@@ -2,31 +2,35 @@ package models
// 防火墙规则集
type HTTPFirewallRuleSet struct {
Id uint32 `field:"id"` // ID
IsOn uint8 `field:"isOn"` // 是否启用
Code string `field:"code"` // 代号
Name string `field:"name"` // 名称
Description string `field:"description"` // 描述
CreatedAt uint64 `field:"createdAt"` // 创建时间
Rules string `field:"rules"` // 规则列表
Connector string `field:"connector"` // 规则之间的关系
State uint8 `field:"state"` // 状态
AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID
Id uint32 `field:"id"` // ID
IsOn uint8 `field:"isOn"` // 是否启用
Code string `field:"code"` // 代号
Name string `field:"name"` // 名称
Description string `field:"description"` // 描述
CreatedAt uint64 `field:"createdAt"` // 创建时间
Rules string `field:"rules"` // 规则列表
Connector string `field:"connector"` // 规则之间的关系
State uint8 `field:"state"` // 状态
AdminId uint32 `field:"adminId"` // 管理员ID
UserId uint32 `field:"userId"` // 用户ID
Action string `field:"action"` // 执行的动作
ActionOptions string `field:"actionOptions"` // 动作的选项
}
type HTTPFirewallRuleSetOperator struct {
Id interface{} // ID
IsOn interface{} // 是否启用
Code interface{} // 代号
Name interface{} // 名称
Description interface{} // 描述
CreatedAt interface{} // 创建时间
Rules interface{} // 规则列表
Connector interface{} // 规则之间的关系
State interface{} // 状态
AdminId interface{} // 管理员ID
UserId interface{} // 用户ID
Id interface{} // ID
IsOn interface{} // 是否启用
Code interface{} // 代号
Name interface{} // 名称
Description interface{} // 描述
CreatedAt interface{} // 创建时间
Rules interface{} // 规则列表
Connector interface{} // 规则之间的关系
State interface{} // 状态
AdminId interface{} // 管理员ID
UserId interface{} // 用户ID
Action interface{} // 执行的动作
ActionOptions interface{} // 动作的选项
}
func NewHTTPFirewallRuleSetOperator() *HTTPFirewallRuleSetOperator {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"errors"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/shared"
_ "github.com/go-sql-driver/mysql"
"github.com/iwind/TeaGo/Tea"
@@ -236,7 +237,7 @@ func (this *HTTPWebDAO) ComposeWebConfig(webId int64) (*serverconfigs.HTTPWebCon
// 防火墙配置
if IsNotNull(web.Firewall) {
firewallRef := &serverconfigs.HTTPFirewallRef{}
firewallRef := &firewallconfigs.HTTPFirewallRef{}
err = json.Unmarshal([]byte(web.Firewall), firewallRef)
if err != nil {
return nil, err
@@ -510,7 +511,51 @@ func (this *HTTPWebDAO) FindAllWebIdsWithCachePolicyId(cachePolicyId int64) ([]i
ones, err := this.Query().
State(HTTPWebStateEnabled).
ResultPk().
Where(`JSON_CONTAINS(cache, '{"cachePolicyId": ` + strconv.FormatInt(cachePolicyId, 10) + ` }')`).
Where(`JSON_CONTAINS(cache, '{"cachePolicyId": ` + strconv.FormatInt(cachePolicyId, 10) + ` }', '$.cacheRefs')`).
Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用
FindAll()
if err != nil {
return nil, err
}
result := []int64{}
for _, one := range ones {
webId := int64(one.(*HTTPWeb).Id)
// 判断是否为Location
for {
locationId, err := SharedHTTPLocationDAO.FindEnabledLocationIdWithWebId(webId)
if err != nil {
return nil, err
}
// 如果非Location
if locationId == 0 {
if !this.containsInt64(result, webId) {
result = append(result, webId)
}
break
}
// 查找包含此Location的Web
// TODO 需要支持嵌套的Location查询
webId, err = this.FindEnabledWebIdWithLocationId(locationId)
if err != nil {
return nil, err
}
if webId == 0 {
break
}
}
}
return result, nil
}
// 根据防火墙策略ID查找所有的WebId
func (this *HTTPWebDAO) FindAllWebIdsWithHTTPFirewallPolicyId(firewallPolicyId int64) ([]int64, error) {
ones, err := this.Query().
State(HTTPWebStateEnabled).
ResultPk().
Where(`JSON_CONTAINS(firewall, '{"firewallPolicyId": ` + strconv.FormatInt(firewallPolicyId, 10) + ` }')`).
Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用
FindAll()
if err != nil {

View File

@@ -2,11 +2,16 @@ package services
import (
"context"
"encoding/json"
"github.com/TeaOSLab/EdgeAPI/internal/db/models"
"github.com/TeaOSLab/EdgeAPI/internal/errors"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
"github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs"
"github.com/iwind/TeaGo/lists"
)
// HTTP防火墙WAF相关服务
type HTTPFirewallPolicyService struct {
}
@@ -26,11 +31,314 @@ func (this *HTTPFirewallPolicyService) FindAllEnabledHTTPFirewallPolicies(ctx co
result := []*pb.HTTPFirewallPolicy{}
for _, p := range policies {
result = append(result, &pb.HTTPFirewallPolicy{
Id: int64(p.Id),
Name: p.Name,
IsOn: p.IsOn == 1,
Id: int64(p.Id),
Name: p.Name,
Description: p.Description,
IsOn: p.IsOn == 1,
InboundJSON: []byte(p.Inbound),
OutboundJSON: []byte(p.Outbound),
})
}
return &pb.FindAllEnabledHTTPFirewallPoliciesResponse{FirewallPolicies: result}, nil
}
// 创建防火墙策略
func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Context, req *pb.CreateHTTPFirewallPolicyRequest) (*pb.CreateHTTPFirewallPolicyResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(req.IsOn, req.Name, req.Description, nil, nil)
if err != nil {
return nil, err
}
// 初始化
inboundConfig := &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
outboundConfig := &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true}
templatePolicy := firewallconfigs.HTTPFirewallTemplate()
if templatePolicy.Inbound != nil {
for _, group := range templatePolicy.Inbound.Groups {
isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group)
if err != nil {
return nil, err
}
inboundConfig.GroupRefs = append(inboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
if templatePolicy.Outbound != nil {
for _, group := range templatePolicy.Outbound.Groups {
isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group)
if err != nil {
return nil, err
}
outboundConfig.GroupRefs = append(outboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
inboundConfigJSON, err := json.Marshal(inboundConfig)
if err != nil {
return nil, err
}
outboundConfigJSON, err := json.Marshal(outboundConfig)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(policyId, inboundConfigJSON, outboundConfigJSON)
if err != nil {
return nil, err
}
return &pb.CreateHTTPFirewallPolicyResponse{FirewallPolicyId: policyId}, nil
}
// 修改防火墙策略
func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicy(ctx context.Context, req *pb.UpdateHTTPFirewallPolicyRequest) (*pb.RPCUpdateSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
templatePolicy := firewallconfigs.HTTPFirewallTemplate()
// 已经有的数据
firewallPolicy, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(req.FirewallPolicyId)
if err != nil {
return nil, err
}
if firewallPolicy == nil {
return nil, errors.New("can not found firewall policy")
}
inboundConfig := firewallPolicy.Inbound
if inboundConfig == nil {
inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true}
}
outboundConfig := firewallPolicy.Outbound
if outboundConfig == nil {
outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true}
}
// 更新老的
oldCodes := []string{}
if firewallPolicy.Inbound != nil {
for _, g := range firewallPolicy.Inbound.Groups {
if len(g.Code) > 0 {
oldCodes = append(oldCodes, g.Code)
if lists.ContainsString(req.FirewallGroupCodes, g.Code) {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, true)
if err != nil {
return nil, err
}
} else {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, false)
if err != nil {
return nil, err
}
}
}
}
}
if firewallPolicy.Outbound != nil {
for _, g := range firewallPolicy.Outbound.Groups {
if len(g.Code) > 0 {
oldCodes = append(oldCodes, g.Code)
if lists.ContainsString(req.FirewallGroupCodes, g.Code) {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, true)
if err != nil {
return nil, err
}
} else {
err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupIsOn(g.Id, false)
if err != nil {
return nil, err
}
}
}
}
}
// 加入新的
if templatePolicy.Inbound != nil {
for _, group := range templatePolicy.Inbound.Groups {
if lists.ContainsString(oldCodes, group.Code) {
continue
}
isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group)
if err != nil {
return nil, err
}
inboundConfig.GroupRefs = append(inboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
if templatePolicy.Outbound != nil {
for _, group := range templatePolicy.Outbound.Groups {
if lists.ContainsString(oldCodes, group.Code) {
continue
}
isOn := lists.ContainsString(req.FirewallGroupCodes, group.Code)
group.IsOn = isOn
groupId, err := models.SharedHTTPFirewallRuleGroupDAO.CreateGroupFromConfig(group)
if err != nil {
return nil, err
}
outboundConfig.GroupRefs = append(outboundConfig.GroupRefs, &firewallconfigs.HTTPFirewallRuleGroupRef{
IsOn: true,
GroupId: groupId,
})
}
}
inboundConfigJSON, err := json.Marshal(inboundConfig)
if err != nil {
return nil, err
}
outboundConfigJSON, err := json.Marshal(outboundConfig)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicy(req.FirewallPolicyId, req.IsOn, req.Name, req.Description, inboundConfigJSON, outboundConfigJSON)
if err != nil {
return nil, err
}
return rpcutils.RPCUpdateSuccess()
}
// 计算可用的防火墙策略数量
func (this *HTTPFirewallPolicyService) CountAllEnabledFirewallPolicies(ctx context.Context, req *pb.CountAllEnabledFirewallPoliciesRequest) (*pb.CountAllEnabledFirewallPoliciesResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
count, err := models.SharedHTTPFirewallPolicyDAO.CountAllEnabledFirewallPolicies()
if err != nil {
return nil, err
}
return &pb.CountAllEnabledFirewallPoliciesResponse{Count: count}, nil
}
// 列出单页的防火墙策略
func (this *HTTPFirewallPolicyService) ListEnabledFirewallPolicies(ctx context.Context, req *pb.ListEnabledFirewallPoliciesRequest) (*pb.ListEnabledFirewallPoliciesResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
policies, err := models.SharedHTTPFirewallPolicyDAO.ListEnabledFirewallPolicies(req.Offset, req.Size)
if err != nil {
return nil, err
}
result := []*pb.HTTPFirewallPolicy{}
for _, p := range policies {
result = append(result, &pb.HTTPFirewallPolicy{
Id: int64(p.Id),
Name: p.Name,
Description: p.Description,
IsOn: p.IsOn == 1,
InboundJSON: []byte(p.Inbound),
OutboundJSON: []byte(p.Outbound),
})
}
return &pb.ListEnabledFirewallPoliciesResponse{FirewallPolicies: result}, nil
}
// 删除某个防火墙策略
func (this *HTTPFirewallPolicyService) DeleteFirewallPolicy(ctx context.Context, req *pb.DeleteFirewallPolicyRequest) (*pb.RPCDeleteSuccess, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
err = models.SharedHTTPFirewallPolicyDAO.DisableHTTPFirewallPolicy(req.FirewallPolicyId)
if err != nil {
return nil, err
}
return rpcutils.RPCDeleteSuccess()
}
// 查找单个防火墙配置
func (this *HTTPFirewallPolicyService) FindEnabledFirewallPolicyConfig(ctx context.Context, req *pb.FindEnabledFirewallPolicyConfigRequest) (*pb.FindEnabledFirewallPolicyConfigResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
config, err := models.SharedHTTPFirewallPolicyDAO.ComposeFirewallPolicy(req.FirewallPolicyId)
if err != nil {
return nil, err
}
if config == nil {
return &pb.FindEnabledFirewallPolicyConfigResponse{FirewallPolicyJSON: nil}, nil
}
configJSON, err := json.Marshal(config)
if err != nil {
return nil, err
}
return &pb.FindEnabledFirewallPolicyConfigResponse{FirewallPolicyJSON: configJSON}, nil
}
// 获取防火墙的基本信息
func (this *HTTPFirewallPolicyService) FindEnabledFirewallPolicy(ctx context.Context, req *pb.FindEnabledFirewallPolicyRequest) (*pb.FindEnabledFirewallPolicyResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
policy, err := models.SharedHTTPFirewallPolicyDAO.FindEnabledHTTPFirewallPolicy(req.FirewallPolicyId)
if err != nil {
return nil, err
}
if policy == nil {
return &pb.FindEnabledFirewallPolicyResponse{FirewallPolicy: nil}, nil
}
return &pb.FindEnabledFirewallPolicyResponse{FirewallPolicy: &pb.HTTPFirewallPolicy{
Id: int64(policy.Id),
Name: policy.Name,
Description: policy.Description,
IsOn: policy.IsOn == 1,
InboundJSON: []byte(policy.Inbound),
OutboundJSON: []byte(policy.Outbound),
}}, nil
}

View File

@@ -693,3 +693,27 @@ func (this *ServerService) FindAllEnabledServersWithCachePolicyId(ctx context.Co
}
return &pb.FindAllEnabledServersWithCachePolicyIdResponse{Servers: result}, nil
}
// 计算使用某个WAF策略的服务数量
func (this *ServerService) CountAllEnabledServersWithHTTPFirewallPolicyId(ctx context.Context, req *pb.CountAllEnabledServersWithHTTPFirewallPolicyIdRequest) (*pb.CountAllEnabledServersWithHTTPFirewallPolicyIdResponse, error) {
// 校验请求
_, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin)
if err != nil {
return nil, err
}
webIds, err := models.SharedHTTPWebDAO.FindAllWebIdsWithCachePolicyId(req.FirewallPolicyId)
if err != nil {
return nil, err
}
if len(webIds) == 0 {
return &pb.CountAllEnabledServersWithHTTPFirewallPolicyIdResponse{Count: 0}, nil
}
countServers, err := models.SharedServerDAO.CountEnabledServersWithWebIds(webIds)
if err != nil {
return nil, err
}
return &pb.CountAllEnabledServersWithHTTPFirewallPolicyIdResponse{Count: countServers}, nil
}