修复用户可能无法添加黑白名单IP的Bug

This commit is contained in:
GoEdgeLab
2022-03-10 15:59:00 +08:00
parent a38369733e
commit 76d692fe41
5 changed files with 41 additions and 7 deletions

View File

@@ -184,6 +184,10 @@ func (this *IPListDAO) IncreaseVersion(tx *dbs.Tx) (int64, error) {
// CheckUserIPList 检查用户权限 // CheckUserIPList 检查用户权限
func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) error { func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) error {
if userId == 0 || listId == 0 {
return ErrNotFound
}
ok, err := this.Query(tx). ok, err := this.Query(tx).
Pk(listId). Pk(listId).
Attr("userId", userId). Attr("userId", userId).
@@ -194,6 +198,18 @@ func (this *IPListDAO) CheckUserIPList(tx *dbs.Tx, userId int64, listId int64) e
if ok { if ok {
return nil return nil
} }
// 检查是否被用户的服务所使用
policyIds, err := SharedHTTPFirewallPolicyDAO.FindEnabledFirewallPolicyIdsWithIPListId(tx, listId)
if err != nil {
return err
}
for _, policyId := range policyIds {
if SharedHTTPFirewallPolicyDAO.CheckUserFirewallPolicy(tx, userId, policyId) == nil {
return nil
}
}
return ErrNotFound return ErrNotFound
} }

View File

@@ -1227,6 +1227,16 @@ func (this *ServerDAO) FindAllEnabledServerIdsWithSSLPolicyIds(tx *dbs.Tx, sslPo
return return
} }
// ExistEnabledUserServerWithSSLPolicyId 检查是否存在某个用户的策略
func (this *ServerDAO) ExistEnabledUserServerWithSSLPolicyId(tx *dbs.Tx, userId int64, sslPolicyId int64) (bool, error) {
return this.Query(tx).
State(ServerStateEnabled).
Attr("userId", userId).
Where("(JSON_CONTAINS(https, :jsonQuery) OR JSON_CONTAINS(tls, :jsonQuery))").
Param("jsonQuery", maps.Map{"sslPolicyRef": maps.Map{"sslPolicyId": sslPolicyId}}.AsJSON()).
Exist()
}
// CountEnabledServersWithWebIds 计算使用某个缓存策略的所有服务数量 // CountEnabledServersWithWebIds 计算使用某个缓存策略的所有服务数量
func (this *ServerDAO) CountEnabledServersWithWebIds(tx *dbs.Tx, webIds []int64) (count int64, err error) { func (this *ServerDAO) CountEnabledServersWithWebIds(tx *dbs.Tx, webIds []int64) (count int64, err error) {
if len(webIds) == 0 { if len(webIds) == 0 {

View File

@@ -281,9 +281,9 @@ func (this *SSLPolicyDAO) UpdatePolicy(tx *dbs.Tx, policyId int64, http2Enabled
} }
// CheckUserPolicy 检查是否为用户所属策略 // CheckUserPolicy 检查是否为用户所属策略
func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int64) error { func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, userId int64, policyId int64) error {
if policyId <= 0 || userId <= 0 { if policyId <= 0 || userId <= 0 {
return errors.New("not found") return ErrNotFound
} }
ok, err := this.Query(tx). ok, err := this.Query(tx).
State(SSLPolicyStateEnabled). State(SSLPolicyStateEnabled).
@@ -294,7 +294,14 @@ func (this *SSLPolicyDAO) CheckUserPolicy(tx *dbs.Tx, policyId int64, userId int
return err return err
} }
if !ok { if !ok {
return errors.New("not found") // 是否为当前用户的某个服务所用
exists, err := SharedServerDAO.ExistEnabledUserServerWithSSLPolicyId(tx, userId, policyId)
if err != nil {
return err
}
if !exists {
return ErrNotFound
}
} }
return nil return nil
} }

View File

@@ -42,7 +42,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe
return nil, err return nil, err
} }
if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 { if httpsConfig.SSLPolicyRef != nil && httpsConfig.SSLPolicyRef.SSLPolicyId > 0 {
err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, httpsConfig.SSLPolicyRef.SSLPolicyId, userId) err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, httpsConfig.SSLPolicyRef.SSLPolicyId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -57,7 +57,7 @@ func (this *ServerService) CreateServer(ctx context.Context, req *pb.CreateServe
return nil, err return nil, err
} }
if tlsConfig.SSLPolicyRef != nil && tlsConfig.SSLPolicyRef.SSLPolicyId > 0 { if tlsConfig.SSLPolicyRef != nil && tlsConfig.SSLPolicyRef.SSLPolicyId > 0 {
err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, tlsConfig.SSLPolicyRef.SSLPolicyId, userId) err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, tlsConfig.SSLPolicyRef.SSLPolicyId)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -3,6 +3,7 @@ package services
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models"
rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils" rpcutils "github.com/TeaOSLab/EdgeAPI/internal/rpc/utils"
"github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb"
@@ -62,9 +63,9 @@ func (this *SSLPolicyService) UpdateSSLPolicy(ctx context.Context, req *pb.Updat
tx := this.NullTx() tx := this.NullTx()
if userId > 0 { if userId > 0 {
err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, req.SslPolicyId, userId) err := models.SharedSSLPolicyDAO.CheckUserPolicy(tx, userId, req.SslPolicyId)
if err != nil { if err != nil {
return nil, err return nil, errors.New("check ssl policy failed: " + err.Error())
} }
} }