Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make (Ordered)ThreadSafeQueue safer und easier to use #1023

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 110 additions & 20 deletions src/util/ThreadSafeQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
#include <mutex>
#include <optional>
#include <queue>
#include <ranges>

#include "absl/cleanup/cleanup.h"
#include "util/Exception.h"
#include "util/Generator.h"
#include "util/jthread.h"

namespace ad_utility::data_structures {

Expand All @@ -20,10 +26,15 @@ class ThreadSafeQueue {
std::mutex mutex_;
std::condition_variable pushNotification_;
std::condition_variable popNotification_;
bool finish_ = false;
// Note: Although this class is generally synchronized via `std::mutex`, we
// still use `std::atomic` for the information whether it has finished. This
// allows the `finish()` function to be noexcept which allows a safe way to
// prevent deadlocks.
std::atomic_flag finish_ = ATOMIC_FLAG_INIT;
size_t maxSize_;

public:
using value_type = T;
explicit ThreadSafeQueue(size_t maxSize) : maxSize_{maxSize} {}

// We can neither copy nor move this class
Expand All @@ -39,8 +50,8 @@ class ThreadSafeQueue {
bool push(T value) {
std::unique_lock lock{mutex_};
popNotification_.wait(
lock, [this] { return queue_.size() < maxSize_ || finish_; });
if (finish_) {
lock, [this] { return queue_.size() < maxSize_ || finish_.test(); });
if (finish_.test()) {
return false;
}
queue_.push(std::move(value));
Expand All @@ -55,7 +66,7 @@ class ThreadSafeQueue {
void pushException(std::exception_ptr exception) {
std::unique_lock lock{mutex_};
pushedException_ = std::move(exception);
finish_ = true;
finish_.test_and_set();
lock.unlock();
pushNotification_.notify_all();
popNotification_.notify_all();
Expand All @@ -68,10 +79,11 @@ class ThreadSafeQueue {
// function can be called from the producing/pushing threads to signal that
// all elements have been pushed, or from the consumers to signal that they
// will not pop further elements from the queue.
void finish() {
std::unique_lock lock{mutex_};
finish_ = true;
lock.unlock();
void finish() noexcept {
// It is crucial that this function never throws, so that we can safely call
// it unconditionally in destructors to prevent deadlocks. Should the
// implementation ever change, make sure that it is still `noexcept`.
finish_.test_and_set();
pushNotification_.notify_all();
popNotification_.notify_all();
}
Expand All @@ -88,12 +100,12 @@ class ThreadSafeQueue {
std::optional<T> pop() {
std::unique_lock lock{mutex_};
pushNotification_.wait(lock, [this] {
return !queue_.empty() || finish_ || pushedException_;
return !queue_.empty() || finish_.test() || pushedException_;
});
if (pushedException_) {
std::rethrow_exception(pushedException_);
}
if (finish_ && queue_.empty()) {
if (finish_.test() && queue_.empty()) {
return {};
}
std::optional<T> value = std::move(queue_.front());
Expand All @@ -120,9 +132,12 @@ class OrderedThreadSafeQueue {
std::condition_variable cv_;
ThreadSafeQueue<T> queue_;
size_t nextIndex_ = 0;
bool finish_ = false;
// For the reason why this is `atomic_flag`, see the same member in
// `ThreadSafeQueue`.
std::atomic_flag finish_ = ATOMIC_FLAG_INIT;

public:
using value_type = T;
// Construct from the maximal queue size (see `ThreadSafeQueue` for details).
explicit OrderedThreadSafeQueue(size_t maxSize) : queue_{maxSize} {}

Expand All @@ -139,8 +154,9 @@ class OrderedThreadSafeQueue {
// equal to `ThreadSafeQueue::push`.
bool push(size_t index, T value) {
std::unique_lock lock{mutex_};
cv_.wait(lock, [this, index]() { return index == nextIndex_ || finish_; });
if (finish_) {
cv_.wait(lock,
[this, index]() { return index == nextIndex_ || finish_.test(); });
if (finish_.test()) {
return false;
}
++nextIndex_;
Expand All @@ -150,21 +166,26 @@ class OrderedThreadSafeQueue {
return result;
}

// Same as the function above, but the two arguments are passed in as a
// `std::pair`.
bool push(std::pair<size_t, T> indexAndValue) {
return push(indexAndValue.first, std::move(indexAndValue.second));
}

// See `ThreadSafeQueue` for details.
void pushException(std::exception_ptr exception) {
std::unique_lock l{mutex_};
queue_.pushException(std::move(exception));
finish_ = true;
l.unlock();
finish_.test_and_set();
cv_.notify_all();
}

// See `ThreadSafeQueue` for details.
void finish() {
void finish() noexcept {
// It is crucial that this function never throws, so that we can safely call
// it unconditionally in destructors to prevent deadlocks. Should the
// implementation ever change, make sure that it is still `noexcept`.
queue_.finish();
std::unique_lock lock{mutex_};
finish_ = true;
lock.unlock();
finish_.test_and_set();
cv_.notify_all();
}

Expand All @@ -176,4 +197,73 @@ class OrderedThreadSafeQueue {
std::optional<T> pop() { return queue_.pop(); }
};

// A concept for one of the thread-safe queue types above
template <typename T>
concept IsThreadsafeQueue =
ad_utility::similarToInstantiation<T, ThreadSafeQueue> ||
ad_utility::similarToInstantiation<T, OrderedThreadSafeQueue>;

namespace detail {
// A helper function for setting up a producer for one of the threadsafe
// queues above. Takes a reference to a queue and a `producer`. The producer
// must return `std::optional<somethingThatCanBePushedToTheQueue>`. The producer
// is called repeatedly, and the resulting values are pushed to the queue. If
// the producer returns `nullopt`, `numThreads` is decremented, and the queue is
// finished if `numThreads <= 0`. All exceptions that happen during the
// execution of `producer` are propagated to the queue.
template <IsThreadsafeQueue Queue, std::invocable Producer>
auto makeQueueTask(Queue& queue, Producer producer,
std::atomic<int64_t>& numThreads) {
return [&queue, producer = std::move(producer), &numThreads] {
try {
while (auto opt = producer()) {
if (!queue.push(std::move(opt.value()))) {
break;
}
}
} catch (...) {
try {
queue.pushException(std::current_exception());
} catch (...) {
queue.finish();
}
}
--numThreads;
if (numThreads <= 0) {
queue.finish();
}
};
}
} // namespace detail

// This helper function makes the usage of the (Ordered)ThreadSafeQueue above
// much easier. It takes the size of the queue, the number of producer threads,
// and a `producer` (a callable that produces values). The `producer` is called
// repeatedly in `numThreads` many concurrent threads. It needs to return
// `std::optional<SomethingThatCanBePushedToTheQueue>` and has the following
// semantics: If `nullopt` is returned, then the thread is finished. The queue
// is finished, when all the producer threads have finished by yielding
// `nullopt`, or if any call to `producer` in any thread throws an
// exception. In that case the exception is propagated to the resulting
// generator. The resulting generator yields all the values that have been
// pushed to the queue.
template <typename Queue>
cppcoro::generator<typename Queue::value_type> queueManager(size_t queueSize,
size_t numThreads,
auto producer) {
Queue queue{queueSize};
AD_CONTRACT_CHECK(numThreads > 0u);
std::vector<ad_utility::JThread> threads;
std::atomic<int64_t> numUnfinishedThreads{static_cast<int64_t>(numThreads)};
absl::Cleanup queueFinisher{[&queue] { queue.finish(); }};
for ([[maybe_unused]] auto i : std::views::iota(0u, numThreads)) {
threads.emplace_back(
detail::makeQueueTask(queue, producer, numUnfinishedThreads));
}

while (auto opt = queue.pop()) {
co_yield (opt.value());
}
}

} // namespace ad_utility::data_structures
167 changes: 167 additions & 0 deletions test/ThreadSafeQueueTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <atomic>
#include <ranges>

#include "./util/GTestHelpers.h"
#include "absl/cleanup/cleanup.h"
#include "util/ThreadSafeQueue.h"
#include "util/TypeTraits.h"
#include "util/jthread.h"
Expand All @@ -31,6 +33,20 @@ auto makePush(Queue& queue) {
};
}

// Similar to `makePush` above, but the returned lambda doesn't push directly to
// the queue, but simply returns a value that can then be pushed to the queue.
// This is useful when testing the `queueManager` template.
template <typename Queue>
auto makeQueueValue() {
return [](size_t i) {
if constexpr (ad_utility::similarToInstantiation<Queue, ThreadSafeQueue>) {
return i;
} else {
return std::pair{i, i};
}
};
}

// Some constants that are used in almost every test case.
constexpr size_t queueSize = 5;
constexpr size_t numThreads = 20;
Expand Down Expand Up @@ -248,3 +264,154 @@ TEST(ThreadSafeQueue, DisablePush) {
};
runWithBothQueueTypes(runTest);
}

// Demonstrate the safe way to handle exceptions and early destruction in the
// worker threads as well as in the consumer threads. By `safe` we mean that the
// program is neither terminated nor does it run into a deadlock.
TEST(ThreadSafeQueue, SafeExceptionHandling) {
auto runTest = []<typename Queue>(bool workerThrows, Queue&& queue) {
auto throwingProcedure = [&]() {
auto threadFunction = [&queue, workerThrows] {
try {
auto push = makePush(queue);
size_t numPushed = 0;
// We have to finish the threadas soon as `push` returns false.
while (push(numPushed++)) {
// Manually throw an exception if `workerThrows` was specified.
if (numPushed >= numValues / 2 && workerThrows) {
throw std::runtime_error{"Producer died"};
}
}
} catch (...) {
// We have to catch all exceptions in the worker thread(s), otherwise
// the program will immediately terminate. When there was an exception
// and the queue still expects results from this worker thread
// (especially if the queue is ordered), we have to finish the queue.
// If we just call `finish` then the producer will see a noop when
// popping from the queue. When we use `pushException` the call to
// `pop` will rethrow the exception.
try {
// In theory, `pushException` might throw if something goes really
// wrong with the underlying mutex. In practice this should never
// happen, but we demonstrate the really safe way here.
queue.pushException(std::current_exception());
} catch (...) {
// `finish()` can never fail.
queue.finish();
}
}
};
ad_utility::JThread thread{threadFunction};
// This cleanup is important in case the consumer throws an exception. We
// then first have to `finish` the queue, s.t. the producer threads can
// join. We then can join and destroy the worker threads and finally
// destroy the queue. So the order of declaration is important:
// 1. Queue, 2. WorkerThreads, 3. `Cleanup` that finishes the queue.
absl::Cleanup cleanup{[&queue] { queue.finish(); }};

for ([[maybe_unused]] auto i : std::views::iota(0u, numValues)) {
auto opt = queue.pop();
if (!opt) {
return;
}
}
// When throwing, the `Cleanup` calls `finish` and the producers can run
// to completion because their calls to `push` will return false.
throw std::runtime_error{"Consumer died"};
};
if (workerThrows) {
AD_EXPECT_THROW_WITH_MESSAGE(throwingProcedure(),
::testing::StartsWith("Producer"));
} else {
AD_EXPECT_THROW_WITH_MESSAGE(throwingProcedure(),
::testing::StartsWith("Consumer"));
}
};
runWithBothQueueTypes(std::bind_front(runTest, true));
runWithBothQueueTypes(std::bind_front(runTest, false));
}

// ________________________________________________________________
TEST(ThreadSafeQueue, queueManager) {
enum class TestType {
producerThrows,
consumerThrows,
normalExecution,
consumerFinishesEarly,
bothThrowImmediately
};
auto runTest = []<typename Queue>(TestType testType, Queue&&) {
std::atomic<size_t> numPushed = 0;
auto task =
[&numPushed,
&testType]() -> std::optional<decltype(makeQueueValue<Queue>()(3))> {
auto makeValue = makeQueueValue<Queue>();
if (testType == TestType::bothThrowImmediately) {
throw std::runtime_error{"Producer"};
}
auto value = numPushed++;
if (testType == TestType::producerThrows && value > numValues / 2) {
throw std::runtime_error{"Producer"};
}
if (value < numValues) {
return makeValue(value);
} else {
return std::nullopt;
}
};
std::vector<size_t> result;
size_t numPopped = 0;
try {
if (testType == TestType::bothThrowImmediately) {
throw std::runtime_error{"Consumer"};
}
for (size_t value : queueManager<Queue>(queueSize, numThreads, task)) {
++numPopped;
if (numPopped > numValues / 3) {
if (testType == TestType::consumerThrows) {
throw std::runtime_error{"Consumer"};
} else if (testType == TestType::consumerFinishesEarly) {
break;
}
}
result.push_back(value);
EXPECT_LE(numPushed, numPopped + queueSize + 1 + numThreads);
}
if (testType == TestType::consumerThrows ||
testType == TestType::producerThrows) {
FAIL() << "Should have thrown" << static_cast<unsigned>(testType);
}
} catch (const std::runtime_error& e) {
if (testType == TestType::consumerThrows ||
testType == TestType::bothThrowImmediately) {
EXPECT_STREQ(e.what(), "Consumer");
} else if (testType == TestType::producerThrows) {
EXPECT_STREQ(e.what(), "Producer");
} else {
FAIL() << "Should not have thrown";
}
}

if (testType == TestType::consumerFinishesEarly) {
EXPECT_EQ(result.size(), numValues / 3);
} else if (testType == TestType::normalExecution) {
EXPECT_EQ(result.size(), numValues);
// For the `OrderedThreadSafeQueue` we expect the result to already be in
// order, for the `ThreadSafeQueue` the order is unspecified and we only
// check the content.
if (ad_utility::isInstantiation<Queue, ThreadSafeQueue>) {
std::ranges::sort(result);
}
EXPECT_THAT(result, ::testing::ElementsAreArray(
std::views::iota(0UL, numValues)));
}
// The probably most important test of all is that the destructors which are
// run at the following closing brace never lead to a deadlock.
};
using enum TestType;
runWithBothQueueTypes(std::bind_front(runTest, consumerThrows));
runWithBothQueueTypes(std::bind_front(runTest, producerThrows));
runWithBothQueueTypes(std::bind_front(runTest, consumerFinishesEarly));
runWithBothQueueTypes(std::bind_front(runTest, normalExecution));
runWithBothQueueTypes(std::bind_front(runTest, bothThrowImmediately));
}