Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix System.Text.Json IAsyncEnumerator disposal on cancellation #57505

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,22 @@ protected override bool OnWriteResume(Utf8JsonWriter writer, TAsyncEnumerable va
IAsyncEnumerator<TElement> enumerator;
ValueTask<bool> moveNextTask;

if (state.Current.AsyncEnumerator is null)
if (state.Current.AsyncDisposable is null)
{
enumerator = value.GetAsyncEnumerator(state.CancellationToken);
// async enumerators can only be disposed asynchronously;
// store in the WriteStack for future disposal
// by the root async serialization context.
state.Current.AsyncDisposable = enumerator;
// enumerator.MoveNextAsync() calls can throw,
// ensure the enumerator already is stored
// in the WriteStack for proper disposal.
moveNextTask = enumerator.MoveNextAsync();
// we always need to attach the enumerator to the stack
// since it will need to be disposed asynchronously.
state.Current.AsyncEnumerator = enumerator;
}
else
{
Debug.Assert(state.Current.AsyncEnumerator is IAsyncEnumerator<TElement>);
enumerator = (IAsyncEnumerator<TElement>)state.Current.AsyncEnumerator;
Debug.Assert(state.Current.AsyncDisposable is IAsyncEnumerator<TElement>);
enumerator = (IAsyncEnumerator<TElement>)state.Current.AsyncDisposable;

if (state.Current.AsyncEnumeratorIsPendingCompletion)
{
Expand All @@ -84,6 +88,10 @@ protected override bool OnWriteResume(Utf8JsonWriter writer, TAsyncEnumerable va
{
if (!moveNextTask.Result)
{
// we have completed serialization for the enumerator,
// clear from the stack and schedule for async disposal.
state.Current.AsyncDisposable = null;
state.AddCompletedAsyncDisposable(enumerator);
return true;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,35 +325,38 @@ private static async Task WriteStreamAsync<TValue>(
try
{
isFinalBlock = WriteCore(converter, writer, value, options, ref state);
await bufferWriter.WriteToStreamAsync(utf8Json, cancellationToken).ConfigureAwait(false);
bufferWriter.Clear();
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
}
finally
{
if (state.PendingAsyncDisposables?.Count > 0)
// Await any pending resumable converter tasks (currently these can only be IAsyncEnumerator.MoveNextAsync() tasks).
// Note that pending tasks are always awaited, even if an exception has been thrown or the cancellation token has fired.
if (state.PendingTask is not null)
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
{
await state.DisposePendingAsyncDisposables().ConfigureAwait(false);
try
{
await state.PendingTask.ConfigureAwait(false);
}
catch
{
// Exceptions should only be propagated by the resuming converter
// TODO https://github.com/dotnet/runtime/issues/22144
}
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
}
}

await bufferWriter.WriteToStreamAsync(utf8Json, cancellationToken).ConfigureAwait(false);
bufferWriter.Clear();

if (state.PendingTask is not null)
{
try
{
await state.PendingTask.ConfigureAwait(false);
}
catch
// Dispose any pending async disposables (currently these can only be completed IAsyncEnumerators).
if (state.CompletedAsyncDisposables?.Count > 0)
{
// Exceptions will be propagated elsewhere
// TODO https://github.com/dotnet/runtime/issues/22144
await state.DisposeCompletedAsyncDisposables().ConfigureAwait(false);
}
}

} while (!isFinalBlock);
}
catch
{
// On exception, walk the WriteStack for any orphaned disposables and try to dispose them.
await state.DisposePendingDisposablesOnExceptionAsync().ConfigureAwait(false);
throw;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ internal struct WriteStack
public Task? PendingTask;

/// <summary>
/// List of IAsyncDisposables that have been scheduled for disposal by converters.
/// List of completed IAsyncDisposables that have been scheduled for disposal by converters.
/// </summary>
public List<IAsyncDisposable>? PendingAsyncDisposables;
public List<IAsyncDisposable>? CompletedAsyncDisposables;

/// <summary>
/// The amount of bytes to write before the underlying Stream should be flushed and the
Expand Down Expand Up @@ -196,28 +196,23 @@ public void Pop(bool success)
{
Debug.Assert(_continuationCount == 0);

if (Current.AsyncEnumerator is not null)
{
// we have completed serialization of an AsyncEnumerator,
// pop from the stack and schedule for async disposal.
PendingAsyncDisposables ??= new List<IAsyncDisposable>();
PendingAsyncDisposables.Add(Current.AsyncEnumerator);
}

if (--_count > 0)
{
Current = _stack[_count - 1];
}
}
}

public void AddCompletedAsyncDisposable(IAsyncDisposable asyncDisposable)
=> (CompletedAsyncDisposables ??= new List<IAsyncDisposable>()).Add(asyncDisposable);

// Asynchronously dispose of any AsyncDisposables that have been scheduled for disposal
public async ValueTask DisposePendingAsyncDisposables()
public async ValueTask DisposeCompletedAsyncDisposables()
{
Debug.Assert(PendingAsyncDisposables?.Count > 0);
Debug.Assert(CompletedAsyncDisposables?.Count > 0);
Exception? exception = null;

foreach (IAsyncDisposable asyncDisposable in PendingAsyncDisposables)
foreach (IAsyncDisposable asyncDisposable in CompletedAsyncDisposables)
{
try
{
Expand All @@ -234,7 +229,7 @@ public async ValueTask DisposePendingAsyncDisposables()
ExceptionDispatchInfo.Capture(exception).Throw();
}

PendingAsyncDisposables.Clear();
CompletedAsyncDisposables.Clear();
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
}

/// <summary>
Expand All @@ -245,13 +240,13 @@ public void DisposePendingDisposablesOnException()
{
Exception? exception = null;

Debug.Assert(Current.AsyncEnumerator is null);
Debug.Assert(Current.AsyncDisposable is null);
DisposeFrame(Current.CollectionEnumerator, ref exception);

int stackSize = Math.Max(_count, _continuationCount);
for (int i = 0; i < stackSize - 1; i++)
{
Debug.Assert(_stack[i].AsyncEnumerator is null);
Debug.Assert(_stack[i].AsyncDisposable is null);
DisposeFrame(_stack[i].CollectionEnumerator, ref exception);
}

Expand Down Expand Up @@ -284,12 +279,12 @@ public async ValueTask DisposePendingDisposablesOnExceptionAsync()
{
Exception? exception = null;

exception = await DisposeFrame(Current.CollectionEnumerator, Current.AsyncEnumerator, exception).ConfigureAwait(false);
exception = await DisposeFrame(Current.CollectionEnumerator, Current.AsyncDisposable, exception).ConfigureAwait(false);

int stackSize = Math.Max(_count, _continuationCount);
for (int i = 0; i < stackSize - 1; i++)
{
exception = await DisposeFrame(_stack[i].CollectionEnumerator, _stack[i].AsyncEnumerator, exception).ConfigureAwait(false);
exception = await DisposeFrame(_stack[i].CollectionEnumerator, _stack[i].AsyncDisposable, exception).ConfigureAwait(false);
}

if (exception is not null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ internal struct WriteStackFrame
/// <summary>
/// The enumerator for resumable async disposables.
/// </summary>
public IAsyncDisposable? AsyncEnumerator;
public IAsyncDisposable? AsyncDisposable;

/// <summary>
/// The current stackframe has suspended serialization due to a pending task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,24 @@ public async Task WriteNestedAsyncEnumerable_DTO<TElement>(IEnumerable<TElement>
Assert.Equal(1, asyncEnumerable.TotalDisposedEnumerators);
}

[Fact, OuterLoop]
public async Task WriteAsyncEnumerable_LongRunningEnumeration_Cancellation()
[Theory, OuterLoop]
[InlineData(5000, 1000, true)]
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
[InlineData(5000, 1000, false)]
[InlineData(1000, 10_000, true)]
[InlineData(1000, 10_000, false)]
public async Task WriteAsyncEnumerable_LongRunningEnumeration_Cancellation(
int cancellationTokenSourceDelayMilliseconds,
int enumeratorDelayMilliseconds,
bool passCancellationTokenToDelayTask)
{
var longRunningEnumerable = new MockedAsyncEnumerable<int>(
source: Enumerable.Range(1, 100),
source: Enumerable.Range(1, 1000),
delayInterval: 1,
delay: TimeSpan.FromMinutes(1));
delay: TimeSpan.FromMilliseconds(enumeratorDelayMilliseconds),
passCancellationTokenToDelayTask);

using var utf8Stream = new Utf8MemoryStream();
using var cts = new CancellationTokenSource(delay: TimeSpan.FromSeconds(5));
using var cts = new CancellationTokenSource(delay: TimeSpan.FromMilliseconds(cancellationTokenSourceDelayMilliseconds));
await Assert.ThrowsAsync<TaskCanceledException>(async () =>
await JsonSerializer.SerializeAsync(utf8Stream, longRunningEnumerable, cancellationToken: cts.Token));

Expand Down Expand Up @@ -225,21 +233,42 @@ public static IEnumerable<object[]> GetAsyncEnumerableSources()
static object[] WrapArgs<TSource>(IEnumerable<TSource> source, int delayInterval, int bufferSize) => new object[]{ source, delayInterval, bufferSize };
}

private class MockedAsyncEnumerable<TElement> : IAsyncEnumerable<TElement>, IEnumerable<TElement>
[Fact]
public async Task RegressionTest_DisposingEnumeratorOnPendingMoveNextAsyncOperation()
{
// Regression test for https://github.com/dotnet/runtime/issues/57360
using var stream = new Utf8MemoryStream();
using var cts = new CancellationTokenSource(millisecondsDelay: 1000);
await Assert.ThrowsAsync<TaskCanceledException>(async () => await JsonSerializer.SerializeAsync(stream, GetNumbersAsync(), cancellationToken: cts.Token));

static async IAsyncEnumerable<int> GetNumbersAsync()
{
int i = 0;
while (true)
{
await Task.Delay(100);
yield return i++;
}
}
}

public class MockedAsyncEnumerable<TElement> : IAsyncEnumerable<TElement>, IEnumerable<TElement>
{
private readonly IEnumerable<TElement> _source;
private readonly TimeSpan _delay;
private readonly int _delayInterval;
private readonly bool _passCancellationTokenToDelayTask;

internal int TotalCreatedEnumerators { get; private set; }
internal int TotalDisposedEnumerators { get; private set; }
internal int TotalEnumeratedElements { get; private set; }
public int TotalCreatedEnumerators { get; private set; }
public int TotalDisposedEnumerators { get; private set; }
public int TotalEnumeratedElements { get; private set; }

public MockedAsyncEnumerable(IEnumerable<TElement> source, int delayInterval = 0, TimeSpan? delay = null)
public MockedAsyncEnumerable(IEnumerable<TElement> source, int delayInterval = 0, TimeSpan? delay = null, bool passCancellationTokenToDelayTask = true)
{
_source = source;
_delay = delay ?? TimeSpan.FromMilliseconds(20);
_delayInterval = delayInterval;
_passCancellationTokenToDelayTask = passCancellationTokenToDelayTask;
}

public IAsyncEnumerator<TElement> GetAsyncEnumerator(CancellationToken cancellationToken = default)
Expand Down Expand Up @@ -277,7 +306,7 @@ private async IAsyncEnumerator<TElement> GetAsyncEnumeratorInner(CancellationTok
{
if (i > 0 && _delayInterval > 0 && i % _delayInterval == 0)
{
await Task.Delay(_delay, cancellationToken);
await Task.Delay(_delay, _passCancellationTokenToDelayTask ? cancellationToken : default);
}

if (cancellationToken.IsCancellationRequested)
Expand Down