Skip to content

Commit

Permalink
Merge pull request #561 from dirac-institute/batch
Browse files Browse the repository at this point in the history
Remove the BatchFilter abstract data type
  • Loading branch information
jeremykubica authored Apr 18, 2024
2 parents ae4b2e4 + eb22a90 commit 6011686
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 64 deletions.
40 changes: 0 additions & 40 deletions src/kbmod/filters/base_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,3 @@ def keep_row(self, row: ResultRow):
An indicator of whether to keep the row.
"""
pass


class BatchFilter(abc.ABC):
"""The base class for derived filters on the ResultList
that operate on the results in a single batch.
Batching should be used when the user needs greater control
over how the filter is run, such as using aggregate statistics
from all candidates or running batch computations on GPUs.
"""

def __init__(self, *args, **kwargs):
pass

@abc.abstractmethod
def get_filter_name(self):
"""Get the name of the filter.
Returns
-------
str
The filter name.
"""
pass

@abc.abstractmethod
def keep_indices(self, results: ResultList):
"""Determine which of the ResultList's indices to keep.
Parameters
----------
results: ResultList
The set of results to filter.
Returns
-------
list
A list of indices (int) indicating which rows to keep.
"""
pass
9 changes: 4 additions & 5 deletions src/kbmod/filters/clustering_filters.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
from sklearn.cluster import DBSCAN

from kbmod.filters.base_filter import BatchFilter
from kbmod.result_list import ResultList, ResultRow


class DBSCANFilter(BatchFilter):
class DBSCANFilter:
"""Cluster the candidates using DBSCAN and only keep a
single representative trajectory from each cluster."""

Expand All @@ -17,8 +16,6 @@ def __init__(self, eps, *args, **kwargs):
eps : `float`
The clustering threshold.
"""
super().__init__(*args, **kwargs)

self.eps = eps
self.cluster_type = ""
self.cluster_args = dict(eps=self.eps, min_samples=1, n_jobs=-1)
Expand Down Expand Up @@ -272,4 +269,6 @@ def apply_clustering(result_list, cluster_params):
filt = ClusterMidPosFilter(**cluster_params)
else:
raise ValueError(f"Unknown clustering type: {cluster_type}")
result_list.apply_batch_filter(filt)

indices_to_keep = filt.keep_indices(result_list)
result_list.filter_results(indices_to_keep, filt.get_filter_name())
19 changes: 0 additions & 19 deletions src/kbmod/result_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,25 +800,6 @@ def apply_filter(self, filter_obj, num_threads=1):

return self

def apply_batch_filter(self, filter_obj):
"""Apply the given batch filter object to the ResultList.
Modifies the ResultList in place.
Parameters
----------
filter_obj : BatchFilter
The filtering object to use.
Returns
-------
self : ResultList
Returns a reference to itself to allow chaining.
"""
indices_to_keep = filter_obj.keep_indices(self)
self.filter_results(indices_to_keep, filter_obj.get_filter_name())
return self

def get_filtered(self, label=None):
"""Get the results filtered at a given stage or all stages.
Expand Down

0 comments on commit 6011686

Please sign in to comment.