-
Notifications
You must be signed in to change notification settings - Fork 4.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #44901 from makortel/edmAsync
Introduce edm::Async service, and use it in CUDA and Alpaka modules
- Loading branch information
Showing
24 changed files
with
1,052 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#ifndef FWCore_Concurrency_Async_h | ||
#define FWCore_Concurrency_Async_h | ||
|
||
#include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h" | ||
#include "FWCore/Concurrency/interface/WaitingThreadPool.h" | ||
|
||
namespace edm { | ||
// All member functions are thread safe | ||
class Async { | ||
public: | ||
Async() = default; | ||
virtual ~Async() noexcept; | ||
|
||
// prevent copying and moving | ||
Async(Async const&) = delete; | ||
Async(Async&&) = delete; | ||
Async& operator=(Async const&) = delete; | ||
Async& operator=(Async&&) = delete; | ||
|
||
template <typename F, typename G> | ||
void runAsync(WaitingTaskWithArenaHolder holder, F&& func, G&& errorContextFunc) { | ||
ensureAllowed(); | ||
pool_.runAsync(std::move(holder), std::forward<F>(func), std::forward<G>(errorContextFunc)); | ||
} | ||
|
||
protected: | ||
virtual void ensureAllowed() const = 0; | ||
|
||
private: | ||
WaitingThreadPool pool_; | ||
}; | ||
} // namespace edm | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#ifndef FWCore_Concurrency_WaitingThreadPool_h | ||
#define FWCore_Concurrency_WaitingThreadPool_h | ||
|
||
#include "FWCore/Utilities/interface/ConvertException.h" | ||
#include "FWCore/Utilities/interface/ReusableObjectHolder.h" | ||
#include "FWCore/Concurrency/interface/WaitingTaskWithArenaHolder.h" | ||
|
||
#include <condition_variable> | ||
#include <mutex> | ||
#include <thread> | ||
|
||
namespace edm { | ||
namespace impl { | ||
class WaitingThread { | ||
public: | ||
WaitingThread(); | ||
~WaitingThread() noexcept; | ||
|
||
WaitingThread(WaitingThread const&) = delete; | ||
WaitingThread& operator=(WaitingThread&&) = delete; | ||
WaitingThread(WaitingThread&&) = delete; | ||
WaitingThread& operator=(WaitingThread const&) = delete; | ||
|
||
template <typename F, typename G> | ||
void run(WaitingTaskWithArenaHolder holder, | ||
F&& func, | ||
G&& errorContextFunc, | ||
std::shared_ptr<WaitingThread> thisPtr) { | ||
std::unique_lock lk(mutex_); | ||
func_ = [holder = std::move(holder), | ||
func = std::forward<F>(func), | ||
errorContext = std::forward<G>(errorContextFunc)]() mutable { | ||
try { | ||
convertException::wrap([&func]() { func(); }); | ||
} catch (cms::Exception& e) { | ||
e.addContext(errorContext()); | ||
holder.doneWaiting(std::current_exception()); | ||
} | ||
}; | ||
thisPtr_ = std::move(thisPtr); | ||
cond_.notify_one(); | ||
} | ||
|
||
private: | ||
void stopThread() { | ||
std::unique_lock lk(mutex_); | ||
stopThread_ = true; | ||
cond_.notify_one(); | ||
} | ||
|
||
void threadLoop() noexcept; | ||
|
||
std::thread thread_; | ||
std::mutex mutex_; | ||
std::condition_variable cond_; | ||
CMS_THREAD_GUARD(mutex_) std::function<void()> func_; | ||
// The purpose of thisPtr_ is to keep the WaitingThread object | ||
// outside of the WaitingThreadPool until the func_ has returned. | ||
CMS_THREAD_GUARD(mutex_) std::shared_ptr<WaitingThread> thisPtr_; | ||
CMS_THREAD_GUARD(mutex_) bool stopThread_ = false; | ||
}; | ||
} // namespace impl | ||
|
||
// Provides a mechanism to run the function 'func' asynchronously, | ||
// i.e. without the calling thread to wait for the func() to return. | ||
// The func should do as little work (outside of the TBB threadpool) | ||
// as possible. The func must terminate eventually. The intended use | ||
// case are blocking synchronization calls with external entities, | ||
// where the calling thread is suspended while waiting. | ||
// | ||
// The func() is run in a thread that belongs to a separate pool of | ||
// threads than the calling thread. Remotely similar to | ||
// std::async(), but instead of dealing with std::futures, takes an | ||
// edm::WaitingTaskWithArenaHolder object, that is signaled upon the | ||
// func() returning or throwing an exception. | ||
// | ||
// The caller is responsible for keeping the WaitingThreadPool | ||
// object alive at least as long as all asynchronous calls finish. | ||
class WaitingThreadPool { | ||
public: | ||
WaitingThreadPool() = default; | ||
WaitingThreadPool(WaitingThreadPool const&) = delete; | ||
WaitingThreadPool& operator=(WaitingThreadPool const&) = delete; | ||
WaitingThreadPool(WaitingThreadPool&&) = delete; | ||
WaitingThreadPool& operator=(WaitingThreadPool&&) = delete; | ||
|
||
/** | ||
* \param holder WaitingTaskWithArenaHolder object to signal the completion of 'func' | ||
* \param func Function to run in a separate thread | ||
* \param errorContextFunc Function returning a string-like object | ||
* that is added to the context of | ||
* cms::Exception in case 'func' throws an | ||
* exception | ||
*/ | ||
template <typename F, typename G> | ||
void runAsync(WaitingTaskWithArenaHolder holder, F&& func, G&& errorContextFunc) { | ||
auto thread = pool_.makeOrGet([]() { return std::make_unique<impl::WaitingThread>(); }); | ||
thread->run(std::move(holder), std::forward<F>(func), std::forward<G>(errorContextFunc), std::move(thread)); | ||
} | ||
|
||
private: | ||
edm::ReusableObjectHolder<impl::WaitingThread> pool_; | ||
}; | ||
} // namespace edm | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#include "FWCore/Concurrency/interface/Async.h" | ||
|
||
namespace edm { | ||
Async::~Async() noexcept = default; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
#include "FWCore/Concurrency/interface/WaitingThreadPool.h" | ||
|
||
#include <cassert> | ||
#include <string_view> | ||
|
||
#include <pthread.h> | ||
|
||
namespace edm::impl { | ||
WaitingThread::WaitingThread() { | ||
thread_ = std::thread(&WaitingThread::threadLoop, this); | ||
static constexpr auto poolName = "edm async pool"; | ||
// pthread_setname_np() string length is limited to 16 characters, | ||
// including the null termination | ||
static_assert(std::string_view(poolName).size() < 16); | ||
|
||
int err = pthread_setname_np(thread_.native_handle(), poolName); | ||
// According to the glibc documentation, the only error | ||
// pthread_setname_np() can return is about the argument C-string | ||
// being too long. We already check above the C-string is shorter | ||
// than the limit was at the time of writing. In order to capture | ||
// if the limit shortens, or other error conditions get added, | ||
// let's assert() anyway (exception feels overkill) | ||
assert(err == 0); | ||
} | ||
|
||
WaitingThread::~WaitingThread() noexcept { | ||
// When we are shutting down, we don't care about any possible | ||
// system errors anymore | ||
CMS_SA_ALLOW try { | ||
stopThread(); | ||
thread_.join(); | ||
} catch (...) { | ||
} | ||
} | ||
|
||
void WaitingThread::threadLoop() noexcept { | ||
std::unique_lock lk(mutex_); | ||
|
||
while (true) { | ||
cond_.wait(lk, [this]() { return static_cast<bool>(func_) or stopThread_; }); | ||
if (stopThread_) { | ||
// There should be no way to stop the thread when it as the | ||
// func_ assigned, but let's make sure | ||
assert(not thisPtr_); | ||
break; | ||
} | ||
func_(); | ||
// Must return this WaitingThread to the ReusableObjectHolder in | ||
// the WaitingThreadPool before resettting func_ (that holds the | ||
// WaitingTaskWithArenaHolder, that enables the progress in the | ||
// TBB thread pool) in order to meet the requirement of | ||
// ReusableObjectHolder destructor that there are no outstanding | ||
// objects. | ||
thisPtr_.reset(); | ||
decltype(func_)().swap(func_); | ||
} | ||
} | ||
} // namespace edm::impl |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
#include "catch.hpp" | ||
|
||
#include <atomic> | ||
|
||
#include "oneapi/tbb/global_control.h" | ||
|
||
#include "FWCore/Concurrency/interface/chain_first.h" | ||
#include "FWCore/Concurrency/interface/FinalWaitingTask.h" | ||
#include "FWCore/Concurrency/interface/Async.h" | ||
|
||
namespace { | ||
constexpr char const* errorContext() { return "AsyncServiceTest"; } | ||
|
||
class AsyncServiceTest : public edm::Async { | ||
public: | ||
enum class State { kAllowed, kDisallowed, kShutdown }; | ||
|
||
AsyncServiceTest() = default; | ||
|
||
void setAllowed(bool allowed) noexcept { allowed_ = allowed; } | ||
|
||
private: | ||
void ensureAllowed() const final { | ||
if (not allowed_) { | ||
throw std::runtime_error("Calling run in this context is not allowed"); | ||
} | ||
} | ||
|
||
std::atomic<bool> allowed_ = true; | ||
}; | ||
} // namespace | ||
|
||
TEST_CASE("Test Async", "[edm::Async") { | ||
// Using parallelism 2 here because otherwise the | ||
// tbb::task_arena::enqueue() in WaitingTaskWithArenaHolder will | ||
// start a new TBB thread that "inherits" the name from the | ||
// WaitingThreadPool thread. | ||
oneapi::tbb::global_control control(oneapi::tbb::global_control::max_allowed_parallelism, 2); | ||
|
||
SECTION("Normal operation") { | ||
AsyncServiceTest service; | ||
std::atomic<int> count{0}; | ||
|
||
oneapi::tbb::task_group group; | ||
edm::FinalWaitingTask waitTask{group}; | ||
|
||
{ | ||
using namespace edm::waiting_task::chain; | ||
auto h1 = first([&service, &count](edm::WaitingTaskHolder h) { | ||
edm::WaitingTaskWithArenaHolder h2(std::move(h)); | ||
service.runAsync( | ||
h2, [&count]() { ++count; }, errorContext); | ||
}) | | ||
lastTask(edm::WaitingTaskHolder(group, &waitTask)); | ||
|
||
auto h2 = first([&service, &count](edm::WaitingTaskHolder h) { | ||
edm::WaitingTaskWithArenaHolder h2(std::move(h)); | ||
service.runAsync( | ||
h2, [&count]() { ++count; }, errorContext); | ||
}) | | ||
lastTask(edm::WaitingTaskHolder(group, &waitTask)); | ||
h2.doneWaiting(std::exception_ptr()); | ||
h1.doneWaiting(std::exception_ptr()); | ||
} | ||
waitTask.waitNoThrow(); | ||
REQUIRE(count.load() == 2); | ||
REQUIRE(waitTask.done()); | ||
REQUIRE(not waitTask.exceptionPtr()); | ||
} | ||
|
||
SECTION("Disallowed") { | ||
AsyncServiceTest service; | ||
std::atomic<int> count{0}; | ||
|
||
oneapi::tbb::task_group group; | ||
edm::FinalWaitingTask waitTask{group}; | ||
|
||
{ | ||
using namespace edm::waiting_task::chain; | ||
auto h = first([&service, &count](edm::WaitingTaskHolder h) { | ||
edm::WaitingTaskWithArenaHolder h2(std::move(h)); | ||
service.runAsync( | ||
h2, [&count]() { ++count; }, errorContext); | ||
service.setAllowed(false); | ||
}) | | ||
then([&service, &count](edm::WaitingTaskHolder h) { | ||
edm::WaitingTaskWithArenaHolder h2(std::move(h)); | ||
service.runAsync( | ||
h2, [&count]() { ++count; }, errorContext); | ||
}) | | ||
lastTask(edm::WaitingTaskHolder(group, &waitTask)); | ||
h.doneWaiting(std::exception_ptr()); | ||
} | ||
waitTask.waitNoThrow(); | ||
REQUIRE(count.load() == 1); | ||
REQUIRE(waitTask.done()); | ||
REQUIRE(waitTask.exceptionPtr()); | ||
REQUIRE_THROWS_WITH(std::rethrow_exception(waitTask.exceptionPtr()), | ||
Catch::Contains("Calling run in this context is not allowed")); | ||
} | ||
} |
Oops, something went wrong.