diff --git a/scripts/JSON_data_files_validator.py b/scripts/JSON_data_files_validator.py index 67abb5f688..72a5c51c38 100644 --- a/scripts/JSON_data_files_validator.py +++ b/scripts/JSON_data_files_validator.py @@ -433,13 +433,14 @@ def validate_comm_links(all_jsons): task_ids = set() for data in all_jsons: + tasks = data["phases"][n]["tasks"] + id_key = "id" if "id" in tasks[0]["entity"] else "seq_id" + task_ids.update({int(task["entity"][id_key]) for task in tasks}) + if data["phases"][n].get("communications") is not None: comms = data["phases"][n]["communications"] - comm_ids.update({int(comm["from"]["id"]) for comm in comms}) - comm_ids.update({int(comm["to"]["id"]) for comm in comms}) - - tasks = data["phases"][n]["tasks"] - task_ids.update({int(task["entity"]["id"]) for task in tasks}) + comm_ids.update({int(comm["from"][id_key]) for comm in comms}) + comm_ids.update({int(comm["to"][id_key]) for comm in comms}) if not comm_ids.issubset(task_ids): logging.error( diff --git a/scripts/LBDatafile_schema.py b/scripts/LBDatafile_schema.py index bdbbeb84d2..743fff574e 100644 --- a/scripts/LBDatafile_schema.py +++ b/scripts/LBDatafile_schema.py @@ -1,5 +1,11 @@ from schema import And, Optional, Schema +def validate_id_and_seq_id(field): + """Ensure that either seq_id or id is provided.""" + if 'seq_id' not in field and 'id' not in field: + raise ValueError('Either id (bit-encoded) or seq_id must be provided.') + return field + LBDatafile_schema = Schema( { Optional('type'): And(str, "LBDatafile", error="'LBDatafile' must be chosen."), @@ -30,15 +36,16 @@ 'id': int, 'tasks': [ { - 'entity': { + 'entity': And({ Optional('collection_id'): int, 'home': int, - 'id': int, + Optional('id'): int, + Optional('seq_id'): int, Optional('index'): [int], 'type': str, 'migratable': bool, Optional('objgroup_id'): int - }, + }, validate_id_and_seq_id), 'node': int, 'resource': str, Optional('subphases'): [ @@ -55,25 +62,27 @@ Optional('communications'): [ { 'type': str, - 'to': { + 'to': And({ 'type': str, - 'id': int, + Optional('id'): int, + Optional('seq_id'): int, Optional('home'): int, Optional('collection_id'): int, Optional('migratable'): bool, Optional('index'): [int], Optional('objgroup_id'): int, - }, + }, validate_id_and_seq_id), 'messages': int, - 'from': { + 'from': And({ 'type': str, - 'id': int, + Optional('id'): int, + Optional('seq_id'): int, Optional('home'): int, Optional('collection_id'): int, Optional('migratable'): bool, Optional('index'): [int], Optional('objgroup_id'): int, - }, + }, validate_id_and_seq_id), 'bytes': float } ], diff --git a/src/vt/vrt/collection/balance/lb_data_holder.cc b/src/vt/vrt/collection/balance/lb_data_holder.cc index 333dc1babd..701d3d27bb 100644 --- a/src/vt/vrt/collection/balance/lb_data_holder.cc +++ b/src/vt/vrt/collection/balance/lb_data_holder.cc @@ -41,14 +41,54 @@ //@HEADER */ -#include "vt/vrt/collection/balance/lb_data_holder.h" #include "vt/context/context.h" #include "vt/elm/elm_id_bits.h" +#include "vt/vrt/collection/balance/lb_data_holder.h" #include namespace vt { namespace vrt { namespace collection { namespace balance { +void LBDataHolder::getObjectFromJsonField_( + nlohmann::json const& field, nlohmann::json& object, bool& is_bitpacked, + bool& is_collection) { + if (field.find("id") != field.end()) { + object = field["id"]; + is_bitpacked = true; + } else { + object = field["seq_id"]; + is_bitpacked = false; + } + vtAssertExpr(object.is_number()); + if (field.find("collection_id") != field.end()) { + is_collection = true; + } else { + is_collection = false; + } +} + +ElementIDStruct +LBDataHolder::getElmFromCommObject_( + nlohmann::json const& field) const { + // Get the object's id and determine if it is bit-encoded + nlohmann::json object; + bool is_bitpacked, is_collection; + getObjectFromJsonField_(field, object, is_bitpacked, is_collection); + + // Create elm with encoded data + ElementIDStruct elm; + if (is_collection and not is_bitpacked) { + int home = field["home"]; + bool is_migratable = field["migratable"]; + elm = elm::ElmIDBits::createCollectionImpl( + is_migratable, static_cast(object), home, this_node_); + } else { + elm = ElementIDStruct{object, this_node_}; + } + + return elm; +} + void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) const { j["type"] = "object"; j["id"] = id.id; @@ -278,7 +318,7 @@ std::unique_ptr LBDataHolder::toJson(PhaseType phase) const { LBDataHolder::LBDataHolder(nlohmann::json const& j) { - auto this_node = theContext()->getNode(); + this_node_ = theContext()->getNode(); // read metadata for skipped and identical phases readMetadata(j); @@ -298,41 +338,35 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) auto time = task["time"]; auto etype = task["entity"]["type"]; auto home = task["entity"]["home"]; - bool migratable = task["entity"]["migratable"]; + bool is_migratable = task["entity"]["migratable"]; vtAssertExpr(time.is_number()); vtAssertExpr(node.is_number()); if (etype == "object") { - auto object = task["entity"]["id"]; - vtAssertExpr(object.is_number()); - - auto elm = ElementIDStruct{object, node}; + nlohmann::json object; + bool is_bitpacked, is_collection; + getObjectFromJsonField_(task["entity"], object, is_bitpacked, is_collection); + + // Create elm + ElementIDStruct elm = is_collection and not is_bitpacked + ? elm::ElmIDBits::createCollectionImpl( + is_migratable, static_cast(object), home, this_node_) + : ElementIDStruct{object, this_node_}; + this->node_data_[id][elm].whole_phase_load = time; - if ( - task["entity"].find("collection_id") != task["entity"].end() and - task["entity"].find("index") != task["entity"].end() - ) { - using Field = uint64_t; - auto strippedObject = BitPackerType::getField< - vt::elm::eElmIDProxyBitsNonObjGroup::ID, - vt::elm::elm_id_num_bits, - Field - >(static_cast(object)); - elm = elm::ElmIDBits::createCollectionImpl(migratable, - strippedObject, - home, - node); + if (is_collection) { auto cid = task["entity"]["collection_id"]; - auto idx = task["entity"]["index"]; - if (cid.is_number() && idx.is_array()) { - std::vector arr = idx; - auto proxy = static_cast(cid); - this->node_idx_[elm] = std::make_tuple(proxy, arr); + if (task["entity"].find("index") != task["entity"].end()) { + auto idx = task["entity"]["index"]; + if (cid.is_number() && idx.is_array()) { + std::vector arr = idx; + auto proxy = static_cast(cid); + this->node_idx_[elm] = std::make_tuple(proxy, arr); + } } } - this->node_data_[id][elm].whole_phase_load = time; if (task.find("subphases") != task.end()) { auto subphases = task["subphases"]; @@ -397,13 +431,8 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) vtAssertExpr(comm["from"]["type"] == "object"); vtAssertExpr(comm["to"]["type"] == "object"); - auto from_object = comm["from"]["id"]; - vtAssertExpr(from_object.is_number()); - auto from_elm = ElementIDStruct{from_object, this_node}; - - auto to_object = comm["to"]["id"]; - vtAssertExpr(to_object.is_number()); - auto to_elm = ElementIDStruct{to_object, this_node}; + auto from_elm = getElmFromCommObject_(comm["from"]); + auto to_elm = getElmFromCommObject_(comm["to"]); CommKey key( CommKey::CollectionTag{}, @@ -420,9 +449,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) auto from_node = comm["from"]["id"]; vtAssertExpr(from_node.is_number()); - auto to_object = comm["to"]["id"]; - vtAssertExpr(to_object.is_number()); - auto to_elm = ElementIDStruct{to_object, this_node}; + auto to_elm = getElmFromCommObject_(comm["to"]); CommKey key( CommKey::NodeToCollectionTag{}, @@ -437,9 +464,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) vtAssertExpr(comm["from"]["type"] == "object"); vtAssertExpr(comm["to"]["type"] == "node"); - auto from_object = comm["from"]["id"]; - vtAssertExpr(from_object.is_number()); - auto from_elm = ElementIDStruct{from_object, this_node}; + auto from_elm = getElmFromCommObject_(comm["from"]); auto to_node = comm["to"]["id"]; vtAssertExpr(to_node.is_number()); diff --git a/src/vt/vrt/collection/balance/lb_data_holder.h b/src/vt/vrt/collection/balance/lb_data_holder.h index a3badad166..fb2c9fce48 100644 --- a/src/vt/vrt/collection/balance/lb_data_holder.h +++ b/src/vt/vrt/collection/balance/lb_data_holder.h @@ -127,6 +127,29 @@ struct LBDataHolder { void addInitialTask(nlohmann::json& j, std::size_t n) const; + /** + * \brief Determine the object ID from the tasks or communication field of + * input JSON + * + * \param[in] field the json field containing an object ID + * \param[in] object empty json object to be populated with the object's ID + * \param[in] is_bitpacked empty bool to be populated with whether or not + * the ID is bit-encoded + * \param[in] is_collection empty bool to be populated with whether + * or not the object belongs to a collection + */ + static void getObjectFromJsonField_( + nlohmann::json const& field, nlohmann::json& object, + bool& is_bitpacked, bool& is_collection); + + /** + * \brief Create an ElementIDStruct for the communication object + * + * \param[in] field the communication field for the desired object + * e.g. communications["to"] or communications["from"] + */ + ElementIDStruct getElmFromCommObject_(nlohmann::json const& field) const; + /** * \brief Read the LB phase's metadata * @@ -135,6 +158,8 @@ struct LBDataHolder { void readMetadata(nlohmann::json const& j); public: + /// The current node + NodeType this_node_ = vt::uninitialized_destination; /// Node attributes for the current rank ElmUserDataType rank_attributes_; /// Node timings for each local object diff --git a/tests/unit/collection/test_lb_data_holder.cc b/tests/unit/collection/test_lb_data_holder.cc index d13f51fc6b..03a0016d1d 100644 --- a/tests/unit/collection/test_lb_data_holder.cc +++ b/tests/unit/collection/test_lb_data_holder.cc @@ -47,6 +47,7 @@ #include "test_helpers.h" #include "test_collection_common.h" +#include "vt/elm/elm_id_bits.h" #include "vt/vrt/collection/manager.h" #include "vt/vrt/collection/balance/lb_data_holder.h" @@ -90,6 +91,104 @@ void addPhasesDataToJson(nlohmann::json& json, PhaseType amountOfPhasesToAdd, st json["phases"] = phases; } +nlohmann::json createEntity_(std::string id_type, int id, int home, bool is_migratable) { + nlohmann::json entity = { + {id_type, id}, + {"type", "object"}, + {"collection_id", 7}, + {"index", {0}}, + {"home", home}, + {"migratable", is_migratable}, + }; + return entity; +} + +nlohmann::json createJson_( + std::string id_type, int id_1, int id_2, int home, int node, + bool is_migratable) { + + auto entity_1 = createEntity_(id_type, id_1, home, is_migratable); + auto entity_2 = createEntity_(id_type, id_2, home, is_migratable); + + // Generate JSON + nlohmann::json j = { + {"metadata", {{"rank", 0}, {"type", "LBDatafile"}}}, + {"phases", + {{{"communications", + {{{"bytes", 2.0}, + {"from", entity_1}, + {"messages", 1}, + {"to", entity_2}, + {"type", "SendRecv"}}}}, + {"id", 0}, + {"tasks", + { + { + {"entity", entity_1}, + {"node", node}, + {"resource", "cpu"}, + {"time", 0.5}, + }, + { + {"entity", entity_2}, + {"node", node}, + {"resource", "cpu"}, + {"time", 0.5}, + } + } + }}} + } + }; + + return j; +} + +void testDataHolderElms(int seq_id_1, int home, int node, bool is_migratable) { + // Create a second seq_id + auto seq_id_2 = seq_id_1 + 1; + + // Determine encoded ID + auto elm_1 = + elm::ElmIDBits::createCollectionImpl(is_migratable, seq_id_1, home, node); + auto encoded_id_1 = elm_1.id; + + // Create second encoded ID + auto elm_2 = + elm::ElmIDBits::createCollectionImpl(is_migratable, seq_id_2, home, node); + auto encoded_id_2 = elm_2.id; + + // Create DataHolder and get resulting object elm + auto simple_json_id = createJson_( + "id", encoded_id_1, encoded_id_2, home, node, is_migratable); + auto dh_id = vt::vrt::collection::balance::LBDataHolder(simple_json_id); + + // Create new DataHolder using "seq_id" and get elm + auto simple_json_seq = + createJson_("seq_id", seq_id_1, seq_id_2, home, node, is_migratable); + auto dh_seq = vt::vrt::collection::balance::LBDataHolder(simple_json_seq); + + // Assert that both elms exist in both DataHolders + ASSERT_NE(dh_id.node_data_[0].find(elm_1), dh_id.node_data_[0].end()); + ASSERT_NE(dh_seq.node_data_[0].find(elm_1), dh_seq.node_data_[0].end()); + ASSERT_NE(dh_id.node_data_[0].find(elm_2), dh_id.node_data_[0].end()); + ASSERT_NE(dh_seq.node_data_[0].find(elm_2), dh_seq.node_data_[0].end()); + + // Check the communication data + auto comm_id = dh_id.node_comm_[0]; + auto comm_key_id = comm_id.begin()->first; + auto comm_seq = dh_seq.node_comm_[0]; + auto comm_key_seq = comm_seq.begin()->first; + + // Ensure that we get the same CommKey from both id types + EXPECT_EQ(comm_key_id, comm_key_seq); + + // Assert that both elms are present in the communication data + EXPECT_EQ(comm_key_id.fromObj(), elm_1); + EXPECT_EQ(comm_key_id.toObj(), elm_2); + EXPECT_EQ(comm_key_seq.fromObj(), elm_1); + EXPECT_EQ(comm_key_seq.toObj(), elm_2); +} + TEST_F(TestLBDataHolder, test_no_metadata) { using LBDataHolder = vt::vrt::collection::balance::LBDataHolder; @@ -302,6 +401,19 @@ TEST_F(TestLBDataHolder, test_default_time_format) { } } +TEST_F(TestLBDataHolder, test_lb_data_holder_object_id) { + // Run a variety of test cases (seq_id, home, node, is_migratable) + auto current_node = theContext()->getNode(); + testDataHolderElms(0, 0, current_node, false); + testDataHolderElms(0, 0, current_node, true); + testDataHolderElms(1, 0, current_node, false); + testDataHolderElms(1, 0, current_node, true); + testDataHolderElms(0, 1, current_node, false); + testDataHolderElms(0, 1, current_node, true); + testDataHolderElms(3, 1, current_node, false); + testDataHolderElms(2, 2, current_node, true); +} + }}}} // end namespace vt::tests::unit::lb #endif /*vt_check_enabled(lblite)*/