diff --git a/nestcheck/data_processing.py b/nestcheck/data_processing.py index 35a102b..ef33a48 100644 --- a/nestcheck/data_processing.py +++ b/nestcheck/data_processing.py @@ -68,8 +68,8 @@ >= v3.11, and is used in the output processing for these packages via the ``birth_inds_given_contours`` and ``threads_given_birth_inds`` functions. Also sufficient is a list of the indexes of the point which was removed -at the step when each point was sampled ("birth indexes"), as this can be mapped -to the birth contours and vice versa. +at the step when each point was sampled ("birth indexes"), as this can be +mapped to the birth contours and vice versa. ``process_dynesty_run`` does not require the ``birth_inds_given_contours`` and ``threads_given_birth_inds`` functions as ``dynesty`` results objects @@ -95,6 +95,7 @@ import copy import numpy as np import nestcheck.io_utils +import nestcheck.ns_run_utils import nestcheck.parallel_utils @@ -103,8 +104,9 @@ def batch_process_data(file_roots, **kwargs): """Process output from many nested sampling runs in parallel with optional error handling and caching. - The result can be cached using the 'save_name', 'save' and 'load' kwargs (by - default this is not done). See save_load_result docstring for more details. + The result can be cached using the 'save_name', 'save' and 'load' kwargs + (by default this is not done). See save_load_result docstring for more + details. Remaining kwargs passed to parallel_utils.parallel_apply (see its docstring for more details). @@ -228,7 +230,7 @@ def process_polychord_run(file_root, base_dir, process_stats_file=True, don't have the .stats file (such as if PolyChord was run with write_stats=False). kwargs: dict, optional - Options passed to check_ns_run + Options passed to ns_run_utils.check_ns_run. Returns ------- @@ -269,7 +271,7 @@ def process_multinest_run(file_root, base_dir, **kwargs): Directory containing output files. When running MultiNest, this is determined by the nest_root parameter. kwargs: dict, optional - Passed to check_ns_run (via process_samples_array) + Passed to ns_run_utils.check_ns_run (via process_samples_array) Returns @@ -351,7 +353,7 @@ def process_dynesty_run(results): samples[final_ind, 2] = -1 assert np.all(~np.isnan(thread_min_max)) run = nestcheck.ns_run_utils.dict_given_run_array(samples, thread_min_max) - nestcheck.data_processing.check_ns_run(run) + nestcheck.ns_run_utils.check_ns_run(run) return run @@ -441,7 +443,8 @@ def process_samples_array(samples, **kwargs): birth_contours = samples[:, -1] # birth_contours, ns_run['theta'] = check_logls_unique( # samples[:, -2], samples[:, -1], samples[:, :-2]) - birth_inds = birth_inds_given_contours(birth_contours, ns_run['logl'], **kwargs) + birth_inds = birth_inds_given_contours( + birth_contours, ns_run['logl'], **kwargs) ns_run['thread_labels'] = threads_given_birth_inds(birth_inds) unique_threads = np.unique(ns_run['thread_labels']) assert np.array_equal(unique_threads, @@ -499,9 +502,9 @@ def birth_inds_given_contours(birth_logl_arr, logl_arr, **kwargs): Birth contours - i.e. logl values of the iso-likelihood contour from within each point was sampled (on which it was born). dup_assert: bool, optional - See check_ns_run_logls docstring. + See ns_run_utils.check_ns_run_logls docstring. dup_warn: bool, optional - See check_ns_run_logls docstring. + See ns_run_utils.check_ns_run_logls docstring. Returns ------- @@ -516,8 +519,8 @@ def birth_inds_given_contours(birth_logl_arr, logl_arr, **kwargs): assert logl_arr.ndim == 1, logl_arr.ndim assert birth_logl_arr.ndim == 1, birth_logl_arr.ndim # Check for duplicate logl values (if specified by dup_assert or dup_warn) - check_ns_run_logls({'logl': logl_arr}, dup_assert=dup_assert, - dup_warn=dup_warn) + nestcheck.ns_run_utils.check_ns_run_logls( + {'logl': logl_arr}, dup_assert=dup_assert, dup_warn=dup_warn) # Random seed so results are consistent if there are duplicate logls state = np.random.get_state() # Save random state before seeding np.random.seed(0) @@ -692,142 +695,3 @@ def threads_given_birth_inds(birth_inds): str(np.unique(thread_labels)) + ' is not equal to range(' + str(sum(thread_start_counts)) + ')') return thread_labels - - -# Functions for checking nestcheck format nested sampling run dictionaries to -# ensure they have the expected properties. - - -def check_ns_run(run, dup_assert=False, dup_warn=False): - """Checks a nestcheck format nested sampling run dictionary has the - expected properties (see the module docstring for more details). - - Parameters - ---------- - run: dict - nested sampling run to check. - dup_assert: bool, optional - See check_ns_run_logls docstring. - dup_warn: bool, optional - See check_ns_run_logls docstring. - - - Raises - ------ - AssertionError - if run does not have expected properties. - """ - assert isinstance(run, dict) - check_ns_run_members(run) - check_ns_run_logls(run, dup_assert=dup_assert, dup_warn=dup_warn) - check_ns_run_threads(run) - - -def check_ns_run_members(run): - """Check nested sampling run member keys and values. - - Parameters - ---------- - run: dict - nested sampling run to check. - - Raises - ------ - AssertionError - if run does not have expected properties. - """ - run_keys = list(run.keys()) - # Mandatory keys - for key in ['logl', 'nlive_array', 'theta', 'thread_labels', - 'thread_min_max']: - assert key in run_keys - run_keys.remove(key) - # Optional keys - for key in ['output']: - try: - run_keys.remove(key) - except ValueError: - pass - # Check for unexpected keys - assert not run_keys, 'Unexpected keys in ns_run: ' + str(run_keys) - # Check type of mandatory members - for key in ['logl', 'nlive_array', 'theta', 'thread_labels', - 'thread_min_max']: - assert isinstance(run[key], np.ndarray), ( - key + ' is type ' + type(run[key]).__name__) - # check shapes of keys - assert run['logl'].ndim == 1 - assert run['logl'].shape == run['nlive_array'].shape - assert run['logl'].shape == run['thread_labels'].shape - assert run['theta'].ndim == 2 - assert run['logl'].shape[0] == run['theta'].shape[0] - - -def check_ns_run_logls(run, dup_assert=False, dup_warn=False): - """Check run logls are unique and in the correct order. - - Parameters - ---------- - run: dict - nested sampling run to check. - dup_assert: bool, optional - Whether to raise and AssertionError if there are duplicate logl values. - dup_warn: bool, optional - Whether to give a UserWarning if there are duplicate logl values (only - used if dup_assert is False). - - Raises - ------ - AssertionError - if run does not have expected properties. - """ - assert np.array_equal(run['logl'], run['logl'][np.argsort(run['logl'])]) - if dup_assert or dup_warn: - unique_logls, counts = np.unique(run['logl'], return_counts=True) - repeat_logls = run['logl'].shape[0] - unique_logls.shape[0] - msg = ('{} duplicate logl values (out of a total of {}). This may be ' - 'caused by limited numerical precision in the output files.' - '\nrepeated logls = {}\ncounts = {}\npositions in list of {}' - ' unique logls = {}').format( - repeat_logls, run['logl'].shape[0], - unique_logls[counts != 1], counts[counts != 1], - unique_logls.shape[0], np.where(counts != 1)[0]) - if dup_assert: - assert repeat_logls == 0, msg - elif dup_warn: - if repeat_logls != 0: - warnings.warn(msg, UserWarning) - - -def check_ns_run_threads(run): - """Check thread labels and thread_min_max have expected properties. - - Parameters - ---------- - run: dict - Nested sampling run to check. - - Raises - ------ - AssertionError - If run does not have expected properties. - """ - assert run['thread_labels'].dtype == int - uniq_th = np.unique(run['thread_labels']) - assert np.array_equal( - np.asarray(range(run['thread_min_max'].shape[0])), uniq_th), \ - str(uniq_th) - # Check thread_min_max - assert np.any(run['thread_min_max'][:, 0] == -np.inf), ( - 'Run should have at least one thread which starts by sampling the ' + - 'whole prior') - for th_lab in uniq_th: - inds = np.where(run['thread_labels'] == th_lab)[0] - th_info = 'thread label={}, first_logl={}, thread_min_max={}'.format( - th_lab, run['logl'][inds[0]], run['thread_min_max'][th_lab, :]) - assert run['thread_min_max'][th_lab, 0] <= run['logl'][inds[0]], ( - 'First point in thread has logl less than thread min logl! ' + - th_info + ', difference={}'.format( - run['logl'][inds[0]] - run['thread_min_max'][th_lab, 0])) - assert run['thread_min_max'][th_lab, 1] == run['logl'][inds[-1]], ( - 'Last point in thread logl != thread end logl! ' + th_info) diff --git a/nestcheck/ns_run_utils.py b/nestcheck/ns_run_utils.py index 92fa559..feb7f31 100644 --- a/nestcheck/ns_run_utils.py +++ b/nestcheck/ns_run_utils.py @@ -10,7 +10,6 @@ import warnings import numpy as np import scipy.special -import nestcheck.data_processing as dp def run_estimators(ns_run, estimator_list, simulate=False): @@ -195,7 +194,7 @@ def combine_ns_runs(run_list_in, **kwargs): else: nthread_tot = 0 for i, _ in enumerate(run_list): - dp.check_ns_run(run_list[i], **kwargs) + check_ns_run(run_list[i], **kwargs) run_list[i]['thread_labels'] += nthread_tot nthread_tot += run_list[i]['thread_min_max'].shape[0] thread_min_max = np.vstack([run['thread_min_max'] for run in run_list]) @@ -212,7 +211,7 @@ def combine_ns_runs(run_list_in, **kwargs): run_list_in]) except KeyError: pass - dp.check_ns_run(run, **kwargs) + check_ns_run(run, **kwargs) return run @@ -278,7 +277,7 @@ def combine_threads(threads, assert_birth_point=False): # make run ns_run = dict_given_run_array(samples_temp, thread_min_max) try: - dp.check_ns_run_threads(ns_run) + check_ns_run_threads(ns_run) except AssertionError: # If the threads are not valid (e.g. for bootstrap resamples) then # set them to None so they can't be accidentally used @@ -416,3 +415,143 @@ def log_subtract(loga, logb): log(a - b): float """ return loga + np.log(1 - np.exp(logb - loga)) + + +# Functions for checking nestcheck format nested sampling run dictionaries to +# ensure they have the expected properties. + + +def check_ns_run(run, dup_assert=False, dup_warn=False): + """Checks a nestcheck format nested sampling run dictionary has the + expected properties (see the data_processing module docstring for more + details). + + Parameters + ---------- + run: dict + nested sampling run to check. + dup_assert: bool, optional + See check_ns_run_logls docstring. + dup_warn: bool, optional + See check_ns_run_logls docstring. + + + Raises + ------ + AssertionError + if run does not have expected properties. + """ + assert isinstance(run, dict) + check_ns_run_members(run) + check_ns_run_logls(run, dup_assert=dup_assert, dup_warn=dup_warn) + check_ns_run_threads(run) + + +def check_ns_run_members(run): + """Check nested sampling run member keys and values. + + Parameters + ---------- + run: dict + nested sampling run to check. + + Raises + ------ + AssertionError + if run does not have expected properties. + """ + run_keys = list(run.keys()) + # Mandatory keys + for key in ['logl', 'nlive_array', 'theta', 'thread_labels', + 'thread_min_max']: + assert key in run_keys + run_keys.remove(key) + # Optional keys + for key in ['output']: + try: + run_keys.remove(key) + except ValueError: + pass + # Check for unexpected keys + assert not run_keys, 'Unexpected keys in ns_run: ' + str(run_keys) + # Check type of mandatory members + for key in ['logl', 'nlive_array', 'theta', 'thread_labels', + 'thread_min_max']: + assert isinstance(run[key], np.ndarray), ( + key + ' is type ' + type(run[key]).__name__) + # check shapes of keys + assert run['logl'].ndim == 1 + assert run['logl'].shape == run['nlive_array'].shape + assert run['logl'].shape == run['thread_labels'].shape + assert run['theta'].ndim == 2 + assert run['logl'].shape[0] == run['theta'].shape[0] + + +def check_ns_run_logls(run, dup_assert=False, dup_warn=False): + """Check run logls are unique and in the correct order. + + Parameters + ---------- + run: dict + nested sampling run to check. + dup_assert: bool, optional + Whether to raise and AssertionError if there are duplicate logl values. + dup_warn: bool, optional + Whether to give a UserWarning if there are duplicate logl values (only + used if dup_assert is False). + + Raises + ------ + AssertionError + if run does not have expected properties. + """ + assert np.array_equal(run['logl'], run['logl'][np.argsort(run['logl'])]) + if dup_assert or dup_warn: + unique_logls, counts = np.unique(run['logl'], return_counts=True) + repeat_logls = run['logl'].shape[0] - unique_logls.shape[0] + msg = ('{} duplicate logl values (out of a total of {}). This may be ' + 'caused by limited numerical precision in the output files.' + '\nrepeated logls = {}\ncounts = {}\npositions in list of {}' + ' unique logls = {}').format( + repeat_logls, run['logl'].shape[0], + unique_logls[counts != 1], counts[counts != 1], + unique_logls.shape[0], np.where(counts != 1)[0]) + if dup_assert: + assert repeat_logls == 0, msg + elif dup_warn: + if repeat_logls != 0: + warnings.warn(msg, UserWarning) + + +def check_ns_run_threads(run): + """Check thread labels and thread_min_max have expected properties. + + Parameters + ---------- + run: dict + Nested sampling run to check. + + Raises + ------ + AssertionError + If run does not have expected properties. + """ + assert run['thread_labels'].dtype == int + uniq_th = np.unique(run['thread_labels']) + assert np.array_equal( + np.asarray(range(run['thread_min_max'].shape[0])), uniq_th), \ + str(uniq_th) + # Check thread_min_max + assert np.any(run['thread_min_max'][:, 0] == -np.inf), ( + 'Run should have at least one thread which starts by sampling the ' + + 'whole prior') + for th_lab in uniq_th: + inds = np.where(run['thread_labels'] == th_lab)[0] + th_info = 'thread label={}, first_logl={}, thread_min_max={}'.format( + th_lab, run['logl'][inds[0]], run['thread_min_max'][th_lab, :]) + assert run['thread_min_max'][th_lab, 0] <= run['logl'][inds[0]], ( + 'First point in thread has logl less than thread min logl! ' + + th_info + ', difference={}'.format( + run['logl'][inds[0]] - run['thread_min_max'][th_lab, 0])) + assert run['thread_min_max'][th_lab, 1] == run['logl'][inds[-1]], ( + 'Last point in thread logl != thread end logl! ' + th_info) diff --git a/nestcheck/write_polychord_output.py b/nestcheck/write_polychord_output.py index 644fa19..36ed10b 100644 --- a/nestcheck/write_polychord_output.py +++ b/nestcheck/write_polychord_output.py @@ -135,7 +135,7 @@ def run_dead_birth_array(run, **kwargs): Has #parameters + 2 columns: param_1, param_2, ... , logl, birth_logl """ - nestcheck.data_processing.check_ns_run(run, **kwargs) + nestcheck.ns_run_utils.check_ns_run(run, **kwargs) threads = nestcheck.ns_run_utils.get_run_threads(run) samp_arrays = [] ndim = run['theta'].shape[1] diff --git a/tests/test_core.py b/tests/test_core.py index c5781bc..ad067c7 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -266,6 +266,19 @@ def test_dict_given_run_array(self): numpy.testing.assert_array_equal( run_in['nlive_array'], run_out['nlive_array']) + def test_check_ns_run_logls(self): + """Ensure check_ns_run_logls raises error if and only if + warn_only=False""" + repeat_logl_run = {'logl': np.asarray([0, 0, 1])} + self.assertRaises( + AssertionError, nestcheck.ns_run_utils.check_ns_run_logls, + repeat_logl_run, dup_assert=True) + with warnings.catch_warnings(record=True) as war: + warnings.simplefilter("always") + nestcheck.ns_run_utils.check_ns_run_logls( + repeat_logl_run, dup_warn=True) + self.assertEqual(len(war), 1) + class TestErrorAnalysis(unittest.TestCase): diff --git a/tests/test_data_io.py b/tests/test_data_io.py index e718f30..f02be4f 100644 --- a/tests/test_data_io.py +++ b/tests/test_data_io.py @@ -95,20 +95,6 @@ def test_threads_given_birth_inds(self): np.array([0, 1, 1, 1, 1, 0]).astype(int)) self.assertEqual(len(war), 1) - - def test_check_ns_run_logls(self): - """Ensure check_ns_run_logls raises error if and only if - warn_only=False""" - repeat_logl_run = {'logl': np.asarray([0, 0, 1])} - self.assertRaises( - AssertionError, nestcheck.data_processing.check_ns_run_logls, - repeat_logl_run, dup_assert=True) - with warnings.catch_warnings(record=True) as war: - warnings.simplefilter("always") - nestcheck.data_processing.check_ns_run_logls( - repeat_logl_run, dup_warn=True) - self.assertEqual(len(war), 1) - def test_process_polychord_data(self): """Check processing some dummy PolyChord data.""" file_root = 'dummy_run' @@ -122,7 +108,7 @@ def test_process_polychord_data(self): processed_run = nestcheck.data_processing.process_polychord_run( file_root, TEST_CACHE_DIR) self.assertEqual(len(war), 1) - nestcheck.data_processing.check_ns_run(processed_run) + nestcheck.ns_run_utils.check_ns_run(processed_run) for key, value in processed_run.items(): if key not in ['output']: numpy.testing.assert_array_equal( @@ -157,7 +143,7 @@ def test_process_multinest_data(self): TEST_CACHE_DIR, file_root + '-phys_live-birth.txt'), live) processed_run = nestcheck.data_processing.process_multinest_run( file_root, TEST_CACHE_DIR) - nestcheck.data_processing.check_ns_run(processed_run) + nestcheck.ns_run_utils.check_ns_run(processed_run) for key, value in processed_run.items(): if key not in ['output']: numpy.testing.assert_array_equal( diff --git a/tests/test_plots.py b/tests/test_plots.py index 16a595d..c9b5207 100644 --- a/tests/test_plots.py +++ b/tests/test_plots.py @@ -8,7 +8,6 @@ import numpy.testing import pandas as pd import scipy.special -import nestcheck.data_processing import nestcheck.diagnostics_tables import nestcheck.dummy_data import nestcheck.error_analysis @@ -26,7 +25,7 @@ class TestPlots(unittest.TestCase): def setUp(self): """Get some dummy data to plot.""" self.ns_run = nestcheck.dummy_data.get_dummy_run(3, 10) - nestcheck.data_processing.check_ns_run(self.ns_run) + nestcheck.ns_run_utils.check_ns_run(self.ns_run) def test_alternate_helper(self): """Check alternate_helper."""