diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index b5eab000..bd2fd6a2 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -1594,12 +1594,16 @@ func (this *ServerDAO) FindAllEnabledServersWithDomain(tx *dbs.Tx, domain string } // FindEnabledServerWithDomain 根据域名查找服务集群ID -func (this *ServerDAO) FindEnabledServerWithDomain(tx *dbs.Tx, domain string) (server *Server, err error) { +func (this *ServerDAO) FindEnabledServerWithDomain(tx *dbs.Tx, userId int64, domain string) (server *Server, err error) { if len(domain) == 0 { return } - one, err := this.Query(tx). + var query = this.Query(tx) + if userId > 0 { + query.Attr("userId", userId) + } + one, err := query. State(ServerStateEnabled). Where("JSON_CONTAINS(plainServerNames, :domain)"). Param("domain", strconv.Quote(domain)). @@ -1618,7 +1622,11 @@ func (this *ServerDAO) FindEnabledServerWithDomain(tx *dbs.Tx, domain string) (s var dotIndex = strings.Index(domain, ".") if dotIndex > 0 { var wildcardDomain = "*." + domain[dotIndex+1:] - one, err = this.Query(tx). + var wildcardQuery = this.Query(tx) + if userId > 0 { + wildcardQuery.Attr("userId", userId) + } + one, err = wildcardQuery. State(ServerStateEnabled). Where("JSON_CONTAINS(plainServerNames, :domain)"). Param("domain", strconv.Quote(wildcardDomain)). diff --git a/internal/rpc/services/service_http_cache_task.go b/internal/rpc/services/service_http_cache_task.go index a5d88faf..0420df63 100644 --- a/internal/rpc/services/service_http_cache_task.go +++ b/internal/rpc/services/service_http_cache_task.go @@ -131,7 +131,7 @@ func (this *HTTPCacheTaskService) CreateHTTPCacheTask(ctx context.Context, req * // 查询所在集群 server, ok := domainMap[domain] if !ok { - server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, domain) + server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, userId, domain) if err != nil { return nil, err } diff --git a/internal/rpc/services/service_http_cache_task_key.go b/internal/rpc/services/service_http_cache_task_key.go index 29bd6413..7bad00ec 100644 --- a/internal/rpc/services/service_http_cache_task_key.go +++ b/internal/rpc/services/service_http_cache_task_key.go @@ -35,7 +35,8 @@ func (this *HTTPCacheTaskKeyService) ValidateHTTPCacheTaskKeys(ctx context.Conte } var pbFailResults = []*pb.ValidateHTTPCacheTaskKeysResponse_FailKey{} - var domainMap = map[string]*models.Server{} // domain name => *Server + var foundDomainMap = map[string]*models.Server{} // domain name => *Server + var missingDomainMap = map[string]bool{} // domain name => true for _, key := range req.Keys { if len(key) == 0 { pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{ @@ -55,21 +56,31 @@ func (this *HTTPCacheTaskKeyService) ValidateHTTPCacheTaskKeys(ctx context.Conte continue } + // 是否不存在 + if missingDomainMap[domain] { + pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{ + Key: key, + ReasonCode: "requireServer", + }) + continue + } + // 查询所在集群 - server, ok := domainMap[domain] + server, ok := foundDomainMap[domain] if !ok { - server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, domain) + server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, userId, domain) if err != nil { return nil, err } if server == nil { + missingDomainMap[domain] = true pbFailResults = append(pbFailResults, &pb.ValidateHTTPCacheTaskKeysResponse_FailKey{ Key: key, ReasonCode: "requireServer", }) continue } - domainMap[domain] = server + foundDomainMap[domain] = server } // 检查用户 diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index c37e496c..76b05391 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -2010,7 +2010,7 @@ func (this *ServerService) PurgeServerCache(ctx context.Context, req *pb.PurgeSe // 查询所在集群 server, ok := domainMap[domain] if !ok { - server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, domain) + server, err = models.SharedServerDAO.FindEnabledServerWithDomain(tx, 0, domain) if err != nil { return nil, err }