Skip to content

Commit

Permalink
Merge branch 'dev' into db/macosx-13
Browse files Browse the repository at this point in the history
  • Loading branch information
dudoslav authored Dec 9, 2024
2 parents edad604 + 2637682 commit 50f9d16
Show file tree
Hide file tree
Showing 6 changed files with 244 additions and 32 deletions.
6 changes: 3 additions & 3 deletions tiledb/common/thread_pool/test/unit_thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void wait_all(
ThreadPool& pool, bool use_wait, std::vector<ThreadPool::Task>& results) {
if (use_wait) {
for (auto& r : results) {
REQUIRE(pool.wait(r).ok());
REQUIRE(r.wait().ok());
}
} else {
REQUIRE(pool.wait_all(results).ok());
Expand All @@ -117,7 +117,7 @@ Status wait_all_status(
if (use_wait) {
Status ret;
for (auto& r : results) {
auto st = pool.wait(r);
auto st = r.wait();
if (ret.ok() && !st.ok()) {
ret = st;
}
Expand All @@ -139,7 +139,7 @@ uint64_t wait_all_num_status(
int num_ok = 0;
if (use_wait) {
for (auto& r : results) {
num_ok += pool.wait(r).ok() ? 1 : 0;
num_ok += r.wait().ok() ? 1 : 0;
}
} else {
std::vector<Status> statuses = pool.wait_all_status(results);
Expand Down
50 changes: 44 additions & 6 deletions tiledb/common/thread_pool/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ void ThreadPool::shutdown() {
threads_.clear();
}

Status ThreadPool::wait_all(std::vector<Task>& tasks) {
Status ThreadPool::wait_all(std::vector<ThreadPoolTask*>& tasks) {
auto statuses = wait_all_status(tasks);
for (auto& st : statuses) {
if (!st.ok()) {
Expand All @@ -131,14 +131,33 @@ Status ThreadPool::wait_all(std::vector<Task>& tasks) {
return Status::Ok();
}

Status ThreadPool::wait_all(std::vector<Task>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all(task_ptrs);
}

Status ThreadPool::wait_all(std::vector<SharedTask>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all(task_ptrs);
}

// Return a vector of Status. If any task returns an error value or throws an
// exception, we save an error code in the corresponding location in the Status
// vector. All tasks are waited on before return. Multiple error statuses may
// be saved. We may call logger here because thread pool will not be used until
// context is fully constructed (which will include logger).
// Unfortunately, C++ does not have the notion of an aggregate exception, so we
// don't throw in the case of errors/exceptions.
std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
std::vector<Status> ThreadPool::wait_all_status(
std::vector<ThreadPoolTask*>& tasks) {
std::vector<Status> statuses(tasks.size());

std::queue<size_t> pending_tasks;
Expand All @@ -154,17 +173,17 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
pending_tasks.pop();
auto& task = tasks[task_id];

if (!task.valid()) {
if (task && !task->valid()) {
statuses[task_id] = Status_ThreadPoolError("Invalid task future");
LOG_STATUS_NO_RETURN_VALUE(statuses[task_id]);
} else if (
task.wait_for(std::chrono::milliseconds(0)) ==
task->wait_for(std::chrono::milliseconds(0)) ==
std::future_status::ready) {
// Task is completed, get result, handling possible exceptions

Status st = [&task] {
try {
return task.get();
return task->get();
} catch (const std::exception& e) {
return Status_TaskError(
"Caught std::exception: " + std::string(e.what()));
Expand Down Expand Up @@ -205,7 +224,26 @@ std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
return statuses;
}

Status ThreadPool::wait(Task& task) {
std::vector<Status> ThreadPool::wait_all_status(std::vector<Task>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all_status(task_ptrs);
}

std::vector<Status> ThreadPool::wait_all_status(
std::vector<SharedTask>& tasks) {
std::vector<ThreadPoolTask*> task_ptrs;
for (auto& t : tasks) {
task_ptrs.emplace_back(&t);
}

return wait_all_status(task_ptrs);
}

Status ThreadPool::wait(ThreadPoolTask& task) {
while (true) {
if (!task.valid()) {
return Status_ThreadPoolError("Invalid task future");
Expand Down
196 changes: 186 additions & 10 deletions tiledb/common/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,175 @@ namespace tiledb::common {

class ThreadPool {
public:
using Task = std::future<Status>;
/**
* @brief Abstract base class for tasks that can run in this threadpool.
*/
class ThreadPoolTask {
public:
ThreadPoolTask() = default;
ThreadPoolTask(ThreadPool* tp)
: tp_(tp){};

virtual ~ThreadPoolTask(){};

protected:
friend class ThreadPool;

/* C.67 A polymorphic class should suppress public copy/move to prevent
* slicing */
ThreadPoolTask(const ThreadPoolTask&) = default;
ThreadPoolTask& operator=(const ThreadPoolTask&) = default;
ThreadPoolTask(ThreadPoolTask&&) = default;
ThreadPoolTask& operator=(ThreadPoolTask&&) = default;

/**
* Pure virtual functions that tasks need to implement so that they can be
* run in the threadpool wait loop
*/
virtual std::future_status wait_for(
const std::chrono::milliseconds timeout_duration) const = 0;
virtual bool valid() const noexcept = 0;
virtual Status get() = 0;

ThreadPool* tp_{nullptr};
};

/**
* @brief Task class encapsulating std::future. Like std::future it's shared
* state can only be get once and thus only one thread. It can only be moved
* and not copied.
*/
class Task : public ThreadPoolTask {
public:
/**
* Default constructor
* @brief Default constructed SharedTask is possible but not valid().
*/
Task() = default;

/**
* Constructor from std::future
*/
Task(std::future<Status>&& f, ThreadPool* tp)
: ThreadPoolTask(tp)
, f_(std::move(f)){};

/**
* Wait in the threadpool for this task to be ready.
*/
Status wait() {
if (tp_ == nullptr) {
throw std::runtime_error("Cannot wait, threadpool is not initialized.");
} else if (!f_.valid()) {
throw std::runtime_error("Cannot wait, task is invalid.");
} else {
return tp_->wait(*this);
}
}

/**
* Is this task valid. Wait can only be called on vaid tasks.
*/
bool valid() const noexcept {
return f_.valid();
}

private:
friend class ThreadPool;

/**
* Wait for input milliseconds for this task to be ready.
*/
std::future_status wait_for(
const std::chrono::milliseconds timeout_duration) const {
return f_.wait_for(timeout_duration);
}

/**
* Get the result of that task. Can only be used once. Only accessible from
* within the threadpool `wait` loop.
*/
Status get() {
return f_.get();
}

/**
* The encapsulated std::shared_future
*/
std::future<Status> f_;
};

/**
* @brief SharedTask class encapsulating std::shared_future. Like
* std::shared_future multiple threads can wait/get on the shared state
* multiple times. It can be both moved and copied.
*/
class SharedTask : public ThreadPoolTask {
public:
/**
* Default constructor
* @brief Default constructed SharedTask is possible but not valid().
*/
SharedTask() = default;

/**
* Constructor from std::future or std::shared_future
*/
SharedTask(auto&& f, ThreadPool* tp)
: ThreadPoolTask(tp)
, f_(std::forward<decltype(f)>(f)){};

/**
* Move constructor from a Task
*/
SharedTask(Task&& t) noexcept
: ThreadPoolTask(t.tp_)
, f_(std::move(t.f_)){};

/**
* Wait in the threadpool for this task to be ready.
*/
Status wait() {
if (tp_ == nullptr) {
throw std::runtime_error("Cannot wait, threadpool is not initialized.");
} else if (!f_.valid()) {
throw std::runtime_error("Cannot wait, shared task is invalid.");
} else {
return tp_->wait(*this);
}
}

/**
* Is this task valid. Wait can only be called on vaid tasks.
*/
bool valid() const noexcept {
return f_.valid();
}

private:
friend class ThreadPool;

/**
* Wait for input milliseconds for this task to be ready.
*/
std::future_status wait_for(
const std::chrono::milliseconds timeout_duration) const {
return f_.wait_for(timeout_duration);
}

/**
* Get the result of that task. Can be called multiple times from multiple
* threads. Only accessible from within the threadpool `wait` loop.
*/
Status get() {
return f_.get();
}

/**
* The encapsulated std::shared_future
*/
std::shared_future<Status> f_;
};

/* ********************************* */
/* CONSTRUCTORS & DESTRUCTORS */
Expand Down Expand Up @@ -108,7 +276,7 @@ class ThreadPool {
return std::apply(std::move(f), std::move(args));
});

std::future<R> future = task->get_future();
Task future(task->get_future(), this);

task_queue_.push(task);

Expand All @@ -127,6 +295,19 @@ class ThreadPool {
return async(std::forward<Fn>(f), std::forward<Args>(args)...);
}

/* Helper functions for lists that consists purely of Tasks */
Status wait_all(std::vector<Task>& tasks);
std::vector<Status> wait_all_status(std::vector<Task>& tasks);

/* Helper functions for lists that consists purely of SharedTasks */
Status wait_all(std::vector<SharedTask>& tasks);
std::vector<Status> wait_all_status(std::vector<SharedTask>& tasks);

/* ********************************* */
/* PRIVATE ATTRIBUTES */
/* ********************************* */

private:
/**
* Wait on all the given tasks to complete. This function is safe to call
* recursively and may execute pending tasks on the calling thread while
Expand All @@ -136,7 +317,7 @@ class ThreadPool {
* @return Status::Ok if all tasks returned Status::Ok, otherwise the first
* error status is returned
*/
Status wait_all(std::vector<Task>& tasks);
Status wait_all(std::vector<ThreadPoolTask*>& tasks);

/**
* Wait on all the given tasks to complete, returning a vector of their return
Expand All @@ -151,7 +332,7 @@ class ThreadPool {
* @param tasks Task list to wait on
* @return Vector of each task's Status.
*/
std::vector<Status> wait_all_status(std::vector<Task>& tasks);
std::vector<Status> wait_all_status(std::vector<ThreadPoolTask*>& tasks);

/**
* Wait on a single tasks to complete. This function is safe to call
Expand All @@ -162,13 +343,8 @@ class ThreadPool {
* @return Status::Ok if the task returned Status::Ok, otherwise the error
* status is returned
*/
Status wait(Task& task);

/* ********************************* */
/* PRIVATE ATTRIBUTES */
/* ********************************* */
Status wait(ThreadPoolTask& task);

private:
/** The worker thread routine */
void worker();

Expand Down
Loading

0 comments on commit 50f9d16

Please sign in to comment.