diff --git a/scripts/JSON_data_files_validator.py b/scripts/JSON_data_files_validator.py index 05190ec508..06e9f5b7ef 100644 --- a/scripts/JSON_data_files_validator.py +++ b/scripts/JSON_data_files_validator.py @@ -435,12 +435,12 @@ def validate_comm_links(all_jsons): for data in all_jsons: if data["phases"][n].get("communications") is not None: comms = data["phases"][n]["communications"] - id_string = "id" if "id" in comms[0]["from"] else "seq_id" - comm_ids.update({int(comm["from"][id_string]) for comm in comms}) - comm_ids.update({int(comm["to"][id_string]) for comm in comms}) + id_key = "id" if "id" in comms[0]["from"] else "seq_id" + comm_ids.update({int(comm["from"][id_key]) for comm in comms}) + comm_ids.update({int(comm["to"][id_key]) for comm in comms}) tasks = data["phases"][n]["tasks"] - task_ids.update({int(task["entity"][id_string]) for task in tasks}) + task_ids.update({int(task["entity"][id_key]) for task in tasks}) if not comm_ids.issubset(task_ids): logging.error( diff --git a/src/vt/vrt/collection/balance/lb_data_holder.cc b/src/vt/vrt/collection/balance/lb_data_holder.cc index d47297ec9a..9712697d13 100644 --- a/src/vt/vrt/collection/balance/lb_data_holder.cc +++ b/src/vt/vrt/collection/balance/lb_data_holder.cc @@ -41,14 +41,68 @@ //@HEADER */ -#include "vt/vrt/collection/balance/lb_data_holder.h" #include "vt/context/context.h" #include "vt/elm/elm_id_bits.h" +#include "vt/vrt/collection/balance/lb_data_holder.h" #include namespace vt { namespace vrt { namespace collection { namespace balance { +void get_object_from_json_field_( + const nlohmann::json& field, nlohmann::json& object, bool& bitpacked) { + if (field.find("id") != field.end()) { + object = field["id"]; + bitpacked = true; + } else { + object = field["seq_id"]; + bitpacked = false; + } +} + +ElementIDStruct get_elm_from_object_info_( + const nlohmann::json& object, bool bitpacked, bool migratable, + const nlohmann::json& home, const nlohmann::json& node) { + using Field = uint64_t; + + Field object_id; + if (bitpacked) { + object_id = BitPackerType::getField< + vt::elm::eElmIDProxyBitsNonObjGroup::ID, vt::elm::elm_id_num_bits, Field>( + static_cast(object)); + } else { + object_id = static_cast(object); + } + + return elm::ElmIDBits::createCollectionImpl( + migratable, object_id, home, node); +} + +ElementIDStruct get_elm_from_comm_object_(const nlohmann::json& field, bool collection) { + // Get the object's id and determine if it is bit-encoded + nlohmann::json object; + bool bitpacked_id; + get_object_from_json_field_(field, object, bitpacked_id); + vtAssertExpr(object.is_number()); + + // Somehow will this information + int home = 0; + int node = 0; + bool migratable = false; + + // Create elm with encoded data + ElementIDStruct elm; + if (collection) { + elm = + get_elm_from_object_info_(object, bitpacked_id, migratable, home, node); + } else { + elm = ElementIDStruct{object, theContext()->getNode()}; + } + + return elm; +} + + void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) const { j["type"] = "object"; j["id"] = id.id; @@ -278,8 +332,6 @@ std::unique_ptr LBDataHolder::toJson(PhaseType phase) const { LBDataHolder::LBDataHolder(nlohmann::json const& j) { - auto this_node = theContext()->getNode(); - // read metadata for skipped and identical phases readMetadata(j); @@ -314,27 +366,13 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) } vtAssertExpr(object.is_number()); - auto elm = ElementIDStruct{object, node}; + ElementIDStruct elm; if ( task["entity"].find("collection_id") != task["entity"].end() and - task["entity"].find("index") != task["entity"].end() - ) { - using Field = uint64_t; - Field object_id; - if (bitpacked_id) { - object_id = BitPackerType::getField< - vt::elm::eElmIDProxyBitsNonObjGroup::ID, - vt::elm::elm_id_num_bits, - Field - >(static_cast(object)); - } else { - object_id = static_cast(object); - } - elm = elm::ElmIDBits::createCollectionImpl(migratable, - object_id, - home, - node); + task["entity"].find("index") != task["entity"].end()) { + elm = get_elm_from_object_info_( + object, bitpacked_id, migratable, home, node); auto cid = task["entity"]["collection_id"]; auto idx = task["entity"]["index"]; if (cid.is_number() && idx.is_array()) { @@ -342,6 +380,8 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) auto proxy = static_cast(cid); this->node_idx_[elm] = std::make_tuple(proxy, arr); } + } else { + elm = ElementIDStruct{object, node}; } this->node_data_[id][elm].whole_phase_load = time; @@ -409,13 +449,11 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) vtAssertExpr(comm["from"]["type"] == "object"); vtAssertExpr(comm["to"]["type"] == "object"); - auto from_object = comm["from"]["id"]; - vtAssertExpr(from_object.is_number()); - auto from_elm = ElementIDStruct{from_object, this_node}; - - auto to_object = comm["to"]["id"]; - vtAssertExpr(to_object.is_number()); - auto to_elm = ElementIDStruct{to_object, this_node}; + // TODO: passing false here (and below) avoids encoding + // any information into the obj ids, which preserves + // the original behavior + auto from_elm = get_elm_from_comm_object_(comm["from"], false); + auto to_elm = get_elm_from_comm_object_(comm["to"], false); CommKey key( CommKey::CollectionTag{}, @@ -432,9 +470,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) auto from_node = comm["from"]["id"]; vtAssertExpr(from_node.is_number()); - auto to_object = comm["to"]["id"]; - vtAssertExpr(to_object.is_number()); - auto to_elm = ElementIDStruct{to_object, this_node}; + auto to_elm = get_elm_from_comm_object_(comm["to"], false); CommKey key( CommKey::NodeToCollectionTag{}, @@ -449,9 +485,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j) vtAssertExpr(comm["from"]["type"] == "object"); vtAssertExpr(comm["to"]["type"] == "node"); - auto from_object = comm["from"]["id"]; - vtAssertExpr(from_object.is_number()); - auto from_elm = ElementIDStruct{from_object, this_node}; + auto from_elm = get_elm_from_comm_object_(comm["from"], false); auto to_node = comm["to"]["id"]; vtAssertExpr(to_node.is_number()); diff --git a/tests/unit/lb/test_lb_data_holder.cc b/tests/unit/lb/test_lb_data_holder.cc index 8d330d6484..cfcb3583af 100644 --- a/tests/unit/lb/test_lb_data_holder.cc +++ b/tests/unit/lb/test_lb_data_holder.cc @@ -106,12 +106,6 @@ void test_data_holder_elms(int seq_id, int home, int node, bool migratable) { } TEST_F(TestLBDataHolder, test_lb_data_holder_no_comms_object_id) { - // Initialize - int argc = 0; - char** argv = nullptr; - MPI_Comm comm = MPI_COMM_WORLD; - vt::initialize(argc, argv, &comm); - // Run a variety of test cases (seq_id, home, node, migratable) test_data_holder_elms(0,0,0,false); test_data_holder_elms(0,0,0,true); @@ -120,9 +114,6 @@ TEST_F(TestLBDataHolder, test_lb_data_holder_no_comms_object_id) { test_data_holder_elms(1,1,0,false); test_data_holder_elms(2,1,9,true); test_data_holder_elms(3,0,1,false); - - // Finalize - // vt::finalize(); } -}}}} // end namespace vt::tests::unit::lb \ No newline at end of file +}}}} // end namespace vt::tests::unit::lb