Skip to content

Commit

Permalink
Dispose the gRPC call underlying a Read if the read is only partially…
Browse files Browse the repository at this point in the history
… consumed

e.g. when early returning from an await foreach or calling .First().

The IAsyncEnumerable is disposed as before, but now that causes the cancellation of the ResponseStream read, so PumpMessages completes and the call is disposed.

The CancellationTokenSource need not be disposed because it is always cancelled (That is, as long as the user makes any attempt to consume the read - if they don't then
the call remains open as before)
  • Loading branch information
timothycoleman committed Dec 12, 2022
1 parent 8b40e89 commit 0acc029
Showing 1 changed file with 62 additions and 37 deletions.
99 changes: 62 additions & 37 deletions src/EventStore.Client.Streams/EventStoreClient.Read.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ public ReadAllStreamResult ReadAllAsync(
/// </summary>
public class ReadAllStreamResult : IAsyncEnumerable<ResolvedEvent> {
private readonly Channel<StreamMessage> _channel;
private readonly CancellationTokenSource _cts;

private int _messagesEnumerated;

Expand All @@ -85,12 +86,16 @@ async IAsyncEnumerable<StreamMessage> GetMessages() {
throw new InvalidOperationException("Messages may only be enumerated once.");
}

await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) {
if (message is StreamMessage.LastAllStreamPosition(var position)) {
LastPosition = position;
}
try {
await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) {
if (message is StreamMessage.LastAllStreamPosition(var position)) {
LastPosition = position;
}

yield return message;
yield return message;
}
} finally {
_cts.Cancel();
}
}
}
Expand All @@ -108,14 +113,17 @@ internal ReadAllStreamResult(Func<CancellationToken, Task<CallInvoker>> selectCa
AllowSynchronousContinuations = true
});

_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var linkedCancellationToken = _cts.Token;

_ = PumpMessages();

async Task PumpMessages() {
try {
var callInvoker = await selectCallInvoker(cancellationToken).ConfigureAwait(false);
var callInvoker = await selectCallInvoker(linkedCancellationToken).ConfigureAwait(false);
var client = new Streams.Streams.StreamsClient(callInvoker);
using var call = client.Read(request, callOptions);
await foreach (var response in call.ResponseStream.ReadAllAsync(cancellationToken)
await foreach (var response in call.ResponseStream.ReadAllAsync(linkedCancellationToken)
.ConfigureAwait(false)) {
await _channel.Writer.WriteAsync(response.ContentCase switch {
StreamNotFound => StreamMessage.NotFound.Instance,
Expand All @@ -128,7 +136,7 @@ async Task PumpMessages() {
new Position(response.LastAllStreamPosition.CommitPosition,
response.LastAllStreamPosition.PreparePosition)),
_ => StreamMessage.Unknown.Instance
}, cancellationToken).ConfigureAwait(false);
}, linkedCancellationToken).ConfigureAwait(false);
}

_channel.Writer.Complete();
Expand All @@ -141,12 +149,17 @@ async Task PumpMessages() {
/// <inheritdoc />
public async IAsyncEnumerator<ResolvedEvent> GetAsyncEnumerator(
CancellationToken cancellationToken = default) {
await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) {
if (message is not StreamMessage.Event e) {
continue;
}

yield return e.ResolvedEvent;
try {
await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) {
if (message is not StreamMessage.Event e) {
continue;
}

yield return e.ResolvedEvent;
}
} finally {
_cts.Cancel();
}
}
}
Expand Down Expand Up @@ -205,6 +218,7 @@ public ReadStreamResult ReadStreamAsync(
/// </summary>
public class ReadStreamResult : IAsyncEnumerable<ResolvedEvent> {
private readonly Channel<StreamMessage> _channel;
private readonly CancellationTokenSource _cts;

private int _messagesEnumerated;

Expand Down Expand Up @@ -235,19 +249,23 @@ async IAsyncEnumerable<StreamMessage> GetMessages() {
throw new InvalidOperationException("Messages may only be enumerated once.");
}

await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) {
switch (message) {
case StreamMessage.FirstStreamPosition(var streamPosition):
FirstStreamPosition = streamPosition;
break;
case StreamMessage.LastStreamPosition(var lastStreamPosition):
LastStreamPosition = lastStreamPosition;
break;
default:
break;
}
try {
await foreach (var message in _channel.Reader.ReadAllAsync().ConfigureAwait(false)) {
switch (message) {
case StreamMessage.FirstStreamPosition(var streamPosition):
FirstStreamPosition = streamPosition;
break;
case StreamMessage.LastStreamPosition(var lastStreamPosition):
LastStreamPosition = lastStreamPosition;
break;
default:
break;
}

yield return message;
yield return message;
}
} finally {
_cts.Cancel();
}
}
}
Expand All @@ -273,6 +291,8 @@ internal ReadStreamResult(Func<CancellationToken, Task<CallInvoker>> selectCallI
StreamName = request.Options.Stream.StreamIdentifier!;

var tcs = new TaskCompletionSource<ReadState>();
_cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var linkedCancellationToken = _cts.Token;
#pragma warning disable CS0612
ReadState = tcs.Task;
#pragma warning restore CS0612
Expand All @@ -283,17 +303,17 @@ async Task PumpMessages() {
var firstMessageRead = false;

try {
var callInvoker = await selectCallInvoker(cancellationToken).ConfigureAwait(false);
var callInvoker = await selectCallInvoker(linkedCancellationToken).ConfigureAwait(false);
var client = new Streams.Streams.StreamsClient(callInvoker);
using var call = client.Read(request, callOptions);

await foreach (var response in call.ResponseStream.ReadAllAsync(cancellationToken)
await foreach (var response in call.ResponseStream.ReadAllAsync(linkedCancellationToken)
.ConfigureAwait(false)) {
if (!firstMessageRead) {
firstMessageRead = true;

if (response.ContentCase != StreamNotFound || request.Options.Stream == null) {
await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, cancellationToken)
await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, linkedCancellationToken)
.ConfigureAwait(false);
tcs.SetResult(Client.ReadState.Ok);
} else {
Expand All @@ -312,7 +332,7 @@ await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, cancellationToken)
new Position(response.LastAllStreamPosition.CommitPosition,
response.LastAllStreamPosition.PreparePosition)),
_ => StreamMessage.Unknown.Instance
}, cancellationToken).ConfigureAwait(false);
}, linkedCancellationToken).ConfigureAwait(false);
}

_channel.Writer.Complete();
Expand All @@ -326,16 +346,21 @@ await _channel.Writer.WriteAsync(StreamMessage.Ok.Instance, cancellationToken)
/// <inheritdoc />
public async IAsyncEnumerator<ResolvedEvent> GetAsyncEnumerator(
CancellationToken cancellationToken = default) {
await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) {
if (message is StreamMessage.NotFound) {
throw new StreamNotFoundException(StreamName);
}

if (message is not StreamMessage.Event e) {
continue;
}
try {
await foreach (var message in _channel.Reader.ReadAllAsync(cancellationToken).ConfigureAwait(false)) {
if (message is StreamMessage.NotFound) {
throw new StreamNotFoundException(StreamName);
}

if (message is not StreamMessage.Event e) {
continue;
}

yield return e.ResolvedEvent;
yield return e.ResolvedEvent;
}
} finally {
_cts.Cancel();
}
}
}
Expand Down

0 comments on commit 0acc029

Please sign in to comment.