diff --git a/src/kbmod/analysis_utils.py b/src/kbmod/analysis_utils.py index 06697c591..4da73ca78 100644 --- a/src/kbmod/analysis_utils.py +++ b/src/kbmod/analysis_utils.py @@ -9,7 +9,7 @@ from .file_utils import * from .filters.clustering_filters import DBSCANFilter -from .filters.stats_filters import LHFilter, NumObsFilter +from .filters.stats_filters import CombinedStatsFilter from .filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping from .result_list import ResultList, ResultRow @@ -25,6 +25,7 @@ class PostProcess: def __init__(self, config, mjds): self.coeff = None self.num_cores = config["num_cores"] + self.num_obs = config["num_obs"] self.sigmaG_lims = config["sigmaG_lims"] self.eps = config["eps"] self.cluster_type = config["cluster_type"] @@ -75,6 +76,12 @@ def load_and_filter_results( bnds = [25, 75] clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative) + # Set up the combined stats filter. + if lh_level > 0.0: + stats_filter = CombinedStatsFilter(min_obs=self.num_obs, min_lh=lh_level) + else: + stats_filter = CombinedStatsFilter(min_obs=self.num_obs) + print("---------------------------------------") print("Retrieving Results") print("---------------------------------------") @@ -94,6 +101,7 @@ def load_and_filter_results( if trj.lh < lh_level: likelihood_limit = True break + if trj.lh < max_lh: row = ResultRow(trj, len(self._mjds)) psi_curve = np.array(search.get_psi_curves(trj)) @@ -106,11 +114,7 @@ def load_and_filter_results( print("Extracted batch of %i results for total of %i" % (batch_size, total_count)) if batch_size > 0: apply_clipped_sigma_g(clipper, result_batch, self.num_cores) - result_batch.apply_filter(NumObsFilter(3)) - - # Apply the likelihood filter if one is provided. - if lh_level > 0.0: - result_batch.apply_filter(LHFilter(lh_level, None)) + result_batch.apply_filter(stats_filter) # Add the results to the final set. keep.extend(result_batch) diff --git a/src/kbmod/filters/stats_filters.py b/src/kbmod/filters/stats_filters.py index c335ad30d..12eb03683 100644 --- a/src/kbmod/filters/stats_filters.py +++ b/src/kbmod/filters/stats_filters.py @@ -1,3 +1,5 @@ +import numpy as np + from kbmod.filters.base_filter import RowFilter from kbmod.result_list import ResultRow @@ -112,3 +114,104 @@ def keep_row(self, row: ResultRow): An indicator of whether to keep the row. """ return len(row.valid_indices) >= self.min_obs + + +class CombinedStatsFilter(RowFilter): + """A filter for result's likelihood and number of observations.""" + + def __init__(self, min_obs=0, min_lh=-np.inf, max_lh=np.inf, *args, **kwargs): + """Create a ResultsLHFilter. + + Parameters + ---------- + min_obs : ``int`` + The minimum number of observations. + min_lh : ``float`` + Minimal allowed likelihood. + max_lh : ``float`` + Maximal allowed likelihood. + """ + super().__init__(*args, **kwargs) + + self.min_obs = min_obs + self.min_lh = min_lh + self.max_lh = max_lh + + def get_filter_name(self): + """Get the name of the filter. + + Returns + ------- + str + The filter name. + """ + return f"CombinedStats_{self.min_obs}_{self.min_lh}_to_{self.max_lh}" + + def keep_row(self, row: ResultRow): + """Determine whether to keep an individual row based on + the likelihood. + + Parameters + ---------- + row : ResultRow + The row to evaluate. + + Returns + ------- + bool + An indicator of whether to keep the row. + """ + if row.final_likelihood < self.min_lh or row.final_likelihood > self.max_lh: + return False + if len(row.valid_indices) < self.min_obs: + return False + return True + + +class DurationFilter(RowFilter): + """A filter for the amount of time covered by the trajectory""" + + def __init__(self, all_times, min_duration, *args, **kwargs): + """Create a ResultsLHFilter. + + Parameters + ---------- + all_times : ``list`` + The time stamps in increasing order. + min_duration : ``float`` + The minimum duration in days for a valid result. + """ + super().__init__(*args, **kwargs) + + self.all_times = all_times + self.min_duration = min_duration + + def get_filter_name(self): + """Get the name of the filter. + + Returns + ------- + str + The filter name. + """ + return f"Duration_{self.min_duration}" + + def keep_row(self, row: ResultRow): + """Determine whether to keep an individual row based on + the likelihood. + + Parameters + ---------- + row : ResultRow + The row to evaluate. + + Returns + ------- + bool + An indicator of whether to keep the row. + """ + min_index = np.min(row.valid_indices) + max_index = np.max(row.valid_indices) + if self.all_times[max_index] - self.all_times[min_index] < self.min_duration: + return False + return True diff --git a/tests/test_analysis_utils.py b/tests/test_analysis_utils.py index 2d5bf914d..72539f37d 100644 --- a/tests/test_analysis_utils.py +++ b/tests/test_analysis_utils.py @@ -4,20 +4,12 @@ from kbmod.fake_data_creator import add_fake_object from kbmod.result_list import * from kbmod.search import * +from kbmod.trajectory_utils import make_trajectory from utils.utils_for_tests import get_absolute_data_path class test_analysis_utils(unittest.TestCase): - def _make_trajectory(self, x0, y0, xv, yv, lh): - t = Trajectory() - t.x = x0 - t.y = y0 - t.vx = xv - t.vy = yv - t.lh = lh - return t - def setUp(self): # The configuration parameters. self.default_mask_bits_dict = { @@ -277,11 +269,11 @@ def test_clustering(self): cluster_params["mjd"] = np.array(self.stack.build_zeroed_times()) trjs = [ - self._make_trajectory(10, 11, 1, 2, 100.0), - self._make_trajectory(10, 11, 10, 20, 100.0), - self._make_trajectory(40, 5, -1, 2, 100.0), - self._make_trajectory(5, 0, 1, 2, 100.0), - self._make_trajectory(5, 1, 1, 2, 100.0), + make_trajectory(x=10, y=11, vx=1, vy=2, lh=100.0), + make_trajectory(x=10, y=11, vx=10, vy=20, lh=100.0), + make_trajectory(x=40, y=5, vx=-1, vy=2, lh=100.0), + make_trajectory(x=5, y=0, vx=1, vy=2, lh=100.0), + make_trajectory(x=5, y=1, vx=1, vy=2, lh=100.0), ] # Try clustering with positions, velocities, and angles. @@ -306,15 +298,15 @@ def test_clustering(self): self.assertEqual(results2.num_results(), 3) def test_load_and_filter_results_lh(self): - # Create fake result trajectories with given initial likelihoods. + # Create fake result trajectories with given initial likelihoods. The 1st is + # filtered by max likelihood. The 4th and 5th are filtered by min likelihood. trjs = [ - self._make_trajectory(20, 20, 0, 0, 9000.0), # Filtered by max likelihood - self._make_trajectory(30, 30, 0, 0, 100.0), - self._make_trajectory(40, 40, 0, 0, 50.0), - self._make_trajectory(50, 50, 0, 0, 2.0), # Filtered by min likelihood - self._make_trajectory(60, 60, 0, 0, 1.0), # Filtered by min likelihood + make_trajectory(20, 20, 0, 0, 500.0, 9000.0, self.img_count), + make_trajectory(30, 30, 0, 0, 100.0, 100.0, self.img_count), + make_trajectory(40, 40, 0, 0, 50.0, 50.0, self.img_count), + make_trajectory(50, 50, 0, 0, 1.0, 2.0, self.img_count), + make_trajectory(60, 60, 0, 0, 1.0, 1.0, self.img_count), ] - fluxes = [500.0, 100.0, 50.0, 1.0, 0.1] # Create fake images with the objects in them. imlist = [] @@ -324,7 +316,7 @@ def test_load_and_filter_results_lh(self): # Add the objects. for j, trj in enumerate(trjs): - add_fake_object(im, trj.x, trj.y, fluxes[j], self.p) + add_fake_object(im, trj.x, trj.y, trj.flux, self.p) # Append the image. imlist.append(im) @@ -334,7 +326,9 @@ def test_load_and_filter_results_lh(self): search.set_results(trjs) # Do the filtering. + self.config["num_obs"] = 5 kb_post_process = PostProcess(self.config, self.time_list) + results = kb_post_process.load_and_filter_results( search, 10.0, # min likelihood @@ -342,7 +336,7 @@ def test_load_and_filter_results_lh(self): max_lh=1000.0, ) - # Only the middle two results should pass the filtering. + # Only two of the middle results should pass the filtering. self.assertEqual(results.num_results(), 2) self.assertEqual(results.results[0].trajectory.y, 30) self.assertEqual(results.results[1].trajectory.y, 40) diff --git a/tests/test_stats_filters.py b/tests/test_stats_filters.py index bc45fffc5..a383082c0 100644 --- a/tests/test_stats_filters.py +++ b/tests/test_stats_filters.py @@ -97,6 +97,45 @@ def test_filter_valid_indices(self): for i in range(self.rs.num_results()): self.assertGreaterEqual(len(self.rs.results[i].valid_indices), 4) + def test_combined_stats_filter(self): + self.assertEqual(self.rs.num_results(), 10) + + f = CombinedStatsFilter(min_obs=4, min_lh=5.1) + self.assertEqual(f.get_filter_name(), "CombinedStats_4_5.1_to_inf") + + # Do the filtering and check we have the correct ones. + self.rs.apply_filter(f) + self.assertEqual(self.rs.num_results(), 4) + for row in self.rs.results: + self.assertGreaterEqual(len(row.valid_indices), 4) + self.assertGreaterEqual(row.final_likelihood, 5.1) + + def test_duration_filter(self): + f = DurationFilter(self.times, 0.81) + self.assertEqual(f.get_filter_name(), "Duration_0.81") + + res_list = ResultList(self.times, track_filtered=True) + + # Add a full track + row0 = ResultRow(Trajectory(), self.num_times) + res_list.append_result(row0) + + # Add a track with every 4th observation + row1 = ResultRow(Trajectory(), self.num_times) + row1.filter_indices([k for k in range(self.num_times) if k % 4 == 0]) + res_list.append_result(row1) + + # Add a track with a short burst in the middle. + row2 = ResultRow(Trajectory(), self.num_times) + row2.filter_indices([3, 4, 5, 6, 7, 8, 9]) + res_list.append_result(row2) + + res_list.apply_filter(f) + self.assertEqual(res_list.num_results(), 2) + + self.assertGreaterEqual(len(res_list.results[0].valid_indices), self.num_times) + self.assertGreaterEqual(len(res_list.results[1].valid_indices), int(self.num_times / 4)) + if __name__ == "__main__": unittest.main()