Skip to content

Commit

Permalink
Improve flow of sigmaG filtering for Results object
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Apr 19, 2024
1 parent 933cf9e commit 5fa1334
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 11 deletions.
30 changes: 19 additions & 11 deletions src/kbmod/filters/sigma_g_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from scipy.special import erfinv

from kbmod.result_list import ResultList, ResultRow
from kbmod.results import Results


class SigmaGClipping:
Expand Down Expand Up @@ -157,46 +158,53 @@ def compute_clipped_sigma_g_matrix(self, lh):
return index_valid


def apply_single_clipped_sigma_g(params, result):
def apply_single_clipped_sigma_g(clipper, result):
"""This function applies a clipped median filter to a single result from
KBMOD using sigmaG as a robust estimater of standard deviation.
Parameters
----------
params : `SigmaGClipping`
clipper : `SigmaGClipping`
The object to apply the SigmaG clipping.
result : `ResultRow`
The result details. This data gets modified directly by the filtering.
"""
single_res = params.compute_clipped_sigma_g(result.likelihood_curve)
single_res = clipper.compute_clipped_sigma_g(result.likelihood_curve)
result.filter_indices(single_res)


def apply_clipped_sigma_g(params, result_list, num_threads=1):
def apply_clipped_sigma_g(clipper, result_data, num_threads=1):
"""This function applies a clipped median filter to the results of a KBMOD
search using sigmaG as a robust estimater of standard deviation.
Parameters
----------
params : `SigmaGClipping`
clipper : `SigmaGClipping`
The object to apply the SigmaG clipping.
result_list : `ResultList`
result_data : `ResultList` or `Results`
The values from trajectories. This data gets modified directly by the filtering.
num_threads : `int`
The number of threads to use.
"""
if type(result_data) is Results:
lh = result_data.compute_likelihood_curves()
index_valid = clipper.compute_clipped_sigma_g_matrix(lh)
result_data.update_index_valid(index_valid)
return

# TODO: Remove this logic once we have switched over to Results.
if num_threads > 1:
lh_list = [[row.likelihood_curve] for row in result_list.results]
lh_list = [[row.likelihood_curve] for row in result_data.results]

keep_idx_results = []
pool = mp.Pool(processes=num_threads)
keep_idx_results = pool.starmap_async(params.compute_clipped_sigma_g, lh_list)
keep_idx_results = pool.starmap_async(clipper.compute_clipped_sigma_g, lh_list)
pool.close()
pool.join()
keep_idx_results = keep_idx_results.get()

for i, res in enumerate(keep_idx_results):
result_list.results[i].filter_indices(res)
result_data.results[i].filter_indices(res)
else:
for row in result_list.results:
apply_single_clipped_sigma_g(params, row)
for row in result_data.results:
apply_single_clipped_sigma_g(clipper, row)
28 changes: 28 additions & 0 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,34 @@ def add_psi_phi_data(self, psi_array, phi_array, index_valid=None):

return self

def update_index_valid(self, index_valid):
"""Updates or appends the 'index_valid' column.
Parameters
----------
index_valid : `numpy.ndarray`
An array with one row per results and one column per timestamp
with Booleans indicating whether the corresponding observation
is valid.
Returns
-------
self : `Results`
Returns a reference to itself to allow chaining.
Raises
------
Raises a ValueError if the input array is not the same size as the table
or a given pair of rows in the arrays are not the same length.
"""
if len(index_valid) != len(self.table):
raise ValueError("Wrong number of index_valid lists provided.")
self.table["index_valid"] = index_valid

# Update the track likelihoods given this new information.
self._update_likelihood()
return self

def filter_mask(self, mask, label=None):
"""Filter the rows in the ResultTable to only include those indices
that are marked True in the mask.
Expand Down
37 changes: 37 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,13 @@ def test_extend(self):
with self.assertRaises(ValueError):
table1.extend(table3)

# Test starting from an empty table.
table4 = Results([])
table4.extend(table1)
self.assertEqual(len(table1), len(table4))
for i in range(self.num_entries):
self.assertEqual(table1["x"][i], i)

def test_add_psi_phi(self):
num_to_use = 3
table = Results(self.trj_list[0:num_to_use])
Expand Down Expand Up @@ -153,6 +160,36 @@ def test_add_psi_phi(self):
self.assertAlmostEqual(table["flux"][i], exp_flux[i], delta=1e-5)
self.assertEqual(table["obs_count"][i], exp_obs[i])

def test_update_index_valid(self):
num_to_use = 3
table = Results(self.trj_list[0:num_to_use])
psi_array = np.array([[1.0, 1.1, 1.2, 1.3] for i in range(num_to_use)])
phi_array = np.array([[1.0, 1.0, 0.0, 2.0] for i in range(num_to_use)])
table.add_psi_phi_data(psi_array, phi_array)
for i in range(num_to_use):
self.assertAlmostEqual(table["likelihood"][i], 2.3, delta=1e-5)
self.assertAlmostEqual(table["flux"][i], 1.15, delta=1e-5)
self.assertEqual(table["obs_count"][i], 4)

# Add the index_valid column later to simulate sigmaG clipping.
index_valid = np.array(
[
[True, True, True, True],
[True, False, True, True],
[False, False, False, False],
]
)
table.update_index_valid(index_valid)

exp_lh = [2.3, 2.020725, 0.0]
exp_flux = [1.15, 1.1666667, 0.0]
exp_obs = [4, 3, 0]
for i in range(num_to_use):
self.assertEqual(len(table["index_valid"][i]), 4)
self.assertAlmostEqual(table["likelihood"][i], exp_lh[i], delta=1e-5)
self.assertAlmostEqual(table["flux"][i], exp_flux[i], delta=1e-5)
self.assertEqual(table["obs_count"][i], exp_obs[i])

def test_compute_likelihood_curves(self):
num_to_use = 3
table = Results(self.trj_list[0:num_to_use])
Expand Down
27 changes: 27 additions & 0 deletions tests/test_sigma_g_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from kbmod.filters.sigma_g_filter import SigmaGClipping, apply_clipped_sigma_g
from kbmod.result_list import ResultRow, ResultList
from kbmod.results import Results
from kbmod.search import Trajectory


Expand Down Expand Up @@ -110,6 +111,32 @@ def test_apply_clipped_sigma_g(self):
for i in range(5):
self.assertEqual(len(r_set.results[i].valid_indices), num_times - i)

def test_apply_clipped_sigma_g_results(self):
"""Confirm the clipped sigmaG filter works when used with a Results object."""
num_times = 20
num_results = 5
trj_all = [Trajectory() for _ in range(num_results)]
table = Results(trj_all)

phi_all = np.full((num_results, num_times), 0.1)
psi_all = np.full((num_results, num_times), 1.0)
for i in range(5):
for j in range(i):
psi_all[i, j] = 100.0
table.add_psi_phi_data(psi_all, phi_all)

clipper = SigmaGClipping(10, 90)
apply_clipped_sigma_g(clipper, table)
self.assertEqual(len(table), 5)

# Confirm that the ResultRows were modified in place.
for i in range(num_results):
valid = table["index_valid"][i]
for j in range(i):
self.assertFalse(valid[j])
for j in range(i, num_times):
self.assertTrue(valid[j])

def test_sigmag_computation(self):
self.assertAlmostEqual(SigmaGClipping.find_sigma_g_coeff(25.0, 75.0), 0.7413, delta=0.001)
self.assertRaises(ValueError, SigmaGClipping.find_sigma_g_coeff, -1.0, 75.0)
Expand Down

0 comments on commit 5fa1334

Please sign in to comment.