From 0acc0293dbb54d40b7dd51b2086e890ac2ea73f0 Mon Sep 17 00:00:00 2001 From: Timothy Coleman Date: Fri, 9 Dec 2022 10:43:09 +0000 Subject: [PATCH] Dispose the gRPC call underlying a Read if the read is only partially consumed e.g. when early returning from an await foreach or calling .First(). The IAsyncEnumerable is disposed as before, but now that causes the cancellation of the ResponseStream read, so PumpMessages completes and the call is disposed. The CancellationTokenSource need not be disposed because it is always cancelled (That is, as long as the user makes any attempt to consume the read - if they don't then the call remains open as before) --- .../EventStoreClient.Read.cs | 99 ++++++++++++------- 1 file changed, 62 insertions(+), 37 deletions(-) diff --git a/src/EventStore.Client.Streams/EventStoreClient.Read.cs b/src/EventStore.Client.Streams/EventStoreClient.Read.cs index 777ed4c64..25635f26a 100644 --- a/src/EventStore.Client.Streams/EventStoreClient.Read.cs +++ b/src/EventStore.Client.Streams/EventStoreClient.Read.cs @@ -65,6 +65,7 @@ public ReadAllStreamResult ReadAllAsync( /// public class ReadAllStreamResult : IAsyncEnumerable { private readonly Channel _channel; + private readonly CancellationTokenSource _cts; private int _messagesEnumerated; @@ -85,12 +86,16 @@ async IAsyncEnumerable GetMessages() { throw new InvalidOperationException("Messages may only be enumerated once."); } - await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) { - if (message is StreamMessage.LastAllStreamPosition(var position)) { - LastPosition = position; - } + try { + await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) { + if (message is StreamMessage.LastAllStreamPosition(var position)) { + LastPosition = position; + } - yield return message; + yield return message; + } + } finally { + _cts.Cancel(); } } } @@ -108,14 +113,17 @@ internal ReadAllStreamResult(Func> selectCa AllowSynchronousContinuations = true }); + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var linkedCancellationToken = _cts.Token; + _ = PumpMessages(); async Task PumpMessages() { try { - var callInvoker = await selectCallInvoker(cancellationToken).ConfigureAwait(false); + var callInvoker = await selectCallInvoker(linkedCancellationToken).ConfigureAwait(false); var client = new Streams.Streams.StreamsClient(callInvoker); using var call = client.Read(request, callOptions); - await foreach (var response in call.ResponseStream.ReadAllAsync(cancellationToken) + await foreach (var response in call.ResponseStream.ReadAllAsync(linkedCancellationToken) .ConfigureAwait(false)) { await _channel.Writer.WriteAsync(response.ContentCase switch { StreamNotFound => StreamMessage.NotFound.Instance, @@ -128,7 +136,7 @@ async Task PumpMessages() { new Position(response.LastAllStreamPosition.CommitPosition, response.LastAllStreamPosition.PreparePosition)), _ => StreamMessage.Unknown.Instance - }, cancellationToken).ConfigureAwait(false); + }, linkedCancellationToken).ConfigureAwait(false); } _channel.Writer.Complete(); @@ -141,12 +149,17 @@ async Task PumpMessages() { /// public async IAsyncEnumerator GetAsyncEnumerator( CancellationToken cancellationToken = default) { - await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { - if (message is not StreamMessage.Event e) { - continue; - } - yield return e.ResolvedEvent; + try { + await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { + if (message is not StreamMessage.Event e) { + continue; + } + + yield return e.ResolvedEvent; + } + } finally { + _cts.Cancel(); } } } @@ -205,6 +218,7 @@ public ReadStreamResult ReadStreamAsync( /// public class ReadStreamResult : IAsyncEnumerable { private readonly Channel _channel; + private readonly CancellationTokenSource _cts; private int _messagesEnumerated; @@ -235,19 +249,23 @@ async IAsyncEnumerable GetMessages() { throw new InvalidOperationException("Messages may only be enumerated once."); } - await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) { - switch (message) { - case StreamMessage.FirstStreamPosition(var streamPosition): - FirstStreamPosition = streamPosition; - break; - case StreamMessage.LastStreamPosition(var lastStreamPosition): - LastStreamPosition = lastStreamPosition; - break; - default: - break; - } + try { + await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) { + switch (message) { + case StreamMessage.FirstStreamPosition(var streamPosition): + FirstStreamPosition = streamPosition; + break; + case StreamMessage.LastStreamPosition(var lastStreamPosition): + LastStreamPosition = lastStreamPosition; + break; + default: + break; + } - yield return message; + yield return message; + } + } finally { + _cts.Cancel(); } } } @@ -273,6 +291,8 @@ internal ReadStreamResult(Func> selectCallI StreamName = request.Options.Stream.StreamIdentifier!; var tcs = new TaskCompletionSource(); + _cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + var linkedCancellationToken = _cts.Token; #pragma warning disable CS0612 ReadState = tcs.Task; #pragma warning restore CS0612 @@ -283,17 +303,17 @@ async Task PumpMessages() { var firstMessageRead = false; try { - var callInvoker = await selectCallInvoker(cancellationToken).ConfigureAwait(false); + var callInvoker = await selectCallInvoker(linkedCancellationToken).ConfigureAwait(false); var client = new Streams.Streams.StreamsClient(callInvoker); using var call = client.Read(request, callOptions); - await foreach (var response in call.ResponseStream.ReadAllAsync(cancellationToken) + await foreach (var response in call.ResponseStream.ReadAllAsync(linkedCancellationToken) .ConfigureAwait(false)) { if (!firstMessageRead) { firstMessageRead = true; if (response.ContentCase != StreamNotFound || request.Options.Stream == null) { - await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, cancellationToken) + await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, linkedCancellationToken) .ConfigureAwait(false); tcs.SetResult(Client.ReadState.Ok); } else { @@ -312,7 +332,7 @@ await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, cancellationToken) new Position(response.LastAllStreamPosition.CommitPosition, response.LastAllStreamPosition.PreparePosition)), _ => StreamMessage.Unknown.Instance - }, cancellationToken).ConfigureAwait(false); + }, linkedCancellationToken).ConfigureAwait(false); } _channel.Writer.Complete(); @@ -326,16 +346,21 @@ await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, cancellationToken) /// public async IAsyncEnumerator GetAsyncEnumerator( CancellationToken cancellationToken = default) { - await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { - if (message is StreamMessage.NotFound) { - throw new StreamNotFoundException(StreamName); - } - if (message is not StreamMessage.Event e) { - continue; - } + try { + await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) { + if (message is StreamMessage.NotFound) { + throw new StreamNotFoundException(StreamName); + } + + if (message is not StreamMessage.Event e) { + continue; + } - yield return e.ResolvedEvent; + yield return e.ResolvedEvent; + } + } finally { + _cts.Cancel(); } } }