Skip to content

Commit

Permalink
Merge pull request #2088 from DARMA-tasking/2087-fix-curr_node-on-res…
Browse files Browse the repository at this point in the history
…tart-reader

2087 fix curr node on restart reader
  • Loading branch information
lifflander authored Feb 10, 2023
2 parents 5444a54 + 57b0117 commit 54ad63d
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 5 deletions.
61 changes: 58 additions & 3 deletions src/vt/vrt/collection/balance/lb_data_restart_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,68 @@ void LBDataRestartReader::readLBData(std::string const& file) {
determinePhasesToMigrate();
}

void LBDataRestartReader::departing(DepartMsg* msg) {
auto m = promoteMsg(msg);
coordinate_[msg->phase][msg->elm].depart = m;
checkBothEnds(coordinate_[msg->phase][msg->elm]);
}

void LBDataRestartReader::arriving(ArriveMsg* msg) {
auto m = promoteMsg(msg);
coordinate_[msg->phase][msg->elm].arrive = m;
checkBothEnds(coordinate_[msg->phase][msg->elm]);
}

void LBDataRestartReader::update(UpdateMsg* msg) {
auto iter = history_[msg->phase].find(msg->elm);
vtAssert(iter != history_[msg->phase].end(), "Must exist");
auto elm = *iter;
elm.curr_node = msg->curr_node;
history_[msg->phase].erase(iter);
history_[msg->phase].insert(elm);
}

void LBDataRestartReader::checkBothEnds(Coord& coord) {
if (coord.arrive != nullptr and coord.depart != nullptr) {
proxy_[coord.arrive->arrive_node].send<
UpdateMsg, &LBDataRestartReader::update
>(coord.depart->depart_node, coord.arrive->phase, coord.arrive->elm);
}
}

void LBDataRestartReader::determinePhasesToMigrate() {
std::vector<bool> local_changed_distro;
local_changed_distro.resize(num_phases_ - 1);

for (PhaseType i = 0; i < num_phases_ - 1; ++i) {
local_changed_distro[i] = history_[i] != history_[i+1];
}
auto const this_node = theContext()->getNode();

runInEpochCollective("LBDataRestartReader::updateLocations", [&]{
for (PhaseType i = 0; i < num_phases_ - 1; ++i) {
local_changed_distro[i] = history_[i] != history_[i+1];
if (local_changed_distro[i]) {
std::set<ElementIDStruct> departing, arriving;

std::set_difference(
history_[i+1].begin(), history_[i+1].end(),
history_[i].begin(), history_[i].end(),
std::inserter(arriving, arriving.begin())
);

std::set_difference(
history_[i].begin(), history_[i].end(),
history_[i+1].begin(), history_[i+1].end(),
std::inserter(departing, departing.begin())
);

for (auto&& d : departing) {
proxy_[d.getHomeNode()].send<DepartMsg, &LBDataRestartReader::departing>(this_node, i+1, d);
}
for (auto&& a : arriving) {
proxy_[a.getHomeNode()].send<ArriveMsg, &LBDataRestartReader::arriving>(this_node, i+1, a);
}
}
}
});

runInEpochCollective("LBDataRestartReader::computeDistributionChanges", [&]{
auto cb = theCB()->makeBcast<
Expand Down
50 changes: 50 additions & 0 deletions src/vt/vrt/collection/balance/lb_data_restart_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,56 @@ struct LBDataRestartReader : runtime::component::Component<LBDataRestartReader>
/// History of mapping that was read in from the data files
std::unordered_map<PhaseType, std::set<ElementIDStruct>> history_;

struct DepartMsg : vt::Message {
DepartMsg(NodeType in_depart_node, PhaseType in_phase, ElementIDStruct in_elm)
: depart_node(in_depart_node),
phase(in_phase),
elm(in_elm)
{ }

NodeType depart_node = uninitialized_destination;
PhaseType phase = no_lb_phase;
ElementIDStruct elm;
};

struct ArriveMsg : vt::Message {
ArriveMsg(NodeType in_arrive_node, PhaseType in_phase, ElementIDStruct in_elm)
: arrive_node(in_arrive_node),
phase(in_phase),
elm(in_elm)
{ }

NodeType arrive_node = uninitialized_destination;
PhaseType phase = no_lb_phase;
ElementIDStruct elm;
};

struct UpdateMsg : vt::Message {
UpdateMsg(NodeType in_curr_node, PhaseType in_phase, ElementIDStruct in_elm)
: curr_node(in_curr_node),
phase(in_phase),
elm(in_elm)
{ }

NodeType curr_node = uninitialized_destination;
PhaseType phase = no_lb_phase;
ElementIDStruct elm;
};

struct Coord {
MsgSharedPtr<ArriveMsg> arrive = nullptr;
MsgSharedPtr<DepartMsg> depart = nullptr;
};

void departing(DepartMsg* msg);
void arriving(ArriveMsg* msg);
void update(UpdateMsg* msg);
void checkBothEnds(Coord& coord);

std::unordered_map<
PhaseType, std::unordered_map<ElementIDStruct, Coord>
> coordinate_;

/// Number of phases read in
std::size_t num_phases_ = 0;
};
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/lb/test_offlinelb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ TEST_F(TestOfflineLB, test_offlinelb_1) {
}

for (int i = 0; i < len; i++) {
auto pid = elm::ElmIDBits::createCollectionImpl(true, i+1, prev_node, prev_node);
auto nid = elm::ElmIDBits::createCollectionImpl(true, i+1, next_node, next_node);
auto pid = elm::ElmIDBits::createCollectionImpl(true, i+1, prev_node, this_node);
auto nid = elm::ElmIDBits::createCollectionImpl(true, i+1, next_node, this_node);
ids[1].push_back(pid);
ids[2].push_back(pid);
ids[4].push_back(nid);
Expand Down

0 comments on commit 54ad63d

Please sign in to comment.