Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade filter logic #399

Merged
merged 3 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/kbmod/analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .file_utils import *
from .filters.clustering_filters import DBSCANFilter
from .filters.stats_filters import LHFilter, NumObsFilter
from .filters.stats_filters import CombinedStatsFilter
from .filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping
from .result_list import ResultList, ResultRow

Expand All @@ -25,6 +25,7 @@ class PostProcess:
def __init__(self, config, mjds):
self.coeff = None
self.num_cores = config["num_cores"]
self.num_obs = config["num_obs"]
self.sigmaG_lims = config["sigmaG_lims"]
self.eps = config["eps"]
self.cluster_type = config["cluster_type"]
Expand Down Expand Up @@ -75,6 +76,12 @@ def load_and_filter_results(
bnds = [25, 75]
clipper = SigmaGClipping(bnds[0], bnds[1], 2, self.clip_negative)

# Set up the combined stats filter.
if lh_level > 0.0:
stats_filter = CombinedStatsFilter(min_obs=self.num_obs, min_lh=lh_level)
else:
stats_filter = CombinedStatsFilter(min_obs=self.num_obs)

print("---------------------------------------")
print("Retrieving Results")
print("---------------------------------------")
Expand All @@ -94,6 +101,7 @@ def load_and_filter_results(
if trj.lh < lh_level:
likelihood_limit = True
break

if trj.lh < max_lh:
row = ResultRow(trj, len(self._mjds))
psi_curve = np.array(search.get_psi_curves(trj))
Expand All @@ -106,11 +114,7 @@ def load_and_filter_results(
print("Extracted batch of %i results for total of %i" % (batch_size, total_count))
if batch_size > 0:
apply_clipped_sigma_g(clipper, result_batch, self.num_cores)
result_batch.apply_filter(NumObsFilter(3))

# Apply the likelihood filter if one is provided.
if lh_level > 0.0:
result_batch.apply_filter(LHFilter(lh_level, None))
result_batch.apply_filter(stats_filter)

# Add the results to the final set.
keep.extend(result_batch)
Expand Down
103 changes: 103 additions & 0 deletions src/kbmod/filters/stats_filters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import numpy as np

from kbmod.filters.base_filter import RowFilter
from kbmod.result_list import ResultRow

Expand Down Expand Up @@ -112,3 +114,104 @@ def keep_row(self, row: ResultRow):
An indicator of whether to keep the row.
"""
return len(row.valid_indices) >= self.min_obs


class CombinedStatsFilter(RowFilter):
"""A filter for result's likelihood and number of observations."""

def __init__(self, min_obs=0, min_lh=-np.inf, max_lh=np.inf, *args, **kwargs):
"""Create a ResultsLHFilter.

Parameters
----------
min_obs : ``int``
The minimum number of observations.
min_lh : ``float``
Minimal allowed likelihood.
max_lh : ``float``
Maximal allowed likelihood.
"""
super().__init__(*args, **kwargs)

self.min_obs = min_obs
self.min_lh = min_lh
self.max_lh = max_lh

def get_filter_name(self):
"""Get the name of the filter.

Returns
-------
str
The filter name.
"""
return f"CombinedStats_{self.min_obs}_{self.min_lh}_to_{self.max_lh}"

def keep_row(self, row: ResultRow):
"""Determine whether to keep an individual row based on
the likelihood.

Parameters
----------
row : ResultRow
The row to evaluate.

Returns
-------
bool
An indicator of whether to keep the row.
"""
if row.final_likelihood < self.min_lh or row.final_likelihood > self.max_lh:
return False
if len(row.valid_indices) < self.min_obs:
return False
return True


class DurationFilter(RowFilter):
"""A filter for the amount of time covered by the trajectory"""

def __init__(self, all_times, min_duration, *args, **kwargs):
"""Create a ResultsLHFilter.

Parameters
----------
all_times : ``list``
The time stamps in increasing order.
min_duration : ``float``
The minimum duration in days for a valid result.
"""
super().__init__(*args, **kwargs)

self.all_times = all_times
self.min_duration = min_duration

def get_filter_name(self):
"""Get the name of the filter.

Returns
-------
str
The filter name.
"""
return f"Duration_{self.min_duration}"

def keep_row(self, row: ResultRow):
"""Determine whether to keep an individual row based on
the likelihood.

Parameters
----------
row : ResultRow
The row to evaluate.

Returns
-------
bool
An indicator of whether to keep the row.
"""
min_index = np.min(row.valid_indices)
max_index = np.max(row.valid_indices)
if self.all_times[max_index] - self.all_times[min_index] < self.min_duration:
return False
return True
40 changes: 17 additions & 23 deletions tests/test_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,12 @@
from kbmod.fake_data_creator import add_fake_object
from kbmod.result_list import *
from kbmod.search import *
from kbmod.trajectory_utils import make_trajectory

from utils.utils_for_tests import get_absolute_data_path


class test_analysis_utils(unittest.TestCase):
def _make_trajectory(self, x0, y0, xv, yv, lh):
t = Trajectory()
t.x = x0
t.y = y0
t.vx = xv
t.vy = yv
t.lh = lh
return t

def setUp(self):
# The configuration parameters.
self.default_mask_bits_dict = {
Expand Down Expand Up @@ -277,11 +269,11 @@ def test_clustering(self):
cluster_params["mjd"] = np.array(self.stack.build_zeroed_times())

trjs = [
self._make_trajectory(10, 11, 1, 2, 100.0),
self._make_trajectory(10, 11, 10, 20, 100.0),
self._make_trajectory(40, 5, -1, 2, 100.0),
self._make_trajectory(5, 0, 1, 2, 100.0),
self._make_trajectory(5, 1, 1, 2, 100.0),
make_trajectory(x=10, y=11, vx=1, vy=2, lh=100.0),
make_trajectory(x=10, y=11, vx=10, vy=20, lh=100.0),
make_trajectory(x=40, y=5, vx=-1, vy=2, lh=100.0),
make_trajectory(x=5, y=0, vx=1, vy=2, lh=100.0),
make_trajectory(x=5, y=1, vx=1, vy=2, lh=100.0),
]

# Try clustering with positions, velocities, and angles.
Expand All @@ -306,15 +298,15 @@ def test_clustering(self):
self.assertEqual(results2.num_results(), 3)

def test_load_and_filter_results_lh(self):
# Create fake result trajectories with given initial likelihoods.
# Create fake result trajectories with given initial likelihoods. The 1st is
# filtered by max likelihood. The 4th and 5th are filtered by min likelihood.
trjs = [
self._make_trajectory(20, 20, 0, 0, 9000.0), # Filtered by max likelihood
self._make_trajectory(30, 30, 0, 0, 100.0),
self._make_trajectory(40, 40, 0, 0, 50.0),
self._make_trajectory(50, 50, 0, 0, 2.0), # Filtered by min likelihood
self._make_trajectory(60, 60, 0, 0, 1.0), # Filtered by min likelihood
make_trajectory(20, 20, 0, 0, 500.0, 9000.0, self.img_count),
make_trajectory(30, 30, 0, 0, 100.0, 100.0, self.img_count),
make_trajectory(40, 40, 0, 0, 50.0, 50.0, self.img_count),
make_trajectory(50, 50, 0, 0, 1.0, 2.0, self.img_count),
make_trajectory(60, 60, 0, 0, 1.0, 1.0, self.img_count),
]
fluxes = [500.0, 100.0, 50.0, 1.0, 0.1]

# Create fake images with the objects in them.
imlist = []
Expand All @@ -324,7 +316,7 @@ def test_load_and_filter_results_lh(self):

# Add the objects.
for j, trj in enumerate(trjs):
add_fake_object(im, trj.x, trj.y, fluxes[j], self.p)
add_fake_object(im, trj.x, trj.y, trj.flux, self.p)

# Append the image.
imlist.append(im)
Expand All @@ -334,15 +326,17 @@ def test_load_and_filter_results_lh(self):
search.set_results(trjs)

# Do the filtering.
self.config["num_obs"] = 5
kb_post_process = PostProcess(self.config, self.time_list)

results = kb_post_process.load_and_filter_results(
search,
10.0, # min likelihood
chunk_size=500000,
max_lh=1000.0,
)

# Only the middle two results should pass the filtering.
# Only two of the middle results should pass the filtering.
self.assertEqual(results.num_results(), 2)
self.assertEqual(results.results[0].trajectory.y, 30)
self.assertEqual(results.results[1].trajectory.y, 40)
Expand Down
39 changes: 39 additions & 0 deletions tests/test_stats_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,45 @@ def test_filter_valid_indices(self):
for i in range(self.rs.num_results()):
self.assertGreaterEqual(len(self.rs.results[i].valid_indices), 4)

def test_combined_stats_filter(self):
self.assertEqual(self.rs.num_results(), 10)

f = CombinedStatsFilter(min_obs=4, min_lh=5.1)
self.assertEqual(f.get_filter_name(), "CombinedStats_4_5.1_to_inf")

# Do the filtering and check we have the correct ones.
self.rs.apply_filter(f)
self.assertEqual(self.rs.num_results(), 4)
for row in self.rs.results:
self.assertGreaterEqual(len(row.valid_indices), 4)
self.assertGreaterEqual(row.final_likelihood, 5.1)

def test_duration_filter(self):
f = DurationFilter(self.times, 0.81)
self.assertEqual(f.get_filter_name(), "Duration_0.81")

res_list = ResultList(self.times, track_filtered=True)

# Add a full track
row0 = ResultRow(Trajectory(), self.num_times)
res_list.append_result(row0)

# Add a track with every 4th observation
row1 = ResultRow(Trajectory(), self.num_times)
row1.filter_indices([k for k in range(self.num_times) if k % 4 == 0])
res_list.append_result(row1)

# Add a track with a short burst in the middle.
row2 = ResultRow(Trajectory(), self.num_times)
row2.filter_indices([3, 4, 5, 6, 7, 8, 9])
res_list.append_result(row2)

res_list.apply_filter(f)
self.assertEqual(res_list.num_results(), 2)

self.assertGreaterEqual(len(res_list.results[0].valid_indices), self.num_times)
self.assertGreaterEqual(len(res_list.results[1].valid_indices), int(self.num_times / 4))


if __name__ == "__main__":
unittest.main()