diff --git a/internal/db/models/api_node_dao.go b/internal/db/models/api_node_dao.go index fdc4ceee..526c3bf7 100644 --- a/internal/db/models/api_node_dao.go +++ b/internal/db/models/api_node_dao.go @@ -12,6 +12,7 @@ import ( "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" "strconv" + "strings" ) const ( @@ -219,7 +220,6 @@ func (this *APINodeDAO) CountAllEnabledAndOnAPINodes(tx *dbs.Tx) (int64, error) Count() } - // CountAllEnabledAndOnOfflineAPINodes 计算API节点数量 func (this *APINodeDAO) CountAllEnabledAndOnOfflineAPINodes(tx *dbs.Tx) (int64, error) { return this.Query(tx). @@ -305,3 +305,19 @@ func (this *APINodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (i Param("version", utils.VersionToLong(version)). Count() } + +// CountAllEnabledAPINodesWithSSLPolicyIds 计算使用SSL策略的所有API节点数量 +func (this *APINodeDAO) CountAllEnabledAPINodesWithSSLPolicyIds(tx *dbs.Tx, sslPolicyIds []int64) (count int64, err error) { + if len(sslPolicyIds) == 0 { + return + } + policyStringIds := []string{} + for _, policyId := range sslPolicyIds { + policyStringIds = append(policyStringIds, strconv.FormatInt(policyId, 10)) + } + return this.Query(tx). + State(APINodeStateEnabled). + Where("(FIND_IN_SET(JSON_EXTRACT(https, '$.sslPolicyRef.sslPolicyId'), :policyIds) OR FIND_IN_SET(JSON_EXTRACT(restHTTPS, '$.sslPolicyRef.sslPolicyId'), :policyIds))"). + Param("policyIds", strings.Join(policyStringIds, ",")). + Count() +} diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index a2e8bbbf..94ec2b9b 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -9,7 +9,6 @@ import ( "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" - "strconv" ) const ( @@ -181,8 +180,8 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(tx *dbs.Tx, certId i ones, err := this.Query(tx). State(SSLPolicyStateEnabled). ResultPk(). - Where(`JSON_CONTAINS(certs, '{"certId": ` + strconv.FormatInt(certId, 10) + ` }')`). - Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + Where("JSON_CONTAINS(certs, :certJSON)"). + Param("certJSON", maps.Map{"certId": certId}.AsJSON()). FindAll() if err != nil { return nil, err diff --git a/internal/db/models/user_node_dao.go b/internal/db/models/user_node_dao.go index f848c063..e97886da 100644 --- a/internal/db/models/user_node_dao.go +++ b/internal/db/models/user_node_dao.go @@ -12,6 +12,7 @@ import ( "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" "strconv" + "strings" ) const ( @@ -282,3 +283,19 @@ func (this *UserNodeDAO) CountAllEnabledAndOnOfflineNodes(tx *dbs.Tx) (int64, er Where("(status IS NULL OR JSON_EXTRACT(status, '$.updatedAt') 0 { - nodeId = req.NodeId + if req.UserNodeId > 0 { + nodeId = req.UserNodeId } if nodeId <= 0 { @@ -267,3 +267,26 @@ func (this *UserNodeService) UpdateUserNodeStatus(ctx context.Context, req *pb.U } return this.Success() } + +// CountAllEnabledUserNodesWithSSLCertId 计算使用某个SSL证书的用户节点数量 +func (this *UserNodeService) CountAllEnabledUserNodesWithSSLCertId(ctx context.Context, req *pb.CountAllEnabledUserNodesWithSSLCertIdRequest) (*pb.RPCCountResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, req.SslCertId) + if err != nil { + return nil, err + } + if len(policyIds) == 0 { + return this.SuccessCount(0) + } + + count, err := models.SharedUserNodeDAO.CountAllEnabledUserNodesWithSSLPolicyIds(tx, policyIds) + if err != nil { + return nil, err + } + return this.SuccessCount(count) +}