diff --git a/src/vt/vrt/collection/balance/node_stats.cc b/src/vt/vrt/collection/balance/node_stats.cc index 924c33b08e..50110664d3 100644 --- a/src/vt/vrt/collection/balance/node_stats.cc +++ b/src/vt/vrt/collection/balance/node_stats.cc @@ -70,22 +70,6 @@ void NodeStats::setProxy(objgroup::proxy::Proxy in_proxy) { return ptr; } -ElementIDType NodeStats::tempToPerm(ElementIDType temp_id) const { - auto iter = node_temp_to_perm_.find(temp_id); - if (iter == node_temp_to_perm_.end()) { - return no_element_id; - } - return iter->second; -} - -ElementIDType NodeStats::permToTemp(ElementIDType perm_id) const { - auto iter = node_perm_to_temp_.find(perm_id); - if (iter == node_perm_to_temp_.end()) { - return no_element_id; - } - return iter->second; -} - bool NodeStats::hasObjectToMigrate(ElementIDType obj_id) const { auto iter = node_migrate_.find(obj_id); return iter != node_migrate_.end(); @@ -122,36 +106,20 @@ void NodeStats::clearStats() { NodeStats::node_data_.clear(); NodeStats::node_subphase_data_.clear(); NodeStats::node_migrate_.clear(); - NodeStats::node_temp_to_perm_.clear(); - NodeStats::node_perm_to_temp_.clear(); next_elm_ = 1; } void NodeStats::startIterCleanup(PhaseType phase, int look_back) { // TODO: Add in subphase support here too - // Convert the temp ID node_data_ for the last iteration into perm ID for - // stats output - auto const prev_data = std::move(node_data_[phase]); - std::unordered_map new_data; - for (auto& elm : prev_data) { - auto iter = node_temp_to_perm_.find(elm.first); - vtAssert(iter != node_temp_to_perm_.end(), "Temp ID must exist"); - auto perm_id = iter->second; - new_data[perm_id] = elm.second; - } - node_data_[phase] = std::move(new_data); - if (phase - look_back >= 0) { node_data_.erase(phase - look_back); node_subphase_data_.erase(phase - look_back); node_comm_.erase(phase - look_back); } - // Create migrate lambdas and temp to perm map since LB is complete + // Create migrate lambdas since LB is complete NodeStats::node_migrate_.clear(); - NodeStats::node_temp_to_perm_.clear(); - NodeStats::node_perm_to_temp_.clear(); node_collection_lookup_.clear(); } @@ -275,7 +243,7 @@ void NodeStats::outputStatsFile() { closeStatsFile(); } -ElementIDType NodeStats::addNodeStats( +void NodeStats::addNodeStats( Migratable* col_elm, PhaseType const& phase, TimeType const& time, std::vector const& subphase_time, CommMapType const& comm @@ -292,20 +260,20 @@ ElementIDType NodeStats::addNodeStats( ); auto &phase_data = node_data_[phase]; - auto elm_iter = phase_data.find(temp_id); + auto elm_iter = phase_data.find(perm_id); vtAssert(elm_iter == phase_data.end(), "Must not exist"); phase_data.emplace( std::piecewise_construct, - std::forward_as_tuple(temp_id), + std::forward_as_tuple(perm_id), std::forward_as_tuple(time) ); auto subphase_data = node_subphase_data_[phase]; - auto elm_subphase_iter = subphase_data.find(temp_id); + auto elm_subphase_iter = subphase_data.find(perm_id); vtAssert(elm_subphase_iter == subphase_data.end(), "Must not exist"); subphase_data.emplace( std::piecewise_construct, - std::forward_as_tuple(temp_id), + std::forward_as_tuple(perm_id), std::forward_as_tuple(subphase_time) ); @@ -314,14 +282,11 @@ ElementIDType NodeStats::addNodeStats( comm_data[c.first] += c.second; } - node_temp_to_perm_[temp_id] = perm_id; - node_perm_to_temp_[perm_id] = temp_id; - - auto migrate_iter = node_migrate_.find(temp_id); + auto migrate_iter = node_migrate_.find(perm_id); if (migrate_iter == node_migrate_.end()) { node_migrate_.emplace( std::piecewise_construct, - std::forward_as_tuple(temp_id), + std::forward_as_tuple(perm_id), std::forward_as_tuple([col_elm](NodeType node){ col_elm->migrate(node); }) @@ -329,15 +294,13 @@ ElementIDType NodeStats::addNodeStats( } auto const col_proxy = col_elm->getProxy(); - node_collection_lookup_[temp_id] = col_proxy; - - return temp_id; + node_collection_lookup_[perm_id] = col_proxy; } VirtualProxyType NodeStats::getCollectionProxyForElement( - ElementIDType temp_id + ElementIDType perm_id ) const { - auto iter = node_collection_lookup_.find(temp_id); + auto iter = node_collection_lookup_.find(perm_id); if (iter == node_collection_lookup_.end()) { return no_vrt_proxy; } diff --git a/src/vt/vrt/collection/balance/node_stats.h b/src/vt/vrt/collection/balance/node_stats.h index a4a697871d..cdb64c189a 100644 --- a/src/vt/vrt/collection/balance/node_stats.h +++ b/src/vt/vrt/collection/balance/node_stats.h @@ -111,7 +111,7 @@ struct NodeStats : runtime::component::Component { * * \return the temporary ID for the object assigned for this phase */ - ElementIDType addNodeStats( + void addNodeStats( Migratable* col_elm, PhaseType const& phase, TimeType const& time, std::vector const& subphase_time, CommMapType const& comm @@ -192,7 +192,7 @@ struct NodeStats : runtime::component::Component { /** * \internal \brief Test if this node has an object to migrate * - * \param[in] obj_id the object temporary ID + * \param[in] obj_id the object permanent ID * * \return whether this node has the object */ @@ -201,41 +201,22 @@ struct NodeStats : runtime::component::Component { /** * \internal \brief Migrate an local object to another node * - * \param[in] obj_id the object temporary ID + * \param[in] obj_id the object permanent ID * \param[in] to_node the node to migrate to * * \return whether this node has the object */ bool migrateObjTo(ElementIDType obj_id, NodeType to_node); - /** - * \internal \brief Convert temporary element ID to permanent Returns - * \c no_element_id if not found. - * \param[in] temp_id temporary ID - * - * \return permanent ID - */ - ElementIDType tempToPerm(ElementIDType temp_id) const; - - /** - * \internal \brief Convert permanent element ID to temporary. Returns - * \c no_element_id if not found. - * - * \param[in] perm_id permanent ID - * - * \return temporary ID - */ - ElementIDType permToTemp(ElementIDType perm_id) const; - /** * \internal \brief Get the collection proxy for a given element ID * - * \param[in] temp_id the temporary ID for the element for a given phase + * \param[in] perm_id the temporary ID for the element for a given phase * * \return the virtual proxy if the element is part of the collection; * otherwise \c no_vrt_proxy */ - VirtualProxyType getCollectionProxyForElement(ElementIDType temp_id) const; + VirtualProxyType getCollectionProxyForElement(ElementIDType perm_id) const; private: /** @@ -249,22 +230,21 @@ struct NodeStats : runtime::component::Component { void closeStatsFile(); private: - /// Local proxy to objgroup - objgroup::proxy::Proxy proxy_; + /// Node timings for each local object std::unordered_map node_data_; /// Node subphase timings for each local object std::unordered_map node_subphase_data_; - /// Local migration type-free lambdas for each object - std::unordered_map node_migrate_; - /// Map of temporary ID to permanent ID - std::unordered_map node_temp_to_perm_; - /// Map of permanent ID to temporary ID - std::unordered_map node_perm_to_temp_; - /// Map from element temporary ID to the collection's virtual proxy (untyped) - std::unordered_map node_collection_lookup_; /// Node communication graph for each local object std::unordered_map node_comm_; + + /// Local migration type-free lambdas for each object (from perm ID) + std::unordered_map node_migrate_; + /// Map from element permanent ID to the collection's virtual proxy (untyped) + std::unordered_map node_collection_lookup_; + + /// Local proxy to objgroup + objgroup::proxy::Proxy proxy_; /// The current element ID ElementIDType next_elm_; /// The stats file name for outputting instrumentation diff --git a/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc b/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc index 4848c32d45..fb589d6cb8 100644 --- a/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc +++ b/src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc @@ -58,11 +58,11 @@ void StatsMapLB::init(objgroup::proxy::Proxy in_proxy) { void StatsMapLB::runLB() { auto const& myNewList = theStatsReader()->getMoveList(phase_); for (size_t in = 0; in < myNewList.size(); in += 2) { - auto temp_id = theNodeStats()->permToTemp(myNewList[in]); + auto perm_id = myNewList[in]; - vtAssert(temp_id != balance::no_element_id, "Must have valid ID here"); + vtAssert(perm_id != balance::no_element_id, "Must have valid ID here"); - migrateObjectTo(temp_id, myNewList[in+1]); + migrateObjectTo(perm_id, myNewList[in+1]); } theStatsReader()->clearMoveList(phase_); diff --git a/tests/unit/collection/test_model_per_collection.extended.cc b/tests/unit/collection/test_model_per_collection.extended.cc index 83d0755919..174a2ef262 100644 --- a/tests/unit/collection/test_model_per_collection.extended.cc +++ b/tests/unit/collection/test_model_per_collection.extended.cc @@ -90,9 +90,9 @@ std::unordered_map id_proxy_map; template void colHandler(MyMsg*, ColT* col) { - // do nothing, except setting up our map using the temp ID, which will hit + // do nothing, except setting up our map using the object ID, which will hit // every node - id_proxy_map[col->getTempID()] = col->getProxy(); + id_proxy_map[col->getElmID()] = col->getProxy(); } TEST_F(TestModelPerCollection, test_model_per_collection_1) {