Skip to content

Commit

Permalink
#2229: Change attributes to use std::variant
Browse files Browse the repository at this point in the history
  • Loading branch information
thearusable authored and cwschilly committed Sep 20, 2024
1 parent f8ddf86 commit 4d57917
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 32 deletions.
6 changes: 3 additions & 3 deletions scripts/JSON_data_files_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ def _get_valid_schema(self) -> Schema:
Optional('index'): [int],
'type': str,
'migratable': bool,
Optional('objgroup_id'): int,
Optional('attributes'): dict
Optional('objgroup_id'): int
},
'node': int,
'resource': str,
Expand All @@ -90,7 +89,8 @@ def _get_valid_schema(self) -> Schema:
}
],
'time': float,
Optional('user_defined'): dict
Optional('user_defined'): dict,
Optional('attributes'): dict
},
],
Optional('communications'): [
Expand Down
49 changes: 32 additions & 17 deletions src/vt/vrt/collection/balance/lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

namespace vt { namespace vrt { namespace collection { namespace balance {

void LBDataHolder::outputEntity(PhaseType phase, nlohmann::json& j, ElementIDStruct const& id) const {
void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) const {
j["type"] = "object";
j["id"] = id.id;
j["home"] = id.getHomeNode();
Expand All @@ -66,11 +66,6 @@ void LBDataHolder::outputEntity(PhaseType phase, nlohmann::json& j, ElementIDStr
} else {
// bare handler
}
if (node_user_attributes_.find(phase) != node_user_attributes_.end()) {
if (node_user_attributes_.at(phase).find(id) != node_user_attributes_.at(phase).end()) {
j["attributes"] = *(node_user_attributes_.at(phase).at(id));
}
}
}

std::unique_ptr<nlohmann::json> LBDataHolder::metadataToJson() const {
Expand Down Expand Up @@ -153,7 +148,21 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {
}
}
}
outputEntity(phase, j["tasks"][i]["entity"], id);
outputEntity(j["tasks"][i]["entity"], id);

if (node_user_attributes_.find(phase) != node_user_attributes_.end()) {
if (node_user_attributes_.at(phase).find(id) != node_user_attributes_.at(phase).end()) {
for (auto const& [key, value] : node_user_attributes_.at(phase).at(id)) {
if (std::holds_alternative<int>(value)) {
j["tasks"][i]["attributes"][key] = std::get<int>(value);
} else if (std::holds_alternative<double>(value)) {
j["tasks"][i]["attributes"][key] = std::get<double>(value);
} else if (std::holds_alternative<std::string>(value)) {
j["tasks"][i]["attributes"][key] = std::get<std::string>(value);
}
}
}
}

auto const& subphase_times = elm.second.subphase_loads;
std::size_t const subphases = subphase_times.size();
Expand Down Expand Up @@ -184,8 +193,8 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {
} else {
j["communications"][i]["type"] = "Broadcast";
}
outputEntity(phase, j["communications"][i]["from"], key.fromObj());
outputEntity(phase, j["communications"][i]["to"], key.toObj());
outputEntity(j["communications"][i]["from"], key.fromObj());
outputEntity(j["communications"][i]["to"], key.toObj());
break;
}
case elm::CommCategory::NodeToCollection:
Expand All @@ -198,7 +207,7 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

j["communications"][i]["from"]["type"] = "node";
j["communications"][i]["from"]["id"] = key.fromNode();
outputEntity(phase, j["communications"][i]["to"], key.toObj());
outputEntity(j["communications"][i]["to"], key.toObj());
break;
}
case elm::CommCategory::CollectionToNode:
Expand All @@ -211,7 +220,7 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

j["communications"][i]["to"]["type"] = "node";
j["communications"][i]["to"]["id"] = key.toNode();
outputEntity(phase, j["communications"][i]["from"], key.fromObj());
outputEntity(j["communications"][i]["from"], key.fromObj());
break;
}
case elm::CommCategory::LocalInvoke:
Expand Down Expand Up @@ -279,12 +288,6 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
}
}

if (task["entity"].find("attributes") != task["entity"].end()) {
auto attrs = task["entity"]["attributes"];
node_user_attributes_[id][elm] = std::make_shared<nlohmann::json>();
*(node_user_attributes_[id][elm]) = attrs;
}

if (task.find("subphases") != task.end()) {
auto subphases = task["subphases"];
if (subphases.is_array()) {
Expand Down Expand Up @@ -314,6 +317,18 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
}
}
}

if (task.find("attributes") != task.end()) {
for (auto const& [key, value] : task["attributes"].items()) {
if (value.is_number_integer()) {
node_user_attributes_[id][elm][key] = value.get<int>();
} else if (value.is_number_float()) {
node_user_attributes_[id][elm][key] = value.get<double>();
} else if (value.is_string()) {
node_user_attributes_[id][elm][key] = value.get<std::string>();
}
}
}
}
}
}
Expand Down
9 changes: 4 additions & 5 deletions src/vt/vrt/collection/balance/lb_data_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,10 @@ struct LBDataHolder {
/**
* \brief Output an entity to json
*
* \param[in] phase the phase
* \param[in] j the json
* \param[in] elm_id the element to output
*/
void outputEntity(PhaseType phase, nlohmann::json& j, ElementIDStruct const& elm_id) const;
void outputEntity(nlohmann::json& j, ElementIDStruct const& elm_id) const;

/**
* \brief Read the LB phase's metadata
Expand All @@ -128,6 +127,7 @@ struct LBDataHolder {
void readMetadata(nlohmann::json const& j);

public:
/// Node attributes for the current rank
std::shared_ptr<nlohmann::json> rank_attributes_;
/// Node timings for each local object
std::unordered_map<PhaseType, LoadMapType> node_data_;
Expand All @@ -139,13 +139,12 @@ struct LBDataHolder {
std::unordered_map<PhaseType, std::unordered_map<
ElementIDStruct, std::shared_ptr<nlohmann::json>
>> user_defined_json_;
std::unordered_map<PhaseType, std::unordered_map<
ElementIDStruct, std::shared_ptr<nlohmann::json>
>> node_user_attributes_;

std::unordered_map<PhaseType, std::shared_ptr<nlohmann::json>> user_per_phase_json_;
/// User-defined data from each phase for LB
std::unordered_map<PhaseType, DataMapType> user_defined_lb_info_;
/// User-defined attributes from each phase
std::unordered_map<PhaseType, DataMapType> node_user_attributes_;
/// Node indices for each ID along with the proxy ID
std::unordered_map<ElementIDStruct, std::tuple<VirtualProxyType, std::vector<uint64_t>>> node_idx_;
/// Map from id to objgroup proxy
Expand Down
5 changes: 5 additions & 0 deletions src/vt/vrt/collection/balance/node_lb_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ NodeLBData::getUserData() const {
return &lb_data_->user_defined_lb_info_;
}

std::unordered_map<PhaseType, DataMapType> const*
NodeLBData::getUserAttributes() const {
return &lb_data_->node_user_attributes_;
}

std::unordered_map<PhaseType, CommMapType> const* NodeLBData::getNodeComm() const {
return &lb_data_->node_comm_;
}
Expand Down
7 changes: 7 additions & 0 deletions src/vt/vrt/collection/balance/node_lb_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,13 @@ struct NodeLBData : runtime::component::Component<NodeLBData> {
*/
std::unordered_map<PhaseType, DataMapType> const* getUserData() const;

/**
* \internal \brief Get the user-defined attributes
*
* \return an observer pointer to the user-defined attributes
*/
std::unordered_map<PhaseType, DataMapType> const* getUserAttributes() const;

/**
* \internal \brief Get stored object comm data for a specific phase
*
Expand Down
21 changes: 14 additions & 7 deletions tests/unit/collection/test_lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,16 @@ TEST_F(TestLBDataHolder, test_lb_entity_attributes) {
"home": 0,
"id": 524291,
"type": "object",
"migratable": true,
"attributes": {
"some_val": 123
}
"migratable": true
},
"node": 0,
"resource": "cpu",
"time": 3.0
"time": 3.0,
"attributes": {
"intSample": 123,
"doubleSample": 1.99,
"stringSample": "abc"
}
}
]
}
Expand All @@ -272,11 +274,16 @@ TEST_F(TestLBDataHolder, test_lb_entity_attributes) {
LBDataHolder testObj(json);
EXPECT_TRUE(testObj.node_user_attributes_.find(0) != testObj.node_user_attributes_.end());
EXPECT_TRUE(testObj.node_user_attributes_[0].find(id) != testObj.node_user_attributes_[0].end());
EXPECT_EQ(123, (*testObj.node_user_attributes_[0][id])["some_val"]);
auto attributes = testObj.node_user_attributes_[0][id];
EXPECT_EQ(123, std::get<int>(attributes["intSample"]));
EXPECT_EQ(1.99, std::get<double>(attributes["doubleSample"]));
EXPECT_EQ("abc", std::get<std::string>(attributes["stringSample"]));

auto outJsonPtr = testObj.toJson(0);
ASSERT_TRUE(outJsonPtr != nullptr);
EXPECT_EQ(123, (*outJsonPtr)["tasks"][0]["entity"]["attributes"]["some_val"]);
EXPECT_EQ(123, (*outJsonPtr)["tasks"][0]["attributes"]["intSample"]);
EXPECT_EQ(1.99, (*outJsonPtr)["tasks"][0]["attributes"]["doubleSample"]);
EXPECT_EQ("abc", (*outJsonPtr)["tasks"][0]["attributes"]["stringSample"]);
}

}}}} // end namespace vt::tests::unit::lb
Expand Down

0 comments on commit 4d57917

Please sign in to comment.