Skip to content

Commit

Permalink
Generalize clustering code so it takes Results objects
Browse files Browse the repository at this point in the history
Duplicates the tests. We can remove the extra tests when re remove the ResultList object.
  • Loading branch information
jeremykubica committed Apr 17, 2024
1 parent ba7dc44 commit c77b95c
Show file tree
Hide file tree
Showing 2 changed files with 247 additions and 43 deletions.
97 changes: 67 additions & 30 deletions src/kbmod/filters/clustering_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@

from kbmod.filters.base_filter import BatchFilter
from kbmod.result_list import ResultList, ResultRow
from kbmod.results import Results
import kbmod.search as kb

logger = kb.Logging.getLogger(__name__)


class DBSCANFilter(BatchFilter):
Expand Down Expand Up @@ -33,12 +37,12 @@ def get_filter_name(self):
"""
return f"DBSCAN_{self.cluster_type}_{self.eps}"

def _build_clustering_data(self, result_list):
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_list: ResultList
result_data: `Results`, `ResultList`, or `list[Trajectory]`
The set of results to filter.
Returns
Expand All @@ -49,20 +53,20 @@ def _build_clustering_data(self, result_list):
"""
raise NotImplementedError()

def keep_indices(self, result_list: ResultList):
"""Determine which of the ResultList's indices to keep.
def keep_indices(self, result_data):
"""Determine which of the results's indices to keep.
Parameters
----------
result_list: ResultList
result_data: `Results` or `ResultList`
The set of results to filter.
Returns
-------
list
`list`
A list of indices (int) indicating which rows to keep.
"""
data = self._build_clustering_data(result_list)
data = self._build_clustering_data(result_data)

# Set up the clustering algorithm
cluster = DBSCAN(**self.cluster_args)
Expand Down Expand Up @@ -98,12 +102,12 @@ def __init__(self, eps, height, width, *args, **kwargs):
self.width = width
self.cluster_type = "position"

def _build_clustering_data(self, result_list):
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_list: ResultList
result_data: `Results` or `ResultList`
The set of results to filter.
Returns
Expand All @@ -112,8 +116,14 @@ def _build_clustering_data(self, result_list):
The N x D matrix to cluster where N is the number of results
and D is the number of attributes.
"""
x_arr = np.array(result_list.get_result_values("trajectory.x")) / self.width
y_arr = np.array(result_list.get_result_values("trajectory.y")) / self.height
if type(result_data) is ResultList:
x_arr = np.array(result_data.get_result_values("trajectory.x")) / self.width
y_arr = np.array(result_data.get_result_values("trajectory.y")) / self.height
elif type(result_data) is Results:
x_arr = np.array(result_data["x"]) / self.width
y_arr = np.array(result_data["y"]) / self.height
else:
raise TypeError("Unknown data type for clustering.")
return np.array([x_arr, y_arr])


Expand Down Expand Up @@ -157,12 +167,12 @@ def __init__(self, eps, height, width, vel_lims, ang_lims, *args, **kwargs):

self.cluster_type = "all"

def _build_clustering_data(self, result_list):
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_list: ResultList
result_data: `Results` or `ResultList`
The set of results to filter.
Returns
Expand All @@ -172,10 +182,19 @@ def _build_clustering_data(self, result_list):
and D is the number of attributes.
"""
# Create arrays of each the trajectories information.
x_arr = np.array(result_list.get_result_values("trajectory.x"))
y_arr = np.array(result_list.get_result_values("trajectory.y"))
vx_arr = np.array(result_list.get_result_values("trajectory.vx"))
vy_arr = np.array(result_list.get_result_values("trajectory.vy"))
if type(result_data) is ResultList:
x_arr = np.array(result_data.get_result_values("trajectory.x"))
y_arr = np.array(result_data.get_result_values("trajectory.y"))
vx_arr = np.array(result_data.get_result_values("trajectory.vx"))
vy_arr = np.array(result_data.get_result_values("trajectory.vy"))
elif type(result_data) is Results:
x_arr = np.array(result_data["x"])
y_arr = np.array(result_data["y"])
vx_arr = np.array(result_data["vx"])
vy_arr = np.array(result_data["vy"])
else:
raise TypeError("Unknown data type for clustering.")

vel_arr = np.sqrt(np.square(vx_arr) + np.square(vy_arr))
ang_arr = np.arctan2(vy_arr, vx_arr)

Expand Down Expand Up @@ -211,12 +230,12 @@ def __init__(self, eps, height, width, times, *args, **kwargs):
self.midtime = np.median(zeroed_times)
self.cluster_type = "midpoint"

def _build_clustering_data(self, result_list):
def _build_clustering_data(self, result_data):
"""Build the specific data set for this clustering approach.
Parameters
----------
result_list: ResultList
result_data: `Results` or `ResultList`
The set of results to filter.
Returns
Expand All @@ -226,10 +245,18 @@ def _build_clustering_data(self, result_list):
and D is the number of attributes.
"""
# Create arrays of each the trajectories information.
x_arr = np.array(result_list.get_result_values("trajectory.x"))
y_arr = np.array(result_list.get_result_values("trajectory.y"))
vx_arr = np.array(result_list.get_result_values("trajectory.vx"))
vy_arr = np.array(result_list.get_result_values("trajectory.vy"))
if type(result_data) is ResultList:
x_arr = np.array(result_data.get_result_values("trajectory.x"))
y_arr = np.array(result_data.get_result_values("trajectory.y"))
vx_arr = np.array(result_data.get_result_values("trajectory.vx"))
vy_arr = np.array(result_data.get_result_values("trajectory.vy"))
elif type(result_data) is Results:
x_arr = np.array(result_data["x"])
y_arr = np.array(result_data["y"])
vx_arr = np.array(result_data["vx"])
vy_arr = np.array(result_data["vy"])
else:
raise TypeError("Unknown data type for clustering.")

# Scale the values.
scaled_mid_x = (x_arr + self.midtime * vx_arr) / self.width
Expand All @@ -238,30 +265,31 @@ def _build_clustering_data(self, result_list):
return np.array([scaled_mid_x, scaled_mid_y])


def apply_clustering(result_list, cluster_params):
def apply_clustering(result_data, cluster_params):
"""This function clusters results that have similar trajectories.
Parameters
----------
result_list : `ResultList`
The values from trajectories. This data gets modified directly by
result_data: `Results` or `ResultList`
The set of results to filter. This data gets modified directly by
the filtering.
cluster_params : dict
Contains values concerning the image and search settings including:
cluster_type, eps, height, width, vel_lims, ang_lims, and mjd.
Raises
------
ValueError if the parameters are not valid.
Raises a ValueError if the parameters are not valid.
Raises a TypeError if ``result_data`` is of an unsupported type.
"""
if "cluster_type" not in cluster_params:
raise ValueError("Missing cluster_type parameter")
cluster_type = cluster_params["cluster_type"]

# Skip clustering if there is nothing to cluster.
if result_list.num_results() == 0:
if len(result_data) == 0:
return
print("Clustering %i results" % result_list.num_results(), flush=True)
logger.info(f"Clustering {len(result_data)} results using {cluster_type}")

# Do the clustering and the filtering.
if cluster_type == "all":
Expand All @@ -272,4 +300,13 @@ def apply_clustering(result_list, cluster_params):
filt = ClusterMidPosFilter(**cluster_params)
else:
raise ValueError(f"Unknown clustering type: {cluster_type}")
result_list.apply_batch_filter(filt)

# Do the actual filtering.
indices_to_keep = filt.keep_indices(result_data)
if type(result_data) is ResultList:
result_data.apply_batch_filter(filt)
elif type(result_data) is Results:
indices_to_keep = filt.keep_indices(result_data)
result_data.filter_by_index(indices_to_keep, filt.get_filter_name())
else:
raise TypeError("Unknown data type for clustering.")
Loading

0 comments on commit c77b95c

Please sign in to comment.