diff --git a/internal/db/models/node_dao.go b/internal/db/models/node_dao.go index cd092158..f2aa9e3c 100644 --- a/internal/db/models/node_dao.go +++ b/internal/db/models/node_dao.go @@ -498,6 +498,22 @@ func (this *NodeDAO) FindAllEnabledNodesWithClusterId(tx *dbs.Tx, clusterId int6 return } +// FindAllEnabledNodeIdsWithClusterId 获取一个集群的所有节点Ids +func (this *NodeDAO) FindAllEnabledNodeIdsWithClusterId(tx *dbs.Tx, clusterId int64) (result []int64, err error) { + ones, err := this.Query(tx). + ResultPk(). + State(NodeStateEnabled). + Attr("clusterId", clusterId). + FindAll() + if err != nil { + return nil, err + } + for _, one := range ones { + result = append(result, int64(one.(*Node).Id)) + } + return +} + // FindAllInactiveNodesWithClusterId 取得一个集群离线的节点 func (this *NodeDAO) FindAllInactiveNodesWithClusterId(tx *dbs.Tx, clusterId int64) (result []*Node, err error) { _, err = this.Query(tx). diff --git a/internal/db/models/server_dao.go b/internal/db/models/server_dao.go index c45d21d4..63c95bd4 100644 --- a/internal/db/models/server_dao.go +++ b/internal/db/models/server_dao.go @@ -1207,6 +1207,54 @@ func (this *ServerDAO) FindAllServersDNSWithClusterId(tx *dbs.Tx, clusterId int6 return } +// FindAllEnabledServersWithDomain 根据域名查找服务 +func (this *ServerDAO) FindAllEnabledServersWithDomain(tx *dbs.Tx, domain string) (result []*Server, err error) { + if len(domain) == 0 { + return + } + + _, err = this.Query(tx). + State(ServerStateEnabled). + Where("(JSON_CONTAINS(serverNames, :domain1) OR JSON_CONTAINS(serverNames, :domain2))"). + Param("domain1", maps.Map{"name": domain}.AsJSON()). + Param("domain2", maps.Map{"subNames": domain}.AsJSON()). + Slice(&result). + DescPk(). + FindAll() + + if err != nil { + return nil, err + } + + // 支持泛解析 + var countPieces = strings.Count(domain, ".") + for { + var index = strings.Index(domain, ".") + if index > 0 { + domain = domain[index+1:] + var search = strings.Repeat("*.", countPieces-strings.Count(domain, ".")) + domain + _, err = this.Query(tx). + State(ServerStateEnabled). + Where("(JSON_CONTAINS(serverNames, :domain1) OR JSON_CONTAINS(serverNames, :domain2))"). + Param("domain1", maps.Map{"name": search}.AsJSON()). + Param("domain2", maps.Map{"subNames": search}.AsJSON()). + Slice(&result). + DescPk(). + FindAll() + if err != nil { + return + } + if len(result) > 0 { + return + } + } else { + break + } + } + + return +} + // GenerateServerDNSName 重新生成子域名 func (this *ServerDAO) GenerateServerDNSName(tx *dbs.Tx, serverId int64) (string, error) { if serverId <= 0 { diff --git a/internal/db/models/server_dao_test.go b/internal/db/models/server_dao_test.go index aa76edfa..9ebab308 100644 --- a/internal/db/models/server_dao_test.go +++ b/internal/db/models/server_dao_test.go @@ -145,6 +145,22 @@ func TestServerDAO_FindAllEnabledServersWithNode(t *testing.T) { } } +func TestServerDAO_FindAllEnabledServersWithDomain(t *testing.T) { + for _, domain := range []string{"yun4s.cn", "teaos.cn", "teaos2.cn", "cdn.teaos.cn", "cdn100.teaos.cn"} { + servers, err := NewServerDAO().FindAllEnabledServersWithDomain(nil, domain) + if err != nil { + t.Fatal(err) + } + if len(servers) > 0 { + for _, server := range servers { + t.Log(domain + ": " + server.ServerNames) + } + } else { + t.Log(domain + ": not found") + } + } +} + func BenchmarkServerDAO_CountAllEnabledServers(b *testing.B) { SharedServerDAO = NewServerDAO() diff --git a/internal/rpc/services/service_node_stream.go b/internal/rpc/services/service_node_stream.go index 2ddf4d2c..6271a9eb 100644 --- a/internal/rpc/services/service_node_stream.go +++ b/internal/rpc/services/service_node_stream.go @@ -38,11 +38,11 @@ func (this *CommandRequestWaiting) Close() { close(this.Chan) } -var responseChanMap = map[int64]*CommandRequestWaiting{} // request id => response +var nodeResponseChanMap = map[int64]*CommandRequestWaiting{} // request id => response var commandRequestId = int64(0) var nodeLocker = &sync.Mutex{} -var requestChanMap = map[int64]chan *CommandRequest{} // node id => chan +var nodeRequestChanMap = map[int64]chan *CommandRequest{} // node id => chan func NextCommandRequestId() int64 { return atomic.AddInt64(&commandRequestId, 1) @@ -54,10 +54,10 @@ func init() { go func() { for range ticker.C { nodeLocker.Lock() - for requestId, request := range responseChanMap { + for requestId, request := range nodeResponseChanMap { if time.Now().Unix()-request.Timestamp > 3600 { - responseChanMap[requestId].Close() - delete(responseChanMap, requestId) + nodeResponseChanMap[requestId].Close() + delete(nodeResponseChanMap, requestId) } } nodeLocker.Unlock() @@ -127,16 +127,16 @@ func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) erro } nodeLocker.Lock() - requestChan, ok := requestChanMap[nodeId] + requestChan, ok := nodeRequestChanMap[nodeId] if !ok { requestChan = make(chan *CommandRequest, 1024) - requestChanMap[nodeId] = requestChan + nodeRequestChanMap[nodeId] = requestChan } nodeLocker.Unlock() defer func() { nodeLocker.Lock() - delete(requestChanMap, nodeId) + delete(nodeRequestChanMap, nodeId) nodeLocker.Unlock() }() @@ -189,7 +189,7 @@ func (this *NodeService) NodeStream(server pb.NodeService_NodeStreamServer) erro }() nodeLocker.Lock() - responseChan, ok := responseChanMap[req.RequestId] + responseChan, ok := nodeResponseChanMap[req.RequestId] if ok { select { case responseChan.Chan <- req: @@ -215,25 +215,37 @@ func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStre return nil, errors.New("node id should not be less than 0") } + return SendCommandToNode(req.NodeId, req.RequestId, req.Code, req.DataJSON, req.TimeoutSeconds, true) +} + +// SendCommandToNode 向节点发送命令 +func SendCommandToNode(nodeId int64, requestId int64, messageCode string, dataJSON []byte, timeoutSeconds int32, forceConnecting bool) (result *pb.NodeStreamMessage, err error) { nodeLocker.Lock() - requestChan, ok := requestChanMap[nodeId] + requestChan, ok := nodeRequestChanMap[nodeId] nodeLocker.Unlock() if !ok { - return &pb.NodeStreamMessage{ - RequestId: req.RequestId, - IsOk: false, - Message: "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet", - }, nil + if forceConnecting { + return &pb.NodeStreamMessage{ + RequestId: requestId, + IsOk: false, + Message: "node '" + strconv.FormatInt(nodeId, 10) + "' not connected yet", + }, nil + } else { + return &pb.NodeStreamMessage{ + RequestId: requestId, + IsOk: true, + }, nil + } } - req.RequestId = NextCommandRequestId() + requestId = NextCommandRequestId() select { case requestChan <- &CommandRequest{ - Id: req.RequestId, - Code: req.Code, - CommandJSON: req.DataJSON, + Id: requestId, + Code: messageCode, + CommandJSON: dataJSON, }: // 加入到等待队列中 respChan := make(chan *pb.NodeStreamMessage, 1) @@ -243,11 +255,10 @@ func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStre } nodeLocker.Lock() - responseChanMap[req.RequestId] = waiting + nodeResponseChanMap[requestId] = waiting nodeLocker.Unlock() // 等待响应 - timeoutSeconds := req.TimeoutSeconds if timeoutSeconds <= 0 { timeoutSeconds = 10 } @@ -256,14 +267,14 @@ func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStre case resp := <-respChan: // 从队列中删除 nodeLocker.Lock() - delete(responseChanMap, req.RequestId) + delete(nodeResponseChanMap, requestId) waiting.Close() nodeLocker.Unlock() if resp == nil { return &pb.NodeStreamMessage{ - RequestId: req.RequestId, - Code: req.Code, + RequestId: requestId, + Code: messageCode, Message: "response timeout", IsOk: false, }, nil @@ -273,21 +284,21 @@ func (this *NodeService) SendCommandToNode(ctx context.Context, req *pb.NodeStre case <-timeout.C: // 从队列中删除 nodeLocker.Lock() - delete(responseChanMap, req.RequestId) + delete(nodeResponseChanMap, requestId) waiting.Close() nodeLocker.Unlock() return &pb.NodeStreamMessage{ - RequestId: req.RequestId, - Code: req.Code, + RequestId: requestId, + Code: messageCode, Message: "response timeout over " + fmt.Sprintf("%d", timeoutSeconds) + " seconds", IsOk: false, }, nil } default: return &pb.NodeStreamMessage{ - RequestId: req.RequestId, - Code: req.Code, + RequestId: requestId, + Code: messageCode, Message: "command queue is full over " + strconv.Itoa(len(requestChan)), IsOk: false, }, nil diff --git a/internal/rpc/services/service_server.go b/internal/rpc/services/service_server.go index da02c14a..19e88a48 100644 --- a/internal/rpc/services/service_server.go +++ b/internal/rpc/services/service_server.go @@ -8,6 +8,7 @@ import ( "github.com/TeaOSLab/EdgeAPI/internal/db/models" "github.com/TeaOSLab/EdgeAPI/internal/db/models/dns" "github.com/TeaOSLab/EdgeAPI/internal/db/models/regions" + "github.com/TeaOSLab/EdgeCommon/pkg/messageconfigs" "github.com/TeaOSLab/EdgeCommon/pkg/rpc/pb" "github.com/TeaOSLab/EdgeCommon/pkg/serverconfigs" "github.com/iwind/TeaGo/maps" @@ -1559,3 +1560,91 @@ func (this *ServerService) FindNearbyServers(ctx context.Context, req *pb.FindNe Groups: []*pb.FindNearbyServersResponse_GroupInfo{pbGroup}, }, nil } + +// PurgeServerCache 清除缓存 +func (this *ServerService) PurgeServerCache(ctx context.Context, req *pb.PurgeServerCacheRequest) (*pb.PurgeServerCacheResponse, error) { + _, err := this.ValidateAdmin(ctx, 0) + if err != nil { + return nil, err + } + + if len(req.Domains) == 0 { + return nil, errors.New("'domains' field is required") + } + + if len(req.Keys) == 0 && len(req.Prefixes) == 0 { + return &pb.PurgeServerCacheResponse{IsOk: true}, nil + } + + var tx = this.NullTx() + var cacheMap = maps.Map{} + var purgeResponse = &pb.PurgeServerCacheResponse{} + + for _, domain := range req.Domains { + servers, err := models.SharedServerDAO.FindAllEnabledServersWithDomain(tx, domain) + if err != nil { + return nil, err + } + + for _, server := range servers { + clusterId := int64(server.ClusterId) + if clusterId > 0 { + nodeIds, err := models.SharedNodeDAO.FindAllEnabledNodeIdsWithClusterId(tx, clusterId) + if err != nil { + return nil, err + } + + cachePolicyId, err := models.SharedNodeClusterDAO.FindClusterHTTPCachePolicyId(tx, clusterId, cacheMap) + if err != nil { + return nil, err + } + if cachePolicyId == 0 { + continue + } + + cachePolicy, err := models.SharedHTTPCachePolicyDAO.ComposeCachePolicy(tx, cachePolicyId, cacheMap) + if err != nil { + return nil, err + } + if cachePolicy == nil { + continue + } + cachePolicyJSON, err := json.Marshal(cachePolicy) + if err != nil { + return nil, err + } + + for _, nodeId := range nodeIds { + msg := &messageconfigs.PurgeCacheMessage{ + CachePolicyJSON: cachePolicyJSON, + } + if len(req.Prefixes) > 0 { + msg.Type = messageconfigs.PurgeCacheMessageTypeDir + msg.Keys = req.Prefixes + } else { + msg.Type = messageconfigs.PurgeCacheMessageTypeFile + msg.Keys = req.Keys + } + msgJSON, err := json.Marshal(msg) + if err != nil { + return nil, err + } + + resp, err := SendCommandToNode(nodeId, NextCommandRequestId(), messageconfigs.MessageCodePurgeCache, msgJSON, 10, false) + if err != nil { + return nil, err + } + if !resp.IsOk { + purgeResponse.IsOk = false + purgeResponse.Message = resp.Message + return purgeResponse, nil + } + } + } + } + } + + purgeResponse.IsOk = true + + return purgeResponse, nil +}