diff --git a/internal/nodes/api_stream.go b/internal/nodes/api_stream.go index 7450bb8..0d083a5 100644 --- a/internal/nodes/api_stream.go +++ b/internal/nodes/api_stream.go @@ -31,6 +31,9 @@ import ( type APIStream struct { stream pb.NodeService_NodeStreamClient + + isQuiting bool + cancelFunc context.CancelFunc } func NewAPIStream() *APIStream { @@ -38,12 +41,14 @@ func NewAPIStream() *APIStream { } func (this *APIStream) Start() { - isQuiting := false events.On(events.EventQuit, func() { - isQuiting = true + this.isQuiting = true + if this.cancelFunc != nil { + this.cancelFunc() + } }) for { - if isQuiting { + if this.isQuiting { return } err := this.loop() @@ -61,34 +66,29 @@ func (this *APIStream) loop() error { if err != nil { return errors.Wrap(err) } - isQuiting := false - ctx, cancelFunc := context.WithCancel(rpcClient.Context()) - nodeStream, err := rpcClient.NodeRPC().NodeStream(ctx) - events.On(events.EventQuit, func() { - isQuiting = true - remotelogs.Println("API_STREAM", "quiting") - if nodeStream != nil { - cancelFunc() - } - }) + ctx, cancelFunc := context.WithCancel(rpcClient.Context()) + this.cancelFunc = cancelFunc + + defer func() { + cancelFunc() + }() + + nodeStream, err := rpcClient.NodeRPC().NodeStream(ctx) if err != nil { - if isQuiting { - return nil - } return errors.Wrap(err) } this.stream = nodeStream for { - if isQuiting { + if this.isQuiting { remotelogs.Println("API_STREAM", "quit") break } message, err := nodeStream.Recv() if err != nil { - if isQuiting { + if this.isQuiting { remotelogs.Println("API_STREAM", "quit") return nil }