diff --git a/internal/db/models/http_firewall_rule_dao.go b/internal/db/models/http_firewall_rule_dao.go index 3d5a2158..17430f99 100644 --- a/internal/db/models/http_firewall_rule_dao.go +++ b/internal/db/models/http_firewall_rule_dao.go @@ -105,8 +105,9 @@ func (this *HTTPFirewallRuleDAO) ComposeFirewallRule(ruleId int64) (*firewallcon } // 从配置中配置规则 -func (this *HTTPFirewallRuleDAO) CreateRuleFromConfig(ruleConfig *firewallconfigs.HTTPFirewallRule) (int64, error) { +func (this *HTTPFirewallRuleDAO) CreateOrUpdateRuleFromConfig(ruleConfig *firewallconfigs.HTTPFirewallRule) (int64, error) { op := NewHTTPFirewallRuleOperator() + op.Id = ruleConfig.Id op.State = HTTPFirewallRuleStateEnabled op.IsOn = ruleConfig.IsOn op.Description = ruleConfig.Description diff --git a/internal/db/models/http_firewall_rule_group_dao.go b/internal/db/models/http_firewall_rule_group_dao.go index 6a7af1cc..57f11c5d 100644 --- a/internal/db/models/http_firewall_rule_group_dao.go +++ b/internal/db/models/http_firewall_rule_group_dao.go @@ -131,7 +131,7 @@ func (this *HTTPFirewallRuleGroupDAO) CreateGroupFromConfig(groupConfig *firewal // sets setRefs := []*firewallconfigs.HTTPFirewallRuleSetRef{} for _, setConfig := range groupConfig.Sets { - setId, err := SharedHTTPFirewallRuleSetDAO.CreateSetFromConfig(setConfig) + setId, err := SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(setConfig) if err != nil { return 0, err } @@ -188,3 +188,15 @@ func (this *HTTPFirewallRuleGroupDAO) UpdateGroup(groupId int64, isOn bool, name _, err := this.Save(op) return err } + +// 修改分组中的规则集 +func (this *HTTPFirewallRuleGroupDAO) UpdateGroupSets(groupId int64, setsJSON []byte) error { + if groupId <= 0 { + return errors.New("invalid groupId") + } + op := NewHTTPFirewallRuleGroupOperator() + op.Id = groupId + op.Sets = setsJSON + _, err := this.Save(op) + return err +} diff --git a/internal/db/models/http_firewall_rule_set_dao.go b/internal/db/models/http_firewall_rule_set_dao.go index 06bca229..095c95d6 100644 --- a/internal/db/models/http_firewall_rule_set_dao.go +++ b/internal/db/models/http_firewall_rule_set_dao.go @@ -2,6 +2,7 @@ package models import ( "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" @@ -131,9 +132,10 @@ func (this *HTTPFirewallRuleSetDAO) ComposeFirewallRuleSet(setId int64) (*firewa } // 从配置中创建规则集 -func (this *HTTPFirewallRuleSetDAO) CreateSetFromConfig(setConfig *firewallconfigs.HTTPFirewallRuleSet) (int64, error) { +func (this *HTTPFirewallRuleSetDAO) CreateOrUpdateSetFromConfig(setConfig *firewallconfigs.HTTPFirewallRuleSet) (int64, error) { op := NewHTTPFirewallRuleSetOperator() op.State = HTTPFirewallRuleSetStateEnabled + op.Id = setConfig.Id op.IsOn = setConfig.IsOn op.Name = setConfig.Name op.Description = setConfig.Description @@ -147,12 +149,14 @@ func (this *HTTPFirewallRuleSetDAO) CreateSetFromConfig(setConfig *firewallconfi return 0, err } op.ActionOptions = actionOptionsJSON + } else { + op.ActionOptions = "{}" } // rules ruleRefs := []*firewallconfigs.HTTPFirewallRuleRef{} for _, ruleConfig := range setConfig.Rules { - ruleId, err := SharedHTTPFirewallRuleDAO.CreateRuleFromConfig(ruleConfig) + ruleId, err := SharedHTTPFirewallRuleDAO.CreateOrUpdateRuleFromConfig(ruleConfig) if err != nil { return 0, err } @@ -172,3 +176,15 @@ func (this *HTTPFirewallRuleSetDAO) CreateSetFromConfig(setConfig *firewallconfi } return types.Int64(op.Id), nil } + +// 设置是否启用 +func (this *HTTPFirewallRuleSetDAO) UpdateRuleSetIsOn(ruleSetId int64, isOn bool) error { + if ruleSetId <= 0 { + return errors.New("invalid ruleSetId") + } + _, err := this.Query(). + Pk(ruleSetId). + Set("isOn", isOn). + Update() + return err +} diff --git a/internal/nodes/api_node.go b/internal/nodes/api_node.go index ae064c7f..c4791fb1 100644 --- a/internal/nodes/api_node.go +++ b/internal/nodes/api_node.go @@ -162,6 +162,7 @@ func (this *APINode) listenRPC(listener net.Listener, tlsConfig *tls.Config) err pb.RegisterSSLPolicyServiceServer(rpcServer, &services.SSLPolicyService{}) pb.RegisterSysSettingServiceServer(rpcServer, &services.SysSettingService{}) pb.RegisterHTTPFirewallRuleGroupServiceServer(rpcServer, &services.HTTPFirewallRuleGroupService{}) + pb.RegisterHTTPFirewallRuleSetServiceServer(rpcServer, &services.HTTPFirewallRuleSetService{}) err := rpcServer.Serve(listener) if err != nil { return errors.New("[API]start rpc failed: " + err.Error()) diff --git a/internal/rpc/services/service_http_firewall_rule_group.go b/internal/rpc/services/service_http_firewall_rule_group.go index 3efdb66d..71370f91 100644 --- a/internal/rpc/services/service_http_firewall_rule_group.go +++ b/internal/rpc/services/service_http_firewall_rule_group.go @@ -80,3 +80,18 @@ func (this *HTTPFirewallRuleGroupService) FindHTTPFirewallRuleGroupConfig(ctx co } return &pb.FindHTTPFirewallRuleGroupConfigResponse{FirewallRuleGroupJSON: groupConfigJSON}, nil } + +// 修改分组的规则集 +func (this *HTTPFirewallRuleGroupService) UpdateHTTPFirewallRuleGroupSets(ctx context.Context, req *pb.UpdateHTTPFirewallRuleGroupSetsRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedHTTPFirewallRuleGroupDAO.UpdateGroupSets(req.GetFirewallRuleGroupId(), req.FirewallRuleSetsJSON) + if err != nil { + return nil, err + } + return rpcutils.RPCUpdateSuccess() +} diff --git a/internal/rpc/services/service_http_firewall_rule_set.go b/internal/rpc/services/service_http_firewall_rule_set.go new file mode 100644 index 00000000..a051425e --- /dev/null +++ b/internal/rpc/services/service_http_firewall_rule_set.go @@ -0,0 +1,74 @@ +package services + +import ( + "context" + "encoding/json" + "github.com/TeaOSLab/EdgeAPI/internal/db/models" + rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" + "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" +) + +// 规则集相关服务 +type HTTPFirewallRuleSetService struct { +} + +// 根据配置创建规则集 +func (this *HTTPFirewallRuleSetService) CreateOrUpdateHTTPFirewallRuleSetFromConfig(ctx context.Context, req *pb.CreateOrUpdateHTTPFirewallRuleSetFromConfigRequest) (*pb.CreateOrUpdateHTTPFirewallRuleSetFromConfigResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + setConfig := &firewallconfigs.HTTPFirewallRuleSet{} + err = json.Unmarshal(req.FirewallRuleSetConfigJSON, setConfig) + if err != nil { + return nil, err + } + + setId, err := models.SharedHTTPFirewallRuleSetDAO.CreateOrUpdateSetFromConfig(setConfig) + if err != nil { + return nil, err + } + + return &pb.CreateOrUpdateHTTPFirewallRuleSetFromConfigResponse{FirewallRuleSetId: setId}, nil +} + +// 修改是否开启 +func (this *HTTPFirewallRuleSetService) UpdateHTTPFirewallRuleSetIsOn(ctx context.Context, req *pb.UpdateHTTPFirewallRuleSetIsOnRequest) (*pb.RPCUpdateSuccess, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + err = models.SharedHTTPFirewallRuleSetDAO.UpdateRuleSetIsOn(req.FirewallRuleSetId, req.IsOn) + if err != nil { + return nil, err + } + + return rpcutils.RPCUpdateSuccess() +} + +// 查找规则集配置 +func (this *HTTPFirewallRuleSetService) FindHTTPFirewallRuleSetConfig(ctx context.Context, req *pb.FindHTTPFirewallRuleSetConfigRequest) (*pb.FindHTTPFirewallRuleSetConfigResponse, error) { + // 校验请求 + _, _, err := rpcutils.ValidateRequest(ctx, rpcutils.UserTypeAdmin) + if err != nil { + return nil, err + } + + config, err := models.SharedHTTPFirewallRuleSetDAO.ComposeFirewallRuleSet(req.FirewallRuleSetId) + if err != nil { + return nil, err + } + if config == nil { + return &pb.FindHTTPFirewallRuleSetConfigResponse{FirewallRuleSetJSON: nil}, nil + } + configJSON, err := json.Marshal(config) + if err != nil { + return nil, err + } + return &pb.FindHTTPFirewallRuleSetConfigResponse{FirewallRuleSetJSON: configJSON}, nil +}