Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#2342: Change schema to allow for id or seq_id #2343

Merged
merged 10 commits into from
Sep 9, 2024
7 changes: 4 additions & 3 deletions scripts/JSON_data_files_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,11 +435,12 @@ def validate_comm_links(all_jsons):
for data in all_jsons:
if data["phases"][n].get("communications") is not None:
comms = data["phases"][n]["communications"]
comm_ids.update({int(comm["from"]["id"]) for comm in comms})
comm_ids.update({int(comm["to"]["id"]) for comm in comms})
id_key = "id" if "id" in comms[0]["from"] else "seq_id"
comm_ids.update({int(comm["from"][id_key]) for comm in comms})
comm_ids.update({int(comm["to"][id_key]) for comm in comms})

tasks = data["phases"][n]["tasks"]
task_ids.update({int(task["entity"]["id"]) for task in tasks})
task_ids.update({int(task["entity"][id_key]) for task in tasks})

if not comm_ids.issubset(task_ids):
logging.error(
Expand Down
27 changes: 18 additions & 9 deletions scripts/LBDatafile_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from schema import And, Optional, Schema

def validate_id_and_seq_id(field):
"""Ensure that either seq_id or id is provided."""
if 'seq_id' not in field and 'id' not in field:
raise ValueError('Either id (bit-encoded) or seq_id must be provided.')
return field

LBDatafile_schema = Schema(
{
Optional('type'): And(str, "LBDatafile", error="'LBDatafile' must be chosen."),
Expand Down Expand Up @@ -30,15 +36,16 @@
'id': int,
'tasks': [
{
'entity': {
'entity': And({
Optional('collection_id'): int,
'home': int,
'id': int,
Optional('id'): int,
Optional('seq_id'): int,
Optional('index'): [int],
'type': str,
'migratable': bool,
Optional('objgroup_id'): int
},
}, validate_id_and_seq_id),
'node': int,
'resource': str,
Optional('subphases'): [
Expand All @@ -55,25 +62,27 @@
Optional('communications'): [
{
'type': str,
'to': {
'to': And({
'type': str,
'id': int,
Optional('id'): int,
Optional('seq_id'): int,
Optional('home'): int,
Optional('collection_id'): int,
Optional('migratable'): bool,
Optional('index'): [int],
Optional('objgroup_id'): int,
},
}, validate_id_and_seq_id),
'messages': int,
'from': {
'from': And({
'type': str,
'id': int,
Optional('id'): int,
Optional('seq_id'): int,
Optional('home'): int,
Optional('collection_id'): int,
Optional('migratable'): bool,
Optional('index'): [int],
Optional('objgroup_id'): int,
},
}, validate_id_and_seq_id),
'bytes': float
}
],
Expand Down
120 changes: 81 additions & 39 deletions src/vt/vrt/collection/balance/lb_data_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,72 @@
//@HEADER
*/

#include "vt/vrt/collection/balance/lb_data_holder.h"
#include "vt/context/context.h"
#include "vt/elm/elm_id_bits.h"
#include "vt/vrt/collection/balance/lb_data_holder.h"

#include <nlohmann/json.hpp>

namespace vt { namespace vrt { namespace collection { namespace balance {

void get_object_from_json_field_(
const nlohmann::json& field, nlohmann::json& object, bool& is_bitpacked,
bool& is_collection) {
if (field.find("id") != field.end()) {
object = field["id"];
is_bitpacked = true;
} else {
object = field["seq_id"];
is_bitpacked = false;
}
if (field.find("collection_id") != field.end()) {
is_collection = true;
} else {
is_collection = false;
}
}
nlslatt marked this conversation as resolved.
Show resolved Hide resolved

ElementIDStruct get_elm_from_object_info_(
const nlohmann::json& object, bool is_bitpacked, bool is_migratable,
const nlohmann::json& home) {
using Field = uint64_t;
nlslatt marked this conversation as resolved.
Show resolved Hide resolved

Field object_id;
if (is_bitpacked) {
object_id = BitPackerType::getField<
vt::elm::eElmIDProxyBitsNonObjGroup::ID, vt::elm::elm_id_num_bits, Field>(
static_cast<Field>(object));
} else {
object_id = static_cast<Field>(object);
}

return elm::ElmIDBits::createCollectionImpl(
is_migratable, object_id, home, theContext()->getNode());
nlslatt marked this conversation as resolved.
Show resolved Hide resolved
}

ElementIDStruct
get_elm_from_comm_object_(const nlohmann::json& field) {
// Get the object's id and determine if it is bit-encoded
nlohmann::json object;
bool is_bitpacked;
bool is_collection;
get_object_from_json_field_(field, object, is_bitpacked, is_collection);
vtAssertExpr(object.is_number());

// Create elm with encoded data
ElementIDStruct elm;
if (is_collection) {
int home = field["home"];
bool is_migratable = field["migratable"];
elm = get_elm_from_object_info_(
object, is_bitpacked, is_migratable, home);
} else {
elm = ElementIDStruct{object, theContext()->getNode()};
}

return elm;
}

void LBDataHolder::outputEntity(nlohmann::json& j, ElementIDStruct const& id) const {
j["type"] = "object";
j["id"] = id.id;
Expand Down Expand Up @@ -278,8 +336,6 @@ std::unique_ptr<nlohmann::json> LBDataHolder::toJson(PhaseType phase) const {

LBDataHolder::LBDataHolder(nlohmann::json const& j)
{
auto this_node = theContext()->getNode();

// read metadata for skipped and identical phases
readMetadata(j);

Expand All @@ -298,38 +354,33 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
auto time = task["time"];
auto etype = task["entity"]["type"];
auto home = task["entity"]["home"];
bool migratable = task["entity"]["migratable"];
bool is_migratable = task["entity"]["migratable"];

vtAssertExpr(time.is_number());
vtAssertExpr(node.is_number());

if (etype == "object") {
auto object = task["entity"]["id"];
nlohmann::json object;
bool is_bitpacked, is_collection;
get_object_from_json_field_(task["entity"], object, is_bitpacked, is_collection);
vtAssertExpr(object.is_number());

auto elm = ElementIDStruct{object, node};

if (
task["entity"].find("collection_id") != task["entity"].end() and
task["entity"].find("index") != task["entity"].end()
) {
using Field = uint64_t;
auto strippedObject = BitPackerType::getField<
vt::elm::eElmIDProxyBitsNonObjGroup::ID,
vt::elm::elm_id_num_bits,
Field
>(static_cast<Field>(object));
elm = elm::ElmIDBits::createCollectionImpl(migratable,
strippedObject,
home,
node);
// Creating elm from `tasks` field
ElementIDStruct elm;
if (is_collection) {
elm = get_elm_from_object_info_(
object, is_bitpacked, is_migratable, home);
auto cid = task["entity"]["collection_id"];
auto idx = task["entity"]["index"];
if (cid.is_number() && idx.is_array()) {
std::vector<uint64_t> arr = idx;
auto proxy = static_cast<VirtualProxyType>(cid);
this->node_idx_[elm] = std::make_tuple(proxy, arr);
if (task["entity"].find("index") != task["entity"].end()) {
auto idx = task["entity"]["index"];
if (cid.is_number() && idx.is_array()) {
std::vector<uint64_t> arr = idx;
auto proxy = static_cast<VirtualProxyType>(cid);
this->node_idx_[elm] = std::make_tuple(proxy, arr);
}
}
} else {
elm = ElementIDStruct{object, node};
}

this->node_data_[id][elm].whole_phase_load = time;
Expand Down Expand Up @@ -397,13 +448,8 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "object");

auto from_object = comm["from"]["id"];
vtAssertExpr(from_object.is_number());
auto from_elm = ElementIDStruct{from_object, this_node};

auto to_object = comm["to"]["id"];
vtAssertExpr(to_object.is_number());
auto to_elm = ElementIDStruct{to_object, this_node};
auto from_elm = get_elm_from_comm_object_(comm["from"]);
auto to_elm = get_elm_from_comm_object_(comm["to"]);

CommKey key(
CommKey::CollectionTag{},
Expand All @@ -420,9 +466,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
auto from_node = comm["from"]["id"];
vtAssertExpr(from_node.is_number());

auto to_object = comm["to"]["id"];
vtAssertExpr(to_object.is_number());
auto to_elm = ElementIDStruct{to_object, this_node};
auto to_elm = get_elm_from_comm_object_(comm["to"]);

CommKey key(
CommKey::NodeToCollectionTag{},
Expand All @@ -437,9 +481,7 @@ LBDataHolder::LBDataHolder(nlohmann::json const& j)
vtAssertExpr(comm["from"]["type"] == "object");
vtAssertExpr(comm["to"]["type"] == "node");

auto from_object = comm["from"]["id"];
vtAssertExpr(from_object.is_number());
auto from_elm = ElementIDStruct{from_object, this_node};
auto from_elm = get_elm_from_comm_object_(comm["from"]);

auto to_node = comm["to"]["id"];
vtAssertExpr(to_node.is_number());
Expand Down
Loading
Loading