From 144b9b9519f624c46737adc82dda4825a0320a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E7=A5=A5=E8=B6=85?= Date: Fri, 5 Nov 2021 17:56:35 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=A4=9A=E4=B8=AAAPI/?= =?UTF-8?q?=E8=A7=84=E8=8C=83=E5=91=BD=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/api_node_dao.go | 18 ++++++++- internal/db/models/ssl_policy_dao.go | 5 +-- internal/db/models/user_node_dao.go | 17 ++++++++ internal/rpc/services/service_api_node.go | 23 +++++++++++ internal/rpc/services/service_user_node.go | 47 ++++++++++++++++------ 5 files changed, 94 insertions(+), 16 deletions(-) diff --git a/internal/db/models/api_node_dao.go b/internal/db/models/api_node_dao.go index fdc4ceee..526c3bf7 100644 --- a/internal/db/models/api_node_dao.go +++ b/internal/db/models/api_node_dao.go @@ -12,6 +12,7 @@ import ( "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" "strconv" + "strings" ) const ( @@ -219,7 +220,6 @@ func (this *APINodeDAO) CountAllEnabledAndOnAPINodes(tx *dbs.Tx) (int64, error) Count() } - // CountAllEnabledAndOnOfflineAPINodes 计算API节点数量 func (this *APINodeDAO) CountAllEnabledAndOnOfflineAPINodes(tx *dbs.Tx) (int64, error) { return this.Query(tx). @@ -305,3 +305,19 @@ func (this *APINodeDAO) CountAllLowerVersionNodes(tx *dbs.Tx, version string) (i Param("version", utils.VersionToLong(version)). Count() } + +// CountAllEnabledAPINodesWithSSLPolicyIds 计算使用SSL策略的所有API节点数量 +func (this *APINodeDAO) CountAllEnabledAPINodesWithSSLPolicyIds(tx *dbs.Tx, sslPolicyIds []int64) (count int64, err error) { + if len(sslPolicyIds) == 0 { + return + } + policyStringIds := []string{} + for _, policyId := range sslPolicyIds { + policyStringIds = append(policyStringIds, strconv.FormatInt(policyId, 10)) + } + return this.Query(tx). + State(APINodeStateEnabled). + Where("(FIND_IN_SET(JSON_EXTRACT(https, '$.sslPolicyRef.sslPolicyId'), :policyIds) OR FIND_IN_SET(JSON_EXTRACT(restHTTPS, '$.sslPolicyRef.sslPolicyId'), :policyIds))"). + Param("policyIds", strings.Join(policyStringIds, ",")). + Count() +} diff --git a/internal/db/models/ssl_policy_dao.go b/internal/db/models/ssl_policy_dao.go index a2e8bbbf..94ec2b9b 100644 --- a/internal/db/models/ssl_policy_dao.go +++ b/internal/db/models/ssl_policy_dao.go @@ -9,7 +9,6 @@ import ( "github.com/iwind/TeaGo/dbs" "github.com/iwind/TeaGo/maps" "github.com/iwind/TeaGo/types" - "strconv" ) const ( @@ -181,8 +180,8 @@ func (this *SSLPolicyDAO) FindAllEnabledPolicyIdsWithCertId(tx *dbs.Tx, certId i ones, err := this.Query(tx). State(SSLPolicyStateEnabled). ResultPk(). - Where(`JSON_CONTAINS(certs, '{"certId": ` + strconv.FormatInt(certId, 10) + ` }')`). - Reuse(false). // 由于我们在JSON_CONTAINS()直接使用了变量,所以不能重用 + Where("JSON_CONTAINS(certs, :certJSON)"). + Param("certJSON", maps.Map{"certId": certId}.AsJSON()). FindAll() if err != nil { return nil, err diff --git a/internal/db/models/user_node_dao.go b/internal/db/models/user_node_dao.go index f848c063..e97886da 100644 --- a/internal/db/models/user_node_dao.go +++ b/internal/db/models/user_node_dao.go @@ -12,6 +12,7 @@ import ( "github.com/iwind/TeaGo/rands" "github.com/iwind/TeaGo/types" "strconv" + "strings" ) const ( @@ -282,3 +283,19 @@ func (this *UserNodeDAO) CountAllEnabledAndOnOfflineNodes(tx *dbs.Tx) (int64, er Where("(status IS NULL OR JSON_EXTRACT(status, '$.updatedAt') 0 { - nodeId = req.NodeId + if req.UserNodeId > 0 { + nodeId = req.UserNodeId } if nodeId <= 0 { @@ -267,3 +267,26 @@ func (this *UserNodeService) UpdateUserNodeStatus(ctx context.Context, req *pb.U } return this.Success() } + +// CountAllEnabledUserNodesWithSSLCertId 计算使用某个SSL证书的用户节点数量 +func (this *UserNodeService) CountAllEnabledUserNodesWithSSLCertId(ctx context.Context, req *pb.CountAllEnabledUserNodesWithSSLCertIdRequest) (*pb.RPCCountResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + var tx = this.NullTx() + policyIds, err := models.SharedSSLPolicyDAO.FindAllEnabledPolicyIdsWithCertId(tx, req.SslCertId) + if err != nil { + return nil, err + } + if len(policyIds) == 0 { + return this.SuccessCount(0) + } + + count, err := models.SharedUserNodeDAO.CountAllEnabledUserNodesWithSSLPolicyIds(tx, policyIds) + if err != nil { + return nil, err + } + return this.SuccessCount(count) +}