Skip to content

Commit

Permalink
Updated test_simulation.py to properly account for averages with R_t
Browse files Browse the repository at this point in the history
  • Loading branch information
mghosh00 committed Mar 7, 2024
1 parent c6a9c58 commit 49cf5e7
Showing 1 changed file with 38 additions and 14 deletions.
52 changes: 38 additions & 14 deletions pyEpiabm/pyEpiabm/tests/test_unit/test_routine/test_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def setUpClass(cls) -> None:
cls.pop_params = {"population_size": 1, "cell_number": 1,
"microcell_number": 1, "household_number": 1}
cls.test_population = cls.pop_factory.make_pop(cls.pop_params)
cls.rt_pop_params = {"population_size": 3, "cell_number": 1,
"microcell_number": 1, "household_number": 1}
cls.rt_test_population = cls.pop_factory.make_pop(cls.rt_pop_params)
pe.Parameters.instance().time_steps_per_day = 1
cls.sim_params = {"simulation_start_time": 0,
"simulation_end_time": 1,
Expand Down Expand Up @@ -563,31 +566,52 @@ def test_write_to_Rt_file(self, mock_mkdir, time=1):
self.inf_history_params['secondary_infections_output'] = True
with patch('pyEpiabm.output._csv_dict_writer.open', mo):
test_sim = pe.routine.Simulation()
test_sim.configure(self.test_population, self.initial_sweeps,
test_sim.configure(self.rt_test_population, self.initial_sweeps,
self.sweeps, self.sim_params, self.file_params,
self.inf_history_params)
person = self.test_population.cells[0].persons[0]
person.num_times_infected = 1
person.infection_start_times = [1.0]
person.secondary_infections_counts = [5]
dict_1 = {"time": 0, "0.0.0.0": np.nan, "R_t": np.nan}
dict_2 = {"time": 1, "0.0.0.0": 5.0, "R_t": 5.0}
person1 = self.rt_test_population.cells[0].persons[0]
person1.num_times_infected = 1
person1.infection_start_times = [1.0]
person1.secondary_infections_counts = [5]
person2 = self.rt_test_population.cells[0].persons[1]
person2.num_times_infected = 1
person2.infection_start_times = [0.0]
person2.secondary_infections_counts = [7]
person3 = self.rt_test_population.cells[0].persons[2]
person3.num_times_infected = 1
person3.infection_start_times = [1.0]
person3.secondary_infections_counts = [8]
dict_1 = {"time": 0, "0.0.0.0": np.nan, "0.0.0.1": 7.0,
"0.0.0.2": np.nan, "R_t": 7.0}
dict_2 = {"time": 1, "0.0.0.0": 5.0, "0.0.0.1": np.nan,
"0.0.0.2": 8.0, "R_t": 6.5}
dict_3 = {"time": 2, "0.0.0.0": np.nan, "0.0.0.1": np.nan,
"0.0.0.2": np.nan, "R_t": np.nan}

with patch('pyEpiabm.output._csv_dict_writer'
'._CsvDictWriter.write') as mock_write:
test_sim.write_to_Rt_file(np.array([1]))
test_sim.write_to_Rt_file(np.array([1, 2]))
calls = mock_write.call_args_list
# Need to use np.testing for the NaNs
# Need to test keys and values separately in case we are using
# python 3.7 (for which np.testing.assert_equal will not work)
if sys.version_info[1] > 7:
actual_dict = calls[0].args[0]
if sys.version_info[0] >= 3 or sys.version_info[1] >= 8:
actual_dict_1 = calls[0].args[0]
for key in dict_1:
self.assertTrue(key in actual_dict.keys())
self.assertTrue(key in actual_dict_1)
np.testing.assert_array_equal(dict_1[key],
actual_dict[key])
self.assertEqual(calls[1], call(dict_2))
self.assertEqual(mock_write.call_count, 2)
actual_dict_1[key])
actual_dict_2 = calls[1].args[0]
for key in dict_2:
self.assertTrue(key in actual_dict_2)
np.testing.assert_array_equal(dict_2[key],
actual_dict_2[key])
actual_dict_3 = calls[2].args[0]
for key in dict_3:
self.assertTrue(key in actual_dict_3)
np.testing.assert_array_equal(dict_3[key],
actual_dict_3[key])
self.assertEqual(mock_write.call_count, 3)
mock_mkdir.assert_called_with(
os.path.join(os.getcwd(), self.inf_history_params["output_dir"]))

Expand Down

0 comments on commit 49cf5e7

Please sign in to comment.