Skip to content

Commit

Permalink
Add DataflowBlock.ReceiveAllAsync extension method (#37876)
Browse files Browse the repository at this point in the history
  • Loading branch information
manandre authored Mar 6, 2021
1 parent 11cd272 commit 50494e0
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 5 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
<TargetFrameworks>netstandard2.1;netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<Compile Include="System.Threading.Tasks.Dataflow.cs" />
</ItemGroup>
</Project>
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.1'">
<Compile Include="System.Threading.Tasks.Dataflow.netstandard21.cs" />
</ItemGroup>
</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// ------------------------------------------------------------------------------
// Changes to this file must follow the https://aka.ms/api-review process.
// ------------------------------------------------------------------------------

namespace System.Threading.Tasks.Dataflow
{
public static partial class DataflowBlock
{
public static System.Collections.Generic.IAsyncEnumerable<TOutput> ReceiveAllAsync<TOutput>(this System.Threading.Tasks.Dataflow.IReceivableSourceBlock<TOutput> source, System.Threading.CancellationToken cancellationToken = default) { throw null; }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Runtime.CompilerServices;

namespace System.Threading.Tasks.Dataflow
{
public static partial class DataflowBlock
{
/// <summary>Creates an <see cref="IAsyncEnumerable{TOutput}"/> that enables receiving all of the data from the source.</summary>
/// <typeparam name="TOutput">Specifies the type of data contained in the source.</typeparam>
/// <param name="source">The source from which to asynchronously receive.</param>
/// <param name="cancellationToken">The <see cref="System.Threading.CancellationToken"/> which may be used to cancel the receive operation.</param>
/// <returns>The created async enumerable.</returns>
/// <exception cref="System.ArgumentNullException">The <paramref name="source"/> is null (Nothing in Visual Basic).</exception>
public static IAsyncEnumerable<TOutput> ReceiveAllAsync<TOutput>(this IReceivableSourceBlock<TOutput> source, CancellationToken cancellationToken = default)
{
if (source == null)
{
throw new ArgumentNullException(nameof(source));
}

return ReceiveAllAsyncCore(source, cancellationToken);

static async IAsyncEnumerable<TOutput> ReceiveAllAsyncCore(IReceivableSourceBlock<TOutput> source, [EnumeratorCancellation] CancellationToken cancellationToken)
{
while (await source.OutputAvailableAsync(cancellationToken).ConfigureAwait(false))
{
while (source.TryReceive(out TOutput? item))
{
yield return item;
}
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace System.Threading.Tasks.Dataflow
/// <summary>
/// Provides a set of static (Shared in Visual Basic) methods for working with dataflow blocks.
/// </summary>
public static class DataflowBlock
public static partial class DataflowBlock
{
#region LinkTo
/// <summary>Links the <see cref="ISourceBlock{TOutput}"/> to the specified <see cref="ITargetBlock{TOutput}"/>.</summary>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFrameworks>$(NetCoreAppCurrent);netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
<TargetFrameworks>$(NetCoreAppCurrent);netstandard2.1;netstandard2.0;netstandard1.0;netstandard1.1;net461</TargetFrameworks>
<ExcludeCurrentNetCoreAppFromPackage>true</ExcludeCurrentNetCoreAppFromPackage>
<Nullable>enable</Nullable>
</PropertyGroup>
Expand Down Expand Up @@ -55,6 +55,9 @@
<Compile Include="Internal\ConcurrentQueue.cs" />
<Compile Include="Internal\IProducerConsumerCollection.cs" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'netstandard2.1' or '$(TargetFramework)' == '$(NetCoreAppCurrent)'">
<Compile Include="Base\DataflowBlock.IAsyncEnumerable.cs" />
</ItemGroup>
<ItemGroup>
<None Include="XmlDocs\CommonXmlDocComments.xml" />
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using Xunit;

namespace System.Threading.Tasks.Dataflow.Tests
{
public partial class DataflowBlockExtensionTests
{
[Fact]
public void ReceiveAllAsync_ArgumentValidation()
{
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IReceivableSourceBlock<int>)null).ReceiveAllAsync());
AssertExtensions.Throws<ArgumentNullException>("source", () => ((IReceivableSourceBlock<int>)null).ReceiveAllAsync(new CancellationToken(true)));
}

[Fact]
public void ReceiveAllAsync_NotIdempotent()
{
var source = new BufferBlock<int>();
IAsyncEnumerable<int> e = source.ReceiveAllAsync();
Assert.NotNull(e);
Assert.NotSame(e, source.ReceiveAllAsync());
}

[Fact]
public async Task ReceiveAllAsync_UseMoveNextAsyncAfterCompleted_ReturnsFalse()
{
var source = new BufferBlock<int>();
IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();

ValueTask<bool> vt = e.MoveNextAsync();
Assert.False(vt.IsCompleted);
source.Complete();
Assert.False(await vt);

vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.False(vt.Result);
}

[Fact]
public void ReceiveAllAsync_AvailableDataCompletesSynchronously()
{
var source = new BufferBlock<int>();

IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();
try
{
for (int i = 100; i < 110; i++)
{
Assert.True(source.Post(i));
ValueTask<bool> vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.True(vt.Result);
Assert.Equal(i, e.Current);
}
}
finally
{
ValueTask vt = e.DisposeAsync();
Assert.True(vt.IsCompletedSuccessfully);
vt.GetAwaiter().GetResult();
}
}

[Fact]
public async Task ReceiveAllAsync_UnavailableDataCompletesAsynchronously()
{
var source = new BufferBlock<int>();

IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();
try
{
for (int i = 100; i < 110; i++)
{
ValueTask<bool> vt = e.MoveNextAsync();
Assert.False(vt.IsCompleted);
Task producer = Task.Run(() => source.Post(i));
Assert.True(await vt);
await producer;
Assert.Equal(i, e.Current);
}
}
finally
{
ValueTask vt = e.DisposeAsync();
Assert.True(vt.IsCompletedSuccessfully);
vt.GetAwaiter().GetResult();
}
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(128)]
public async Task ReceiveAllAsync_ProducerConsumer_ConsumesAllData(int items)
{
var source = new BufferBlock<int>();

int producedTotal = 0, consumedTotal = 0;
await Task.WhenAll(
Task.Run(async () =>
{
for (int i = 0; i < items; i++)
{
await source.SendAsync(i);
producedTotal += i;
}
source.Complete();
}),
Task.Run(async () =>
{
IAsyncEnumerator<int> e = source.ReceiveAllAsync().GetAsyncEnumerator();
try
{
while (await e.MoveNextAsync())
{
consumedTotal += e.Current;
}
}
finally
{
await e.DisposeAsync();
}
}));

Assert.Equal(producedTotal, consumedTotal);
}

[Fact]
public async Task ReceiveAllAsync_MultipleEnumerationsToEnd()
{
var source = new BufferBlock<int>();

Assert.True(source.Post(42));
source.Complete();

IAsyncEnumerable<int> enumerable = source.ReceiveAllAsync();
IAsyncEnumerator<int> e = enumerable.GetAsyncEnumerator();

Assert.True(await e.MoveNextAsync());
Assert.Equal(42, e.Current);

Assert.False(await e.MoveNextAsync());
Assert.False(await e.MoveNextAsync());

await e.DisposeAsync();

e = enumerable.GetAsyncEnumerator();

Assert.False(await e.MoveNextAsync());
Assert.False(await e.MoveNextAsync());
}

[Theory]
[InlineData(false, false)]
[InlineData(false, true)]
[InlineData(true, false)]
[InlineData(true, true)]
public void ReceiveAllAsync_MultipleSingleElementEnumerations_AllItemsEnumerated(bool sameEnumerable, bool dispose)
{
var source = new BufferBlock<int>();
IAsyncEnumerable<int> enumerable = source.ReceiveAllAsync();

for (int i = 0; i < 10; i++)
{
Assert.True(source.Post(i));
IAsyncEnumerator<int> e = (sameEnumerable ? enumerable : source.ReceiveAllAsync()).GetAsyncEnumerator();
ValueTask<bool> vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.True(vt.Result);
Assert.Equal(i, e.Current);
if (dispose)
{
ValueTask dvt = e.DisposeAsync();
Assert.True(dvt.IsCompletedSuccessfully);
dvt.GetAwaiter().GetResult();
}
}
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ReceiveAllAsync_DualConcurrentEnumeration_AllItemsEnumerated(bool sameEnumerable)
{
var source = new BufferBlock<int>();

IAsyncEnumerable<int> enumerable = source.ReceiveAllAsync();

IAsyncEnumerator<int> e1 = enumerable.GetAsyncEnumerator();
IAsyncEnumerator<int> e2 = (sameEnumerable ? enumerable : source.ReceiveAllAsync()).GetAsyncEnumerator();
Assert.NotSame(e1, e2);

ValueTask<bool> vt1, vt2;
int producerTotal = 0, consumerTotal = 0;
for (int i = 0; i < 10; i++)
{
vt1 = e1.MoveNextAsync();
vt2 = e2.MoveNextAsync();

await source.SendAsync(i);
producerTotal += i;
await source.SendAsync(i * 2);
producerTotal += i * 2;

Assert.True(await vt1);
Assert.True(await vt2);
consumerTotal += e1.Current;
consumerTotal += e2.Current;
}

vt1 = e1.MoveNextAsync();
vt2 = e2.MoveNextAsync();
source.Complete();
Assert.False(await vt1);
Assert.False(await vt2);

Assert.Equal(producerTotal, consumerTotal);
}

[Theory]
[InlineData(false)]
[InlineData(true)]
public async Task ReceiveAllAsync_CanceledBeforeMoveNextAsync_Throws(bool dataAvailable)
{
var source = new BufferBlock<int>();
if (dataAvailable)
{
Assert.True(source.Post(42));
}

using var cts = new CancellationTokenSource();
cts.Cancel();

IAsyncEnumerator<int> e = source.ReceiveAllAsync(cts.Token).GetAsyncEnumerator();
ValueTask<bool> vt = e.MoveNextAsync();
Assert.True(vt.IsCompleted);
Assert.False(vt.IsCompletedSuccessfully);
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await vt);
Assert.Equal(cts.Token, oce.CancellationToken);
}

[Fact]
public async Task ReceiveAllAsync_CanceledAfterMoveNextAsync_Throws()
{
var source = new BufferBlock<int>();
using var cts = new CancellationTokenSource();

IAsyncEnumerator<int> e = source.ReceiveAllAsync(cts.Token).GetAsyncEnumerator();
ValueTask<bool> vt = e.MoveNextAsync();
Assert.False(vt.IsCompleted);

cts.Cancel();
OperationCanceledException oce = await Assert.ThrowsAnyAsync<OperationCanceledException>(async () => await vt);

vt = e.MoveNextAsync();
Assert.True(vt.IsCompletedSuccessfully);
Assert.False(vt.Result);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace System.Threading.Tasks.Dataflow.Tests
{
public class DataflowBlockExtensionsTests
public partial class DataflowBlockExtensionsTests
{
[Fact]
public void TestDataflowMessageHeader()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
<Compile Include="$(CommonTestPath)System\Diagnostics\DebuggerAttributes.cs"
Link="Common\System\Diagnostics\DebuggerAttributes.cs" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == '$(NetCoreAppCurrent)'">
<Compile Include="Dataflow\DataflowBlockExtensionTests.IAsyncEnumerable.cs" />
</ItemGroup>
<ItemGroup Condition="'$(TargetFramework)' == 'net461'">
<ProjectReference Include="..\src\System.Threading.Tasks.Dataflow.csproj" />
</ItemGroup>
Expand Down

0 comments on commit 50494e0

Please sign in to comment.