diff --git a/swmmanywhere/paper/experimenter.py b/swmmanywhere/paper/experimenter.py index aae3e03a..55947887 100644 --- a/swmmanywhere/paper/experimenter.py +++ b/swmmanywhere/paper/experimenter.py @@ -118,7 +118,12 @@ def process_parameters(jobid: int, """Generate and run parameter samples for the sensitivity analysis. This function generates parameter samples and runs the swmmanywhere model - for each sample. It is designed to be run in parallel as a jobarray. + for each sample. It is designed to be run in parallel as a jobarray. It + selects parameters values from the generated ones based on the jobid and + the number of processors. It copies the config file and passes these + parameters into swmmanywhere via the `parameter_overrides` property. Existing + overrides that are not being sampled are retained, existing overrides that + are being sampled are overwritten by the sampled value. Args: jobid (int): The job id. @@ -160,11 +165,16 @@ def process_parameters(jobid: int, "param", "value"]].itertuples(index=False, name=None): - if grp not in overrides: + + # Experimenter overrides take precedence over the config file + if grp in config.get('parameter_overrides',{}): + overrides[grp] = config['parameter_overrides'][grp] + elif grp not in overrides: overrides[grp] = {} - overrides[grp][param] = val - config['parameter_overrides'].update(overrides) + overrides[grp][param] = val + config['parameter_overrides'] = overrides + # Run the model config['model_number'] = ix logger.info(f"Running swmmanywhere for model {ix}") diff --git a/tests/test_experimenter.py b/tests/test_experimenter.py index 37c40bc6..afd5cfb3 100644 --- a/tests/test_experimenter.py +++ b/tests/test_experimenter.py @@ -1,6 +1,8 @@ """Tests for the main experimenter.""" from __future__ import annotations +from unittest import mock + import numpy as np from swmmanywhere import parameters @@ -41,4 +43,38 @@ def test_generate_samples(): seed = 1, groups = True) assert len(samples) == 36 - \ No newline at end of file + +def test_process_parameters(): + """Test process_parameters.""" + config = {'parameters_to_sample' : ['min_v','max_v'], + 'sample_magnitude' : 3, + } + + # Test standard + with mock.patch('swmmanywhere.paper.experimenter.swmmanywhere.swmmanywhere', + return_value=('fake_path',{'fake_metric' : 1})) as mock_sa: + result = experimenter.process_parameters(0,1,config) + + assert len(result[0]) == 48 + assert_close(result[0][0]['min_v'], 0.310930) + + # Test experimenter takes precedence over overrides + config['parameter_overrides'] = {'hydraulic_design': {'min_v': 1.0}} + with mock.patch('swmmanywhere.paper.experimenter.swmmanywhere.swmmanywhere', + return_value=('fake_path',{'fake_metric' : 1})) as mock_sa: + result = experimenter.process_parameters(0,1,config) + + assert len(result[0]) == 48 + assert_close(result[0][0]['min_v'], 0.310930) + + # Test non experimenter overrides still work + config['parameter_overrides'] = {'hydraulic_design': {'max_fr': 0.5}} + with mock.patch('swmmanywhere.paper.experimenter.swmmanywhere.swmmanywhere', + return_value=('fake_path',{'fake_metric' : 1})) as mock_sa: + result = experimenter.process_parameters(0,1,config) + + for call in mock_sa.mock_calls: + assert call.args[0]['parameter_overrides']['hydraulic_design']['max_fr'] == 0.5 + + assert len(result[0]) == 48 + assert_close(result[0][0]['min_v'], 0.310930) \ No newline at end of file