Skip to content

Commit

Permalink
fixing import bug in process_dypolychord_run, and moving check_ns_run…
Browse files Browse the repository at this point in the history
… functions to ns_run_utils, so data processing imports ns_run_utils rather than the reverse.
  • Loading branch information
ejhigson committed Sep 12, 2018
1 parent 3b8f49d commit d2879b9
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 174 deletions.
166 changes: 15 additions & 151 deletions nestcheck/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@
>= v3.11, and is used in the output processing for these packages via the
``birth_inds_given_contours`` and ``threads_given_birth_inds`` functions.
Also sufficient is a list of the indexes of the point which was removed
at the step when each point was sampled ("birth indexes"), as this can be mapped
to the birth contours and vice versa.
at the step when each point was sampled ("birth indexes"), as this can be
mapped to the birth contours and vice versa.
``process_dynesty_run`` does not require the ``birth_inds_given_contours`` and
``threads_given_birth_inds`` functions as ``dynesty`` results objects
Expand All @@ -95,6 +95,7 @@
import copy
import numpy as np
import nestcheck.io_utils
import nestcheck.ns_run_utils
import nestcheck.parallel_utils


Expand All @@ -103,8 +104,9 @@ def batch_process_data(file_roots, **kwargs):
"""Process output from many nested sampling runs in parallel with optional
error handling and caching.
The result can be cached using the 'save_name', 'save' and 'load' kwargs (by
default this is not done). See save_load_result docstring for more details.
The result can be cached using the 'save_name', 'save' and 'load' kwargs
(by default this is not done). See save_load_result docstring for more
details.
Remaining kwargs passed to parallel_utils.parallel_apply (see its
docstring for more details).
Expand Down Expand Up @@ -228,7 +230,7 @@ def process_polychord_run(file_root, base_dir, process_stats_file=True,
don't have the <root>.stats file (such as if PolyChord was run with
write_stats=False).
kwargs: dict, optional
Options passed to check_ns_run
Options passed to ns_run_utils.check_ns_run.
Returns
-------
Expand Down Expand Up @@ -269,7 +271,7 @@ def process_multinest_run(file_root, base_dir, **kwargs):
Directory containing output files. When running MultiNest, this is
determined by the nest_root parameter.
kwargs: dict, optional
Passed to check_ns_run (via process_samples_array)
Passed to ns_run_utils.check_ns_run (via process_samples_array)
Returns
Expand Down Expand Up @@ -351,7 +353,7 @@ def process_dynesty_run(results):
samples[final_ind, 2] = -1
assert np.all(~np.isnan(thread_min_max))
run = nestcheck.ns_run_utils.dict_given_run_array(samples, thread_min_max)
nestcheck.data_processing.check_ns_run(run)
nestcheck.ns_run_utils.check_ns_run(run)
return run


Expand Down Expand Up @@ -441,7 +443,8 @@ def process_samples_array(samples, **kwargs):
birth_contours = samples[:, -1]
# birth_contours, ns_run['theta'] = check_logls_unique(
# samples[:, -2], samples[:, -1], samples[:, :-2])
birth_inds = birth_inds_given_contours(birth_contours, ns_run['logl'], **kwargs)
birth_inds = birth_inds_given_contours(
birth_contours, ns_run['logl'], **kwargs)
ns_run['thread_labels'] = threads_given_birth_inds(birth_inds)
unique_threads = np.unique(ns_run['thread_labels'])
assert np.array_equal(unique_threads,
Expand Down Expand Up @@ -499,9 +502,9 @@ def birth_inds_given_contours(birth_logl_arr, logl_arr, **kwargs):
Birth contours - i.e. logl values of the iso-likelihood contour from
within each point was sampled (on which it was born).
dup_assert: bool, optional
See check_ns_run_logls docstring.
See ns_run_utils.check_ns_run_logls docstring.
dup_warn: bool, optional
See check_ns_run_logls docstring.
See ns_run_utils.check_ns_run_logls docstring.
Returns
-------
Expand All @@ -516,8 +519,8 @@ def birth_inds_given_contours(birth_logl_arr, logl_arr, **kwargs):
assert logl_arr.ndim == 1, logl_arr.ndim
assert birth_logl_arr.ndim == 1, birth_logl_arr.ndim
# Check for duplicate logl values (if specified by dup_assert or dup_warn)
check_ns_run_logls({'logl': logl_arr}, dup_assert=dup_assert,
dup_warn=dup_warn)
nestcheck.ns_run_utils.check_ns_run_logls(
{'logl': logl_arr}, dup_assert=dup_assert, dup_warn=dup_warn)
# Random seed so results are consistent if there are duplicate logls
state = np.random.get_state() # Save random state before seeding
np.random.seed(0)
Expand Down Expand Up @@ -692,142 +695,3 @@ def threads_given_birth_inds(birth_inds):
str(np.unique(thread_labels)) + ' is not equal to range('
+ str(sum(thread_start_counts)) + ')')
return thread_labels


# Functions for checking nestcheck format nested sampling run dictionaries to
# ensure they have the expected properties.


def check_ns_run(run, dup_assert=False, dup_warn=False):
"""Checks a nestcheck format nested sampling run dictionary has the
expected properties (see the module docstring for more details).
Parameters
----------
run: dict
nested sampling run to check.
dup_assert: bool, optional
See check_ns_run_logls docstring.
dup_warn: bool, optional
See check_ns_run_logls docstring.
Raises
------
AssertionError
if run does not have expected properties.
"""
assert isinstance(run, dict)
check_ns_run_members(run)
check_ns_run_logls(run, dup_assert=dup_assert, dup_warn=dup_warn)
check_ns_run_threads(run)


def check_ns_run_members(run):
"""Check nested sampling run member keys and values.
Parameters
----------
run: dict
nested sampling run to check.
Raises
------
AssertionError
if run does not have expected properties.
"""
run_keys = list(run.keys())
# Mandatory keys
for key in ['logl', 'nlive_array', 'theta', 'thread_labels',
'thread_min_max']:
assert key in run_keys
run_keys.remove(key)
# Optional keys
for key in ['output']:
try:
run_keys.remove(key)
except ValueError:
pass
# Check for unexpected keys
assert not run_keys, 'Unexpected keys in ns_run: ' + str(run_keys)
# Check type of mandatory members
for key in ['logl', 'nlive_array', 'theta', 'thread_labels',
'thread_min_max']:
assert isinstance(run[key], np.ndarray), (
key + ' is type ' + type(run[key]).__name__)
# check shapes of keys
assert run['logl'].ndim == 1
assert run['logl'].shape == run['nlive_array'].shape
assert run['logl'].shape == run['thread_labels'].shape
assert run['theta'].ndim == 2
assert run['logl'].shape[0] == run['theta'].shape[0]


def check_ns_run_logls(run, dup_assert=False, dup_warn=False):
"""Check run logls are unique and in the correct order.
Parameters
----------
run: dict
nested sampling run to check.
dup_assert: bool, optional
Whether to raise and AssertionError if there are duplicate logl values.
dup_warn: bool, optional
Whether to give a UserWarning if there are duplicate logl values (only
used if dup_assert is False).
Raises
------
AssertionError
if run does not have expected properties.
"""
assert np.array_equal(run['logl'], run['logl'][np.argsort(run['logl'])])
if dup_assert or dup_warn:
unique_logls, counts = np.unique(run['logl'], return_counts=True)
repeat_logls = run['logl'].shape[0] - unique_logls.shape[0]
msg = ('{} duplicate logl values (out of a total of {}). This may be '
'caused by limited numerical precision in the output files.'
'\nrepeated logls = {}\ncounts = {}\npositions in list of {}'
' unique logls = {}').format(
repeat_logls, run['logl'].shape[0],
unique_logls[counts != 1], counts[counts != 1],
unique_logls.shape[0], np.where(counts != 1)[0])
if dup_assert:
assert repeat_logls == 0, msg
elif dup_warn:
if repeat_logls != 0:
warnings.warn(msg, UserWarning)


def check_ns_run_threads(run):
"""Check thread labels and thread_min_max have expected properties.
Parameters
----------
run: dict
Nested sampling run to check.
Raises
------
AssertionError
If run does not have expected properties.
"""
assert run['thread_labels'].dtype == int
uniq_th = np.unique(run['thread_labels'])
assert np.array_equal(
np.asarray(range(run['thread_min_max'].shape[0])), uniq_th), \
str(uniq_th)
# Check thread_min_max
assert np.any(run['thread_min_max'][:, 0] == -np.inf), (
'Run should have at least one thread which starts by sampling the ' +
'whole prior')
for th_lab in uniq_th:
inds = np.where(run['thread_labels'] == th_lab)[0]
th_info = 'thread label={}, first_logl={}, thread_min_max={}'.format(
th_lab, run['logl'][inds[0]], run['thread_min_max'][th_lab, :])
assert run['thread_min_max'][th_lab, 0] <= run['logl'][inds[0]], (
'First point in thread has logl less than thread min logl! ' +
th_info + ', difference={}'.format(
run['logl'][inds[0]] - run['thread_min_max'][th_lab, 0]))
assert run['thread_min_max'][th_lab, 1] == run['logl'][inds[-1]], (
'Last point in thread logl != thread end logl! ' + th_info)
Loading

0 comments on commit d2879b9

Please sign in to comment.