Skip to content

Commit

Permalink
Fix race condition for Thread Pool pending counter
Browse files Browse the repository at this point in the history
  • Loading branch information
lczech committed Jul 4, 2024
1 parent 597fb13 commit 588fd9b
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 66 deletions.
4 changes: 2 additions & 2 deletions lib/genesis/utils/containers/generic_input_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,8 @@ class GenericInputStream
assert( buffer_block_->size() == generator_->block_size_ );

// In order to use lambda captures by copy for class member variables in C++11, we first
// have to make local copies, and then capture those. Capturing the class members direclty
// was only introduced later. Bit cumbersome, but gets the job done.
// have to make local copies, and then capture those. Capturing the class members
// directly was only introduced later. Bit cumbersome, but gets the job done.
auto generator = generator_;
auto buffer_block = buffer_block_;
auto block_size = generator_->block_size_;
Expand Down
178 changes: 117 additions & 61 deletions lib/genesis/utils/threading/thread_pool.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -372,14 +372,13 @@ class ThreadPool
ThreadPool& operator= ( ThreadPool&& ) = delete;

/**
* @brief Destruct the thread pool,
* stopping and joining any workers that are potentially still running or waiting.
* @brief Destruct the thread pool, waiting for all unfinished tasks.
*/
~ThreadPool()
{
// Just in case, we wait for any unfinished work to be done, to avoid terminating when
// tasks are still doing work that needs to be finished.
wait();
// Just in case, we wait for any unfinished work to be done, to avoid terminating
// when tasks are still doing work that needs to be finished.
wait_for_all_pending_tasks();
assert( unfinished_tasks_.load() == 0 );

// Send the special stop task to the pool, once for each worker.
Expand Down Expand Up @@ -410,9 +409,22 @@ class ThreadPool
}

/**
* @brief Return the current number of queued tasks.
* @brief Return the current number of pending tasks.
*
* These is the number of tasks that have been enqueued, but not yet finished running.
* It hence includes both the number of waiting tasks and those that are currently being
* processed by a worker thread. The count is only reduced once a task is finished
* (or threw an exception). The counter can be used to wait for all enqueued tasks to be done,
* which is what wait_for_all_pending_tasks() does.
*
* Note that there is a very small window where it can happen that this counter is reduced
* after finishing the work of a task, but before setting the value of its associated promise.
* Hence, this counter might exclude a finished task for which the caller is still waiting
* for the future returned from the enqueue_and_retrieve() call. That should usually not be
* an issue though, as the caller will typically just wait for the future anyway, instead
* of checking this counter.
*/
size_t currently_enqueued_tasks() const
size_t pending_tasks_count() const
{
return unfinished_tasks_.load();
// return task_queue_.size_approx();
Expand Down Expand Up @@ -442,52 +454,34 @@ class ThreadPool
auto enqueue_and_retrieve( F&& f, Args&&... args )
-> ProactiveFuture<typename std::result_of<F(Args...)>::type>
{
// Check that we can enqueue a task at the moment, of if we need to wait and to work first.
// In a high-contention situation, this of course could fail, so that once the loop condition
// is checked, some other task already has finished the work. But that doesn't matter, the
// call to try_run_pending_task will catch that and just do nothing. Also, the other way round
// could happen, and the queue could in theory be overloaded if many threads try to enqueue
// at exactly the same time. But we probably never have enough threads for that to be a real
// issue - worst case, we exceed the max queue size by the number of threads, which is fine.
// All we want to avoid is to have an infinitely growing queue.
while( max_queue_size_ > 0 && currently_enqueued_tasks() >= max_queue_size_ ) {
try_run_pending_task();
}
using result_type = typename std::result_of<F(Args...)>::type;

// Make sure that we do not enqueue more tasks than the max size.
run_tasks_until_below_max_queue_size_();

// Prepare the task by binding the function to its arguments.
// Using a packaged task ensures that any exception thrown in the task function
// Prepare a promise and associated future of the task; the latter is our return value.
// Using a promise ensures that any exception thrown in the task function
// will be caught by the future, and re-thrown when its get() function is called,
// see e.g., https://stackoverflow.com/a/16345305/4184258
using result_type = typename std::result_of<F(Args...)>::type;
auto task_package = std::make_shared< std::packaged_task<result_type()> >(
std::bind( std::forward<F>(f), std::forward<Args>(args)... )
);
auto task_promise = std::make_shared<std::promise<result_type>>();
auto future_result = ProactiveFuture<result_type>( task_promise->get_future(), *this );

// Prepare the resulting future result of the task, which is our return value.
auto future_result = ProactiveFuture<result_type>( task_package->get_future(), *this );
// To make our lives easier for the helper functions used below, we just wrap
// the task in a function that can be called without arguments.
std::function<result_type()> task_function = std::bind(
std::forward<F>(f), std::forward<Args>(args)...
);

// Prepare the task that we want to submit, by wrapping the function to be called.
// Prepare the task that we want to submit.
// All this wrapping should be completely transparent to the compiler, and removed.
// The task captures the package including the promise that is needed for the future.
WrappedTask wrapped_task;
wrapped_task.function = [task_package, this]()
{
// Run the actual work task here.
// Once done, we can signal this to the unfinished list.
(*task_package)();
assert( this->unfinished_tasks_.load() > 0 );
--this->unfinished_tasks_;
};
wrapped_task.function = make_wrapped_task_with_promise_( task_promise, task_function );

// We add the task, incrementing the unfinished counter, and only decrementing it once the
// task has been fully processed. That way, the counter always tells us if there is still
// work going on. We capture a reference to `this` in the task above, which could be
// dangerous if the threads survive the lifetime of the pool, but given that their exit
// condition is only called from the pool destructor, this should never be able to happen.
// We first incrementi the unfinished counter, and only decrementing it once the task has
// been fully processed. Thus, the counter always tells us if there is still work going on.
++unfinished_tasks_;
task_queue_.enqueue(
std::move( wrapped_task )
);
task_queue_.enqueue( std::move( wrapped_task ));

// The task is submitted. Return its future for the caller to be able to wait for it.
return future_result;
Expand All @@ -507,17 +501,8 @@ class ThreadPool
template<class F, class... Args>
void enqueue_detached( F&& f, Args&&... args )
{
// Check that we can enqueue a task at the moment, of if we need to wait and to work first.
// In a high-contention situation, this of course could fail, so that once the loop condition
// is checked, some other task already has finished the work. But that doesn't matter, the
// call to try_run_pending_task will catch that and just do nothing. Also, the other way round
// could happen, and the queue could in theory be overloaded if many threads try to enqueue
// at exactly the same time. But we probably never have enough threads for that to be a real
// issue - worst case, we exceed the max queue size by the number of threads, which is fine.
// All we want to avoid is to have an infinitely growing queue.
while( max_queue_size_ > 0 && currently_enqueued_tasks() >= max_queue_size_ ) {
try_run_pending_task();
}
// Make sure that we do not enqueue more tasks than the max size.
run_tasks_until_below_max_queue_size_();

// Prepare the task that we want to submit, by wrapping the function to be called.
// All this wrapping should be completely transparent to the compiler, and removed.
Expand All @@ -526,8 +511,7 @@ class ThreadPool
auto task_function = std::bind( std::forward<F>(f), std::forward<Args>(args)... );
wrapped_task.function = [task_function, this]()
{
// Run the actual work task here.
// Once done, we can signal this to the unfinished list.
// Run the actual work task here. Once done, we can signal this to the unfinished list.
task_function();
assert( this->unfinished_tasks_.load() > 0 );
--this->unfinished_tasks_;
Expand Down Expand Up @@ -567,24 +551,26 @@ class ThreadPool
/**
* @brief Wait for all current tasks to be finished processing.
*
* This function simply calls try_run_pending_task() until there are no more tasks to process.
* This is an alternative mechanism for tasks whose future has not been captured when being
* enqueued. This can be used for instance by a main thread that keeps submitting work,
* and then later needs to wait for everything to be finished. In that case, it might make
* sense to set a max_queue_size when constructing the pool, to ensure that the pool does not
* grow indefinitely. See the main class description for details.
* and then later needs to wait for everything to be finished. For this use case, it might
* make sense to set a max_queue_size when constructing the pool, to ensure that the pool does
* not grow indefinitely. See the main class description for details.
*/
void wait()
void wait_for_all_pending_tasks()
{
// Wait for all pending tasks to be processed. While we wait, we can also help
// processing tasks! The loop stops once there are not more unfinished tasks.
while( unfinished_tasks_.load() > 0 ) {
while( try_run_pending_task() );
std::this_thread::yield();
}
assert( unfinished_tasks_.load() == 0 );
}

// -------------------------------------------------------------
// Internal Members
// Wrapped Task
// -------------------------------------------------------------

private:
Expand Down Expand Up @@ -650,6 +636,76 @@ class ThreadPool
}
}

inline void run_tasks_until_below_max_queue_size_()
{
// Check that we can enqueue a task at the moment, of if we need to wait and to work first.
// In a high-contention situation, this of course could fail, so that once the loop condition
// is checked, some other task already has finished the work. But that doesn't matter, the
// call to try_run_pending_task will catch that and just do nothing. Also, the other way round
// could happen, and the queue could in theory be overloaded if many threads try to enqueue
// at exactly the same time. But we probably never have enough threads for that to be a real
// issue - worst case, we exceed the max queue size by the number of threads, which is fine.
// All we want to avoid is to have an infinitely growing queue.
while( max_queue_size_ > 0 && pending_tasks_count() >= max_queue_size_ ) {
try_run_pending_task();
}
}

template<typename T>
inline std::function<void()> make_wrapped_task_with_promise_(
std::shared_ptr<std::promise<T>> task_promise,
std::function<T()> task_function
) {
// We capture a reference to `this` in the below lambda, which could be dangerous
// if the threads survive the lifetime of the pool, but given that the pool destructor
// waits for all of them to finish, this should never be able to happen.
return [this, task_promise, task_function]()
{
// Run the work task, and set the value of the associated promise.
// We need to delegate this here, as the std::promise::set_value() function
// differs for void and non-void return types. That is unfortunate.
try {
run_task_and_fulfill_promise_<T>(
task_promise, task_function
);
} catch (...) {
// TODO Here, we might or might not already have decremented unfinished_tasks_,
// depending on where exaclty the exception occurred. This is not clean,
// but for now, this is accetable, as we terminate upon exceptions anyway.
task_promise->set_exception( std::current_exception() );
}
};
}

template<typename T>
typename std::enable_if<!std::is_void<T>::value>::type
inline run_task_and_fulfill_promise_(
std::shared_ptr<std::promise<T>> task_promise,
std::function<T()> task_function
) {
// Run the actual work task here. Once done, we can signal this to the unfinished list.
// This bit is the only reason why the whole wrapping exists: We need to first decrement
// the unfinished tasks count, before setting the promise value, as otherwise, outside
// threads might deduce that there are more pending tasks, when in fact we are already done.
auto result = task_function();
assert( unfinished_tasks_.load() > 0 );
--unfinished_tasks_;
task_promise->set_value( std::move( result ));
}

template<typename T>
typename std::enable_if<std::is_void<T>::value>::type
inline run_task_and_fulfill_promise_(
std::shared_ptr<std::promise<T>> task_promise,
std::function<void()> task_function
) {
// Same as above, but for void functions, i.e., without setting a value for the promise.
task_function();
assert( unfinished_tasks_.load() > 0 );
--unfinished_tasks_;
task_promise->set_value();
}

// -------------------------------------------------------------
// Internal Members
// -------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions test/src/utils/containers/iterators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ void test_generic_input_stream_( size_t num_elements, size_t block_size )
// any more - the iterator should have waited for the end of everything before finishing.
// We are only using the global thread pool sequentially in the tests here, so there
// cannot be anything left from other places once we are done with the iteration.
EXPECT_EQ( 0, Options::get().global_thread_pool()->currently_enqueued_tasks() );
EXPECT_EQ( 0, Options::get().global_thread_pool()->pending_tasks_count() );

// We called the get element function in the lambda exactly one per data item,
// and one last time at the end to indicate that ther is no more data.
Expand All @@ -332,7 +332,7 @@ TEST( Containers, GenericInputStream )
LOG_SCOPE_LEVEL( genesis::utils::Logging::kInfo );

// Loop a few times, to have a higher chance of finding race conditions etc in the threading.
for( size_t i = 0; i < 250; ++i ) {
for( size_t i = 0; i < 500; ++i ) {

// No elements
test_generic_input_stream_( 0, 0 );
Expand Down
2 changes: 1 addition & 1 deletion test/src/utils/threading/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ void thread_pool_nested_fuzzy_work_()

// Run the function that recursively splits the tasks into blocks.
thread_pool_compute_nested_fuzzy_work_( pool, numbers, 0, num_tasks, 0, counter );
ASSERT_EQ( 0, pool->currently_enqueued_tasks() );
ASSERT_EQ( 0, pool->pending_tasks_count() );

// Aggregate the result and check that we got the correct sum.
auto const total = std::accumulate( numbers.begin(), numbers.end(), 0 );
Expand Down

0 comments on commit 588fd9b

Please sign in to comment.