Skip to content

Commit

Permalink
Merge pull request #399 from dirac-institute/add_filters
Browse files Browse the repository at this point in the history
Upgrade filter logic
jeremykubica authored Dec 5, 2023

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 58e2283 + e1143b9 commit 0ff88d0
Showing 4 changed files with 169 additions and 29 deletions.
16 changes: 10 additions & 6 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@

from .file_utils import *
from .filters.clustering_filters import DBSCANFilter
from .filters.stats_filters import LHFilter, NumObsFilter
from .filters.stats_filters import CombinedStatsFilter
from .filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping
from .result_list import ResultList, ResultRow

@@ -25,6 +25,7 @@ class PostProcess:
def __init__(self, config, mjds):
self.coeff = None
self.num_cores = config["num_cores"]
self.num_obs = config["num_obs"]
self.sigmaG_lims = config["sigmaG_lims"]
self.eps = config["eps"]
self.cluster_type = config["cluster_type"]
@@ -75,6 +76,12 @@ def load_and_filter_results(
bnds = [25, 75]
clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative)

# Set up the combined stats filter.
if lh_level > 0.0:
stats_filter = CombinedStatsFilter(min_obs=self.num_obs, min_lh=lh_level)
else:
stats_filter = CombinedStatsFilter(min_obs=self.num_obs)

print("---------------------------------------")
print("Retrieving Results")
print("---------------------------------------")
@@ -94,6 +101,7 @@ def load_and_filter_results(
if trj.lh < lh_level:
likelihood_limit = True
break

if trj.lh < max_lh:
row = ResultRow(trj, len(self._mjds))
psi_curve = np.array(search.get_psi_curves(trj))
@@ -106,11 +114,7 @@ def load_and_filter_results(
print("Extracted batch of %i results for total of %i" % (batch_size, total_count))
if batch_size > 0:
apply_clipped_sigma_g(clipper, result_batch, self.num_cores)
result_batch.apply_filter(NumObsFilter(3))

# Apply the likelihood filter if one is provided.
if lh_level > 0.0:
result_batch.apply_filter(LHFilter(lh_level, None))
result_batch.apply_filter(stats_filter)

# Add the results to the final set.
keep.extend(result_batch)
103 changes: 103 additions & 0 deletions src/kbmod/filters/stats_filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from kbmod.filters.base_filter import RowFilter
from kbmod.result_list import ResultRow

@@ -112,3 +114,104 @@ def keep_row(self, row: ResultRow):
An indicator of whether to keep the row.
"""
return len(row.valid_indices) >= self.min_obs


class CombinedStatsFilter(RowFilter):
"""A filter for result's likelihood and number of observations."""

def __init__(self, min_obs=0, min_lh=-np.inf, max_lh=np.inf, *args, **kwargs):
"""Create a ResultsLHFilter.
Parameters
----------
min_obs : ``int``
The minimum number of observations.
min_lh : ``float``
Minimal allowed likelihood.
max_lh : ``float``
Maximal allowed likelihood.
"""
super().__init__(*args, **kwargs)

self.min_obs = min_obs
self.min_lh = min_lh
self.max_lh = max_lh

def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
return f"CombinedStats_{self.min_obs}_{self.min_lh}_to_{self.max_lh}"

def keep_row(self, row: ResultRow):
"""Determine whether to keep an individual row based on
the likelihood.
Parameters
----------
row : ResultRow
The row to evaluate.
Returns
-------
bool
An indicator of whether to keep the row.
"""
if row.final_likelihood < self.min_lh or row.final_likelihood > self.max_lh:
return False
if len(row.valid_indices) < self.min_obs:
return False
return True


class DurationFilter(RowFilter):
"""A filter for the amount of time covered by the trajectory"""

def __init__(self, all_times, min_duration, *args, **kwargs):
"""Create a ResultsLHFilter.
Parameters
----------
all_times : ``list``
The time stamps in increasing order.
min_duration : ``float``
The minimum duration in days for a valid result.
"""
super().__init__(*args, **kwargs)

self.all_times = all_times
self.min_duration = min_duration

def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
return f"Duration_{self.min_duration}"

def keep_row(self, row: ResultRow):
"""Determine whether to keep an individual row based on
the likelihood.
Parameters
----------
row : ResultRow
The row to evaluate.
Returns
-------
bool
An indicator of whether to keep the row.
"""
min_index = np.min(row.valid_indices)
max_index = np.max(row.valid_indices)
if self.all_times[max_index] - self.all_times[min_index] < self.min_duration:
return False
return True
40 changes: 17 additions & 23 deletions tests/test_analysis_utils.py
Original file line number Diff line number Diff line change
@@ -4,20 +4,12 @@
from kbmod.fake_data_creator import add_fake_object
from kbmod.result_list import *
from kbmod.search import *
from kbmod.trajectory_utils import make_trajectory

from utils.utils_for_tests import get_absolute_data_path


class test_analysis_utils(unittest.TestCase):
def _make_trajectory(self, x0, y0, xv, yv, lh):
t = Trajectory()
t.x = x0
t.y = y0
t.vx = xv
t.vy = yv
t.lh = lh
return t

def setUp(self):
# The configuration parameters.
self.default_mask_bits_dict = {
@@ -277,11 +269,11 @@ def test_clustering(self):
cluster_params["mjd"] = np.array(self.stack.build_zeroed_times())

trjs = [
self._make_trajectory(10, 11, 1, 2, 100.0),
self._make_trajectory(10, 11, 10, 20, 100.0),
self._make_trajectory(40, 5, -1, 2, 100.0),
self._make_trajectory(5, 0, 1, 2, 100.0),
self._make_trajectory(5, 1, 1, 2, 100.0),
make_trajectory(x=10, y=11, vx=1, vy=2, lh=100.0),
make_trajectory(x=10, y=11, vx=10, vy=20, lh=100.0),
make_trajectory(x=40, y=5, vx=-1, vy=2, lh=100.0),
make_trajectory(x=5, y=0, vx=1, vy=2, lh=100.0),
make_trajectory(x=5, y=1, vx=1, vy=2, lh=100.0),
]

# Try clustering with positions, velocities, and angles.
@@ -306,15 +298,15 @@ def test_clustering(self):
self.assertEqual(results2.num_results(), 3)

def test_load_and_filter_results_lh(self):
# Create fake result trajectories with given initial likelihoods.
# Create fake result trajectories with given initial likelihoods. The 1st is
# filtered by max likelihood. The 4th and 5th are filtered by min likelihood.
trjs = [
self._make_trajectory(20, 20, 0, 0, 9000.0), # Filtered by max likelihood
self._make_trajectory(30, 30, 0, 0, 100.0),
self._make_trajectory(40, 40, 0, 0, 50.0),
self._make_trajectory(50, 50, 0, 0, 2.0), # Filtered by min likelihood
self._make_trajectory(60, 60, 0, 0, 1.0), # Filtered by min likelihood
make_trajectory(20, 20, 0, 0, 500.0, 9000.0, self.img_count),
make_trajectory(30, 30, 0, 0, 100.0, 100.0, self.img_count),
make_trajectory(40, 40, 0, 0, 50.0, 50.0, self.img_count),
make_trajectory(50, 50, 0, 0, 1.0, 2.0, self.img_count),
make_trajectory(60, 60, 0, 0, 1.0, 1.0, self.img_count),
]
fluxes = [500.0, 100.0, 50.0, 1.0, 0.1]

# Create fake images with the objects in them.
imlist = []
@@ -324,7 +316,7 @@ def test_load_and_filter_results_lh(self):

# Add the objects.
for j, trj in enumerate(trjs):
add_fake_object(im, trj.x, trj.y, fluxes[j], self.p)
add_fake_object(im, trj.x, trj.y, trj.flux, self.p)

# Append the image.
imlist.append(im)
@@ -334,15 +326,17 @@ def test_load_and_filter_results_lh(self):
search.set_results(trjs)

# Do the filtering.
self.config["num_obs"] = 5
kb_post_process = PostProcess(self.config, self.time_list)

results = kb_post_process.load_and_filter_results(
search,
10.0, # min likelihood
chunk_size=500000,
max_lh=1000.0,
)

# Only the middle two results should pass the filtering.
# Only two of the middle results should pass the filtering.
self.assertEqual(results.num_results(), 2)
self.assertEqual(results.results[0].trajectory.y, 30)
self.assertEqual(results.results[1].trajectory.y, 40)
39 changes: 39 additions & 0 deletions tests/test_stats_filters.py
Original file line number Diff line number Diff line change
@@ -97,6 +97,45 @@ def test_filter_valid_indices(self):
for i in range(self.rs.num_results()):
self.assertGreaterEqual(len(self.rs.results[i].valid_indices), 4)

def test_combined_stats_filter(self):
self.assertEqual(self.rs.num_results(), 10)

f = CombinedStatsFilter(min_obs=4, min_lh=5.1)
self.assertEqual(f.get_filter_name(), "CombinedStats_4_5.1_to_inf")

# Do the filtering and check we have the correct ones.
self.rs.apply_filter(f)
self.assertEqual(self.rs.num_results(), 4)
for row in self.rs.results:
self.assertGreaterEqual(len(row.valid_indices), 4)
self.assertGreaterEqual(row.final_likelihood, 5.1)

def test_duration_filter(self):
f = DurationFilter(self.times, 0.81)
self.assertEqual(f.get_filter_name(), "Duration_0.81")

res_list = ResultList(self.times, track_filtered=True)

# Add a full track
row0 = ResultRow(Trajectory(), self.num_times)
res_list.append_result(row0)

# Add a track with every 4th observation
row1 = ResultRow(Trajectory(), self.num_times)
row1.filter_indices([k for k in range(self.num_times) if k % 4 == 0])
res_list.append_result(row1)

# Add a track with a short burst in the middle.
row2 = ResultRow(Trajectory(), self.num_times)
row2.filter_indices([3, 4, 5, 6, 7, 8, 9])
res_list.append_result(row2)

res_list.apply_filter(f)
self.assertEqual(res_list.num_results(), 2)

self.assertGreaterEqual(len(res_list.results[0].valid_indices), self.num_times)
self.assertGreaterEqual(len(res_list.results[1].valid_indices), int(self.num_times / 4))


if __name__ == "__main__":
unittest.main()

0 comments on commit 0ff88d0

Please sign in to comment.