Skip to content

Commit

Permalink
Make (Ordered)ThreadSafeQueue safer und easier to use (#1023)
Browse files Browse the repository at this point in the history
In particular, the `finish` method is now `noexcept`. To use the queue, use the function `queueManager`, which takes a queue size, the number of threads, and a producer function, and returns a value generator.
  • Loading branch information
joka921 authored Jul 11, 2023
1 parent 104cf1a commit b57869e
Show file tree
Hide file tree
Showing 2 changed files with 277 additions and 20 deletions.
130 changes: 110 additions & 20 deletions src/util/ThreadSafeQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
#include <mutex>
#include <optional>
#include <queue>
#include <ranges>

#include "absl/cleanup/cleanup.h"
#include "util/Exception.h"
#include "util/Generator.h"
#include "util/jthread.h"

namespace ad_utility::data_structures {

Expand All @@ -20,10 +26,15 @@ class ThreadSafeQueue {
std::mutex mutex_;
std::condition_variable pushNotification_;
std::condition_variable popNotification_;
bool finish_ = false;
// Note: Although this class is generally synchronized via `std::mutex`, we
// still use `std::atomic` for the information whether it has finished. This
// allows the `finish()` function to be noexcept which allows a safe way to
// prevent deadlocks.
std::atomic_flag finish_ = ATOMIC_FLAG_INIT;
size_t maxSize_;

public:
using value_type = T;
explicit ThreadSafeQueue(size_t maxSize) : maxSize_{maxSize} {}

// We can neither copy nor move this class
Expand All @@ -39,8 +50,8 @@ class ThreadSafeQueue {
bool push(T value) {
std::unique_lock lock{mutex_};
popNotification_.wait(
lock, [this] { return queue_.size() < maxSize_ || finish_; });
if (finish_) {
lock, [this] { return queue_.size() < maxSize_ || finish_.test(); });
if (finish_.test()) {
return false;
}
queue_.push(std::move(value));
Expand All @@ -55,7 +66,7 @@ class ThreadSafeQueue {
void pushException(std::exception_ptr exception) {
std::unique_lock lock{mutex_};
pushedException_ = std::move(exception);
finish_ = true;
finish_.test_and_set();
lock.unlock();
pushNotification_.notify_all();
popNotification_.notify_all();
Expand All @@ -68,10 +79,11 @@ class ThreadSafeQueue {
// function can be called from the producing/pushing threads to signal that
// all elements have been pushed, or from the consumers to signal that they
// will not pop further elements from the queue.
void finish() {
std::unique_lock lock{mutex_};
finish_ = true;
lock.unlock();
void finish() noexcept {
// It is crucial that this function never throws, so that we can safely call
// it unconditionally in destructors to prevent deadlocks. Should the
// implementation ever change, make sure that it is still `noexcept`.
finish_.test_and_set();
pushNotification_.notify_all();
popNotification_.notify_all();
}
Expand All @@ -88,12 +100,12 @@ class ThreadSafeQueue {
std::optional<T> pop() {
std::unique_lock lock{mutex_};
pushNotification_.wait(lock, [this] {
return !queue_.empty() || finish_ || pushedException_;
return !queue_.empty() || finish_.test() || pushedException_;
});
if (pushedException_) {
std::rethrow_exception(pushedException_);
}
if (finish_ && queue_.empty()) {
if (finish_.test() && queue_.empty()) {
return {};
}
std::optional<T> value = std::move(queue_.front());
Expand All @@ -120,9 +132,12 @@ class OrderedThreadSafeQueue {
std::condition_variable cv_;
ThreadSafeQueue<T> queue_;
size_t nextIndex_ = 0;
bool finish_ = false;
// For the reason why this is `atomic_flag`, see the same member in
// `ThreadSafeQueue`.
std::atomic_flag finish_ = ATOMIC_FLAG_INIT;

public:
using value_type = T;
// Construct from the maximal queue size (see `ThreadSafeQueue` for details).
explicit OrderedThreadSafeQueue(size_t maxSize) : queue_{maxSize} {}

Expand All @@ -139,8 +154,9 @@ class OrderedThreadSafeQueue {
// equal to `ThreadSafeQueue::push`.
bool push(size_t index, T value) {
std::unique_lock lock{mutex_};
cv_.wait(lock, [this, index]() { return index == nextIndex_ || finish_; });
if (finish_) {
cv_.wait(lock,
[this, index]() { return index == nextIndex_ || finish_.test(); });
if (finish_.test()) {
return false;
}
++nextIndex_;
Expand All @@ -150,21 +166,26 @@ class OrderedThreadSafeQueue {
return result;
}

// Same as the function above, but the two arguments are passed in as a
// `std::pair`.
bool push(std::pair<size_t, T> indexAndValue) {
return push(indexAndValue.first, std::move(indexAndValue.second));
}

// See `ThreadSafeQueue` for details.
void pushException(std::exception_ptr exception) {
std::unique_lock l{mutex_};
queue_.pushException(std::move(exception));
finish_ = true;
l.unlock();
finish_.test_and_set();
cv_.notify_all();
}

// See `ThreadSafeQueue` for details.
void finish() {
void finish() noexcept {
// It is crucial that this function never throws, so that we can safely call
// it unconditionally in destructors to prevent deadlocks. Should the
// implementation ever change, make sure that it is still `noexcept`.
queue_.finish();
std::unique_lock lock{mutex_};
finish_ = true;
lock.unlock();
finish_.test_and_set();
cv_.notify_all();
}

Expand All @@ -176,4 +197,73 @@ class OrderedThreadSafeQueue {
std::optional<T> pop() { return queue_.pop(); }
};

// A concept for one of the thread-safe queue types above
template <typename T>
concept IsThreadsafeQueue =
ad_utility::similarToInstantiation<T, ThreadSafeQueue> ||
ad_utility::similarToInstantiation<T, OrderedThreadSafeQueue>;

namespace detail {
// A helper function for setting up a producer for one of the threadsafe
// queues above. Takes a reference to a queue and a `producer`. The producer
// must return `std::optional<somethingThatCanBePushedToTheQueue>`. The producer
// is called repeatedly, and the resulting values are pushed to the queue. If
// the producer returns `nullopt`, `numThreads` is decremented, and the queue is
// finished if `numThreads <= 0`. All exceptions that happen during the
// execution of `producer` are propagated to the queue.
template <IsThreadsafeQueue Queue, std::invocable Producer>
auto makeQueueTask(Queue& queue, Producer producer,
std::atomic<int64_t>& numThreads) {
return [&queue, producer = std::move(producer), &numThreads] {
try {
while (auto opt = producer()) {
if (!queue.push(std::move(opt.value()))) {
break;
}
}
} catch (...) {
try {
queue.pushException(std::current_exception());
} catch (...) {
queue.finish();
}
}
--numThreads;
if (numThreads <= 0) {
queue.finish();
}
};
}
} // namespace detail

// This helper function makes the usage of the (Ordered)ThreadSafeQueue above
// much easier. It takes the size of the queue, the number of producer threads,
// and a `producer` (a callable that produces values). The `producer` is called
// repeatedly in `numThreads` many concurrent threads. It needs to return
// `std::optional<SomethingThatCanBePushedToTheQueue>` and has the following
// semantics: If `nullopt` is returned, then the thread is finished. The queue
// is finished, when all the producer threads have finished by yielding
// `nullopt`, or if any call to `producer` in any thread throws an
// exception. In that case the exception is propagated to the resulting
// generator. The resulting generator yields all the values that have been
// pushed to the queue.
template <typename Queue>
cppcoro::generator<typename Queue::value_type> queueManager(size_t queueSize,
size_t numThreads,
auto producer) {
Queue queue{queueSize};
AD_CONTRACT_CHECK(numThreads > 0u);
std::vector<ad_utility::JThread> threads;
std::atomic<int64_t> numUnfinishedThreads{static_cast<int64_t>(numThreads)};
absl::Cleanup queueFinisher{[&queue] { queue.finish(); }};
for ([[maybe_unused]] auto i : std::views::iota(0u, numThreads)) {
threads.emplace_back(
detail::makeQueueTask(queue, producer, numUnfinishedThreads));
}

while (auto opt = queue.pop()) {
co_yield (opt.value());
}
}

} // namespace ad_utility::data_structures
167 changes: 167 additions & 0 deletions test/ThreadSafeQueueTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <atomic>
#include <ranges>

#include "./util/GTestHelpers.h"
#include "absl/cleanup/cleanup.h"
#include "util/ThreadSafeQueue.h"
#include "util/TypeTraits.h"
#include "util/jthread.h"
Expand All @@ -31,6 +33,20 @@ auto makePush(Queue& queue) {
};
}

// Similar to `makePush` above, but the returned lambda doesn't push directly to
// the queue, but simply returns a value that can then be pushed to the queue.
// This is useful when testing the `queueManager` template.
template <typename Queue>
auto makeQueueValue() {
return [](size_t i) {
if constexpr (ad_utility::similarToInstantiation<Queue, ThreadSafeQueue>) {
return i;
} else {
return std::pair{i, i};
}
};
}

// Some constants that are used in almost every test case.
constexpr size_t queueSize = 5;
constexpr size_t numThreads = 20;
Expand Down Expand Up @@ -248,3 +264,154 @@ TEST(ThreadSafeQueue, DisablePush) {
};
runWithBothQueueTypes(runTest);
}

// Demonstrate the safe way to handle exceptions and early destruction in the
// worker threads as well as in the consumer threads. By `safe` we mean that the
// program is neither terminated nor does it run into a deadlock.
TEST(ThreadSafeQueue, SafeExceptionHandling) {
auto runTest = []<typename Queue>(bool workerThrows, Queue&& queue) {
auto throwingProcedure = [&]() {
auto threadFunction = [&queue, workerThrows] {
try {
auto push = makePush(queue);
size_t numPushed = 0;
// We have to finish the threadas soon as `push` returns false.
while (push(numPushed++)) {
// Manually throw an exception if `workerThrows` was specified.
if (numPushed >= numValues / 2 && workerThrows) {
throw std::runtime_error{"Producer died"};
}
}
} catch (...) {
// We have to catch all exceptions in the worker thread(s), otherwise
// the program will immediately terminate. When there was an exception
// and the queue still expects results from this worker thread
// (especially if the queue is ordered), we have to finish the queue.
// If we just call `finish` then the producer will see a noop when
// popping from the queue. When we use `pushException` the call to
// `pop` will rethrow the exception.
try {
// In theory, `pushException` might throw if something goes really
// wrong with the underlying mutex. In practice this should never
// happen, but we demonstrate the really safe way here.
queue.pushException(std::current_exception());
} catch (...) {
// `finish()` can never fail.
queue.finish();
}
}
};
ad_utility::JThread thread{threadFunction};
// This cleanup is important in case the consumer throws an exception. We
// then first have to `finish` the queue, s.t. the producer threads can
// join. We then can join and destroy the worker threads and finally
// destroy the queue. So the order of declaration is important:
// 1. Queue, 2. WorkerThreads, 3. `Cleanup` that finishes the queue.
absl::Cleanup cleanup{[&queue] { queue.finish(); }};

for ([[maybe_unused]] auto i : std::views::iota(0u, numValues)) {
auto opt = queue.pop();
if (!opt) {
return;
}
}
// When throwing, the `Cleanup` calls `finish` and the producers can run
// to completion because their calls to `push` will return false.
throw std::runtime_error{"Consumer died"};
};
if (workerThrows) {
AD_EXPECT_THROW_WITH_MESSAGE(throwingProcedure(),
::testing::StartsWith("Producer"));
} else {
AD_EXPECT_THROW_WITH_MESSAGE(throwingProcedure(),
::testing::StartsWith("Consumer"));
}
};
runWithBothQueueTypes(std::bind_front(runTest, true));
runWithBothQueueTypes(std::bind_front(runTest, false));
}

// ________________________________________________________________
TEST(ThreadSafeQueue, queueManager) {
enum class TestType {
producerThrows,
consumerThrows,
normalExecution,
consumerFinishesEarly,
bothThrowImmediately
};
auto runTest = []<typename Queue>(TestType testType, Queue&&) {
std::atomic<size_t> numPushed = 0;
auto task =
[&numPushed,
&testType]() -> std::optional<decltype(makeQueueValue<Queue>()(3))> {
auto makeValue = makeQueueValue<Queue>();
if (testType == TestType::bothThrowImmediately) {
throw std::runtime_error{"Producer"};
}
auto value = numPushed++;
if (testType == TestType::producerThrows && value > numValues / 2) {
throw std::runtime_error{"Producer"};
}
if (value < numValues) {
return makeValue(value);
} else {
return std::nullopt;
}
};
std::vector<size_t> result;
size_t numPopped = 0;
try {
if (testType == TestType::bothThrowImmediately) {
throw std::runtime_error{"Consumer"};
}
for (size_t value : queueManager<Queue>(queueSize, numThreads, task)) {
++numPopped;
if (numPopped > numValues / 3) {
if (testType == TestType::consumerThrows) {
throw std::runtime_error{"Consumer"};
} else if (testType == TestType::consumerFinishesEarly) {
break;
}
}
result.push_back(value);
EXPECT_LE(numPushed, numPopped + queueSize + 1 + numThreads);
}
if (testType == TestType::consumerThrows ||
testType == TestType::producerThrows) {
FAIL() << "Should have thrown" << static_cast<unsigned>(testType);
}
} catch (const std::runtime_error& e) {
if (testType == TestType::consumerThrows ||
testType == TestType::bothThrowImmediately) {
EXPECT_STREQ(e.what(), "Consumer");
} else if (testType == TestType::producerThrows) {
EXPECT_STREQ(e.what(), "Producer");
} else {
FAIL() << "Should not have thrown";
}
}

if (testType == TestType::consumerFinishesEarly) {
EXPECT_EQ(result.size(), numValues / 3);
} else if (testType == TestType::normalExecution) {
EXPECT_EQ(result.size(), numValues);
// For the `OrderedThreadSafeQueue` we expect the result to already be in
// order, for the `ThreadSafeQueue` the order is unspecified and we only
// check the content.
if (ad_utility::isInstantiation<Queue, ThreadSafeQueue>) {
std::ranges::sort(result);
}
EXPECT_THAT(result, ::testing::ElementsAreArray(
std::views::iota(0UL, numValues)));
}
// The probably most important test of all is that the destructors which are
// run at the following closing brace never lead to a deadlock.
};
using enum TestType;
runWithBothQueueTypes(std::bind_front(runTest, consumerThrows));
runWithBothQueueTypes(std::bind_front(runTest, producerThrows));
runWithBothQueueTypes(std::bind_front(runTest, consumerFinishesEarly));
runWithBothQueueTypes(std::bind_front(runTest, normalExecution));
runWithBothQueueTypes(std::bind_front(runTest, bothThrowImmediately));
}

0 comments on commit b57869e

Please sign in to comment.