Skip to content

Commit

Permalink
Dedup UnboundedChannel and UnboundedPriorityChannel (dotnet#101396)
Browse files Browse the repository at this point in the history
* Dedup UnboundedChannel and UnboundedPriorityChannel

We can use generic specialization to avoid duplicating all the code for the different queue types. This should also make it much simpler to add other queue types in the future.

* Address PR feedback
  • Loading branch information
stephentoub authored and michaelgsharp committed May 8, 2024
1 parent 4b297b3 commit b965caf
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 402 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
<Compile Include="System\Threading\Channels\Channel_1.cs" />
<Compile Include="System\Threading\Channels\Channel_2.cs" />
<Compile Include="System\Threading\Channels\IDebugEnumerator.cs" />
<Compile Include="System\Threading\Channels\IUnboundedChannelQueue.cs" />
<Compile Include="System\Threading\Channels\SingleConsumerUnboundedChannel.cs" />
<Compile Include="System\Threading\Channels\UnboundedChannel.cs" />
<Compile Include="$(CommonPath)Internal\Padding.cs" Link="Common\Internal\Padding.cs" />
Expand All @@ -44,7 +45,6 @@ System.Threading.Channel&lt;T&gt;</PackageDescription>
<Compile Include="System\Threading\Channels\AsyncOperation.netcoreapp.cs" />
<Compile Include="System\Threading\Channels\Channel.netcoreapp.cs" />
<Compile Include="System\Threading\Channels\ChannelOptions.netcoreapp.cs" />
<Compile Include="System\Threading\Channels\UnboundedPriorityChannel.cs" />
</ItemGroup>

<ItemGroup Condition="'$(TargetFramework)' == '$(NetCoreAppCurrent)'">
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

namespace System.Threading.Channels
{
/// <summary>Provides static methods for creating channels.</summary>
Expand All @@ -9,7 +13,7 @@ public static partial class Channel
/// <summary>Creates an unbounded channel usable by any number of readers and writers concurrently.</summary>
/// <returns>The created channel.</returns>
public static Channel<T> CreateUnbounded<T>() =>
new UnboundedChannel<T>(runContinuationsAsynchronously: true);
new UnboundedChannel<T, UnboundedChannelConcurrentQueue<T>>(new(new()), runContinuationsAsynchronously: true);

/// <summary>Creates an unbounded channel subject to the provided options.</summary>
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
Expand All @@ -27,7 +31,7 @@ public static Channel<T> CreateUnbounded<T>(UnboundedChannelOptions options)
return new SingleConsumerUnboundedChannel<T>(!options.AllowSynchronousContinuations);
}

return new UnboundedChannel<T>(!options.AllowSynchronousContinuations);
return new UnboundedChannel<T, UnboundedChannelConcurrentQueue<T>>(new(new()), !options.AllowSynchronousContinuations);
}

/// <summary>Creates a channel with the specified maximum capacity.</summary>
Expand Down Expand Up @@ -71,5 +75,32 @@ public static Channel<T> CreateBounded<T>(BoundedChannelOptions options, Action<

return new BoundedChannel<T>(options.Capacity, options.FullMode, !options.AllowSynchronousContinuations, itemDropped);
}

/// <summary>Provides an <see cref="IUnboundedChannelQueue{T}"/> for a <see cref="ConcurrentQueue{T}"/>.</summary>
private readonly struct UnboundedChannelConcurrentQueue<T>(ConcurrentQueue<T> queue) : IUnboundedChannelQueue<T>
{
private readonly ConcurrentQueue<T> _queue = queue;

/// <inheritdoc/>
public bool IsThreadSafe => true;

/// <inheritdoc/>
public void Enqueue(T item) => _queue.Enqueue(item);

/// <inheritdoc/>
public bool TryDequeue([MaybeNullWhen(false)] out T item) => _queue.TryDequeue(out item);

/// <inheritdoc/>
public bool TryPeek([MaybeNullWhen(false)] out T item) => _queue.TryPeek(out item);

/// <inheritdoc/>
public int Count => _queue.Count;

/// <inheritdoc/>
public bool IsEmpty => _queue.IsEmpty;

/// <inheritdoc/>
public IEnumerator<T> GetEnumerator() => _queue.GetEnumerator();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,22 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

namespace System.Threading.Channels
{
/// <summary>Provides static methods for creating channels.</summary>
public static partial class Channel
{
/// <summary>Creates an unbounded prioritized channel usable by any number of readers and writers concurrently.</summary>
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
/// <returns>The created channel.</returns>
/// <remarks>
/// <see cref="Comparer{T}.Default"/> is used to determine priority of elements.
/// The next item read from the channel will be the element available in the channel with the lowest priority value.
/// </remarks>
public static Channel<T> CreateUnboundedPrioritized<T>() =>
new UnboundedPrioritizedChannel<T>(runContinuationsAsynchronously: true, comparer: null);
new UnboundedChannel<T, UnboundedChannelPriorityQueue<T>>(new(new()), runContinuationsAsynchronously: true);

/// <summary>Creates an unbounded prioritized channel subject to the provided options.</summary>
/// <typeparam name="T">Specifies the type of data in the channel.</typeparam>
Expand All @@ -30,7 +32,45 @@ public static Channel<T> CreateUnboundedPrioritized<T>(UnboundedPrioritizedChann
{
ArgumentNullException.ThrowIfNull(options);

return new UnboundedPrioritizedChannel<T>(!options.AllowSynchronousContinuations, options.Comparer);
return new UnboundedChannel<T, UnboundedChannelPriorityQueue<T>>(new(new(options.Comparer)), !options.AllowSynchronousContinuations);
}

/// <summary>Provides an <see cref="IUnboundedChannelQueue{T}"/> for a <see cref="PriorityQueue{TElement, TPriority}"/>.</summary>
private readonly struct UnboundedChannelPriorityQueue<T>(PriorityQueue<bool, T> queue) : IUnboundedChannelQueue<T>
{
private readonly PriorityQueue<bool, T> _queue = queue;

/// <inheritdoc/>
public bool IsThreadSafe => false;

/// <inheritdoc/>
public void Enqueue(T item) => _queue.Enqueue(true, item);

/// <inheritdoc/>
public bool TryDequeue([MaybeNullWhen(false)] out T item) => _queue.TryDequeue(out _, out item);

/// <inheritdoc/>
public bool TryPeek([MaybeNullWhen(false)] out T item) => _queue.TryPeek(out _, out item);

/// <inheritdoc/>
public int Count => _queue.Count;

/// <inheritdoc/>
public bool IsEmpty => _queue.Count == 0;

/// <inheritdoc/>
public IEnumerator<T> GetEnumerator()
{
List<T> list = [];
foreach ((bool _, T Priority) item in _queue.UnorderedItems)
{
list.Add(item.Priority);
}

list.Sort(_queue.Comparer);

return list.GetEnumerator();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ internal interface IDebugEnumerable<T>
IEnumerator<T> GetEnumerator();
}

internal sealed class DebugEnumeratorDebugView<T>
internal class DebugEnumeratorDebugView<T>
{
public DebugEnumeratorDebugView(IDebugEnumerable<T> enumerable)
{
Expand All @@ -26,4 +26,6 @@ public DebugEnumeratorDebugView(IDebugEnumerable<T> enumerable)
[DebuggerBrowsable(DebuggerBrowsableState.RootHidden)]
public T[] Items { get; }
}

internal sealed class DebugEnumeratorDebugView<T, TOther>(IDebugEnumerable<T> enumerable) : DebugEnumeratorDebugView<T>(enumerable);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;

namespace System.Threading.Channels
{
/// <summary>Representation of the queue data structure used by <see cref="UnboundedChannel{T, TQueue}"/>.</summary>
internal interface IUnboundedChannelQueue<T> : IDebugEnumerable<T>
{
/// <summary>Gets whether the other members are safe to use concurrently with each other and themselves.</summary>
bool IsThreadSafe { get; }

/// <summary>Enqueues an item into the queue.</summary>
/// <param name="item">The item to enqueue.</param>
void Enqueue(T item);

/// <summary>Dequeues an item from the queue, if possible.</summary>
/// <param name="item">The dequeued item, or default if the queue was empty.</param>
/// <returns>Whether an item was dequeued.</returns>
bool TryDequeue([MaybeNullWhen(false)] out T item);

/// <summary>Peeks at the next item from the queue that would be dequeued, if possible.</summary>
/// <param name="item">The peeked item, or default if the queue was empty.</param>
/// <returns>Whether an item was peeked.</returns>
bool TryPeek([MaybeNullWhen(false)] out T item);

/// <summary>Gets the number of elements in the queue.</summary>
int Count { get; }

/// <summary>Gets whether the queue is empty.</summary>
bool IsEmpty { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Threading.Tasks;

namespace System.Threading.Channels
{
/// <summary>Provides a buffered channel of unbounded capacity.</summary>
[DebuggerDisplay("Items = {ItemsCountForDebugger}, Closed = {ChannelIsClosedForDebugger}")]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
internal sealed class UnboundedChannel<T> : Channel<T>, IDebugEnumerable<T>
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
internal sealed class UnboundedChannel<T, TQueue> : Channel<T>, IDebugEnumerable<T> where TQueue : struct, IUnboundedChannelQueue<T>
{
/// <summary>Task that indicates the channel has completed.</summary>
private readonly TaskCompletionSource _completion;
/// <summary>The items in the channel.</summary>
private readonly ConcurrentQueue<T> _items = new ConcurrentQueue<T>();
private readonly TQueue _items;
/// <summary>Readers blocked reading from the channel.</summary>
private readonly Deque<AsyncOperation<T>> _blockedReaders = new Deque<AsyncOperation<T>>();
/// <summary>Whether to force continuations to be executed asynchronously from producer writes.</summary>
Expand All @@ -29,23 +30,24 @@ internal sealed class UnboundedChannel<T> : Channel<T>, IDebugEnumerable<T>
private Exception? _doneWriting;

/// <summary>Initialize the channel.</summary>
internal UnboundedChannel(bool runContinuationsAsynchronously)
internal UnboundedChannel(TQueue items, bool runContinuationsAsynchronously)
{
_items = items;
_runContinuationsAsynchronously = runContinuationsAsynchronously;
_completion = new TaskCompletionSource(runContinuationsAsynchronously ? TaskCreationOptions.RunContinuationsAsynchronously : TaskCreationOptions.None);
Reader = new UnboundedChannelReader(this);
Writer = new UnboundedChannelWriter(this);
}

[DebuggerDisplay("Items = {Count}")]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
private sealed class UnboundedChannelReader : ChannelReader<T>, IDebugEnumerable<T>
{
internal readonly UnboundedChannel<T> _parent;
internal readonly UnboundedChannel<T, TQueue> _parent;
private readonly AsyncOperation<T> _readerSingleton;
private readonly AsyncOperation<bool> _waiterSingleton;

internal UnboundedChannelReader(UnboundedChannel<T> parent)
internal UnboundedChannelReader(UnboundedChannel<T, TQueue> parent)
{
_parent = parent;
_readerSingleton = new AsyncOperation<T>(parent._runContinuationsAsynchronously, pooled: true);
Expand All @@ -68,8 +70,8 @@ public override ValueTask<T> ReadAsync(CancellationToken cancellationToken)
}

// Dequeue an item if we can.
UnboundedChannel<T> parent = _parent;
if (parent._items.TryDequeue(out T? item))
UnboundedChannel<T, TQueue> parent = _parent;
if (parent._items.IsThreadSafe && parent._items.TryDequeue(out T? item))
{
CompleteIfDone(parent);
return new ValueTask<T>(item);
Expand Down Expand Up @@ -112,24 +114,60 @@ public override ValueTask<T> ReadAsync(CancellationToken cancellationToken)

public override bool TryRead([MaybeNullWhen(false)] out T item)
{
UnboundedChannel<T> parent = _parent;
UnboundedChannel<T, TQueue> parent = _parent;
return parent._items.IsThreadSafe ?
LockFree(parent, out item) :
Locked(parent, out item);

// Dequeue an item if we can
if (parent._items.TryDequeue(out item))
static bool LockFree(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
{
CompleteIfDone(parent);
return true;
if (parent._items.TryDequeue(out item))
{
CompleteIfDone(parent);
return true;
}

item = default;
return false;
}

item = default;
return false;
static bool Locked(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
{
lock (parent.SyncObj)
{
if (parent._items.TryDequeue(out item))
{
CompleteIfDone(parent);
return true;
}
}

item = default;
return false;
}
}

public override bool TryPeek([MaybeNullWhen(false)] out T item) =>
_parent._items.TryPeek(out item);
public override bool TryPeek([MaybeNullWhen(false)] out T item)
{
UnboundedChannel<T, TQueue> parent = _parent;
return parent._items.IsThreadSafe ?
parent._items.TryPeek(out item) :
Locked(parent, out item);

// Separated out to keep the try/finally from preventing TryPeek from being inlined
static bool Locked(UnboundedChannel<T, TQueue> parent, [MaybeNullWhen(false)] out T item)
{
lock (parent.SyncObj)
{
return parent._items.TryPeek(out item);
}
}
}

private static void CompleteIfDone(UnboundedChannel<T> parent)
private static void CompleteIfDone(UnboundedChannel<T, TQueue> parent)
{
Debug.Assert(parent._items.IsThreadSafe || Monitor.IsEntered(parent.SyncObj));

if (parent._doneWriting != null && parent._items.IsEmpty)
{
// If we've now emptied the items queue and we're not getting any more, complete.
Expand All @@ -144,12 +182,12 @@ public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationTo
return new ValueTask<bool>(Task.FromCanceled<bool>(cancellationToken));
}

if (!_parent._items.IsEmpty)
if (_parent._items.IsThreadSafe && !_parent._items.IsEmpty)
{
return new ValueTask<bool>(true);
}

UnboundedChannel<T> parent = _parent;
UnboundedChannel<T, TQueue> parent = _parent;

lock (parent.SyncObj)
{
Expand Down Expand Up @@ -192,15 +230,15 @@ public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationTo
}

[DebuggerDisplay("Items = {ItemsCountForDebugger}")]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<>))]
[DebuggerTypeProxy(typeof(DebugEnumeratorDebugView<,>))]
private sealed class UnboundedChannelWriter : ChannelWriter<T>, IDebugEnumerable<T>
{
internal readonly UnboundedChannel<T> _parent;
internal UnboundedChannelWriter(UnboundedChannel<T> parent) => _parent = parent;
internal readonly UnboundedChannel<T, TQueue> _parent;
internal UnboundedChannelWriter(UnboundedChannel<T, TQueue> parent) => _parent = parent;

public override bool TryComplete(Exception? error)
{
UnboundedChannel<T> parent = _parent;
UnboundedChannel<T, TQueue> parent = _parent;
bool completeTask;

lock (parent.SyncObj)
Expand Down Expand Up @@ -240,7 +278,7 @@ public override bool TryComplete(Exception? error)

public override bool TryWrite(T item)
{
UnboundedChannel<T> parent = _parent;
UnboundedChannel<T, TQueue> parent = _parent;
while (true)
{
AsyncOperation<T>? blockedReader = null;
Expand Down Expand Up @@ -321,7 +359,7 @@ public override ValueTask WriteAsync(T item, CancellationToken cancellationToken
}

/// <summary>Gets the object used to synchronize access to all state on this instance.</summary>
private object SyncObj => _items;
private object SyncObj => _blockedReaders;

[Conditional("DEBUG")]
private void AssertInvariants()
Expand Down
Loading

0 comments on commit b965caf

Please sign in to comment.