diff --git a/src/vt/vrt/collection/balance/lb_data_restart_reader.cc b/src/vt/vrt/collection/balance/lb_data_restart_reader.cc index 776fb8fa3c..a2326f8560 100644 --- a/src/vt/vrt/collection/balance/lb_data_restart_reader.cc +++ b/src/vt/vrt/collection/balance/lb_data_restart_reader.cc @@ -121,13 +121,61 @@ void LBDataRestartReader::readLBData(std::string const& file) { determinePhasesToMigrate(); } +void LBDataRestartReader::departing(DepartMsg* msg) { + coordinate_[msg->phase][msg->elm].arrive = promoteMsg(msg); + checkBothEnds(coordinate_[msg->phase][msg->elm]); +} + +void LBDataRestartReader::arriving(ArriveMsg* msg) { + coordinate_[msg->phase][msg->elm].depart = promoteMsg(msg); + checkBothEnds(coordinate_[msg->phase][msg->elm]); +} + +void LBDataRestartReader::update(UpdateMsg* msg) { + auto iter = history[msg->phase].find(msg->elm); + vtAssert(iter != history[msg->phase].end(), "Must exist"); + iter->second.curr_node = msg->curr_node; +} + +void LBDataRestartReader::checkBothEnds(Coord& coord) { + if (coord.arrive != nullptr and coord.depart != nullptr) { + proxy[coord.arrive->arrive_node].send< + UpdateMsg, &LBDataRestartReader::update + >(coord.depart->depart_node, coord.arrive->phase, coord.arrive->elm); + } +} + void LBDataRestartReader::determinePhasesToMigrate() { std::vector local_changed_distro; local_changed_distro.resize(num_phases_ - 1); - for (PhaseType i = 0; i < num_phases_ - 1; ++i) { - local_changed_distro[i] = history_[i] != history_[i+1]; - } + runInEpochCollective("LBDataRestartReader::updateLocations", [&]{ + for (PhaseType i = 0; i < num_phases_ - 1; ++i) { + local_changed_distro[i] = history_[i] != history_[i+1]; + if (local_changed_distro[i]) { + std::set departing, arriving; + + std::set_difference( + history_[i+1].begin(), history_[i+1].end(), + history_[i].begin(), history_[i].end(), + std::inserter(arriving, arriving.begin()) + ); + + std::set_difference( + history_[i].begin(), history_[i].end(), + history_[i+1].begin(), history_[i+1].end(), + std::inserter(departing, departing.begin()) + ); + + for (auto&& d : departing) { + proxy[d.getHomeNode()].send(this_node, i+1, d); + } + for (auto&& a : arriving) { + proxy[d.getHomeNode()].send(this_node, i+1, a); + } + } + } + }); runInEpochCollective("LBDataRestartReader::computeDistributionChanges", [&]{ auto cb = theCB()->makeBcast< diff --git a/src/vt/vrt/collection/balance/lb_data_restart_reader.h b/src/vt/vrt/collection/balance/lb_data_restart_reader.h index d78305fed6..e694db6b06 100644 --- a/src/vt/vrt/collection/balance/lb_data_restart_reader.h +++ b/src/vt/vrt/collection/balance/lb_data_restart_reader.h @@ -172,6 +172,55 @@ struct LBDataRestartReader : runtime::component::Component /// History of mapping that was read in from the data files std::unordered_map> history_; + struct DepartMsg : vt::Message { + DepartMsg(NodeType in_depart_node, PhaseType in_phase, ElementIDStruct in_elm) + : depart_node(in_depart_node), + phase(in_phase), + elm(in_elm) + { } + + NodeType depart_node = uninitialized_destination; + PhaseType phase = no_phase; + ElementIDStruct elm; + }; + + struct ArriveMsg : vt::Message { + ArriveMsg(NodeType in_arrive_node, PhaseType in_phase, ElementIDStruct in_elm) + : arrive_node(in_arrive_node), + phase(in_phase), + elm(in_elm) + { } + + NodeType arrive_node = uninitialized_destination; + PhaseType phase = no_phase; + ElementIDStruct elm; + }; + + struct UpdateMsg : vt::Message { + UpdateMsg(NodeType in_curr_node, PhaseType in_phase, ElementIDStruct in_elm) + : curr_node(in_curr_node), + phase(in_phase), + elm(in_elm) + { } + + NodeType curr_node = uninitialized_destination; + PhaseType phase = no_phase; + ElementIDStruct elm; + }; + + struct Coord { + MsgSharedPtr arrive = nullptr; + MsgSharedPtr depart = nullptr; + }; + + void departing(DepartMsg* msg); + void arrive(ArriveMsg* msg); + void checkBothEnds(Coord& coord); + + std::unordered_map< + PhaseType, std::unordered_map + > coordinate_; + /// Number of phases read in std::size_t num_phases_ = 0; };