diff --git a/pyproject.toml b/pyproject.toml index 97311fe5..5ad3bb59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ dynamic = ["version"] dependencies = [ "astropy>=5.1", "astroquery>=0.4.6", + "jax", "joblib>=1.4", "matplotlib>=3.5", "numpy<2.0", diff --git a/src/kbmod/filters/sigma_g_filter.py b/src/kbmod/filters/sigma_g_filter.py index c862cddf..87cdd67e 100644 --- a/src/kbmod/filters/sigma_g_filter.py +++ b/src/kbmod/filters/sigma_g_filter.py @@ -5,6 +5,9 @@ by Smotherman et. al. 2021 """ +from functools import partial +from jax import jit, vmap +import jax.numpy as jnp import logging import numpy as np from scipy.special import erfinv @@ -12,9 +15,50 @@ from kbmod.results import Results from kbmod.search import DebugTimer + logger = logging.getLogger(__name__) +def sigma_g_jax(data, low_bnd, high_bnd, n_sigma, coeff, clip_negative): + """The core function for performing a sigma filtering on a series of data points + with clipped_negative. These are typically likelihoods for KBMOD. + + Parameters + ---------- + data : `numpy.ndarray` + A length T matrix of data points for filtering. + low_bnd : `float` + The lower bound of the interval to use to estimate the standard deviation. + high_bnd : `float` + The upper bound of the interval to use to estimate the standard deviation. + n_sigma : `float` + The number of standard deviations to use for the bound. + coeff : `float` + The precomputed coefficient based on the given bounds. + clip_negative : `bool` + A Boolean indicating whether to use negative values when computing + standard deviation. + + Returns + ------- + index_valid : `numpy.ndarray` + A length T array of Booleans indicating if each point is valid (True) + or has been filtered (False). + """ + # Compute the percentiles for this array of values. If we are clipping the negatives then only + # use the positive points. + masked_data = jnp.where((not clip_negative) | (data > 0.0), data, jnp.nan) + lower_per, median, upper_per = jnp.nanpercentile(masked_data, jnp.array([low_bnd, 50, high_bnd])) + + # Compute the bounds for each row, enforcing a minimum gap in case all the + # points are identical (upper_per == lower_per). + delta = upper_per - lower_per + nSigmaG = n_sigma * coeff * jnp.where(delta > 1e-8, delta, 1e-8) + + index_valid = jnp.isfinite(data) & (data <= median + nSigmaG) & (data >= median - nSigmaG) + return index_valid + + class SigmaGClipping: """This class contains the basic information for performing SigmaG clipping. @@ -41,9 +85,21 @@ def __init__(self, low_bnd=25, high_bnd=75, n_sigma=2, clip_negative=False): self.low_bnd = low_bnd self.high_bnd = high_bnd - self.clip_negative = clip_negative self.n_sigma = n_sigma self.coeff = SigmaGClipping.find_sigma_g_coeff(low_bnd, high_bnd) + self.clip_negative = clip_negative + + # Create compiled vmapped functions that applies the Sigma G filtering + # with the given parameters. + base_fn = partial( + sigma_g_jax, + low_bnd=self.low_bnd, + high_bnd=self.high_bnd, + n_sigma=self.n_sigma, + coeff=self.coeff, + clip_negative=self.clip_negative, + ) + self.sigma_g_jax_fn = vmap(jit(base_fn)) @staticmethod def find_sigma_g_coeff(low_bnd, high_bnd): @@ -107,15 +163,7 @@ def compute_clipped_sigma_g(self, lh): sigmaG = self.coeff * delta nSigmaG = self.n_sigma * sigmaG - # Its unclear why we only filter zeros for one of the two cases, but leaving the logic in - # to stay consistent with the original code. - if self.clip_negative: - good_index = np.where( - np.logical_and(lh != 0, np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG)) - )[0] - else: - good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0] - + good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0] return good_index def compute_clipped_sigma_g_matrix(self, lh): @@ -134,33 +182,8 @@ def compute_clipped_sigma_g_matrix(self, lh): A N x T matrix of Booleans indicating if each point is valid (True) or has been filtered (False). """ - if self.clip_negative: - # We mask out the values less than zero so they are not used in the median computation. - masked_lh = np.copy(lh) - masked_lh[lh <= 0] = np.nan - lower_per, median, upper_per = np.nanpercentile( - masked_lh, [self.low_bnd, 50, self.high_bnd], axis=1 - ) - else: - lower_per, median, upper_per = np.nanpercentile(lh, [self.low_bnd, 50, self.high_bnd], axis=1) - - # Compute the bounds for each row, enforcing a minimum gap in case all the - # points are identical (upper_per == lower_per). - delta = upper_per - lower_per - delta[delta < 1e-8] = 1e-8 - nSigmaG = self.n_sigma * self.coeff * delta - - num_cols = lh.shape[1] - lower_bnd = np.repeat(np.array([median - nSigmaG]).T, num_cols, axis=1) - upper_bnd = np.repeat(np.array([median + nSigmaG]).T, num_cols, axis=1) - - # Its unclear why we only filter zeros for one of the two cases, but leaving the logic in - # to stay consistent with the original code. - if self.clip_negative: - index_valid = np.isfinite(lh) & (lh != 0) & (lh < upper_bnd) & (lh > lower_bnd) - else: - index_valid = np.isfinite(lh) & (lh < upper_bnd) & (lh > lower_bnd) - return index_valid + inds_valid = self.sigma_g_jax_fn(jnp.asarray(lh)).block_until_ready() + return inds_valid def apply_clipped_sigma_g(clipper, result_data): @@ -183,4 +206,3 @@ def apply_clipped_sigma_g(clipper, result_data): obs_valid = clipper.compute_clipped_sigma_g_matrix(lh) result_data.update_obs_valid(obs_valid) filter_timer.stop() - return diff --git a/tests/test_sigma_g_filter.py b/tests/test_sigma_g_filter.py index edddcffe..031e4e14 100644 --- a/tests/test_sigma_g_filter.py +++ b/tests/test_sigma_g_filter.py @@ -84,7 +84,7 @@ def test_sigma_g_negative_clipping(self): params = SigmaGClipping(clip_negative=True) result = params.compute_clipped_sigma_g(lh) for i in range(num_points): - self.assertEqual(i in result, i > 2 and i != 5 and i != 14) + self.assertEqual(i in result, i > 2 and i != 14) def test_sigma_g_all_negative_clipping(self): num_points = 10 @@ -106,10 +106,11 @@ def test_sigma_g_clipping_matrix_negative_clipping(self): expected = np.array( [ [True] * num_points, - [False, False, False, True, True, False] + [True] * (num_points - 6), + [False, False, False] + [True] * (num_points - 3), [False] * num_points, ] ) + sigma_g = SigmaGClipping(clip_negative=True) # Surpress the warning we get from encountering a row of all NaNs. @@ -144,6 +145,37 @@ def test_apply_clipped_sigma_g_results(self): for j in range(i, num_times): self.assertTrue(valid[j]) + def test_sigmag_parity(self): + """Test that we get the same results when using the batch and the non-batch methods.""" + num_tests = 20 + + # Run the test with differing numbers of points and with/without clipping. + for num_obs in [10, 20, 50]: + for clipped in [True, False]: + for num_extreme in [0, 1, 2, 3]: + with self.subTest( + num_obs_used=num_obs, use_clipped=clipped, num_extreme_used=num_extreme + ): + # Generate the data from a fixed random seed (same for every subtest). + rng = np.random.default_rng(100) + data = 10.0 * rng.random((num_tests, num_obs)) - 0.5 + + # Add extreme values for each row. + for row in range(num_tests): + for ext_num in range(num_extreme): + idx = int(num_obs * rng.random()) + data[row, idx] = 100.0 * rng.random() - 50.0 + + clipper = SigmaGClipping(25, 75, clip_negative=clipped) + + batch_res = clipper.compute_clipped_sigma_g_matrix(data) + for row in range(num_tests): + # Compute the individual results (as indices) and convert + # those into a vector of bools for comparison. + ind_res = clipper.compute_clipped_sigma_g(data[row]) + ind_bools = [(idx in ind_res) for idx in range(num_obs)] + self.assertTrue(np.array_equal(batch_res[row], ind_bools)) + def test_sigmag_computation(self): self.assertAlmostEqual(SigmaGClipping.find_sigma_g_coeff(25.0, 75.0), 0.7413, delta=0.001) self.assertRaises(ValueError, SigmaGClipping.find_sigma_g_coeff, -1.0, 75.0)