Skip to content

Commit

Permalink
lock-free async poggers
Browse files Browse the repository at this point in the history
  • Loading branch information
Mishura4 committed Aug 18, 2023
1 parent 4ad1e63 commit 6e6ba3b
Showing 1 changed file with 146 additions and 100 deletions.
246 changes: 146 additions & 100 deletions include/dpp/coro/async.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,82 @@
#include <utility>
#include <type_traits>
#include <functional>
#include <atomic>
#include <cstddef>

namespace dpp {

namespace detail {

/**
* @brief Empty struct used for overload resolution.
*/
struct empty_tag_t{};

/**
* @brief Represents the step an std::async is at.
*/
enum class async_state_t {
sent, /* Request was sent but not co_await-ed. handle is nullptr, result_storage is not constructed */
waiting, /* Request was co_await-ed. handle is valid, result_storage is not constructed */
done, /* Request was completed. handle is unknown, result_storage is valid */
dangling /* Request was never co_await-ed. */
};

/**
* @brief State of the async and its callback.
*
* Defined outside of dpp::async because this seems to work better with Intellisense.
*/
template <typename R>
struct async_callback_data {
/**
* @brief Number of references to this callback state.
*/
std::atomic<int> ref_count{1};

/**
* @brief State of the awaitable and the API callback
*/
std::atomic<detail::async_state_t> state = detail::async_state_t::sent;

/**
* @brief The stored result of the API call, stored as an array of bytes to directly construct with copy constructor
*/
alignas(R) std::array<std::byte, sizeof(R)> result_storage;

/**
* @brief Handle to the coroutine co_await-ing on this API call
*
* @see <a href="https://en.cppreference.com/w/cpp/coroutine/coroutine_handle">std::coroutine_handle</a>
*/
std_coroutine::coroutine_handle<> coro_handle = nullptr;

/**
* @brief Convenience function to construct the result in the storage and initialize its lifetime
*
* @warning This is only a convenience function, ONLY CALL THIS IN THE CALLBACK, before setting state to done.
*/
template <typename... Ts>
void construct_result(Ts&&... ts) {
// Standard-compliant type punning yay
std::construct_at<R>(reinterpret_cast<R *>(result_storage.data()), std::forward<Ts>(ts)...);
}

/**
* @brief Destructor.
*
* Also destroys the result if present.
*/
~async_callback_data() {
if (state.load() == detail::async_state_t::done) {
std::destroy_at<R>(reinterpret_cast<R *>(result_storage.data()));
}
}
};

}

struct confirmation_callback_t;

/**
Expand All @@ -41,7 +114,6 @@ struct confirmation_callback_t;
* @remark - This object's methods, other than constructors and operators, should not be called directly. It is designed to be used with coroutine keywords such as co_await.
* @remark - The coroutine may be resumed in another thread, do not rely on thread_local variables.
* @warning - This feature is EXPERIMENTAL. The API may change at any time and there may be bugs. Please report any to <a href="https://github.com/brainboxdotcc/DPP/issues">GitHub issues</a> or to the <a href="https://discord.gg/dpp">D++ Discord server</a>.
* @warning - Using co_await on this object more than once is undefined behavior.
* @tparam R The return type of the API call. Defaults to confirmation_callback_t
*/
template <typename R>
Expand All @@ -50,97 +122,35 @@ class async {
* @brief Ref-counted callback, contains the callback logic and manages the lifetime of the callback data over multiple threads.
*/
struct shared_callback {
struct empty_tag_t{};

/**
* @brief State of the async and its callback.
*/
struct callback_state {
enum state_t {
waiting,
done,
dangling
};

/**
* @brief Mutex to ensure the API result isn't set at the same time the coroutine is awaited and its value is checked, or the async is destroyed
*/
std::mutex mutex{};

/**
* @brief Number of references to this callback state.
*/
int ref_count;

/**
* @brief State of the awaitable and the API callback
*/
state_t state = waiting;

/**
* @brief The stored result of the API call
*/
std::optional<R> result = std::nullopt;

/**
* @brief Handle to the coroutine co_await-ing on this API call
*
* @see <a href="https://en.cppreference.com/w/cpp/coroutine/coroutine_handle">std::coroutine_handle</a>
*/
detail::std_coroutine::coroutine_handle<> coro_handle = nullptr;
};

callback_state *state;
detail::async_callback_data<R> *state = new detail::async_callback_data<R>;

/**
* @brief Callback function.
*
* Constructs the callback data, and if the coroutine was awaiting, resume it
* @param cback The result of the API call.
*/
void operator()(const R &cback) const {
std::unique_lock lock{get_mutex()};

if (state->state == callback_state::dangling) // Async object is gone - likely an exception killed it or it was never co_await-ed
return;
state->result = cback;
state->state = callback_state::done;
if (state->coro_handle) {
auto handle = state->coro_handle;
state->coro_handle = nullptr;
lock.unlock();
handle.resume();
state->construct_result(cback);
if (state->state.exchange(detail::async_state_t::done) == detail::async_state_t::waiting) {
state->coro_handle.resume();
}
}

/**
* @brief Main constructor, allocates a new callback_state object.
*/
shared_callback() : state{new callback_state{.ref_count = 1}} {}

shared_callback(empty_tag_t) noexcept : state{nullptr} {}
shared_callback() = default;

/**
* @brief Destructor. Releases the held reference and destroys if no other references exist.
* @brief Empty constructor, holds no state.
*/
~shared_callback() {
if (!state) // Moved-from object
return;

std::unique_lock lock{state->mutex};

if (state->ref_count) {
--(state->ref_count);
if (state->ref_count <= 0) {;
lock.unlock();
delete state;
}
}
}
explicit shared_callback(detail::empty_tag_t) noexcept : state{nullptr} {}

/**
* @brief Copy constructor. Takes shared ownership of the callback state, increasing the reference count.
*/
shared_callback(const shared_callback &other) {
shared_callback(const shared_callback &other) noexcept {
this->operator=(other);
}

Expand All @@ -151,12 +161,23 @@ class async {
this->operator=(std::move(other));
}

/**
* @brief Destructor. Releases the held reference and destroys if no other references exist.
*/
~shared_callback() {
if (!state) // Moved-from object
return;

auto count = state->ref_count.fetch_sub(1);
if (count == 0) {
delete state;
}
}

/**
* @brief Copy assignment. Takes shared ownership of the callback state, increasing the reference count.
*/
shared_callback &operator=(const shared_callback &other) noexcept {
std::lock_guard lock{other.get_mutex()};

state = other.state;
++state->ref_count;
return *this;
Expand All @@ -166,36 +187,46 @@ class async {
* @brief Move assignment. Transfers ownership from another object, leaving intact the reference count. The other object releases the callback state.
*/
shared_callback &operator=(shared_callback &&other) noexcept {
std::lock_guard lock{other.get_mutex()};

state = std::exchange(other.state, nullptr);
return *this;
}

/**
* @brief Function called by the async when it is destroyed when it was never co_awaited, signals to the callback to abort.
*/
void set_dangling() {
void set_dangling() noexcept {
if (!state) // moved-from object
return;
std::lock_guard lock{get_mutex()};
/*
If the state is sent but not awaited, set it to dangling, in a relaxed memory order (we don't care if the callback thread actually sees it).
"sent" is the only state we care about to set it to dangling, as if it's done it's not dangling, and if it's waiting... Something went seriously wrong and shouldn't be happening.
*/
auto expected = detail::async_state_t::sent;
state->state.compare_exchange_strong(expected, detail::async_state_t::dangling, std::memory_order_seq_cst, std::memory_order_relaxed);
}

if (state->state == callback_state::waiting)
state->state = callback_state::dangling;
bool done(std::memory_order order = std::memory_order_seq_cst) const noexcept {
return (state->state.load(order) == detail::async_state_t::done);
}

/**
* @brief Convenience function to get the shared callback state's mutex.
* @brief Convenience function to get the shared callback state's result.
*
* @warning It is UB to call this on a callback whose state is anything else but async_state_t::done.
*/
std::mutex &get_mutex() const {
return (state->mutex);
R &get_result() noexcept {
assert(state && done());
return (*reinterpret_cast<R *>(state->result_storage.data()));
}

/**
* @brief Convenience function to get the shared callback state's result.
*
* @warning It is UB to call this on a callback whose state is anything else but async_state_t::done.
*/
std::optional<R> &get_result() const {
return (state->result);
const R &get_result() const noexcept {
assert(state && done());
return (*reinterpret_cast<R *>(state->result_storage.data()));
}
};

Expand Down Expand Up @@ -246,7 +277,7 @@ class async {
/**
* @brief Construct an empty async. Using `co_await` on an empty async is undefined behavior.
*/
async() noexcept : api_callback{typename shared_callback::empty_tag_t{}} {}
async() noexcept : api_callback{detail::empty_tag_t{}} {}

/**
* @brief Destructor. If any callback is pending it will be aborted.
Expand Down Expand Up @@ -294,9 +325,7 @@ class async {
* @return bool Whether we already have the result of the API call or not
*/
bool await_ready() noexcept {
std::lock_guard lock{api_callback.get_mutex()};

return api_callback.get_result().has_value();
return api_callback.done();
}

/**
Expand All @@ -307,25 +336,42 @@ class async {
* @remark Do not call this manually, use the co_await keyword instead.
* @param handle The handle to the coroutine co_await-ing and being suspended
*/
template <typename T>
bool await_suspend(detail::std_coroutine::coroutine_handle<T> caller) {
std::lock_guard lock{api_callback.get_mutex()};

if (api_callback.get_result().has_value())
return false; // immediately resume the coroutine as we already have the result of the api call
bool await_suspend(detail::std_coroutine::coroutine_handle<> caller) {
auto sent = detail::async_state_t::sent;
api_callback.state->coro_handle = caller;
return true; // suspend the caller, the callback will resume it
return api_callback.state->state.compare_exchange_strong(sent, detail::async_state_t::waiting); // true (suspend) if `sent` was replaced with `waiting` -- false (resume) if the value was not `sent` (`done` is the only other option)
}

/**
* @brief Function called by the standard library when the async is resumed. Its return value is what the whole co_await expression evaluates to
*
* @remark Do not call this manually, use the co_await keyword instead.
* @return R& The result of the API call as an lvalue reference.
*/
R& await_resume() & noexcept {
return api_callback.get_result();
}


/**
* @brief Function called by the standard library when the async is resumed. Its return value is what the whole co_await expression evaluates to
*
* @remark Do not call this manually, use the co_await keyword instead.
* @return const R& The result of the API call as a const lvalue reference.
*/
const R& await_resume() const& noexcept {
return api_callback.get_result();
}


/**
* @brief Function called by the standard library when the async is resumed. Its return value is what the whole co_await expression evaluates to
*
* @remark Do not call this manually, use the co_await keyword instead.
* @return R The result of the API call.
* @return R&& The result of the API call as an rvalue reference.
*/
R await_resume() {
// no locking needed here as the callback has already executed
return std::move(*api_callback.get_result());
R&& await_resume() && noexcept {
return std::move(api_callback.get_result());
}
};

Expand Down

0 comments on commit 6e6ba3b

Please sign in to comment.