Skip to content

Commit

Permalink
#482: add unit test for communications data
Browse files Browse the repository at this point in the history
  • Loading branch information
cwschilly committed Dec 19, 2023
1 parent 2a9ac0f commit 8599231
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/lbaf/IO/lbsVTDataReader.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,10 @@ def _populate_rank(self, phase_id: int, rank_id: int) -> tuple:
rank_comm = {}
communications = phase.get("communications") # pylint:disable=W0631:undefined-loop-variable
if communications:
self.__communications_dict[phase_id] = {rank_id: communications}
if phase_id in self.__communications_dict:
self.__communications_dict[phase_id][rank_id] = communications
else:
self.__communications_dict[phase_id] = {rank_id: communications}
for num, comm in enumerate(communications):
# Retrieve communication attributes
c_type = comm.get("type")
Expand Down
2 changes: 0 additions & 2 deletions src/lbaf/IO/lbsVTDataWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,11 @@ def _json_serializer(self, rank_phases_double) -> str:
# Get metadata
if current_phase.get_metadata()[r_id]:
metadata = current_phase.get_metadata()[r_id]
print(f"phase.get_metadata(): {metadata}")
else:
metadata = {
"type": "LBDatafile",
"rank": r_id
}
print(metadata)

# Initialize output dict
output = {
Expand Down
78 changes: 78 additions & 0 deletions tests/unit/IO/test_vt_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,84 @@ def test_vt_writer_required_fields_output(self):
# the following is an assert alternative to compare json encoded data instead of dictionaries
# self.assertEqual(json.dumps(input_data, indent=4), json.dumps(output_data, indent=4))

def test_vt_writer_communications_output(self):
"""Tests that LBAF writes out the correct communications data."""

# run LBAF
config_file = os.path.join(os.path.dirname(__file__), "config", "conf_vt_writer_communications_test.yaml")
proc = subprocess.run(["python", "src/lbaf", "-c", config_file], check=True)
self.assertEqual(0, proc.returncode)

# LBAF config useful information
with open(config_file, "rt", encoding="utf-8") as file_io:
config = yaml.safe_load(file_io)
data_stem = config.get("from_data").get("data_stem")

# input information
input_dir = abspath(f"{os.sep}".join(data_stem.split(os.sep)[:-1]), os.path.dirname(config_file))
input_file_prefix = data_stem.split(os.sep)[-1]
n_ranks = len([name for name in os.listdir(input_dir)])

# output information
output_dir = abspath(config.get("output_dir", '.'), os.path.dirname(config_file))
output_file_prefix = config.get("output_file_stem")

# count total communications
input_communication_count = 0
output_communication_count = 0

# compare input/output files (at each rank)
for i in range(0, n_ranks):
input_file_name = f"{input_file_prefix}.{i}.json"
input_file = os.path.join(input_dir, input_file_name)
output_file_name = f"{output_file_prefix}.{i}.json"
output_file = os.path.join(output_dir, output_file_name)

print(f"[{__loader__.name}] Compare input file ({input_file_name}) and output file ({output_file_name})...")

# validate that output file exists at rank i
self.assertTrue(
os.path.isfile(output_file),
f"File {output_file} not generated at {output_dir}"
)

# read input and output files
input_data = self.__read_data_file(input_file)
output_data = self.__read_data_file(output_file)

# validate output against the JSON schema validator
schema_validator = SchemaValidator(schema_type="LBDatafile")
self.assertTrue(
schema_validator.validate(output_data),
f"Schema not valid for generated file at {output_file_name}"
)

# get the first phase dict (this config only has one phase)
input_phase_dict = input_data["phases"][0]
output_phase_dict = output_data["phases"][0]

# increment the input communication counter
if "communications" in input_phase_dict:
input_communication_data = input_phase_dict["communications"]
input_communication_count += len(input_communication_data)

# increment the output communication counter
if "communications" in output_phase_dict:
output_communication_data = output_phase_dict["communications"]
output_communication_count += len(output_communication_data)

# get list of all objects on this rank
rank_objs = []
tasks = output_phase_dict["tasks"]
for task in tasks:
rank_objs.append(task["entity"].get("id"))

# Make sure all communicating objects belong on this rank
for comm_dict in output_communication_data:
comm_obj = comm_dict["from"]["id"]
self.assertIn(comm_obj, rank_objs, f"Object {comm_obj} is not on rank {i}")

self.assertEqual(input_communication_count, output_communication_count)

if __name__ == "__main__":
unittest.main()
37 changes: 37 additions & 0 deletions tests/unit/config/conf_vt_writer_communications_test.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Specify input
from_data:
data_stem: ../../../data/synthetic_lb_data/data
phase_ids:
- 0
check_schema: false

# Specify work model
work_model:
name: AffineCombination
parameters:
alpha: 0.0
beta: 1.0
gamma: 0.0

# Specify algorithm
algorithm:
name: InformAndTransfer
phase_id: 0
parameters:
n_iterations: 8
n_rounds: 2
fanout: 2
order_strategy: arbitrary
transfer_strategy: Recursive
criterion: Tempered
max_objects_per_transfer: 8
deterministic_transfer: true

# Specify output
output_dir: ../output/vt_writer_communications_test
output_file_stem: output_file
write_JSON:
compressed: false
suffix: json
communications: true
offline_LB_compatible: false

0 comments on commit 8599231

Please sign in to comment.