Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement SigmaG filtering in JAX #750

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dynamic = ["version"]
dependencies = [
"astropy>=5.1",
"astroquery>=0.4.6",
"jax",
"joblib>=1.4",
"matplotlib>=3.5",
"numpy<2.0",
Expand Down
98 changes: 60 additions & 38 deletions src/kbmod/filters/sigma_g_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,60 @@
by Smotherman et. al. 2021
"""

from functools import partial
from jax import jit, vmap
import jax.numpy as jnp
import logging
import numpy as np
from scipy.special import erfinv

from kbmod.results import Results
from kbmod.search import DebugTimer


logger = logging.getLogger(__name__)


def sigma_g_jax(data, low_bnd, high_bnd, n_sigma, coeff, clip_negative):
"""The core function for performing a sigma filtering on a series of data points
with clipped_negative. These are typically likelihoods for KBMOD.

Parameters
----------
data : `numpy.ndarray`
A length T matrix of data points for filtering.
low_bnd : `float`
The lower bound of the interval to use to estimate the standard deviation.
high_bnd : `float`
The upper bound of the interval to use to estimate the standard deviation.
n_sigma : `float`
The number of standard deviations to use for the bound.
coeff : `float`
The precomputed coefficient based on the given bounds.
clip_negative : `bool`
A Boolean indicating whether to use negative values when computing
standard deviation.

Returns
-------
index_valid : `numpy.ndarray`
A length T array of Booleans indicating if each point is valid (True)
or has been filtered (False).
"""
# Compute the percentiles for this array of values. If we are clipping the negatives then only
# use the positive points.
masked_data = jnp.where((not clip_negative) | (data > 0.0), data, jnp.nan)
lower_per, median, upper_per = jnp.nanpercentile(masked_data, jnp.array([low_bnd, 50, high_bnd]))

# Compute the bounds for each row, enforcing a minimum gap in case all the
# points are identical (upper_per == lower_per).
delta = upper_per - lower_per
nSigmaG = n_sigma * coeff * jnp.where(delta > 1e-8, delta, 1e-8)

index_valid = jnp.isfinite(data) & (data <= median + nSigmaG) & (data >= median - nSigmaG)
return index_valid


class SigmaGClipping:
"""This class contains the basic information for performing SigmaG clipping.

Expand All @@ -41,9 +85,21 @@ def __init__(self, low_bnd=25, high_bnd=75, n_sigma=2, clip_negative=False):

self.low_bnd = low_bnd
self.high_bnd = high_bnd
self.clip_negative = clip_negative
self.n_sigma = n_sigma
self.coeff = SigmaGClipping.find_sigma_g_coeff(low_bnd, high_bnd)
self.clip_negative = clip_negative

# Create compiled vmapped functions that applies the Sigma G filtering
# with the given parameters.
base_fn = partial(
sigma_g_jax,
low_bnd=self.low_bnd,
high_bnd=self.high_bnd,
n_sigma=self.n_sigma,
coeff=self.coeff,
clip_negative=self.clip_negative,
)
self.sigma_g_jax_fn = vmap(jit(base_fn))

@staticmethod
def find_sigma_g_coeff(low_bnd, high_bnd):
Expand Down Expand Up @@ -107,15 +163,7 @@ def compute_clipped_sigma_g(self, lh):
sigmaG = self.coeff * delta
nSigmaG = self.n_sigma * sigmaG

# Its unclear why we only filter zeros for one of the two cases, but leaving the logic in
# to stay consistent with the original code.
if self.clip_negative:
good_index = np.where(
np.logical_and(lh != 0, np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))
)[0]
else:
good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]

good_index = np.where(np.logical_and(lh > median - nSigmaG, lh < median + nSigmaG))[0]
return good_index

def compute_clipped_sigma_g_matrix(self, lh):
Expand All @@ -134,33 +182,8 @@ def compute_clipped_sigma_g_matrix(self, lh):
A N x T matrix of Booleans indicating if each point is valid (True)
or has been filtered (False).
"""
if self.clip_negative:
# We mask out the values less than zero so they are not used in the median computation.
masked_lh = np.copy(lh)
masked_lh[lh <= 0] = np.nan
lower_per, median, upper_per = np.nanpercentile(
masked_lh, [self.low_bnd, 50, self.high_bnd], axis=1
)
else:
lower_per, median, upper_per = np.nanpercentile(lh, [self.low_bnd, 50, self.high_bnd], axis=1)

# Compute the bounds for each row, enforcing a minimum gap in case all the
# points are identical (upper_per == lower_per).
delta = upper_per - lower_per
delta[delta < 1e-8] = 1e-8
nSigmaG = self.n_sigma * self.coeff * delta

num_cols = lh.shape[1]
lower_bnd = np.repeat(np.array([median - nSigmaG]).T, num_cols, axis=1)
upper_bnd = np.repeat(np.array([median + nSigmaG]).T, num_cols, axis=1)

# Its unclear why we only filter zeros for one of the two cases, but leaving the logic in
# to stay consistent with the original code.
if self.clip_negative:
index_valid = np.isfinite(lh) & (lh != 0) & (lh < upper_bnd) & (lh > lower_bnd)
else:
index_valid = np.isfinite(lh) & (lh < upper_bnd) & (lh > lower_bnd)
return index_valid
inds_valid = self.sigma_g_jax_fn(jnp.asarray(lh)).block_until_ready()
return inds_valid


def apply_clipped_sigma_g(clipper, result_data):
Expand All @@ -183,4 +206,3 @@ def apply_clipped_sigma_g(clipper, result_data):
obs_valid = clipper.compute_clipped_sigma_g_matrix(lh)
result_data.update_obs_valid(obs_valid)
filter_timer.stop()
return
36 changes: 34 additions & 2 deletions tests/test_sigma_g_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_sigma_g_negative_clipping(self):
params = SigmaGClipping(clip_negative=True)
result = params.compute_clipped_sigma_g(lh)
for i in range(num_points):
self.assertEqual(i in result, i > 2 and i != 5 and i != 14)
self.assertEqual(i in result, i > 2 and i != 14)

def test_sigma_g_all_negative_clipping(self):
num_points = 10
Expand All @@ -106,10 +106,11 @@ def test_sigma_g_clipping_matrix_negative_clipping(self):
expected = np.array(
[
[True] * num_points,
[False, False, False, True, True, False] + [True] * (num_points - 6),
[False, False, False] + [True] * (num_points - 3),
[False] * num_points,
]
)

sigma_g = SigmaGClipping(clip_negative=True)

# Surpress the warning we get from encountering a row of all NaNs.
Expand Down Expand Up @@ -144,6 +145,37 @@ def test_apply_clipped_sigma_g_results(self):
for j in range(i, num_times):
self.assertTrue(valid[j])

def test_sigmag_parity(self):
"""Test that we get the same results when using the batch and the non-batch methods."""
num_tests = 20

# Run the test with differing numbers of points and with/without clipping.
for num_obs in [10, 20, 50]:
for clipped in [True, False]:
for num_extreme in [0, 1, 2, 3]:
with self.subTest(
num_obs_used=num_obs, use_clipped=clipped, num_extreme_used=num_extreme
):
# Generate the data from a fixed random seed (same for every subtest).
rng = np.random.default_rng(100)
data = 10.0 * rng.random((num_tests, num_obs)) - 0.5

# Add extreme values for each row.
for row in range(num_tests):
for ext_num in range(num_extreme):
idx = int(num_obs * rng.random())
data[row, idx] = 100.0 * rng.random() - 50.0

clipper = SigmaGClipping(25, 75, clip_negative=clipped)

batch_res = clipper.compute_clipped_sigma_g_matrix(data)
for row in range(num_tests):
# Compute the individual results (as indices) and convert
# those into a vector of bools for comparison.
ind_res = clipper.compute_clipped_sigma_g(data[row])
ind_bools = [(idx in ind_res) for idx in range(num_obs)]
self.assertTrue(np.array_equal(batch_res[row], ind_bools))

def test_sigmag_computation(self):
self.assertAlmostEqual(SigmaGClipping.find_sigma_g_coeff(25.0, 75.0), 0.7413, delta=0.001)
self.assertRaises(ValueError, SigmaGClipping.find_sigma_g_coeff, -1.0, 75.0)
Expand Down