From eb22a90ea41bbed8245b522209e3943697771aad Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:28:15 -0400 Subject: [PATCH] Remove the BatchFilter abstract data type --- src/kbmod/filters/base_filter.py | 40 ------------------------- src/kbmod/filters/clustering_filters.py | 9 +++--- src/kbmod/result_list.py | 19 ------------ 3 files changed, 4 insertions(+), 64 deletions(-) diff --git a/src/kbmod/filters/base_filter.py b/src/kbmod/filters/base_filter.py index ff946b914..db62c32c7 100644 --- a/src/kbmod/filters/base_filter.py +++ b/src/kbmod/filters/base_filter.py @@ -37,43 +37,3 @@ def keep_row(self, row: ResultRow): An indicator of whether to keep the row. """ pass - - -class BatchFilter(abc.ABC): - """The base class for derived filters on the ResultList - that operate on the results in a single batch. - - Batching should be used when the user needs greater control - over how the filter is run, such as using aggregate statistics - from all candidates or running batch computations on GPUs. - """ - - def __init__(self, *args, **kwargs): - pass - - @abc.abstractmethod - def get_filter_name(self): - """Get the name of the filter. - - Returns - ------- - str - The filter name. - """ - pass - - @abc.abstractmethod - def keep_indices(self, results: ResultList): - """Determine which of the ResultList's indices to keep. - - Parameters - ---------- - results: ResultList - The set of results to filter. - - Returns - ------- - list - A list of indices (int) indicating which rows to keep. - """ - pass diff --git a/src/kbmod/filters/clustering_filters.py b/src/kbmod/filters/clustering_filters.py index d271868ea..6d93f142f 100644 --- a/src/kbmod/filters/clustering_filters.py +++ b/src/kbmod/filters/clustering_filters.py @@ -1,11 +1,10 @@ import numpy as np from sklearn.cluster import DBSCAN -from kbmod.filters.base_filter import BatchFilter from kbmod.result_list import ResultList, ResultRow -class DBSCANFilter(BatchFilter): +class DBSCANFilter: """Cluster the candidates using DBSCAN and only keep a single representative trajectory from each cluster.""" @@ -17,8 +16,6 @@ def __init__(self, eps, *args, **kwargs): eps : `float` The clustering threshold. """ - super().__init__(*args, **kwargs) - self.eps = eps self.cluster_type = "" self.cluster_args = dict(eps=self.eps, min_samples=1, n_jobs=-1) @@ -272,4 +269,6 @@ def apply_clustering(result_list, cluster_params): filt = ClusterMidPosFilter(**cluster_params) else: raise ValueError(f"Unknown clustering type: {cluster_type}") - result_list.apply_batch_filter(filt) + + indices_to_keep = filt.keep_indices(result_list) + result_list.filter_results(indices_to_keep, filt.get_filter_name()) diff --git a/src/kbmod/result_list.py b/src/kbmod/result_list.py index 5b3e158c5..fdd3fa3b3 100644 --- a/src/kbmod/result_list.py +++ b/src/kbmod/result_list.py @@ -800,25 +800,6 @@ def apply_filter(self, filter_obj, num_threads=1): return self - def apply_batch_filter(self, filter_obj): - """Apply the given batch filter object to the ResultList. - - Modifies the ResultList in place. - - Parameters - ---------- - filter_obj : BatchFilter - The filtering object to use. - - Returns - ------- - self : ResultList - Returns a reference to itself to allow chaining. - """ - indices_to_keep = filter_obj.keep_indices(self) - self.filter_results(indices_to_keep, filter_obj.get_filter_name()) - return self - def get_filtered(self, label=None): """Get the results filtered at a given stage or all stages.