Skip to content

Commit

Permalink
#2087: LB: start writing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lifflander authored and cz4rs committed Mar 27, 2023
1 parent 7c6878b commit fa93302
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 3 deletions.
54 changes: 51 additions & 3 deletions src/vt/vrt/collection/balance/lb_data_restart_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,61 @@ void LBDataRestartReader::readLBData(std::string const& file) {
determinePhasesToMigrate();
}

void LBDataRestartReader::departing(DepartMsg* msg) {
coordinate_[msg->phase][msg->elm].arrive = promoteMsg(msg);
checkBothEnds(coordinate_[msg->phase][msg->elm]);
}

void LBDataRestartReader::arriving(ArriveMsg* msg) {
coordinate_[msg->phase][msg->elm].depart = promoteMsg(msg);
checkBothEnds(coordinate_[msg->phase][msg->elm]);
}

void LBDataRestartReader::update(UpdateMsg* msg) {
auto iter = history[msg->phase].find(msg->elm);
vtAssert(iter != history[msg->phase].end(), "Must exist");
iter->second.curr_node = msg->curr_node;
}

void LBDataRestartReader::checkBothEnds(Coord& coord) {
if (coord.arrive != nullptr and coord.depart != nullptr) {
proxy[coord.arrive->arrive_node].send<
UpdateMsg, &LBDataRestartReader::update
>(coord.depart->depart_node, coord.arrive->phase, coord.arrive->elm);
}
}

void LBDataRestartReader::determinePhasesToMigrate() {
std::vector<bool> local_changed_distro;
local_changed_distro.resize(num_phases_ - 1);

for (PhaseType i = 0; i < num_phases_ - 1; ++i) {
local_changed_distro[i] = history_[i] != history_[i+1];
}
runInEpochCollective("LBDataRestartReader::updateLocations", [&]{
for (PhaseType i = 0; i < num_phases_ - 1; ++i) {
local_changed_distro[i] = history_[i] != history_[i+1];
if (local_changed_distro[i]) {
std::set<ElementIDStruct> departing, arriving;

std::set_difference(
history_[i+1].begin(), history_[i+1].end(),
history_[i].begin(), history_[i].end(),
std::inserter(arriving, arriving.begin())
);

std::set_difference(
history_[i].begin(), history_[i].end(),
history_[i+1].begin(), history_[i+1].end(),
std::inserter(departing, departing.begin())
);

for (auto&& d : departing) {
proxy[d.getHomeNode()].send<DepartMsg, &LBDataRestartReader::departing>(this_node, i+1, d);
}
for (auto&& a : arriving) {
proxy[d.getHomeNode()].send<ArriveMsg, &LBDataRestartReader::arriving>(this_node, i+1, a);
}
}
}
});

runInEpochCollective("LBDataRestartReader::computeDistributionChanges", [&]{
auto cb = theCB()->makeBcast<
Expand Down
49 changes: 49 additions & 0 deletions src/vt/vrt/collection/balance/lb_data_restart_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,55 @@ struct LBDataRestartReader : runtime::component::Component<LBDataRestartReader>
/// History of mapping that was read in from the data files
std::unordered_map<PhaseType, std::set<ElementIDStruct>> history_;

struct DepartMsg : vt::Message {
DepartMsg(NodeType in_depart_node, PhaseType in_phase, ElementIDStruct in_elm)
: depart_node(in_depart_node),
phase(in_phase),
elm(in_elm)
{ }

NodeType depart_node = uninitialized_destination;
PhaseType phase = no_phase;
ElementIDStruct elm;
};

struct ArriveMsg : vt::Message {
ArriveMsg(NodeType in_arrive_node, PhaseType in_phase, ElementIDStruct in_elm)
: arrive_node(in_arrive_node),
phase(in_phase),
elm(in_elm)
{ }

NodeType arrive_node = uninitialized_destination;
PhaseType phase = no_phase;
ElementIDStruct elm;
};

struct UpdateMsg : vt::Message {
UpdateMsg(NodeType in_curr_node, PhaseType in_phase, ElementIDStruct in_elm)
: curr_node(in_curr_node),
phase(in_phase),
elm(in_elm)
{ }

NodeType curr_node = uninitialized_destination;
PhaseType phase = no_phase;
ElementIDStruct elm;
};

struct Coord {
MsgSharedPtr<ArriveMsg> arrive = nullptr;
MsgSharedPtr<DepartMsg> depart = nullptr;
};

void departing(DepartMsg* msg);
void arrive(ArriveMsg* msg);
void checkBothEnds(Coord& coord);

std::unordered_map<
PhaseType, std::unordered_map<ElementIDStruct, Coord>
> coordinate_;

/// Number of phases read in
std::size_t num_phases_ = 0;
};
Expand Down

0 comments on commit fa93302

Please sign in to comment.