diff --git a/src/vt/vrt/collection/balance/lb_data_restart_reader.cc b/src/vt/vrt/collection/balance/lb_data_restart_reader.cc index b241e380b4..386683f852 100644 --- a/src/vt/vrt/collection/balance/lb_data_restart_reader.cc +++ b/src/vt/vrt/collection/balance/lb_data_restart_reader.cc @@ -87,20 +87,17 @@ void LBDataRestartReader::readHistory(LBDataHolder const& lbdh) { last_found_phase = phase; for (auto const& obj : iter->second) { if (obj.first.isMigratable()) { - history_[phase].insert(obj.first); + if (history_[phase] == nullptr) { + history_[phase] = std::make_shared>(); + } + history_[phase]->insert(obj.first); } } } else if(lbdh.identical_phases_.find(phase) != lbdh.identical_phases_.end()) { - // Phase is identical to previous one, fill with data from previous phase - auto last_iter = lbdh.node_data_.find(last_found_phase); - for (auto const& obj : last_iter->second) { - if (obj.first.isMigratable()) { - history_[phase].insert(obj.first); - } - } + // Phase is identical to previous one, use the shared pointer to data from previous phase + addIdenticalPhase(phase, last_found_phase); } else if(lbdh.skipped_phases_.find(phase) == lbdh.skipped_phases_.end()) { - // Phases which are not present must be specified in metadata of the file - vtAbort("Could not find data: Unspecified phases needs to be present in skipped section of the file metadata"); + vtAbort("Could not find data: Skipped phases needs to be listed in file metadata."); } } } @@ -143,12 +140,12 @@ void LBDataRestartReader::arriving(ArriveMsg* msg) { } void LBDataRestartReader::update(UpdateMsg* msg) { - auto iter = history_[msg->phase].find(msg->elm); - vtAssert(iter != history_[msg->phase].end(), "Must exist"); + auto iter = history_[msg->phase]->find(msg->elm); + vtAssert(iter != history_[msg->phase]->end(), "Must exist"); auto elm = *iter; elm.curr_node = msg->curr_node; - history_[msg->phase].erase(iter); - history_[msg->phase].insert(elm); + history_[msg->phase]->erase(iter); + history_[msg->phase]->insert(elm); } void LBDataRestartReader::checkBothEnds(Coord& coord) { @@ -167,26 +164,21 @@ void LBDataRestartReader::determinePhasesToMigrate() { runInEpochCollective("LBDataRestartReader::updateLocations", [&]{ PhaseType curr = 0, next; for (;curr < num_phases_ - 1;) { - // find number of next Phase - for(next = curr + 1; next < num_phases_; ++next) { - if(history_.find(next) != history_.end()) { - break; - } - } + next = findNextPhase(curr); - local_changed_distro[curr] = history_[curr] != history_[next]; + local_changed_distro[curr] = *history_[curr] != *history_[next]; if (local_changed_distro[curr]) { std::set departing, arriving; std::set_difference( - history_[next].begin(), history_[next].end(), - history_[curr].begin(), history_[curr].end(), + history_[next]->begin(), history_[next]->end(), + history_[curr]->begin(), history_[curr]->end(), std::inserter(arriving, arriving.begin()) ); std::set_difference( - history_[curr].begin(), history_[curr].end(), - history_[next].begin(), history_[next].end(), + history_[curr]->begin(), history_[curr]->end(), + history_[next]->begin(), history_[next]->end(), std::inserter(departing, departing.begin()) ); diff --git a/src/vt/vrt/collection/balance/lb_data_restart_reader.h b/src/vt/vrt/collection/balance/lb_data_restart_reader.h index fd6640b6d3..02ed49e3f0 100644 --- a/src/vt/vrt/collection/balance/lb_data_restart_reader.h +++ b/src/vt/vrt/collection/balance/lb_data_restart_reader.h @@ -131,14 +131,9 @@ struct LBDataRestartReader : runtime::component::Component * \return the next phase */ PhaseType findNextPhase(PhaseType phase) const { - auto next = phase + 1; - for(; next < num_phases_; ++next) { - if(history_.find(next) != history_.end()) { - return next; - } - } - vtAssert(history_.find(next) != history_.end(), "Must have a valid phase"); - return next; + auto iter = history_.upper_bound(phase); + vtAssert(iter != history_.end(), "Must have a valid phase"); + return iter->first; } /** @@ -146,11 +141,11 @@ struct LBDataRestartReader : runtime::component::Component * * \param[in] phase the phase * - * \return element assigned to this node + * \return pointer to elements assigned to this node, guaranted to be not null */ - std::set const& getDistro(PhaseType phase) { + std::shared_ptr> getDistro(PhaseType phase) const { auto iter = history_.find(phase); - vtAssert(iter != history_.end(), "Must have a valid phase"); + vtAssert(iter != history_.end() && iter->second != nullptr, "Must have a valid phase"); return iter->second; } @@ -166,6 +161,30 @@ struct LBDataRestartReader : runtime::component::Component } } + /** + * \brief Add history for a given phase + * + * \param[in] phase the phase to be added + * \param[in] distro the distribution to be added + */ + void addDistro(PhaseType phase, const std::set& distro) { + if (history_[phase] == nullptr) { + history_[phase] = std::make_shared>(); + } + history_[phase]->insert(distro.begin(), distro.end()); + } + + /** + * \brief Add identical phase to one already present + * + * \param[in] phase the phase to be added + * \param[in] identical the identical phase to be used + */ + void addIdenticalPhase(PhaseType phase, PhaseType identical) { + vtAssert(history_.find(identical) != history_.end(), "Identical phase was not added to history map."); + history_[phase] = history_[identical]; + } + private: /** * \brief Reduce distribution changes globally to find where migrations need @@ -188,7 +207,7 @@ struct LBDataRestartReader : runtime::component::Component std::vector changed_distro_; /// History of mapping that was read in from the data files - std::unordered_map> history_; + std::map>> history_; struct DepartMsg : vt::Message { DepartMsg(NodeType in_depart_node, PhaseType in_phase, ElementIDStruct in_elm) diff --git a/src/vt/vrt/collection/balance/offlinelb/offlinelb.cc b/src/vt/vrt/collection/balance/offlinelb/offlinelb.cc index fad602f208..cbdf650779 100644 --- a/src/vt/vrt/collection/balance/offlinelb/offlinelb.cc +++ b/src/vt/vrt/collection/balance/offlinelb/offlinelb.cc @@ -56,8 +56,8 @@ void OfflineLB::init(objgroup::proxy::Proxy in_proxy) { void OfflineLB::runLB(TimeType) { auto nextPhase = theLBDataReader()->findNextPhase(phase_); - auto const& distro = theLBDataReader()->getDistro(nextPhase); - for (auto&& elm : distro) { + auto const distro = theLBDataReader()->getDistro(nextPhase); + for (auto&& elm : *distro) { migrateObjectTo(elm, theContext()->getNode()); } theLBDataReader()->clearDistro(nextPhase);