From 5fa133493dd2de035f98cb6771c306f6f1e62e8c Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 19 Apr 2024 08:59:42 -0400 Subject: [PATCH] Improve flow of sigmaG filtering for Results object --- src/kbmod/filters/sigma_g_filter.py | 30 ++++++++++++++--------- src/kbmod/results.py | 28 ++++++++++++++++++++++ tests/test_results.py | 37 +++++++++++++++++++++++++++++ tests/test_sigma_g_filter.py | 27 +++++++++++++++++++++ 4 files changed, 111 insertions(+), 11 deletions(-) diff --git a/src/kbmod/filters/sigma_g_filter.py b/src/kbmod/filters/sigma_g_filter.py index e00141bdb..e8f01ddfa 100644 --- a/src/kbmod/filters/sigma_g_filter.py +++ b/src/kbmod/filters/sigma_g_filter.py @@ -10,6 +10,7 @@ from scipy.special import erfinv from kbmod.result_list import ResultList, ResultRow +from kbmod.results import Results class SigmaGClipping: @@ -157,46 +158,53 @@ def compute_clipped_sigma_g_matrix(self, lh): return index_valid -def apply_single_clipped_sigma_g(params, result): +def apply_single_clipped_sigma_g(clipper, result): """This function applies a clipped median filter to a single result from KBMOD using sigmaG as a robust estimater of standard deviation. Parameters ---------- - params : `SigmaGClipping` + clipper : `SigmaGClipping` The object to apply the SigmaG clipping. result : `ResultRow` The result details. This data gets modified directly by the filtering. """ - single_res = params.compute_clipped_sigma_g(result.likelihood_curve) + single_res = clipper.compute_clipped_sigma_g(result.likelihood_curve) result.filter_indices(single_res) -def apply_clipped_sigma_g(params, result_list, num_threads=1): +def apply_clipped_sigma_g(clipper, result_data, num_threads=1): """This function applies a clipped median filter to the results of a KBMOD search using sigmaG as a robust estimater of standard deviation. Parameters ---------- - params : `SigmaGClipping` + clipper : `SigmaGClipping` The object to apply the SigmaG clipping. - result_list : `ResultList` + result_data : `ResultList` or `Results` The values from trajectories. This data gets modified directly by the filtering. num_threads : `int` The number of threads to use. """ + if type(result_data) is Results: + lh = result_data.compute_likelihood_curves() + index_valid = clipper.compute_clipped_sigma_g_matrix(lh) + result_data.update_index_valid(index_valid) + return + + # TODO: Remove this logic once we have switched over to Results. if num_threads > 1: - lh_list = [[row.likelihood_curve] for row in result_list.results] + lh_list = [[row.likelihood_curve] for row in result_data.results] keep_idx_results = [] pool = mp.Pool(processes=num_threads) - keep_idx_results = pool.starmap_async(params.compute_clipped_sigma_g, lh_list) + keep_idx_results = pool.starmap_async(clipper.compute_clipped_sigma_g, lh_list) pool.close() pool.join() keep_idx_results = keep_idx_results.get() for i, res in enumerate(keep_idx_results): - result_list.results[i].filter_indices(res) + result_data.results[i].filter_indices(res) else: - for row in result_list.results: - apply_single_clipped_sigma_g(params, row) + for row in result_data.results: + apply_single_clipped_sigma_g(clipper, row) diff --git a/src/kbmod/results.py b/src/kbmod/results.py index 05493c8f4..2d458e8ec 100644 --- a/src/kbmod/results.py +++ b/src/kbmod/results.py @@ -330,6 +330,34 @@ def add_psi_phi_data(self, psi_array, phi_array, index_valid=None): return self + def update_index_valid(self, index_valid): + """Updates or appends the 'index_valid' column. + + Parameters + ---------- + index_valid : `numpy.ndarray` + An array with one row per results and one column per timestamp + with Booleans indicating whether the corresponding observation + is valid. + + Returns + ------- + self : `Results` + Returns a reference to itself to allow chaining. + + Raises + ------ + Raises a ValueError if the input array is not the same size as the table + or a given pair of rows in the arrays are not the same length. + """ + if len(index_valid) != len(self.table): + raise ValueError("Wrong number of index_valid lists provided.") + self.table["index_valid"] = index_valid + + # Update the track likelihoods given this new information. + self._update_likelihood() + return self + def filter_mask(self, mask, label=None): """Filter the rows in the ResultTable to only include those indices that are marked True in the mask. diff --git a/tests/test_results.py b/tests/test_results.py index 3b4ec0755..2afc73c0b 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -125,6 +125,13 @@ def test_extend(self): with self.assertRaises(ValueError): table1.extend(table3) + # Test starting from an empty table. + table4 = Results([]) + table4.extend(table1) + self.assertEqual(len(table1), len(table4)) + for i in range(self.num_entries): + self.assertEqual(table1["x"][i], i) + def test_add_psi_phi(self): num_to_use = 3 table = Results(self.trj_list[0:num_to_use]) @@ -153,6 +160,36 @@ def test_add_psi_phi(self): self.assertAlmostEqual(table["flux"][i], exp_flux[i], delta=1e-5) self.assertEqual(table["obs_count"][i], exp_obs[i]) + def test_update_index_valid(self): + num_to_use = 3 + table = Results(self.trj_list[0:num_to_use]) + psi_array = np.array([[1.0, 1.1, 1.2, 1.3] for i in range(num_to_use)]) + phi_array = np.array([[1.0, 1.0, 0.0, 2.0] for i in range(num_to_use)]) + table.add_psi_phi_data(psi_array, phi_array) + for i in range(num_to_use): + self.assertAlmostEqual(table["likelihood"][i], 2.3, delta=1e-5) + self.assertAlmostEqual(table["flux"][i], 1.15, delta=1e-5) + self.assertEqual(table["obs_count"][i], 4) + + # Add the index_valid column later to simulate sigmaG clipping. + index_valid = np.array( + [ + [True, True, True, True], + [True, False, True, True], + [False, False, False, False], + ] + ) + table.update_index_valid(index_valid) + + exp_lh = [2.3, 2.020725, 0.0] + exp_flux = [1.15, 1.1666667, 0.0] + exp_obs = [4, 3, 0] + for i in range(num_to_use): + self.assertEqual(len(table["index_valid"][i]), 4) + self.assertAlmostEqual(table["likelihood"][i], exp_lh[i], delta=1e-5) + self.assertAlmostEqual(table["flux"][i], exp_flux[i], delta=1e-5) + self.assertEqual(table["obs_count"][i], exp_obs[i]) + def test_compute_likelihood_curves(self): num_to_use = 3 table = Results(self.trj_list[0:num_to_use]) diff --git a/tests/test_sigma_g_filter.py b/tests/test_sigma_g_filter.py index 2563e073d..477988da6 100644 --- a/tests/test_sigma_g_filter.py +++ b/tests/test_sigma_g_filter.py @@ -3,6 +3,7 @@ from kbmod.filters.sigma_g_filter import SigmaGClipping, apply_clipped_sigma_g from kbmod.result_list import ResultRow, ResultList +from kbmod.results import Results from kbmod.search import Trajectory @@ -110,6 +111,32 @@ def test_apply_clipped_sigma_g(self): for i in range(5): self.assertEqual(len(r_set.results[i].valid_indices), num_times - i) + def test_apply_clipped_sigma_g_results(self): + """Confirm the clipped sigmaG filter works when used with a Results object.""" + num_times = 20 + num_results = 5 + trj_all = [Trajectory() for _ in range(num_results)] + table = Results(trj_all) + + phi_all = np.full((num_results, num_times), 0.1) + psi_all = np.full((num_results, num_times), 1.0) + for i in range(5): + for j in range(i): + psi_all[i, j] = 100.0 + table.add_psi_phi_data(psi_all, phi_all) + + clipper = SigmaGClipping(10, 90) + apply_clipped_sigma_g(clipper, table) + self.assertEqual(len(table), 5) + + # Confirm that the ResultRows were modified in place. + for i in range(num_results): + valid = table["index_valid"][i] + for j in range(i): + self.assertFalse(valid[j]) + for j in range(i, num_times): + self.assertTrue(valid[j]) + def test_sigmag_computation(self): self.assertAlmostEqual(SigmaGClipping.find_sigma_g_coeff(25.0, 75.0), 0.7413, delta=0.001) self.assertRaises(ValueError, SigmaGClipping.find_sigma_g_coeff, -1.0, 75.0)