Skip to content

Commit

Permalink
ARROW-6983: [C++] Fix ThreadedTaskGroup lifetime issue
Browse files Browse the repository at this point in the history
Together with a test.

Closes #5724 from pitrou/ARROW-6983-threaded-task-group-lifetime and squashes the following commits:

451c687 <Antoine Pitrou> ARROW-6983:  Fix ThreadedTaskGroup lifetime issue

Authored-by: Antoine Pitrou <[email protected]>
Signed-off-by: Benjamin Kietzman <[email protected]>
  • Loading branch information
pitrou authored and kszucs committed Oct 24, 2019
1 parent 4142ed5 commit a83a0f0
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 5 deletions.
11 changes: 7 additions & 4 deletions cpp/src/arrow/util/task_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <mutex>
#include <utility>

#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
#include "arrow/util/thread_pool.h"

Expand Down Expand Up @@ -88,13 +89,15 @@ class ThreadedTaskGroup : public TaskGroup {
// Only if an error occurs is the lock taken
if (ok_.load(std::memory_order_acquire)) {
nremaining_.fetch_add(1, std::memory_order_acquire);
Status st = thread_pool_->Spawn([this, task]() {
if (ok_.load(std::memory_order_acquire)) {

auto self = checked_pointer_cast<ThreadedTaskGroup>(shared_from_this());
Status st = thread_pool_->Spawn([self, task]() {
if (self->ok_.load(std::memory_order_acquire)) {
// XXX what about exceptions?
Status st = task();
UpdateStatus(std::move(st));
self->UpdateStatus(std::move(st));
}
OneTaskDone();
self->OneTaskDone();
});
UpdateStatus(std::move(st));
}
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/util/task_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ThreadPool;
/// implementation. When Finish() returns, it is guaranteed that all
/// tasks have finished, or at least one has errored.
///
class ARROW_EXPORT TaskGroup {
class ARROW_EXPORT TaskGroup : public std::enable_shared_from_this<TaskGroup> {
public:
/// Add a Status-returning function to execute. Execution order is
/// undefined. The function may be executed immediately or later.
Expand Down
84 changes: 84 additions & 0 deletions cpp/src/arrow/util/task_group_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,75 @@ void TestTasksSpawnTasks(std::shared_ptr<TaskGroup> task_group) {
ASSERT_EQ(count.load(), (1 << (N + 1)) - 1);
}

// A task that keeps recursing until a barrier is set.
// Using a lambda for this doesn't play well with Thread Sanitizer.
struct BarrierTask {
std::atomic<bool>* barrier_;
std::weak_ptr<TaskGroup> weak_group_ptr_;
Status final_status_;

Status operator()() {
if (!barrier_->load()) {
sleep_for(1e-5);
// Note the TaskGroup should be kept alive by the fact this task
// is still running...
weak_group_ptr_.lock()->Append(*this);
}
return final_status_;
}
};

// Try to replicate subtle lifetime issues when destroying a TaskGroup
// where all tasks may not have finished running.
void StressTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
const int NTASKS = 100;
auto task_group = factory();
auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);

std::atomic<bool> barrier(false);

BarrierTask task{&barrier, weak_group_ptr, Status::OK()};

for (int i = 0; i < NTASKS; ++i) {
task_group->Append(task);
}

// Lose strong reference
barrier.store(true);
task_group.reset();

// Wait for finish
while (!weak_group_ptr.expired()) {
sleep_for(1e-5);
}
}

// Same, but with also a failing task
void StressFailingTaskGroupLifetime(std::function<std::shared_ptr<TaskGroup>()> factory) {
const int NTASKS = 100;
auto task_group = factory();
auto weak_group_ptr = std::weak_ptr<TaskGroup>(task_group);

std::atomic<bool> barrier(false);

BarrierTask task{&barrier, weak_group_ptr, Status::OK()};
BarrierTask failing_task{&barrier, weak_group_ptr, Status::Invalid("XXX")};

for (int i = 0; i < NTASKS; ++i) {
task_group->Append(task);
}
task_group->Append(failing_task);

// Lose strong reference
barrier.store(true);
task_group.reset();

// Wait for finish
while (!weak_group_ptr.expired()) {
sleep_for(1e-5);
}
}

TEST(SerialTaskGroup, Success) { TestTaskGroupSuccess(TaskGroup::MakeSerial()); }

TEST(SerialTaskGroup, Errors) { TestTaskGroupErrors(TaskGroup::MakeSerial()); }
Expand Down Expand Up @@ -259,5 +328,20 @@ TEST(ThreadedTaskGroup, SubGroupsErrors) {
TestTaskSubGroupsErrors(TaskGroup::MakeThreaded(thread_pool.get()));
}

TEST(ThreadedTaskGroup, StressTaskGroupLifetime) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK(ThreadPool::Make(16, &thread_pool));

StressTaskGroupLifetime([&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
}

TEST(ThreadedTaskGroup, StressFailingTaskGroupLifetime) {
std::shared_ptr<ThreadPool> thread_pool;
ASSERT_OK(ThreadPool::Make(16, &thread_pool));

StressFailingTaskGroupLifetime(
[&] { return TaskGroup::MakeThreaded(thread_pool.get()); });
}

} // namespace internal
} // namespace arrow

0 comments on commit a83a0f0

Please sign in to comment.