From 1f288d7dd0c9cefc21c108e2a37d1a5c39f50ab3 Mon Sep 17 00:00:00 2001 From: GoEdgeLab Date: Sun, 28 May 2023 17:44:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=88=9B=E5=BB=BA=E7=BC=93?= =?UTF-8?q?=E5=AD=98=E4=BB=BB=E5=8A=A1=E6=97=B6=E5=9F=9F=E5=90=8D=E6=A3=80?= =?UTF-8?q?=E6=9F=A5=E9=80=9F=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- internal/db/models/server_dao.go | 14 +++++++++++--- .../rpc/services/service_http_cache_task.go | 2 +- .../services/service_http_cache_task_key.go | 19 +++++++++++++++---- internal/rpc/services/service_server.go | 2 +- 4 files changed, 28 insertions(+), 9 deletions(-) diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index b5eab000..bd2fd6a2 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -1594,12 +1594,16 @@ func (this *ServerDAO) FindAllEnabledServersWithDomain(tx *dbs.Tx, domain string } // FindEnabledServerWithDomain 根据域名查找服务集群ID -func (this *ServerDAO) FindEnabledServerWithDomain(tx *dbs.Tx, domain string) (server *Server, err error) { +func (this *ServerDAO) FindEnabledServerWithDomain(tx *dbs.Tx, userId int64, domain string) (server *Server, err error) { if len(domain) == 0 { return } - one, err := this.Query(tx). + var query = this.Query(tx) + if userId > 0 { + query.Attr("userId", userId) + } + one, err := query. State(ServerStateEnabled). Where("JSON_CONTAINS(plainServerNames, :domain)"). Param("domain", strconv.Quote(domain)). @@ -1618,7 +1622,11 @@ func (this *ServerDAO) FindEnabledServerWithDomain(tx *dbs.Tx, domain string) (s var dotIndex = strings.Index(domain, ".") if dotIndex > 0 { var wildcardDomain = "*." + domain[dotIndex+1:] - one, err = this.Query(tx). + var wildcardQuery = this.Query(tx) + if userId > 0 { + wildcardQuery.Attr("userId", userId) + } + one, err = wildcardQuery. State(ServerStateEnabled). Where("JSON_CONTAINS(plainServerNames, :domain)"). Param("domain", strconv.Quote(wildcardDomain)). diff --git a/internal/rpc/services/service_http_cache_task.go b/internal/rpc/services/service_http_cache_task.go index a5d88faf..0420df63 100644 --- a/internal/rpc/services/service_http_cache_task.go +++ b/internal/rpc/services/service_http_cache_task.go @@ -131,7 +131,7 @@ func (this *HTTPCacheTaskService) CreateHTTPCacheTask(ctx context.Context, req * // 查询所在集群 server, ok := domainMap[domain] if !ok { - server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, domain) + server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, userId, domain) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_cache_task_key.go b/internal/rpc/services/service_http_cache_task_key.go index 29bd6413..7bad00ec 100644 --- a/internal/rpc/services/service_http_cache_task_key.go +++ b/internal/rpc/services/service_http_cache_task_key.go @@ -35,7 +35,8 @@ func (this *HTTPCacheTaskKeyService) ValidateHTTPCacheTaskKeys(ctx context.Conte } var pbFailResults = []*pb.ValidateHTTPCacheTaskKeysResponse_FailKey{} - var domainMap = map[string]*models.Server{} // domain name => *Server + var foundDomainMap = map[string]*models.Server{} // domain name => *Server + var missingDomainMap = map[string]bool{} // domain name => true for _, key := range req.Keys { if len(key) == 0 { pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{ @@ -55,21 +56,31 @@ func (this *HTTPCacheTaskKeyService) ValidateHTTPCacheTaskKeys(ctx context.Conte continue } + // 是否不存在 + if missingDomainMap[domain] { + pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{ + Key: key, + ReasonCode: "requireServer", + }) + continue + } + // 查询所在集群 - server, ok := domainMap[domain] + server, ok := foundDomainMap[domain] if !ok { - server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, domain) + server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, userId, domain) if err != nil { return nil, err } if server == nil { + missingDomainMap[domain] = true pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{ Key: key, ReasonCode: "requireServer", }) continue } - domainMap[domain] = server + foundDomainMap[domain] = server } // 检查用户 diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index c37e496c..76b05391 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -2010,7 +2010,7 @@ func (this *ServerService) PurgeServerCache(ctx context.Context, req *pb.PurgeSe // 查询所在集群 server, ok := domainMap[domain] if !ok { - server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, domain) + server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, 0, domain) if err != nil { return nil, err }