From 76d692fe41c572595a340cd1bec8e0d297a94e6a Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Thu, 10 Mar 2022 15:59:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E7=94=A8=E6=88=B7=E5=8F=AF?= =?UTF-8?q?=E8=83=BD=E6=97=A0=E6=B3=95=E6=B7=BB=E5=8A=A0=E9=BB=91=E7=99=BD?= =?UTF-8?q?=E5=90=8D=E5=8D=95IP=E7=9A=84Bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/ip_list_dao.go | 16 ++++++++++++++++ internal/db/models/server_dao.go | 10 ++++++++++ internal/db/models/ssl_policy_dao.go | 13 ++++++++++--- internal/rpc/services/service_server.go | 4 ++-- internal/rpc/services/service_ssl_policy.go | 5 +++-- 5 files changed, 41 insertions(+), 7 deletions(-) diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index a39616e6..8770bef3 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -184,6 +184,10 @@ func (this *IPListDAO) IncreaseVersion(tx *dbs.Tx) (int64, error) { // CheckUserIPList 检查用户权限 func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) error { + if userId == 0 || listId == 0 { + return ErrNotFound + } + ok, err := this.Query(tx). Pk(listId). Attr("userId", userId). @@ -194,6 +198,18 @@ func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) e if ok { return nil } + + // 检查是否被用户的服务所使用 + policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) + if err != nil { + return err + } + for _, policyId := range policyIds { + if SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, policyId) == nil { + return nil + } + } + return ErrNotFound } diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index bf8690b6..8368a049 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -1227,6 +1227,16 @@ func (this *ServerDAO) FindAllEnabledServerIdsWithSSLPolicyIds(tx *dbs.Tx, sslPo return } +// ExistEnabledUserServerWithSSLPolicyId 检查是否存在某个用户的策略 +func (this *ServerDAO) ExistEnabledUserServerWithSSLPolicyId(tx *dbs.Tx, userId int64, sslPolicyId int64) (bool, error) { + return this.Query(tx). + State(ServerStateEnabled). + Attr("userId", userId). + Where("(JSON_CONTAINS(https, :jsonQuery) OR JSON_CONTAINS(tls, :jsonQuery))"). + Param("jsonQuery", maps.Map{"sslPolicyRef": maps.Map{"sslPolicyId": sslPolicyId}}.AsJSON()). + Exist() +} + // CountEnabledServersWithWebIds 计算使用某个缓存策略的所有服务数量 func (this *ServerDAO) CountEnabledServersWithWebIds(tx *dbs.Tx, webIds []int64) (count int64, err error) { if len(webIds) == 0 { diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index 7cffbac1..2c2ef7cc 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -281,9 +281,9 @@ func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled } // CheckUserPolicy 检查是否为用户所属策略 -func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int64) error { +func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, userId int64, policyId int64) error { if policyId <= 0 || userId <= 0 { - return errors.New("not found") + return ErrNotFound } ok, err := this.Query(tx). State(SSLPolicyStateEnabled). @@ -294,7 +294,14 @@ func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int return err } if !ok { - return errors.New("not found") + // 是否为当前用户的某个服务所用 + exists, err := SharedServerDAO.ExistEnabledUserServerWithSSLPolicyId(tx, userId, policyId) + if err != nil { + return err + } + if !exists { + return ErrNotFound + } } return nil } diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 502aa2c4..a13c2640 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -42,7 +42,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return nil, err } if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, httpsConfig.SSLPolicyRef.SSLPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, httpsConfig.SSLPolicyRef.SSLPolicyId) if err != nil { return nil, err } @@ -57,7 +57,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return nil, err } if tlsConfig.SSLPolicyRef != nil && tlsConfig.SSLPolicyRef.SSLPolicyId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, tlsConfig.SSLPolicyRef.SSLPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, tlsConfig.SSLPolicyRef.SSLPolicyId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ssl_policy.go b/internal/rpc/services/service_ssl_policy.go index b6032331..4b0f42d2 100644 --- a/internal/rpc/services/service_ssl_policy.go +++ b/internal/rpc/services/service_ssl_policy.go @@ -3,6 +3,7 @@ package services import ( "context" "encoding/json" + "errors" "github.com/TeaOSLab/EdgeAPI/internal/db/models" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" @@ -62,9 +63,9 @@ func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.Updat tx := this.NullTx() if userId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, req.SslPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, req.SslPolicyId) if err != nil { - return nil, err + return nil, errors.New("check ssl policy failed: " + err.Error()) } }