Skip to content

Commit

Permalink
Ensure ConcurrentBag's TryTake is linearizable (dotnet#30947)
Browse files Browse the repository at this point in the history
For .NET Core 2.0, I ported the ThreadPool's work-stealing implementation to ConcurrentBag, leading to significant performance throughput and allocation improvements.  However, there's a subtle difference in the concurrency guarantees the ThreadPool's implementation provided from what ConcurrentBag needs, which ends up breaking certain usage patterns on top of ConcurrentBag.

Specifically, ThreadPool's "steal" implementation need not be fully linearizable.  It's possible for a thread to see the bag's count as 1, and then while the thread is doing a take/steal for its count to never drop below 1, but for the steal to still fail, even though there was always an item available.  This is ok for the thread pool because it manages a known count of work items in the queues separately, and if it sees that there are still items available after a steal has failed, it'll try again.  That "try again" logic provided above the work-stealing queue thus didn't make it over to ConcurrentBag, which breaks some usages of ConcurrentBag, in particular cases where a type like BlockingCollection is wrapping the bag and managing its own count.  It's possible now for BlockingCollection to know that there's an item in the bag but to then fail to take it, which causes problems such as exceptions being thrown.

The fix is to port back the relevant portion of ConcurrentBag from .NET Core 1.x / .NET Framework, where local push operations on a list track the number of times the list transitions from empty to non-empty.  A steal operation then looks at those counts prior to doing the steal, and if the steal fails, it looks again after: if the count has increased, it retries.  This unfortunately means that local pushes on small lists are now more expensive than in .NET Core 2.0/2.1, as if there are <= 2 items in the list, it takes the lock, but this seems unavoidable given the work-stealing design.
  • Loading branch information
stephentoub committed Jul 12, 2018
1 parent 44454c1 commit 9f19219
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ namespace System.Collections.Concurrent
public class ConcurrentBag<T> : IProducerConsumerCollection<T>, IReadOnlyCollection<T>
{
/// <summary>The per-bag, per-thread work-stealing queues.</summary>
private ThreadLocal<WorkStealingQueue> _locals;
private readonly ThreadLocal<WorkStealingQueue> _locals;
/// <summary>The head work stealing queue in a linked list of queues.</summary>
private volatile WorkStealingQueue _workStealingQueues;
/// <summary>Number of times any list transitions from empty to non-empty.</summary>
private long _emptyToNonEmptyListTransitionCount;

/// <summary>Initializes a new instance of the <see cref="ConcurrentBag{T}"/> class.</summary>
public ConcurrentBag()
Expand Down Expand Up @@ -62,7 +64,7 @@ public ConcurrentBag(IEnumerable<T> collection)
WorkStealingQueue queue = GetCurrentThreadWorkStealingQueue(forceCreate: true);
foreach (T item in collection)
{
queue.LocalPush(item);
queue.LocalPush(item, ref _emptyToNonEmptyListTransitionCount);
}
}

Expand All @@ -72,7 +74,9 @@ public ConcurrentBag(IEnumerable<T> collection)
/// <param name="item">The object to be added to the
/// <see cref="ConcurrentBag{T}"/>. The value can be a null reference
/// (Nothing in Visual Basic) for reference types.</param>
public void Add(T item) => GetCurrentThreadWorkStealingQueue(forceCreate: true).LocalPush(item);
public void Add(T item) =>
GetCurrentThreadWorkStealingQueue(forceCreate: true)
.LocalPush(item, ref _emptyToNonEmptyListTransitionCount);

/// <summary>
/// Attempts to add an object to the <see cref="ConcurrentBag{T}"/>.
Expand Down Expand Up @@ -176,22 +180,55 @@ private bool TrySteal(out T result, bool take)
CDSCollectionETWBCLProvider.Log.ConcurrentBag_TryPeekSteals();
}

// If there's no local queue for this thread, just start from the head queue
// and try to steal from each queue until we get a result.
WorkStealingQueue localQueue = GetCurrentThreadWorkStealingQueue(forceCreate: false);
if (localQueue == null)
while (true)
{
return TryStealFromTo(_workStealingQueues, null, out result, take);
}
// We need to track whether any lists transition from empty to non-empty both before
// and after we attempt the steal in case we don't get an item:
//
// If we don't get an item, we need to handle the possibility of a race condition that led to
// an item being added to a list after we already looked at it in a way that breaks
// linearizability. For example, say there are three threads 0, 1, and 2, each with their own
// list that's currently empty. We could then have the following series of operations:
// - Thread 2 adds an item, such that there's now 1 item in the bag.
// - Thread 1 sees that the count is 1 and does a Take. Its local list is empty, so it tries to
// steal from list 0, but it's empty. Before it can steal from Thread 2, it's pre-empted.
// - Thread 0 adds an item. The count is now 2.
// - Thread 2 takes an item, which comes from its local queue. The count is now 1.
// - Thread 1 continues to try to steal from 2, finds it's empty, and fails its take, even though
// at any given time during its take the count was >= 1. Oops.
// This is particularly problematic for wrapper types that track count using their own synchronization,
// e.g. BlockingCollection, and thus expect that a take will always be successful if the number of items
// is known to be > 0.
//
// We work around this by looking at the number of times any list transitions from == 0 to > 0,
// checking that before and after the steal attempts. We don't care about > 0 to > 0 transitions,
// because a steal from a list with > 0 elements would have been successful.
long initialEmptyToNonEmptyCounts = Interlocked.Read(ref _emptyToNonEmptyListTransitionCount);

// If there's no local queue for this thread, just start from the head queue
// and try to steal from each queue until we get a result. If there is a local queue from this thread,
// then start from the next queue after it, and then iterate around back from the head to this queue,
// not including it.
WorkStealingQueue localQueue = GetCurrentThreadWorkStealingQueue(forceCreate: false);
bool gotItem = localQueue == null ?
TryStealFromTo(_workStealingQueues, null, out result, take) :
(TryStealFromTo(localQueue._nextQueue, null, out result, take) || TryStealFromTo(_workStealingQueues, localQueue, out result, take));
if (gotItem)
{
return true;
}

// If there is a local queue from this thread, then start from the next queue
// after it, and then iterate around back from the head to this queue, not including it.
return
TryStealFromTo(localQueue._nextQueue, null, out result, take) ||
TryStealFromTo(_workStealingQueues, localQueue, out result, take);
if (Interlocked.Read(ref _emptyToNonEmptyListTransitionCount) == initialEmptyToNonEmptyCounts)
{
// The version number matched, so we didn't get an item and we're confident enough
// in our steal attempt to say so.
return false;
}

// TODO: Investigate storing the queues in an array instead of a linked list, and then
// randomly choosing a starting location from which to start iterating.
// Some list transitioned from empty to non-empty between just before the steal and now.
// Since we don't know if it caused a race condition like the above description, we
// have little choice but to try to steal again.
}
}

/// <summary>
Expand Down Expand Up @@ -684,7 +721,7 @@ internal bool IsEmpty
/// Add new item to the tail of the queue.
/// </summary>
/// <param name="item">The item to add.</param>
internal void LocalPush(T item)
internal void LocalPush(T item, ref long emptyToNonEmptyListTransitionCount)
{
Debug.Assert(Environment.CurrentManagedThreadId == _ownerThreadId);
bool lockTaken = false;
Expand All @@ -701,7 +738,7 @@ internal void LocalPush(T item)
_currentOp = (int)Operation.None; // set back to None temporarily to avoid a deadlock
lock (this)
{
Debug.Assert(_tailIndex == int.MaxValue, "No other thread should be changing _tailIndex");
Debug.Assert(_tailIndex == tail, "No other thread should be changing _tailIndex");

// Rather than resetting to zero, we'll just mask off the bits we don't care about.
// This way we don't need to rearrange the items already in the queue; they'll be found
Expand All @@ -711,22 +748,31 @@ internal void LocalPush(T item)
// bits are set, so all of the bits we're keeping will also be set. Thus it's impossible
// for the head to end up > than the tail, since you can't set any more bits than all of them.
_headIndex = _headIndex & _mask;
_tailIndex = tail = _tailIndex & _mask;
_tailIndex = tail = tail & _mask;
Debug.Assert(_headIndex <= _tailIndex);

_currentOp = (int)Operation.Add;
Interlocked.Exchange(ref _currentOp, (int)Operation.Add); // ensure subsequent reads aren't reordered before this
}
}

// We'd like to take the fast path that doesn't require locking, if possible. It's not possible if another
// thread is currently requesting that the whole bag synchronize, e.g. a ToArray operation. It's also
// not possible if there are fewer than two spaces available. One space is necessary for obvious reasons:
// to store the element we're trying to push. The other is necessary due to synchronization with steals.
// A stealing thread first increments _headIndex to reserve the slot at its old value, and then tries to
// read from that slot. We could potentially have a race condition whereby _headIndex is incremented just
// before this check, in which case we could overwrite the element being stolen as that slot would appear
// to be empty. Thus, we only allow the fast path if there are two empty slots.
if (!_frozen && tail < (_headIndex + _mask))
// We'd like to take the fast path that doesn't require locking, if possible. It's not possible if:
// - another thread is currently requesting that the whole bag synchronize, e.g. a ToArray operation
// - if there are fewer than two spaces available. One space is necessary for obvious reasons:
// to store the element we're trying to push. The other is necessary due to synchronization with steals.
// A stealing thread first increments _headIndex to reserve the slot at its old value, and then tries to
// read from that slot. We could potentially have a race condition whereby _headIndex is incremented just
// before this check, in which case we could overwrite the element being stolen as that slot would appear
// to be empty. Thus, we only allow the fast path if there are two empty slots.
// - if there <= 1 elements in the list. We need to be able to successfully track transitions from
// empty to non-empty in a way that other threads can check, and we can only do that tracking
// correctly if we synchronize with steals when it's possible a steal could take the last item
// in the list just as we're adding. It's possible at this point that there's currently an active steal
// operation happening but that it hasn't yet incremented the head index, such that we could read a smaller
// than accurate by 1 value for the head. However, only one steal could possibly be doing so, as steals
// take the lock, and another steal couldn't then increment the header further because it'll see that
// there's currently an add operation in progress and wait until the add completes.
int head = _headIndex; // read after _currentOp set to Add
if (!_frozen && head < tail - 1 & tail < (head + _mask))
{
_array[tail & _mask] = item;
_tailIndex = tail + 1;
Expand All @@ -737,8 +783,8 @@ internal void LocalPush(T item)
_currentOp = (int)Operation.None; // set back to None to avoid a deadlock
Monitor.Enter(this, ref lockTaken);

int head = _headIndex;
int count = _tailIndex - _headIndex;
head = _headIndex;
int count = tail - head; // this count is stable, as we're holding the lock

// If we're full, expand the array.
if (count >= _mask)
Expand Down Expand Up @@ -767,6 +813,14 @@ internal void LocalPush(T item)
_array[tail & _mask] = item;
_tailIndex = tail + 1;

// Now that the item has been added, if we were at 0 (now at 1) item,
// increase the empty to non-empty transition count.
if (count == 0)
{
// We just transitioned from empty to non-empty, so increment the transition count.
Interlocked.Increment(ref emptyToNonEmptyListTransitionCount);
}

// Update the count to avoid overflow. We can trust _stealCount here,
// as we're inside the lock and it's only manipulated there.
_addTakeCount -= _stealCount;
Expand Down Expand Up @@ -908,41 +962,50 @@ internal bool TryLocalPeek(out T result)
/// <param name="take">true to take the item; false to simply peek at it</param>
internal bool TrySteal(out T result, bool take)
{
// Fast-path check to see if the queue is empty.
if (_headIndex < _tailIndex)
lock (this)
{
// Anything other than empty requires synchronization.
lock (this)
int head = _headIndex; // _headIndex is only manipulated under the lock
if (take)
{
int head = _headIndex;
if (take)
// If there are <= 2 items in the list, we need to ensure no add operation
// is in progress, as add operations need to accurately count transitions
// from empty to non-empty, and they can only do that if there are no concurrent
// steal operations happening at the time.
if (head < _tailIndex - 1 && _currentOp != (int)Operation.Add)
{
// Increment head to tentatively take an element: a full fence is used to ensure the read
// of _tailIndex doesn't move earlier, as otherwise we could potentially end up stealing
// the same element that's being popped locally.
Interlocked.Exchange(ref _headIndex, unchecked(head + 1));

// If there's an element to steal, do it.
if (head < _tailIndex)
var spinner = new SpinWait();
do
{
int idx = head & _mask;
result = _array[idx];
_array[idx] = default(T);
_stealCount++;
return true;
}
else
{
// We contended with the local thread and lost the race, so restore the head.
_headIndex = head;
spinner.SpinOnce();
}
while (_currentOp == (int)Operation.Add);
}
else if (head < _tailIndex)

// Increment head to tentatively take an element: a full fence is used to ensure the read
// of _tailIndex doesn't move earlier, as otherwise we could potentially end up stealing
// the same element that's being popped locally.
Interlocked.Exchange(ref _headIndex, unchecked(head + 1));

// If there's an element to steal, do it.
if (head < _tailIndex)
{
// Peek, if there's an element available
result = _array[head & _mask];
int idx = head & _mask;
result = _array[idx];
_array[idx] = default(T);
_stealCount++;
return true;
}
else
{
// We contended with the local thread and lost the race, so restore the head.
_headIndex = head;
}
}
else if (head < _tailIndex)
{
// Peek, if there's an element available
result = _array[head & _mask];
return true;
}
}

Expand Down
56 changes: 56 additions & 0 deletions src/System.Collections.Concurrent/tests/ConcurrentBagTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,62 @@ public static void AddManyItems_ThenTakeOnDifferentThread_ItemsOutputInExpectedO
}, CancellationToken.None, TaskCreationOptions.LongRunning, TaskScheduler.Default).GetAwaiter().GetResult();
}

[Fact]
public static void SingleProducerAdding_MultiConsumerTaking_SemaphoreThrottling_AllTakesSucceed()
{
var bag = new ConcurrentBag<int>();
var s = new SemaphoreSlim(0);
CountdownEvent ce = null;
const int ItemCount = 200_000;

int producerNextValue = 0;
Action producer = null;
producer = delegate
{
ThreadPool.QueueUserWorkItem(delegate
{
bag.Add(producerNextValue++);
s.Release();
if (producerNextValue < ItemCount)
{
producer();
}
else
{
ce.Signal();
}
});
};

int consumed = 0;
Action consumer = null;
consumer = delegate
{
ThreadPool.QueueUserWorkItem(delegate
{
if (s.Wait(0))
{
Assert.True(bag.TryTake(out _), "There's an item available, but we couldn't take it.");
Interlocked.Increment(ref consumed);
}
else if (Volatile.Read(ref consumed) >= ItemCount)
{
ce.Signal();
return;
}

consumer();
});
};

// one producer, two consumers
ce = new CountdownEvent(3);
producer();
consumer();
consumer();
ce.Wait();
}

[Theory]
[InlineData(0)]
[InlineData(1)]
Expand Down

0 comments on commit 9f19219

Please sign in to comment.