Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache the set of active transfers in receive arbiter #319

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions include/receive_arbiter.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#pragma once

Check failure on line 1 in include/receive_arbiter.h

View workflow job for this annotation

GitHub Actions / celerity-ci-report-pr

include/receive_arbiter.h#L1

File is not formatted according to `.clang-format`.

#include "async_event.h"
#include "communicator.h"
Expand All @@ -20,6 +20,8 @@
struct incoming_region_fragment {
detail::box<3> box;
async_event communication; ///< async communicator event for receiving this fragment

bool is_complete() const { return communication.is_complete(); }
};

/// State for a single incomplete `receive` operation or a `begin_split_receive` / `await_split_receive_subregion` tree.
Expand All @@ -28,9 +30,11 @@
box<3> allocated_box;
region<3> incomplete_region;
std::vector<incoming_region_fragment> incoming_fragments;
bool may_await_subregion;

region_request(region<3> requested_region, void* const allocation, const box<3>& allocated_bounding_box)
: allocation(allocation), allocated_box(allocated_bounding_box), incomplete_region(std::move(requested_region)) {}
region_request(region<3> requested_region, void* const allocation, const box<3>& allocated_bounding_box, const bool may_await_subregion)
: allocation(allocation), allocated_box(allocated_bounding_box), incomplete_region(std::move(requested_region)),
may_await_subregion(may_await_subregion) {}
bool do_complete();
};

Expand Down Expand Up @@ -143,9 +147,12 @@
/// the same transfer id that did not temporally overlap with the original ones.
std::unordered_map<transfer_id, receive_arbiter_detail::transfer> m_transfers;

/// Initiates a new `region_request` for which the caller can construct events to await either the entire region or sub-regions.
/// Cache for all transfer ids in m_transfers that are not unassigned_transfers. Bounds complexity of iterating to poll all transfer events.
std::vector<transfer_id> m_active_transfers;

/// Initiates a new `region_request` for which the caller can construct events to await either the entire region or sub-regions (may_await_subregion = true).
receive_arbiter_detail::stable_region_request& initiate_region_request(
const transfer_id& trid, const region<3>& request, void* allocation, const box<3>& allocated_box, size_t elem_size);
const transfer_id& trid, const region<3>& request, void* allocation, const box<3>& allocated_box, size_t elem_size, bool may_await_subregion);

/// Updates the state of an active `region_request` from receiving an inbound pilot.
void handle_region_request_pilot(receive_arbiter_detail::region_request& rr, const inbound_pilot& pilot, size_t elem_size);
Expand Down
6 changes: 0 additions & 6 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@

namespace celerity::detail::utils {

/// Like std::move, but move-constructs the result so it does not reference the argument after returning.
template <typename T>
T take(T& from) {
return std::move(from);
}

template <typename T, typename P>
bool isa(const P* p) {
return dynamic_cast<const T*>(p) != nullptr;
Expand Down
85 changes: 47 additions & 38 deletions src/receive_arbiter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,30 @@ class gather_receive_event final : public async_event_impl {
};

bool region_request::do_complete() {
const auto complete_fragment = [&](const incoming_region_fragment& fragment) {
// Fast path: Avoid polling the entire fragment set when we know that neither the request as a whole nor any subregion-request will complete at this time
if(!may_await_subregion && !incoming_fragments.empty() && !incoming_fragments.front().communication.is_complete()) return false;

std::erase_if(incoming_fragments, [&](const incoming_region_fragment& fragment) {
if(!fragment.communication.is_complete()) return false;
incomplete_region = region_difference(incomplete_region, fragment.box);
return true;
};
incoming_fragments.erase(std::remove_if(incoming_fragments.begin(), incoming_fragments.end(), complete_fragment), incoming_fragments.end());
});
assert(!incomplete_region.empty() || incoming_fragments.empty());
return incomplete_region.empty();
}

bool multi_region_transfer::do_complete() {
const auto complete_request = [](stable_region_request& rr) { return rr->do_complete(); };
active_requests.erase(std::remove_if(active_requests.begin(), active_requests.end(), complete_request), active_requests.end());
std::erase_if(active_requests, [](stable_region_request& rr) { return rr->do_complete(); });
return active_requests.empty() && unassigned_pilots.empty();
}

bool gather_request::do_complete() {
const auto complete_chunk = [&](const incoming_gather_chunk& chunk) {
std::erase_if(incoming_chunks, [&](const incoming_gather_chunk& chunk) {
if(!chunk.communication.is_complete()) return false;
assert(num_incomplete_chunks > 0);
num_incomplete_chunks -= 1;
return true;
};
incoming_chunks.erase(std::remove_if(incoming_chunks.begin(), incoming_chunks.end(), complete_chunk), incoming_chunks.end());
});
return num_incomplete_chunks == 0;
}

Expand All @@ -109,45 +109,49 @@ receive_arbiter::receive_arbiter(communicator& comm) : m_comm(&comm), m_num_node

receive_arbiter::~receive_arbiter() { assert(std::uncaught_exceptions() > 0 || m_transfers.empty()); }

receive_arbiter_detail::stable_region_request& receive_arbiter::initiate_region_request(
const transfer_id& trid, const region<3>& request, void* const allocation, const box<3>& allocated_box, const size_t elem_size) {
receive_arbiter_detail::stable_region_request& receive_arbiter::initiate_region_request(const transfer_id& trid, const region<3>& request,
void* const allocation, const box<3>& allocated_box, const size_t elem_size, const bool may_await_subregion) //
{
assert(allocated_box.covers(bounding_box(request)));

// Ensure there is a multi_region_transfer present - if there is none, create it by consuming unassigned pilots
multi_region_transfer* mrt = nullptr;
if(const auto entry = m_transfers.find(trid); entry != m_transfers.end()) {
matchbox::match(
entry->second, //
[&](unassigned_transfer& ut) { mrt = &entry->second.emplace<multi_region_transfer>(elem_size, utils::take(ut.pilots)); },
[&](multi_region_transfer& existing_mrt) { mrt = &existing_mrt; },
[&](gather_transfer& gt) { utils::panic("calling receive_arbiter::begin_receive on an active gather transfer"); });
} else {
mrt = &m_transfers[trid].emplace<multi_region_transfer>(elem_size);
}
auto& transfer = m_transfers[trid]; // allow default-insert as unassigned_transfer
auto& mrt = matchbox::match(
transfer,
[&](unassigned_transfer& ut) -> multi_region_transfer& {
auto pilots = std::move(ut.pilots);
m_active_transfers.push_back(trid);
return transfer.emplace<multi_region_transfer>(elem_size, std::move(pilots));
},
[&](multi_region_transfer& existing_mrt) -> multi_region_transfer& { //
return existing_mrt;
},
[&](gather_transfer& gt) -> multi_region_transfer& { //
utils::panic("calling receive_arbiter::begin_receive on an active gather transfer");
});

// Add a new region_request to the `mrt` (transfers have transfer_id granularity, but there might be multiple receives from independent range mappers
assert(std::all_of(mrt->active_requests.begin(), mrt->active_requests.end(),
assert(std::all_of(mrt.active_requests.begin(), mrt.active_requests.end(),
[&](const stable_region_request& rr) { return region_intersection(rr->incomplete_region, request).empty(); }));
auto& rr = mrt->active_requests.emplace_back(std::make_shared<region_request>(request, allocation, allocated_box));
auto& rr = mrt.active_requests.emplace_back(std::make_shared<region_request>(request, allocation, allocated_box, may_await_subregion));

// If the new region_request matches any of the still-unassigned pilots associated with `mrt`, immediately initiate the appropriate payload-receives
const auto assign_pilot = [&](const inbound_pilot& pilot) {
std::erase_if(mrt.unassigned_pilots, [&](const inbound_pilot& pilot) {
assert((region_intersection(rr->incomplete_region, pilot.message.box) != pilot.message.box)
== region_intersection(rr->incomplete_region, pilot.message.box).empty());
if(region_intersection(rr->incomplete_region, pilot.message.box) == pilot.message.box) {
handle_region_request_pilot(*rr, pilot, elem_size);
return true;
}
return false;
};
mrt->unassigned_pilots.erase(std::remove_if(mrt->unassigned_pilots.begin(), mrt->unassigned_pilots.end(), assign_pilot), mrt->unassigned_pilots.end());
});

return rr;
}

void receive_arbiter::begin_split_receive(
const transfer_id& trid, const region<3>& request, void* const allocation, const box<3>& allocated_box, const size_t elem_size) {
initiate_region_request(trid, request, allocation, allocated_box, elem_size);
initiate_region_request(trid, request, allocation, allocated_box, elem_size, true /* may_await_subregion */);
}

async_event receive_arbiter::await_split_receive_subregion(const transfer_id& trid, const region<3>& subregion) {
Expand All @@ -168,16 +172,20 @@ async_event receive_arbiter::await_split_receive_subregion(const transfer_id& tr
#endif

// If the transfer (by transfer_id) as a whole has not completed yet but the subregion is, this "await" also completes immediately.
const auto req_it = std::find_if(mrt.active_requests.begin(), mrt.active_requests.end(),
const auto rr_it = std::find_if(mrt.active_requests.begin(), mrt.active_requests.end(),
[&](const stable_region_request& rr) { return !region_intersection(rr->incomplete_region, subregion).empty(); });
if(req_it == mrt.active_requests.end()) { return make_complete_event(); }
if(rr_it == mrt.active_requests.end()) { return make_complete_event(); }

return make_async_event<subregion_receive_event>(*req_it, subregion);
auto& rr = *rr_it;
assert(rr->may_await_subregion && "attempting await_split_receive_subregion() on region that was not initiated with begin_split_receive()");
return make_async_event<subregion_receive_event>(rr, subregion);
}

async_event receive_arbiter::receive(
const transfer_id& trid, const region<3>& request, void* const allocation, const box<3>& allocated_box, const size_t elem_size) {
return make_async_event<region_receive_event>(initiate_region_request(trid, request, allocation, allocated_box, elem_size));
const transfer_id& trid, const region<3>& request, void* const allocation, const box<3>& allocated_box, const size_t elem_size) //
{
auto& rr = initiate_region_request(trid, request, allocation, allocated_box, elem_size, false /* may_await_subregion */);
return make_async_event<region_receive_event>(rr);
}

async_event receive_arbiter::gather_receive(const transfer_id& trid, void* const allocation, const size_t node_chunk_size) {
Expand All @@ -194,19 +202,20 @@ async_event receive_arbiter::gather_receive(const transfer_id& trid, void* const
// Otherwise, we insert the transfer as pending and wait for the first pilots to arrive.
m_transfers.emplace(trid, gather_transfer{gr});
}
m_active_transfers.push_back(trid);

return make_async_event<gather_receive_event>(gr);
}

void receive_arbiter::poll_communicator() {
// Try completing all pending payload sends / receives by polling their communicator events
for(auto entry = m_transfers.begin(); entry != m_transfers.end();) {
if(std::visit([](auto& transfer) { return transfer.do_complete(); }, entry->second)) {
entry = m_transfers.erase(entry);
} else {
++entry;
}
}
std::erase_if(m_active_transfers, [&](const transfer_id& trid) {
const auto entry = m_transfers.find(trid);
assert(entry != m_transfers.end());
const bool is_complete = std::visit([](auto& transfer) { return transfer.do_complete(); }, entry->second);
if(is_complete) { m_transfers.erase(entry); }
return is_complete;
});

for(const auto& pilot : m_comm->poll_inbound_pilots()) {
if(const auto entry = m_transfers.find(pilot.message.transfer_id); entry != m_transfers.end()) {
Expand Down
8 changes: 5 additions & 3 deletions test/receive_arbiter_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,11 @@ TEST_CASE("receive_arbiter aggregates receives from multiple incoming fragments"
REQUIRE(receive.has_value());
CHECK(receive->is_complete());

// it is legal to `await` a transfer that has already been completed and is not tracked by the receive_arbiter anymore
CHECK(ra.await_split_receive_subregion(trid, requested_regions[0]).is_complete());
CHECK(ra.await_split_receive_subregion(trid, incoming_fragments[0]).is_complete());
if(receive_method == "split_await") {
// it is legal to `await` a transfer that has already been completed and is not tracked by the receive_arbiter anymore
CHECK(ra.await_split_receive_subregion(trid, requested_regions[0]).is_complete());
CHECK(ra.await_split_receive_subregion(trid, incoming_fragments[0]).is_complete());
}

std::vector<int> expected_allocation(alloc_box.get_range().size());
for(size_t which = 0; which < incoming_fragments.size(); ++which) {
Expand Down
Loading