Skip to content

Commit

Permalink
#2342: wip: add logic for supporting or creating encoded ids in comm …
Browse files Browse the repository at this point in the history
…field
  • Loading branch information
cwschilly authored and cz4rs committed Sep 20, 2024
1 parent 69df764 commit 6b52e07
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 48 deletions.
8 changes: 4 additions & 4 deletions scripts/JSON_data_files_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,12 @@ def validate_comm_links(all_jsons):
for data in all_jsons:
if data["phases"][n].get("communications") is not None:
comms = data["phases"][n]["communications"]
id_string = "id" if "id" in comms[0]["from"] else "seq_id"
comm_ids.update({int(comm["from"][id_string]) for comm in comms})
comm_ids.update({int(comm["to"][id_string]) for comm in comms})
id_key = "id" if "id" in comms[0]["from"] else "seq_id"
comm_ids.update({int(comm["from"][id_key]) for comm in comms})
comm_ids.update({int(comm["to"][id_key]) for comm in comms})

tasks = data["phases"][n]["tasks"]
task_ids.update({int(task["entity"][id_string]) for task in tasks})
task_ids.update({int(task["entity"][id_key]) for task in tasks})

if not comm_ids.issubset(task_ids):
logging.error(
Expand Down
102 changes: 68 additions & 34 deletions src/vt/vrt/collection/balance/lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,68 @@
//@HEADER
*/

#include "vt/vrt/collection/balance/lb_data_holder.h"
#include "vt/context/context.h"
#include "vt/elm/elm_id_bits.h"
#include "vt/vrt/collection/balance/lb_data_holder.h"

#include <nlohmann/json.hpp>

namespace vt { namespace vrt { namespace collection { namespace balance {

void get_object_from_json_field_(
const nlohmann::json& field, nlohmann::json& object, bool& bitpacked) {
if (field.find("id") != field.end()) {
object = field["id"];
bitpacked = true;
} else {
object = field["seq_id"];
bitpacked = false;
}
}

ElementIDStruct get_elm_from_object_info_(
const nlohmann::json& object, bool bitpacked, bool migratable,
const nlohmann::json& home, const nlohmann::json& node) {
using Field = uint64_t;

Field object_id;
if (bitpacked) {
object_id = BitPackerType::getField<
vt::elm::eElmIDProxyBitsNonObjGroup::ID, vt::elm::elm_id_num_bits, Field>(
static_cast<Field>(object));
} else {
object_id = static_cast<Field>(object);
}

return elm::ElmIDBits::createCollectionImpl(
migratable, object_id, home, node);
}

ElementIDStruct get_elm_from_comm_object_(const nlohmann::json& field, bool collection) {
// Get the object's id and determine if it is bit-encoded
nlohmann::json object;
bool bitpacked_id;
get_object_from_json_field_(field, object, bitpacked_id);
vtAssertExpr(object.is_number());

// Somehow will this information
int home = 0;
int node = 0;
bool migratable = false;

// Create elm with encoded data
ElementIDStruct elm;
if (collection) {
elm =
get_elm_from_object_info_(object, bitpacked_id, migratable, home, node);
} else {
elm = ElementIDStruct{object, theContext()->getNode()};
}

return elm;
}


void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) const {
j["type"] = "object";
j["id"] = id.id;
Expand Down Expand Up @@ -278,8 +332,6 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

LBDataHolder::LBDataHolder(nlohmann::json const& j)
{
auto this_node = theContext()->getNode();

// read metadata for skipped and identical phases
readMetadata(j);

Expand Down Expand Up @@ -314,34 +366,22 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
}
vtAssertExpr(object.is_number());

auto elm = ElementIDStruct{object, node};
ElementIDStruct elm;

if (
task["entity"].find("collection_id") != task["entity"].end() and
task["entity"].find("index") != task["entity"].end()
) {
using Field = uint64_t;
Field object_id;
if (bitpacked_id) {
object_id = BitPackerType::getField<
vt::elm::eElmIDProxyBitsNonObjGroup::ID,
vt::elm::elm_id_num_bits,
Field
>(static_cast<Field>(object));
} else {
object_id = static_cast<Field>(object);
}
elm = elm::ElmIDBits::createCollectionImpl(migratable,
object_id,
home,
node);
task["entity"].find("index") != task["entity"].end()) {
elm = get_elm_from_object_info_(
object, bitpacked_id, migratable, home, node);
auto cid = task["entity"]["collection_id"];
auto idx = task["entity"]["index"];
if (cid.is_number() && idx.is_array()) {
std::vector<uint64_t> arr = idx;
auto proxy = static_cast<VirtualProxyType>(cid);
this->node_idx_[elm] = std::make_tuple(proxy, arr);
}
} else {
elm = ElementIDStruct{object, node};
}

this->node_data_[id][elm].whole_phase_load = time;
Expand Down Expand Up @@ -409,13 +449,11 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "object");

auto from_object = comm["from"]["id"];
vtAssertExpr(from_object.is_number());
auto from_elm = ElementIDStruct{from_object, this_node};

auto to_object = comm["to"]["id"];
vtAssertExpr(to_object.is_number());
auto to_elm = ElementIDStruct{to_object, this_node};
// TODO: passing false here (and below) avoids encoding
// any information into the obj ids, which preserves
// the original behavior
auto from_elm = get_elm_from_comm_object_(comm["from"], false);
auto to_elm = get_elm_from_comm_object_(comm["to"], false);

CommKey key(
CommKey::CollectionTag{},
Expand All @@ -432,9 +470,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
auto from_node = comm["from"]["id"];
vtAssertExpr(from_node.is_number());

auto to_object = comm["to"]["id"];
vtAssertExpr(to_object.is_number());
auto to_elm = ElementIDStruct{to_object, this_node};
auto to_elm = get_elm_from_comm_object_(comm["to"], false);

CommKey key(
CommKey::NodeToCollectionTag{},
Expand All @@ -449,9 +485,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "node");

auto from_object = comm["from"]["id"];
vtAssertExpr(from_object.is_number());
auto from_elm = ElementIDStruct{from_object, this_node};
auto from_elm = get_elm_from_comm_object_(comm["from"], false);

auto to_node = comm["to"]["id"];
vtAssertExpr(to_node.is_number());
Expand Down
11 changes: 1 addition & 10 deletions tests/unit/lb/test_lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,6 @@ void test_data_holder_elms(int seq_id, int home, int node, bool migratable) {
}

TEST_F(TestLBDataHolder, test_lb_data_holder_no_comms_object_id) {
// Initialize
int argc = 0;
char** argv = nullptr;
MPI_Comm comm = MPI_COMM_WORLD;
vt::initialize(argc, argv, &comm);

// Run a variety of test cases (seq_id, home, node, migratable)
test_data_holder_elms(0,0,0,false);
test_data_holder_elms(0,0,0,true);
Expand All @@ -120,9 +114,6 @@ TEST_F(TestLBDataHolder, test_lb_data_holder_no_comms_object_id) {
test_data_holder_elms(1,1,0,false);
test_data_holder_elms(2,1,9,true);
test_data_holder_elms(3,0,1,false);

// Finalize
// vt::finalize();
}

}}}} // end namespace vt::tests::unit::lb
}}}} // end namespace vt::tests::unit::lb

0 comments on commit 6b52e07

Please sign in to comment.