Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#868: Use Permanent IDs everywhere in LB #1017

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 11 additions & 48 deletions src/vt/vrt/collection/balance/node_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,6 @@ void NodeStats::setProxy(objgroup::proxy::Proxy<NodeStats> in_proxy) {
return ptr;
}

ElementIDType NodeStats::tempToPerm(ElementIDType temp_id) const {
auto iter = node_temp_to_perm_.find(temp_id);
if (iter == node_temp_to_perm_.end()) {
return no_element_id;
}
return iter->second;
}

ElementIDType NodeStats::permToTemp(ElementIDType perm_id) const {
auto iter = node_perm_to_temp_.find(perm_id);
if (iter == node_perm_to_temp_.end()) {
return no_element_id;
}
return iter->second;
}

bool NodeStats::hasObjectToMigrate(ElementIDType obj_id) const {
auto iter = node_migrate_.find(obj_id);
return iter != node_migrate_.end();
Expand Down Expand Up @@ -122,36 +106,20 @@ void NodeStats::clearStats() {
NodeStats::node_data_.clear();
NodeStats::node_subphase_data_.clear();
NodeStats::node_migrate_.clear();
NodeStats::node_temp_to_perm_.clear();
NodeStats::node_perm_to_temp_.clear();
next_elm_ = 1;
}

void NodeStats::startIterCleanup(PhaseType phase, int look_back) {
// TODO: Add in subphase support here too

// Convert the temp ID node_data_ for the last iteration into perm ID for
// stats output
auto const prev_data = std::move(node_data_[phase]);
std::unordered_map<ElementIDType,TimeType> new_data;
for (auto& elm : prev_data) {
auto iter = node_temp_to_perm_.find(elm.first);
vtAssert(iter != node_temp_to_perm_.end(), "Temp ID must exist");
auto perm_id = iter->second;
new_data[perm_id] = elm.second;
}
node_data_[phase] = std::move(new_data);

if (phase - look_back >= 0) {
node_data_.erase(phase - look_back);
node_subphase_data_.erase(phase - look_back);
node_comm_.erase(phase - look_back);
}

// Create migrate lambdas and temp to perm map since LB is complete
// Create migrate lambdas since LB is complete
NodeStats::node_migrate_.clear();
NodeStats::node_temp_to_perm_.clear();
NodeStats::node_perm_to_temp_.clear();
node_collection_lookup_.clear();
}

Expand Down Expand Up @@ -275,7 +243,7 @@ void NodeStats::outputStatsFile() {
closeStatsFile();
}

ElementIDType NodeStats::addNodeStats(
void NodeStats::addNodeStats(
Migratable* col_elm,
PhaseType const& phase, TimeType const& time,
std::vector<TimeType> const& subphase_time, CommMapType const& comm
Expand All @@ -292,20 +260,20 @@ ElementIDType NodeStats::addNodeStats(
);

auto &phase_data = node_data_[phase];
auto elm_iter = phase_data.find(temp_id);
auto elm_iter = phase_data.find(perm_id);
vtAssert(elm_iter == phase_data.end(), "Must not exist");
phase_data.emplace(
std::piecewise_construct,
std::forward_as_tuple(temp_id),
std::forward_as_tuple(perm_id),
std::forward_as_tuple(time)
);

auto subphase_data = node_subphase_data_[phase];
auto elm_subphase_iter = subphase_data.find(temp_id);
auto elm_subphase_iter = subphase_data.find(perm_id);
vtAssert(elm_subphase_iter == subphase_data.end(), "Must not exist");
subphase_data.emplace(
std::piecewise_construct,
std::forward_as_tuple(temp_id),
std::forward_as_tuple(perm_id),
std::forward_as_tuple(subphase_time)
);

Expand All @@ -314,30 +282,25 @@ ElementIDType NodeStats::addNodeStats(
comm_data[c.first] += c.second;
}

node_temp_to_perm_[temp_id] = perm_id;
node_perm_to_temp_[perm_id] = temp_id;

auto migrate_iter = node_migrate_.find(temp_id);
auto migrate_iter = node_migrate_.find(perm_id);
if (migrate_iter == node_migrate_.end()) {
node_migrate_.emplace(
std::piecewise_construct,
std::forward_as_tuple(temp_id),
std::forward_as_tuple(perm_id),
std::forward_as_tuple([col_elm](NodeType node){
col_elm->migrate(node);
})
);
}

auto const col_proxy = col_elm->getProxy();
node_collection_lookup_[temp_id] = col_proxy;

return temp_id;
node_collection_lookup_[perm_id] = col_proxy;
}

VirtualProxyType NodeStats::getCollectionProxyForElement(
ElementIDType temp_id
ElementIDType perm_id
) const {
auto iter = node_collection_lookup_.find(temp_id);
auto iter = node_collection_lookup_.find(perm_id);
if (iter == node_collection_lookup_.end()) {
return no_vrt_proxy;
}
Expand Down
48 changes: 14 additions & 34 deletions src/vt/vrt/collection/balance/node_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct NodeStats : runtime::component::Component<NodeStats> {
*
* \return the temporary ID for the object assigned for this phase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment outdated.

*/
ElementIDType addNodeStats(
void addNodeStats(
Migratable* col_elm,
PhaseType const& phase, TimeType const& time,
std::vector<TimeType> const& subphase_time, CommMapType const& comm
Expand Down Expand Up @@ -192,7 +192,7 @@ struct NodeStats : runtime::component::Component<NodeStats> {
/**
* \internal \brief Test if this node has an object to migrate
*
* \param[in] obj_id the object temporary ID
* \param[in] obj_id the object permanent ID
*
* \return whether this node has the object
*/
Expand All @@ -201,41 +201,22 @@ struct NodeStats : runtime::component::Component<NodeStats> {
/**
* \internal \brief Migrate an local object to another node
*
* \param[in] obj_id the object temporary ID
* \param[in] obj_id the object permanent ID
* \param[in] to_node the node to migrate to
*
* \return whether this node has the object
*/
bool migrateObjTo(ElementIDType obj_id, NodeType to_node);

/**
* \internal \brief Convert temporary element ID to permanent Returns
* \c no_element_id if not found.
* \param[in] temp_id temporary ID
*
* \return permanent ID
*/
ElementIDType tempToPerm(ElementIDType temp_id) const;

/**
* \internal \brief Convert permanent element ID to temporary. Returns
* \c no_element_id if not found.
*
* \param[in] perm_id permanent ID
*
* \return temporary ID
*/
ElementIDType permToTemp(ElementIDType perm_id) const;

/**
* \internal \brief Get the collection proxy for a given element ID
*
* \param[in] temp_id the temporary ID for the element for a given phase
* \param[in] perm_id the temporary ID for the element for a given phase
*
* \return the virtual proxy if the element is part of the collection;
* otherwise \c no_vrt_proxy
*/
VirtualProxyType getCollectionProxyForElement(ElementIDType temp_id) const;
VirtualProxyType getCollectionProxyForElement(ElementIDType perm_id) const;

private:
/**
Expand All @@ -249,22 +230,21 @@ struct NodeStats : runtime::component::Component<NodeStats> {
void closeStatsFile();

private:
/// Local proxy to objgroup
objgroup::proxy::Proxy<NodeStats> proxy_;

/// Node timings for each local object
std::unordered_map<PhaseType, LoadMapType> node_data_;
/// Node subphase timings for each local object
std::unordered_map<PhaseType, SubphaseLoadMapType> node_subphase_data_;
/// Local migration type-free lambdas for each object
std::unordered_map<ElementIDType,MigrateFnType> node_migrate_;
/// Map of temporary ID to permanent ID
std::unordered_map<ElementIDType,ElementIDType> node_temp_to_perm_;
/// Map of permanent ID to temporary ID
std::unordered_map<ElementIDType,ElementIDType> node_perm_to_temp_;
/// Map from element temporary ID to the collection's virtual proxy (untyped)
std::unordered_map<ElementIDType,VirtualProxyType> node_collection_lookup_;
/// Node communication graph for each local object
std::unordered_map<PhaseType, CommMapType> node_comm_;

/// Local migration type-free lambdas for each object (from perm ID)
std::unordered_map<ElementIDType,MigrateFnType> node_migrate_;
/// Map from element permanent ID to the collection's virtual proxy (untyped)
std::unordered_map<ElementIDType,VirtualProxyType> node_collection_lookup_;

/// Local proxy to objgroup
objgroup::proxy::Proxy<NodeStats> proxy_;
/// The current element ID
ElementIDType next_elm_;
/// The stats file name for outputting instrumentation
Expand Down
6 changes: 3 additions & 3 deletions src/vt/vrt/collection/balance/statsmaplb/statsmaplb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ void StatsMapLB::init(objgroup::proxy::Proxy<StatsMapLB> in_proxy) {
void StatsMapLB::runLB() {
auto const& myNewList = theStatsReader()->getMoveList(phase_);
for (size_t in = 0; in < myNewList.size(); in += 2) {
auto temp_id = theNodeStats()->permToTemp(myNewList[in]);
auto perm_id = myNewList[in];

vtAssert(temp_id != balance::no_element_id, "Must have valid ID here");
vtAssert(perm_id != balance::no_element_id, "Must have valid ID here");

migrateObjectTo(temp_id, myNewList[in+1]);
migrateObjectTo(perm_id, myNewList[in+1]);
}

theStatsReader()->clearMoveList(phase_);
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/collection/test_model_per_collection.extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ std::unordered_map<ElementIDType, VirtualProxyType> id_proxy_map;

template <typename ColT>
void colHandler(MyMsg<ColT>*, ColT* col) {
// do nothing, except setting up our map using the temp ID, which will hit
// do nothing, except setting up our map using the object ID, which will hit
// every node
id_proxy_map[col->getTempID()] = col->getProxy();
id_proxy_map[col->getElmID()] = col->getProxy();
}

TEST_F(TestModelPerCollection, test_model_per_collection_1) {
Expand Down