From a6210557f796875a00bbdc30ae21a32508adbcad Mon Sep 17 00:00:00 2001 From: Arkadiusz Szczepkowicz Date: Fri, 16 Feb 2024 13:44:21 +0100 Subject: [PATCH] #2229: Change rank_attributes to std::variant --- .../vrt/collection/balance/lb_data_holder.cc | 27 +++++++++++++++++-- .../vrt/collection/balance/lb_data_holder.h | 9 ++++++- src/vt/vrt/collection/balance/node_lb_data.cc | 9 ++++--- src/vt/vrt/collection/balance/node_lb_data.h | 4 +-- tests/unit/collection/test_lb_data_holder.cc | 15 ++++++++--- 5 files changed, 52 insertions(+), 12 deletions(-) diff --git a/src/vt/vrt/collection/balance/lb_data_holder.cc b/src/vt/vrt/collection/balance/lb_data_holder.cc index 30356200ce..147b725f83 100644 --- a/src/vt/vrt/collection/balance/lb_data_holder.cc +++ b/src/vt/vrt/collection/balance/lb_data_holder.cc @@ -125,6 +125,22 @@ std::unique_ptr LBDataHolder::metadataToJson() const { return std::make_unique(std::move(j)); } +std::unique_ptr LBDataHolder::rankAttributesToJson() const { + nlohmann::json j; + + for (auto const& [key, value] : rank_attributes_) { + if (std::holds_alternative(value)) { + j["attributes"][key] = std::get(value); + } else if (std::holds_alternative(value)) { + j["attributes"][key] = std::get(value); + } else if (std::holds_alternative(value)) { + j["attributes"][key] = std::get(value); + } + } + + return std::make_unique(std::move(j)); +} + std::unique_ptr LBDataHolder::toJson(PhaseType phase) const { using json = nlohmann::json; @@ -469,8 +485,15 @@ void LBDataHolder::readMetadata(nlohmann::json const& j) { } // load rank user atrributes if (metadata.find("attributes") != metadata.end()) { - rank_attributes_ = std::make_shared(); - *(rank_attributes_) = metadata["attributes"]; + for (auto const& [key, value] : metadata["attributes"].items()) { + if (value.is_number_integer()) { + rank_attributes_[key] = value.get(); + } else if (value.is_number_float()) { + rank_attributes_[key] = value.get(); + } else if (value.is_string()) { + rank_attributes_[key] = value.get(); + } + } } } } diff --git a/src/vt/vrt/collection/balance/lb_data_holder.h b/src/vt/vrt/collection/balance/lb_data_holder.h index 7b75bc9477..caeb4e4cea 100644 --- a/src/vt/vrt/collection/balance/lb_data_holder.h +++ b/src/vt/vrt/collection/balance/lb_data_holder.h @@ -105,6 +105,13 @@ struct LBDataHolder { */ std::unique_ptr metadataToJson() const; + /** + * \brief Output a LB rank attributes metadata to JSON + * + * \return the json data structure + */ + std::unique_ptr rankAttributesToJson() const; + /** * \brief Clear all LB data */ @@ -128,7 +135,7 @@ struct LBDataHolder { public: /// Node attributes for the current rank - std::shared_ptr rank_attributes_; + ElmUserDataType rank_attributes_; /// Node timings for each local object std::unordered_map node_data_; /// Node communication graph for each local object diff --git a/src/vt/vrt/collection/balance/node_lb_data.cc b/src/vt/vrt/collection/balance/node_lb_data.cc index 4057495560..b1e04c0caf 100644 --- a/src/vt/vrt/collection/balance/node_lb_data.cc +++ b/src/vt/vrt/collection/balance/node_lb_data.cc @@ -115,8 +115,8 @@ std::unordered_map> con return &lb_data_->node_subphase_comm_; } -std::shared_ptr const NodeLBData::getNodeAttributes() const { - return lb_data_->rank_attributes_; +ElmUserDataType const* NodeLBData::getNodeAttributes() const { + return &lb_data_->rank_attributes_; } CommMapType* NodeLBData::getNodeComm(PhaseType phase) { @@ -225,8 +225,9 @@ void NodeLBData::createLBDataFile() { if(phasesMetadata) { metadata["phases"] = *phasesMetadata; } - if(lb_data_->rank_attributes_) { - metadata["attributes"] = *lb_data_->rank_attributes_; + auto attributesMetadata = lb_data_->rankAttributesToJson(); + if(attributesMetadata) { + metadata["attributes"] = *attributesMetadata; } lb_data_writer_ = std::make_unique( "phases", metadata, file_name, compress diff --git a/src/vt/vrt/collection/balance/node_lb_data.h b/src/vt/vrt/collection/balance/node_lb_data.h index fd8321c336..652330928d 100644 --- a/src/vt/vrt/collection/balance/node_lb_data.h +++ b/src/vt/vrt/collection/balance/node_lb_data.h @@ -214,9 +214,9 @@ struct NodeLBData : runtime::component::Component { /** * \internal \brief Get stored node attributes * - * \return an observer shared pointer to the node attributes + * \return an observer pointer to the node attributes */ - std::shared_ptr const getNodeAttributes() const; + ElmUserDataType const* getNodeAttributes() const; /** * \internal \brief Test if this node has an object to migrate diff --git a/tests/unit/collection/test_lb_data_holder.cc b/tests/unit/collection/test_lb_data_holder.cc index a680e10e4b..af95937b92 100644 --- a/tests/unit/collection/test_lb_data_holder.cc +++ b/tests/unit/collection/test_lb_data_holder.cc @@ -226,7 +226,9 @@ TEST_F(TestLBDataHolder, test_lb_rank_attributes) { "type": "LBDatafile", "metadata" : { "attributes": { - "some_val": 123 + "intSample": 123, + "doubleSample": 1.99, + "stringSample": "abc" } }, "phases" : [] @@ -234,8 +236,15 @@ TEST_F(TestLBDataHolder, test_lb_rank_attributes) { )"_json; LBDataHolder testObj(json); - EXPECT_TRUE(nullptr != testObj.rank_attributes_); - EXPECT_EQ(123, (*testObj.rank_attributes_)["some_val"]); + EXPECT_EQ(123, std::get(testObj.rank_attributes_["intSample"])); + EXPECT_EQ(1.99, std::get(testObj.rank_attributes_["doubleSample"])); + EXPECT_EQ("abc", std::get(testObj.rank_attributes_["stringSample"])); + + auto outAttributesJsonPtr = testObj.rankAttributesToJson(); + ASSERT_TRUE(outAttributesJsonPtr != nullptr); + EXPECT_EQ(123, (*outAttributesJsonPtr)["attributes"]["intSample"]); + EXPECT_EQ(1.99, (*outAttributesJsonPtr)["attributes"]["doubleSample"]); + EXPECT_EQ("abc", (*outAttributesJsonPtr)["attributes"]["stringSample"]); } TEST_F(TestLBDataHolder, test_lb_entity_attributes) {