Skip to content

Commit

Permalink
#78: fix max load and max volume bugs for multiple phases
Browse files Browse the repository at this point in the history
  • Loading branch information
cwschilly committed Jun 18, 2024
1 parent 827784c commit f999a09
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 17 deletions.
58 changes: 41 additions & 17 deletions src/vt-tv/api/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ struct Info {
ranks_.try_emplace(r.getRankID(), std::move(r));
}

void setSelectedPhase(PhaseType selected_phase) {
selected_phase_ = selected_phase;
}

/**
* \brief Get all object info
*
Expand Down Expand Up @@ -395,14 +399,23 @@ struct Info {
the maximum volume by iterated through object ids.
*/
auto n_phases = this->getNumPhases();
for (PhaseType phase = 0; phase < n_phases; phase++) {
auto const& objects = this->getPhaseObjects(phase);
for (auto const& [obj_id, obj_work] : objects) {
auto obj_max_v = obj_work.getMaxVolume();
if (obj_max_v > ov_max) ov_max = obj_max_v;

if (selected_phase_ != std::numeric_limits<PhaseType>::max()) {
auto const& objects = this->getPhaseObjects(selected_phase_);
for (auto const& [obj_id, obj_work] : objects) {
auto obj_max_v = obj_work.getMaxVolume();
if (obj_max_v > ov_max) ov_max = obj_max_v;
}
} else {
for (PhaseType phase = 0; phase < n_phases; phase++) {{
auto const& objects = this->getPhaseObjects(phase);
for (auto const& [obj_id, obj_work] : objects) {
auto obj_max_v = obj_work.getMaxVolume();
if (obj_max_v > ov_max) ov_max = obj_max_v;
}
}
}
}

return ov_max;
}

Expand All @@ -415,14 +428,22 @@ struct Info {
double ol_max = 0.;

auto n_phases = this->getNumPhases();
for (PhaseType phase = 0; phase < n_phases; phase++) {
auto const& objects = this->getPhaseObjects(phase);

if (selected_phase_ != std::numeric_limits<PhaseType>::max()) {
auto const& objects = this->getPhaseObjects(selected_phase_);
for (auto const& [obj_id, obj_work] : objects) {
auto obj_load = obj_work.getLoad();
if (obj_load > ol_max) ol_max = obj_load;
}
} else {
for (PhaseType phase = 0; phase < n_phases; phase++) {
auto const& objects = this->getPhaseObjects(phase);
for (auto const& [obj_id, obj_work] : objects) {
auto obj_load = obj_work.getLoad();
if (obj_load > ol_max) ol_max = obj_load;
}
}
}

return ol_max;
}

Expand Down Expand Up @@ -569,25 +590,25 @@ struct Info {
}

// loop through ranks and add communications
fmt::print("Updating communications for phase {}.\n", phase);
// fmt::print("Updating communications for phase {}.\n", phase);
for (auto &[rank_id, rank]: ranks_) {
fmt::print(" Checking objects in rank {}.\n", rank_id);
// fmt::print(" Checking objects in rank {}.\n", rank_id);
auto &phaseWork = rank.getPhaseWork();
auto &phaseWorkAtPhase = phaseWork.at(phase);
auto &objects = phaseWorkAtPhase.getObjectWork();
for (auto &[obj_id, obj_work]: objects) {
fmt::print(" Checking if object {} needs to be updated.\n", obj_id);
fmt::print(" Communications to update:\n");
// fmt::print(" Checking if object {} needs to be updated.\n", obj_id);
// fmt::print(" Communications to update:\n");
uint64_t i = 0;
for (auto &[object_to_update, sender_id, recipient_id, bytes]: communications_to_add) {
fmt::print(" {} needs to be updated in {} -> {} communication of {} bytes.\n", object_to_update,
sender_id, recipient_id, bytes);
// fmt::print(" {} needs to be updated in {} -> {} communication of {} bytes.\n", object_to_update,
// sender_id, recipient_id, bytes);
if (object_to_update == "sender" && sender_id == obj_id) {
fmt::print(" Sender to be updated is object on this rank. Updating.\n");
// fmt::print(" Sender to be updated is object on this rank. Updating.\n");
rank.addObjectSentCommunicationAtPhase(phase, obj_id, recipient_id, bytes);
communications_to_add.erase(communications_to_add.begin() + i);
} else if (object_to_update == "recipient" && recipient_id == obj_id) {
fmt::print(" Recipient to be updated is object on this rank. Updating.\n");
// fmt::print(" Recipient to be updated is object on this rank. Updating.\n");
rank.addObjectReceivedCommunicationAtPhase(phase, obj_id, sender_id, bytes);
communications_to_add.erase(communications_to_add.begin() + i);
}
Expand Down Expand Up @@ -875,6 +896,9 @@ struct Info {

/// Work for each rank across phases
std::unordered_map<NodeType, Rank> ranks_;

/// The current phase (or indication to use all phases)
PhaseType selected_phase_;
};

} /* end namespace vt::tv */
Expand Down
6 changes: 6 additions & 0 deletions src/vt-tv/render/render.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ Render::Render(Info in_info)
}
max_o_per_dim_ = 0;

// Set the info selected_phase
this->info_.setSelectedPhase(selected_phase_);

// Normalize communication edges
for(PhaseType phase = 0; phase < this->n_phases_; phase++) {
if ( selected_phase_ == std::numeric_limits<PhaseType>::max() or
Expand Down Expand Up @@ -144,6 +147,9 @@ Render::Render(
}
max_o_per_dim_ = 0;

// Set the info selected_phase
this->info_.setSelectedPhase(selected_phase_);

// Normalize communication edges
for(PhaseType phase = 0; phase < this->n_phases_; phase++) {
if (
Expand Down

0 comments on commit f999a09

Please sign in to comment.