diff --git a/internal/db/models/http_firewall_policy_dao.go b/internal/db/models/http_firewall_policy_dao.go index 74e41026..97fef162 100644 --- a/internal/db/models/http_firewall_policy_dao.go +++ b/internal/db/models/http_firewall_policy_dao.go @@ -6,6 +6,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/errors" "github.com/TeaOSLab/EdgeAPI/internal/utils" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/firewallconfigs" + "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs/ipconfigs" _ "github.com/go-sql-driver/mysql" "github.com/iwind/TeaGo/Tea" "github.com/iwind/TeaGo/dbs" @@ -238,7 +239,7 @@ func (this *HTTPFirewallPolicyDAO) CreateDefaultFirewallPolicy(tx *dbs.Tx, name return 0, err } - err = this.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, inboundConfigJSON, outboundConfigJSON, false) + err = this.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, 0, 0, inboundConfigJSON, outboundConfigJSON, false) if err != nil { return 0, err } @@ -247,10 +248,60 @@ func (this *HTTPFirewallPolicyDAO) CreateDefaultFirewallPolicy(tx *dbs.Tx, name } // UpdateFirewallPolicyInboundAndOutbound 修改策略的Inbound和Outbound -func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(tx *dbs.Tx, policyId int64, inboundJSON []byte, outboundJSON []byte, shouldNotify bool) error { +func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(tx *dbs.Tx, policyId int64, userId int64, serverId int64, inboundJSON []byte, outboundJSON []byte, shouldNotify bool) error { if policyId <= 0 { return errors.New("invalid policyId") } + + // 创建默认的Inbound + var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true} + if inboundJSON != nil { + err := json.Unmarshal(inboundJSON, inboundConfig) + if err != nil { + return err + } + } + + // IP名单 + if inboundConfig.AllowListRef == nil { + listId, createListErr := SharedIPListDAO.CreateIPList(tx, userId, serverId, ipconfigs.IPListTypeWhite, "白名单", "", nil, "", false, false) + if createListErr != nil { + return createListErr + } + inboundConfig.AllowListRef = &ipconfigs.IPListRef{ + IsOn: true, + ListId: listId, + } + } + + if inboundConfig.DenyListRef == nil { + listId, createListErr := SharedIPListDAO.CreateIPList(tx, userId, serverId, ipconfigs.IPListTypeBlack, "黑名单", "", nil, "", false, false) + if createListErr != nil { + return createListErr + } + inboundConfig.DenyListRef = &ipconfigs.IPListRef{ + IsOn: true, + ListId: listId, + } + } + + if inboundConfig.GreyListRef == nil { + listId, createListErr := SharedIPListDAO.CreateIPList(tx, userId, serverId, ipconfigs.IPListTypeGrey, "灰名单", "", nil, "", false, false) + if createListErr != nil { + return createListErr + } + inboundConfig.GreyListRef = &ipconfigs.IPListRef{ + IsOn: true, + ListId: listId, + } + } + + var err error + inboundJSON, err = json.Marshal(inboundConfig) + if err != nil { + return err + } + var op = NewHTTPFirewallPolicyOperator() op.Id = policyId if len(inboundJSON) > 0 { @@ -263,7 +314,7 @@ func (this *HTTPFirewallPolicyDAO) UpdateFirewallPolicyInboundAndOutbound(tx *db } else { op.Outbound = "null" } - err := this.Save(tx, op) + err = this.Save(tx, op) if err != nil { return err } diff --git a/internal/rpc/services/service_http_firewall_policy.go b/internal/rpc/services/service_http_firewall_policy.go index 1f4b2f21..e075d0e8 100644 --- a/internal/rpc/services/service_http_firewall_policy.go +++ b/internal/rpc/services/service_http_firewall_policy.go @@ -65,9 +65,9 @@ func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Cont } // 初始化 - inboundConfig := &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true} - outboundConfig := &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true} - templatePolicy := firewallconfigs.HTTPFirewallTemplate() + var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true} + var outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true} + var templatePolicy = firewallconfigs.HTTPFirewallTemplate() if templatePolicy.Inbound != nil { for _, group := range templatePolicy.Inbound.Groups { isOn := lists.ContainsString(req.HttpFirewallGroupCodes, group.Code) @@ -109,7 +109,7 @@ func (this *HTTPFirewallPolicyService) CreateHTTPFirewallPolicy(ctx context.Cont return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, inboundConfigJSON, outboundConfigJSON, false) + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, userId, req.ServerId, inboundConfigJSON, outboundConfigJSON, false) if err != nil { return nil, err } @@ -125,26 +125,33 @@ func (this *HTTPFirewallPolicyService) CreateEmptyHTTPFirewallPolicy(ctx context return nil, err } + var tx = this.NullTx() + + var sourceUserId = userId if userId > 0 { if req.ServerId > 0 { - err = models.SharedServerDAO.CheckUserServer(nil, userId, req.ServerId) + err = models.SharedServerDAO.CheckUserServer(tx, userId, req.ServerId) if err != nil { return nil, err } } + } else if req.ServerId > 0 { + sourceUserId, err = models.SharedServerDAO.FindServerUserId(tx, req.ServerId) + if err != nil { + return nil, err + } } - var tx = this.NullTx() - - policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, userId, req.ServerGroupId, req.ServerId, req.IsOn, req.Name, req.Description, nil, nil) + policyId, err := models.SharedHTTPFirewallPolicyDAO.CreateFirewallPolicy(tx, sourceUserId, req.ServerGroupId, req.ServerId, req.IsOn, req.Name, req.Description, nil, nil) if err != nil { return nil, err } // 初始化 - inboundConfig := &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true} - outboundConfig := &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true} + var inboundConfig = &firewallconfigs.HTTPFirewallInboundConfig{IsOn: true} + var outboundConfig = &firewallconfigs.HTTPFirewallOutboundConfig{IsOn: true} + // 准备保存 inboundConfigJSON, err := json.Marshal(inboundConfig) if err != nil { return nil, err @@ -155,7 +162,7 @@ func (this *HTTPFirewallPolicyService) CreateEmptyHTTPFirewallPolicy(ctx context return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, inboundConfigJSON, outboundConfigJSON, false) + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, policyId, sourceUserId, req.ServerId, inboundConfigJSON, outboundConfigJSON, false) if err != nil { return nil, err } @@ -329,7 +336,7 @@ func (this *HTTPFirewallPolicyService) UpdateHTTPFirewallPolicyGroups(ctx contex var tx = this.NullTx() - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, req.InboundJSON, req.OutboundJSON, true) + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, userId, 0, req.InboundJSON, req.OutboundJSON, true) if err != nil { return nil, err } @@ -653,7 +660,7 @@ func (this *HTTPFirewallPolicyService) ImportHTTPFirewallPolicy(ctx context.Cont return nil, err } - err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, inboundJSON, outboundJSON, true) + err = models.SharedHTTPFirewallPolicyDAO.UpdateFirewallPolicyInboundAndOutbound(tx, req.HttpFirewallPolicyId, 0, 0, inboundJSON, outboundJSON, true) if err != nil { return nil, err }