Skip to content

Commit

Permalink
#2342: simplify logic by creating elms directly from encoded ids
Browse files Browse the repository at this point in the history
  • Loading branch information
cwschilly committed Sep 9, 2024
1 parent 4741dc8 commit 411ccdb
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 219 deletions.
67 changes: 25 additions & 42 deletions src/vt/vrt/collection/balance/lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@

namespace vt { namespace vrt { namespace collection { namespace balance {

void get_object_from_json_field_(
const nlohmann::json& field, nlohmann::json& object, bool& is_bitpacked,
void LBDataHolder::getObjectFromJsonField_(
nlohmann::json const& field, nlohmann::json& object, bool& is_bitpacked,
bool& is_collection) {
if (field.find("id") != field.end()) {
object = field["id"];
Expand All @@ -59,49 +59,31 @@ void get_object_from_json_field_(
object = field["seq_id"];
is_bitpacked = false;
}
vtAssertExpr(object.is_number());
if (field.find("collection_id") != field.end()) {
is_collection = true;
} else {
is_collection = false;
}
}

ElementIDStruct get_elm_from_object_info_(
const nlohmann::json& object, bool is_bitpacked, bool is_migratable,
const nlohmann::json& home) {
using Field = uint64_t;

Field object_id;
if (is_bitpacked) {
object_id = BitPackerType::getField<
vt::elm::eElmIDProxyBitsNonObjGroup::ID, vt::elm::elm_id_num_bits, Field>(
static_cast<Field>(object));
} else {
object_id = static_cast<Field>(object);
}

return elm::ElmIDBits::createCollectionImpl(
is_migratable, object_id, home, theContext()->getNode());
}

ElementIDStruct
get_elm_from_comm_object_(const nlohmann::json& field) {
LBDataHolder::getElmFromCommObject_(
nlohmann::json const& field) const {
// Get the object's id and determine if it is bit-encoded
nlohmann::json object;
bool is_bitpacked;
bool is_collection;
get_object_from_json_field_(field, object, is_bitpacked, is_collection);
vtAssertExpr(object.is_number());
bool is_bitpacked, is_collection;
getObjectFromJsonField_(field, object, is_bitpacked, is_collection);

// Create elm with encoded data
ElementIDStruct elm;
if (is_collection) {
if (is_collection and not is_bitpacked) {
int home = field["home"];
bool is_migratable = field["migratable"];
elm = get_elm_from_object_info_(
object, is_bitpacked, is_migratable, home);
elm = elm::ElmIDBits::createCollectionImpl(
is_migratable, static_cast<ElementIDType>(object), home, this_node_);
} else {
elm = ElementIDStruct{object, theContext()->getNode()};
elm = ElementIDStruct{object, this_node_};
}

return elm;
Expand Down Expand Up @@ -336,6 +318,8 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

LBDataHolder::LBDataHolder(nlohmann::json const& j)
{
this_node_ = theContext()->getNode();

// read metadata for skipped and identical phases
readMetadata(j);

Expand All @@ -362,14 +346,16 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
if (etype == "object") {
nlohmann::json object;
bool is_bitpacked, is_collection;
get_object_from_json_field_(task["entity"], object, is_bitpacked, is_collection);
vtAssertExpr(object.is_number());
getObjectFromJsonField_(task["entity"], object, is_bitpacked, is_collection);

// Create elm
ElementIDStruct elm = is_collection and not is_bitpacked
? elm::ElmIDBits::createCollectionImpl(
is_migratable, static_cast<ElementIDType>(object), home, this_node_)
: ElementIDStruct{object, this_node_};
this->node_data_[id][elm].whole_phase_load = time;

// Creating elm from `tasks` field
ElementIDStruct elm;
if (is_collection) {
elm = get_elm_from_object_info_(
object, is_bitpacked, is_migratable, home);
auto cid = task["entity"]["collection_id"];
if (task["entity"].find("index") != task["entity"].end()) {
auto idx = task["entity"]["index"];
Expand All @@ -379,11 +365,8 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
this->node_idx_[elm] = std::make_tuple(proxy, arr);
}
}
} else {
elm = ElementIDStruct{object, node};
}

this->node_data_[id][elm].whole_phase_load = time;

if (task.find("subphases") != task.end()) {
auto subphases = task["subphases"];
Expand Down Expand Up @@ -448,8 +431,8 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "object");

auto from_elm = get_elm_from_comm_object_(comm["from"]);
auto to_elm = get_elm_from_comm_object_(comm["to"]);
auto from_elm = getElmFromCommObject_(comm["from"]);
auto to_elm = getElmFromCommObject_(comm["to"]);

CommKey key(
CommKey::CollectionTag{},
Expand All @@ -466,7 +449,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
auto from_node = comm["from"]["id"];
vtAssertExpr(from_node.is_number());

auto to_elm = get_elm_from_comm_object_(comm["to"]);
auto to_elm = getElmFromCommObject_(comm["to"]);

CommKey key(
CommKey::NodeToCollectionTag{},
Expand All @@ -481,7 +464,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "node");

auto from_elm = get_elm_from_comm_object_(comm["from"]);
auto from_elm = getElmFromCommObject_(comm["from"]);

auto to_node = comm["to"]["id"];
vtAssertExpr(to_node.is_number());
Expand Down
25 changes: 25 additions & 0 deletions src/vt/vrt/collection/balance/lb_data_holder.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,29 @@ struct LBDataHolder {

void addInitialTask(nlohmann::json& j, std::size_t n) const;

/**
* \brief Determine the object ID from the tasks or communication field of
* input JSON
*
* \param[in] field the json field containing an object ID
* \param[in] object empty json object to be populated with the object's ID
* \param[in] is_bitpacked empty bool to be populated with whether or not
* the ID is bit-encoded
* \param[in] is_collection empty bool to be populated with whether
* or not the object belongs to a collection
*/
static void getObjectFromJsonField_(
nlohmann::json const& field, nlohmann::json& object,
bool& is_bitpacked, bool& is_collection);

/**
* \brief Create an ElementIDStruct for the communication object
*
* \param[in] field the communication field for the desired object
* e.g. communications["to"] or communications["from"]
*/
ElementIDStruct getElmFromCommObject_(nlohmann::json const& field) const;

/**
* \brief Read the LB phase's metadata
*
Expand All @@ -135,6 +158,8 @@ struct LBDataHolder {
void readMetadata(nlohmann::json const& j);

public:
/// The current node
NodeType this_node_;
/// Node attributes for the current rank
ElmUserDataType rank_attributes_;
/// Node timings for each local object
Expand Down
112 changes: 112 additions & 0 deletions tests/unit/collection/test_lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
#include "test_helpers.h"
#include "test_collection_common.h"

#include "vt/elm/elm_id_bits.h"
#include "vt/vrt/collection/manager.h"
#include "vt/vrt/collection/balance/lb_data_holder.h"

Expand Down Expand Up @@ -90,6 +91,104 @@ void addPhasesDataToJson(nlohmann::json& json, PhaseType amountOfPhasesToAdd, st
json["phases"] = phases;
}

nlohmann::json createEntity_(std::string id_type, int id, int home, bool is_migratable) {
nlohmann::json entity = {
{id_type, id},
{"type", "object"},
{"collection_id", 7},
{"index", {0}},
{"home", home},
{"migratable", is_migratable},
};
return entity;
}

nlohmann::json createJson_(
std::string id_type, int id_1, int id_2, int home, int node,
bool is_migratable) {

auto entity_1 = createEntity_(id_type, id_1, home, is_migratable);
auto entity_2 = createEntity_(id_type, id_2, home, is_migratable);

// Generate JSON
nlohmann::json j = {
{"metadata", {{"rank", 0}, {"type", "LBDatafile"}}},
{"phases",
{{{"communications",
{{{"bytes", 2.0},
{"from", entity_1},
{"messages", 1},
{"to", entity_2},
{"type", "SendRecv"}}}},
{"id", 0},
{"tasks",
{
{
{"entity", entity_1},
{"node", node},
{"resource", "cpu"},
{"time", 0.5},
},
{
{"entity", entity_2},
{"node", node},
{"resource", "cpu"},
{"time", 0.5},
}
}
}}}
}
};

return j;
}

void testDataHolderElms(int seq_id_1, int home, int node, bool is_migratable) {
// Create a second seq_id
auto seq_id_2 = seq_id_1 + 1;

// Determine encoded ID
auto elm_1 =
elm::ElmIDBits::createCollectionImpl(is_migratable, seq_id_1, home, node);
auto encoded_id_1 = elm_1.id;

// Create second encoded ID
auto elm_2 =
elm::ElmIDBits::createCollectionImpl(is_migratable, seq_id_2, home, node);
auto encoded_id_2 = elm_2.id;

// Create DataHolder and get resulting object elm
auto simple_json_id = createJson_(
"id", encoded_id_1, encoded_id_2, home, node, is_migratable);
auto dh_id = vt::vrt::collection::balance::LBDataHolder(simple_json_id);

// Create new DataHolder using "seq_id" and get elm
auto simple_json_seq =
createJson_("seq_id", seq_id_1, seq_id_2, home, node, is_migratable);
auto dh_seq = vt::vrt::collection::balance::LBDataHolder(simple_json_seq);

// Assert that both elms exist in both DataHolders
ASSERT_NE(dh_id.node_data_[0].find(elm_1), dh_id.node_data_[0].end());
ASSERT_NE(dh_seq.node_data_[0].find(elm_1), dh_seq.node_data_[0].end());
ASSERT_NE(dh_id.node_data_[0].find(elm_2), dh_id.node_data_[0].end());
ASSERT_NE(dh_seq.node_data_[0].find(elm_2), dh_seq.node_data_[0].end());

// Check the communication data
auto comm_id = dh_id.node_comm_[0];
auto comm_key_id = comm_id.begin()->first;
auto comm_seq = dh_seq.node_comm_[0];
auto comm_key_seq = comm_seq.begin()->first;

// Ensure that we get the same CommKey from both id types
EXPECT_EQ(comm_key_id, comm_key_seq);

// Assert that both elms are present in the communication data
EXPECT_EQ(comm_key_id.fromObj(), elm_1);
EXPECT_EQ(comm_key_id.toObj(), elm_2);
EXPECT_EQ(comm_key_seq.fromObj(), elm_1);
EXPECT_EQ(comm_key_seq.toObj(), elm_2);
}

TEST_F(TestLBDataHolder, test_no_metadata) {
using LBDataHolder = vt::vrt::collection::balance::LBDataHolder;

Expand Down Expand Up @@ -302,6 +401,19 @@ TEST_F(TestLBDataHolder, test_default_time_format) {
}
}

TEST_F(TestLBDataHolder, test_lb_data_holder_object_id) {
// Run a variety of test cases (seq_id, home, node, is_migratable)
auto current_node = theContext()->getNode();
testDataHolderElms(0, 0, current_node, false);
testDataHolderElms(0, 0, current_node, true);
testDataHolderElms(1, 0, current_node, false);
testDataHolderElms(1, 0, current_node, true);
testDataHolderElms(0, 1, current_node, false);
testDataHolderElms(0, 1, current_node, true);
testDataHolderElms(3, 1, current_node, false);
testDataHolderElms(2, 2, current_node, true);
}

}}}} // end namespace vt::tests::unit::lb

#endif /*vt_check_enabled(lblite)*/
Loading

0 comments on commit 411ccdb

Please sign in to comment.