Skip to content

Commit

Permalink
#2229: Add reading and writing of user attributes in LBDataHolder
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable committed Feb 16, 2024
1 parent 8a0f594 commit 71150d1
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
32 changes: 26 additions & 6 deletions src/vt/vrt/collection/balance/lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

namespace vt { namespace vrt { namespace collection { namespace balance {

void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) const {
void LBDataHolder::outputEntity(PhaseType phase, nlohmann::json& j, ElementIDStruct const& id) const {
j["type"] = "object";
j["id"] = id.id;
j["home"] = id.getHomeNode();
Expand All @@ -66,6 +66,11 @@ void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) co
} else {
// bare handler
}
if (node_user_attributes_.find(phase) != node_user_attributes_.end()) {
if (node_user_attributes_.at(phase).find(id) != node_user_attributes_.at(phase).end()) {
j["attributes"] = *(node_user_attributes_.at(phase).at(id));
}
}
}

std::unique_ptr<nlohmann::json> LBDataHolder::metadataToJson() const {
Expand Down Expand Up @@ -117,6 +122,10 @@ std::unique_ptr<nlohmann::json> LBDataHolder::metadataToJson() const {
}
}

if (rank_attributes_) {
j["attributes"] = *rank_attributes_;
}

// Save metadata
j["skipped"]["list"] = skipped_list;
j["skipped"]["range"] = skipped_ranges;
Expand Down Expand Up @@ -148,7 +157,7 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {
}
}
}
outputEntity(j["tasks"][i]["entity"], id);
outputEntity(phase, j["tasks"][i]["entity"], id);

auto const& subphase_times = elm.second.subphase_loads;
std::size_t const subphases = subphase_times.size();
Expand Down Expand Up @@ -179,8 +188,8 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {
} else {
j["communications"][i]["type"] = "Broadcast";
}
outputEntity(j["communications"][i]["from"], key.fromObj());
outputEntity(j["communications"][i]["to"], key.toObj());
outputEntity(phase, j["communications"][i]["from"], key.fromObj());
outputEntity(phase, j["communications"][i]["to"], key.toObj());
break;
}
case elm::CommCategory::NodeToCollection:
Expand All @@ -193,7 +202,7 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

j["communications"][i]["from"]["type"] = "node";
j["communications"][i]["from"]["id"] = key.fromNode();
outputEntity(j["communications"][i]["to"], key.toObj());
outputEntity(phase, j["communications"][i]["to"], key.toObj());
break;
}
case elm::CommCategory::CollectionToNode:
Expand All @@ -206,7 +215,7 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

j["communications"][i]["to"]["type"] = "node";
j["communications"][i]["to"]["id"] = key.toNode();
outputEntity(j["communications"][i]["from"], key.fromObj());
outputEntity(phase, j["communications"][i]["from"], key.fromObj());
break;
}
case elm::CommCategory::LocalInvoke:
Expand Down Expand Up @@ -274,6 +283,12 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
}
}

if (task["entity"].find("attributes") != task["entity"].end()) {
auto attrs = task["entity"]["attributes"];
node_user_attributes_[id][elm] = std::make_shared<nlohmann::json>();
*(node_user_attributes_[id][elm]) = attrs;
}

if (task.find("subphases") != task.end()) {
auto subphases = task["subphases"];
if (subphases.is_array()) {
Expand Down Expand Up @@ -441,6 +456,11 @@ void LBDataHolder::readMetadata(nlohmann::json const& j) {
}
}
}
// load rank user atrributes
if (metadata.find("attributes") != metadata.end()) {
rank_attributes_ = std::make_shared<nlohmann::json>();
*(rank_attributes_) = metadata["attributes"];
}
}
}

Expand Down
11 changes: 9 additions & 2 deletions src/vt/vrt/collection/balance/lb_data_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct LBDataHolder {
s | skipped_phases_;
s | identical_phases_;
s | user_defined_lb_info_;
s | node_user_attributes_;
s | rank_attributes_;
}

/**
Expand All @@ -97,7 +99,7 @@ struct LBDataHolder {
std::unique_ptr<nlohmann::json> toJson(PhaseType phase) const;

/**
* \brief Output a LB phase's metdadata to JSON
* \brief Output a LB phase's metadata to JSON
*
* \return the json data structure
*/
Expand All @@ -112,10 +114,11 @@ struct LBDataHolder {
/**
* \brief Output an entity to json
*
* \param[in] phase the phase
* \param[in] j the json
* \param[in] elm_id the element to output
*/
void outputEntity(nlohmann::json& j, ElementIDStruct const& elm_id) const;
void outputEntity(PhaseType phase, nlohmann::json& j, ElementIDStruct const& elm_id) const;

/**
* \brief Read the LB phase's metadata
Expand All @@ -125,6 +128,7 @@ struct LBDataHolder {
void readMetadata(nlohmann::json const& j);

public:
std::shared_ptr<nlohmann::json> rank_attributes_;
/// Node timings for each local object
std::unordered_map<PhaseType, LoadMapType> node_data_;
/// Node communication graph for each local object
Expand All @@ -135,6 +139,9 @@ struct LBDataHolder {
std::unordered_map<PhaseType, std::unordered_map<
ElementIDStruct, std::shared_ptr<nlohmann::json>
>> user_defined_json_;
std::unordered_map<PhaseType, std::unordered_map<
ElementIDStruct, std::shared_ptr<nlohmann::json>
>> node_user_attributes_;

std::unordered_map<PhaseType, std::shared_ptr<nlohmann::json>> user_per_phase_json_;
/// User-defined data from each phase for LB
Expand Down

0 comments on commit 71150d1

Please sign in to comment.