diff --git a/internal/db/models/ip_list_dao.go b/internal/db/models/ip_list_dao.go index a39616e6..8770bef3 100644 --- a/internal/db/models/ip_list_dao.go +++ b/internal/db/models/ip_list_dao.go @@ -184,6 +184,10 @@ func (this *IPListDAO) IncreaseVersion(tx *dbs.Tx) (int64, error) { // CheckUserIPList 检查用户权限 func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) error { + if userId == 0 || listId == 0 { + return ErrNotFound + } + ok, err := this.Query(tx). Pk(listId). Attr("userId", userId). @@ -194,6 +198,18 @@ func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) e if ok { return nil } + + // 检查是否被用户的服务所使用 + policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId) + if err != nil { + return err + } + for _, policyId := range policyIds { + if SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, policyId) == nil { + return nil + } + } + return ErrNotFound } diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index bf8690b6..8368a049 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -1227,6 +1227,16 @@ func (this *ServerDAO) FindAllEnabledServerIdsWithSSLPolicyIds(tx *dbs.Tx, sslPo return } +// ExistEnabledUserServerWithSSLPolicyId 检查是否存在某个用户的策略 +func (this *ServerDAO) ExistEnabledUserServerWithSSLPolicyId(tx *dbs.Tx, userId int64, sslPolicyId int64) (bool, error) { + return this.Query(tx). + State(ServerStateEnabled). + Attr("userId", userId). + Where("(JSON_CONTAINS(https, :jsonQuery) OR JSON_CONTAINS(tls, :jsonQuery))"). + Param("jsonQuery", maps.Map{"sslPolicyRef": maps.Map{"sslPolicyId": sslPolicyId}}.AsJSON()). + Exist() +} + // CountEnabledServersWithWebIds 计算使用某个缓存策略的所有服务数量 func (this *ServerDAO) CountEnabledServersWithWebIds(tx *dbs.Tx, webIds []int64) (count int64, err error) { if len(webIds) == 0 { diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index 7cffbac1..2c2ef7cc 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -281,9 +281,9 @@ func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled } // CheckUserPolicy 检查是否为用户所属策略 -func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int64) error { +func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, userId int64, policyId int64) error { if policyId <= 0 || userId <= 0 { - return errors.New("not found") + return ErrNotFound } ok, err := this.Query(tx). State(SSLPolicyStateEnabled). @@ -294,7 +294,14 @@ func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int return err } if !ok { - return errors.New("not found") + // 是否为当前用户的某个服务所用 + exists, err := SharedServerDAO.ExistEnabledUserServerWithSSLPolicyId(tx, userId, policyId) + if err != nil { + return err + } + if !exists { + return ErrNotFound + } } return nil } diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index 502aa2c4..a13c2640 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -42,7 +42,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return nil, err } if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, httpsConfig.SSLPolicyRef.SSLPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, httpsConfig.SSLPolicyRef.SSLPolicyId) if err != nil { return nil, err } @@ -57,7 +57,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe return nil, err } if tlsConfig.SSLPolicyRef != nil && tlsConfig.SSLPolicyRef.SSLPolicyId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, tlsConfig.SSLPolicyRef.SSLPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, tlsConfig.SSLPolicyRef.SSLPolicyId) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_ssl_policy.go b/internal/rpc/services/service_ssl_policy.go index b6032331..4b0f42d2 100644 --- a/internal/rpc/services/service_ssl_policy.go +++ b/internal/rpc/services/service_ssl_policy.go @@ -3,6 +3,7 @@ package services import ( "context" "encoding/json" + "errors" "github.com/TeaOSLab/EdgeAPI/internal/db/models" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" @@ -62,9 +63,9 @@ func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.Updat tx := this.NullTx() if userId > 0 { - err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, req.SslPolicyId, userId) + err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, req.SslPolicyId) if err != nil { - return nil, err + return nil, errors.New("check ssl policy failed: " + err.Error()) } }