From e5f42e649b39728b9700c794034ba1c5da18174e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 3 Apr 2024 10:08:34 -0400 Subject: [PATCH 1/2] Make the 0.5 offset the default --- src/kbmod/fake_data/fake_data_creator.py | 4 +- src/kbmod/masking.py | 2 +- src/kbmod/search/common.h | 22 ++++++++--- src/kbmod/search/kernels/kernels.cu | 12 ++++-- src/kbmod/search/pydocs/common_docs.h | 30 +++++++++++++-- src/kbmod/search/stack_search.cpp | 5 +-- src/kbmod/search/stamp_creator.cpp | 2 +- tests/test_bilinear_interp.py | 4 +- tests/test_common.py | 27 +++++++++++--- tests/test_fake_data_creator.py | 4 +- tests/test_masking.py | 2 +- tests/test_search.py | 47 ++++++++++++------------ tests/test_trajectory_explorer.py | 4 +- 13 files changed, 108 insertions(+), 57 deletions(-) diff --git a/src/kbmod/fake_data/fake_data_creator.py b/src/kbmod/fake_data/fake_data_creator.py index bc48b5dea..3053a094a 100644 --- a/src/kbmod/fake_data/fake_data_creator.py +++ b/src/kbmod/fake_data/fake_data_creator.py @@ -227,8 +227,8 @@ def insert_object(self, trj): for i in range(self.num_times): dt = self.times[i] - t0 - px = trj.x + dt * trj.vx + 0.5 - py = trj.y + dt * trj.vy + 0.5 + px = trj.get_x_pos(dt) + py = trj.get_y_pos(dt) # Get the image for the timestep, add the object, and # re-set the image. This last step needs to be done diff --git a/src/kbmod/masking.py b/src/kbmod/masking.py index 752978d89..99d6b21f2 100644 --- a/src/kbmod/masking.py +++ b/src/kbmod/masking.py @@ -30,7 +30,7 @@ def mask_trajectory(trj, stack, r): for i in range(stack.img_count()): img = stack.get_single_image(i) time = img.get_obstime() - stack.get_single_image(0).get_obstime() - origin_of_mask = (int(trj.get_x_pos(time) + 0.5), int(trj.get_y_pos(time) + 0.5)) + origin_of_mask = (trj.get_x_index(time), trj.get_y_index(time)) for dy in range(-r, r + 1): for dx in range(-r, r + 1): diff --git a/src/kbmod/search/common.h b/src/kbmod/search/common.h index 3570d7bdc..476e0733d 100644 --- a/src/kbmod/search/common.h +++ b/src/kbmod/search/common.h @@ -53,9 +53,17 @@ struct Trajectory { // Whether the trajectory is valid. Used for on-GPU filtering. bool valid = true; - // Get pixel positions from a zero-shifted time. - float get_x_pos(float time) const { return x + time * vx; } - float get_y_pos(float time) const { return y + time * vy; } + // Get pixel positions from a zero-shifted time. Centered indicates whether + // the prediction starts from the center of the pixel (which it does in the search) + inline float get_x_pos(float time, bool centered = true) const { + return centered ? (x + time * vx + 0.5f) : (x + time * vx); + } + inline float get_y_pos(float time, bool centered = true) const { + return centered ? (y + time * vy + 0.5f) : (y + time * vy); + } + + inline int get_x_index(float time) const { return (int)floor(get_x_pos(time, true)); } + inline int get_y_index(float time) const { return (int)floor(get_y_pos(time, true)); } // A helper function to test if two trajectories are close in pixel space. bool is_close(Trajectory &trj_b, float pos_thresh, float vel_thresh) { @@ -176,8 +184,12 @@ static void trajectory_bindings(py::module &m) { .def_readwrite("y", &tj::y) .def_readwrite("obs_count", &tj::obs_count) .def_readwrite("valid", &tj::valid) - .def("get_x_pos", &tj::get_x_pos, pydocs::DOC_Trajectory_get_x_pos) - .def("get_y_pos", &tj::get_y_pos, pydocs::DOC_Trajectory_get_y_pos) + .def("get_x_pos", &tj::get_x_pos, py::arg("time"), py::arg("centered") = true, + pydocs::DOC_Trajectory_get_x_pos) + .def("get_y_pos", &tj::get_y_pos, py::arg("time"), py::arg("centered") = true, + pydocs::DOC_Trajectory_get_y_pos) + .def("get_x_index", &tj::get_x_index, pydocs::DOC_Trajectory_get_x_index) + .def("get_y_index", &tj::get_y_index, pydocs::DOC_Trajectory_get_y_index) .def("is_close", &tj::is_close, pydocs::DOC_Trajectory_is_close) .def("__repr__", [](const tj &t) { return "Trajectory(" + t.to_string() + ")"; }) .def("__str__", &tj::to_string) diff --git a/src/kbmod/search/kernels/kernels.cu b/src/kbmod/search/kernels/kernels.cu index 0880865aa..497341497 100644 --- a/src/kbmod/search/kernels/kernels.cu +++ b/src/kbmod/search/kernels/kernels.cu @@ -32,6 +32,10 @@ namespace search { __host__ __device__ bool device_pixel_valid(float value) { return isfinite(value); } +__host__ __device__ int predict_index(float pos0, float vel0, float time) { + return (int)(floor(pos0 + vel0 * time + 0.5f)); +} + __host__ __device__ PsiPhi read_encoded_psi_phi(PsiPhiArrayMeta ¶ms, void *psi_phi_vect, int time, int row, int col) { // Bounds checking. @@ -148,8 +152,8 @@ extern "C" __device__ __host__ void evaluateTrajectory(PsiPhiArrayMeta psi_phi_m for (int i = 0; i < psi_phi_meta.num_times; ++i) { // Predict the trajectory's position. float curr_time = image_times[i]; - int current_x = candidate->x + int(candidate->vx * curr_time + 0.5); - int current_y = candidate->y + int(candidate->vy * curr_time + 0.5); + int current_x = predict_index(candidate->x, candidate->vx, curr_time); + int current_y = predict_index(candidate->y, candidate->vy, curr_time); // Get the Psi and Phi pixel values. Skip invalid values, such as those marked NaN or NO_DATA. PsiPhi pixel_vals = read_encoded_psi_phi(psi_phi_meta, psi_phi_vect, i, current_y, current_x); @@ -361,8 +365,8 @@ __global__ void deviceGetCoaddStamp(int num_images, int width, int height, float // Predict the trajectory's position. float curr_time = image_times[t]; - int current_x = int(trj.x + trj.vx * curr_time); - int current_y = int(trj.y + trj.vy * curr_time); + int current_x = predict_index(trj.x, trj.vx, curr_time); + int current_y = predict_index(trj.y, trj.vy, curr_time); // Get the stamp and add it to the list of values. int img_x = current_x - params.radius + stamp_x; diff --git a/src/kbmod/search/pydocs/common_docs.h b/src/kbmod/search/pydocs/common_docs.h index 9f2e00157..ed512bede 100644 --- a/src/kbmod/search/pydocs/common_docs.h +++ b/src/kbmod/search/pydocs/common_docs.h @@ -34,6 +34,9 @@ static const auto DOC_Trajectory_get_x_pos = R"doc( ---------- time : `float` A zero shifted time. + centered : `bool` + Shift the prediction to be at the center of the pixel + (e.g. xp = x + vx * time + 0.5f). Default = True. Returns ------- @@ -48,6 +51,9 @@ static const auto DOC_Trajectory_get_y_pos = R"doc( ---------- time : `float` A zero shifted time. + centered : `bool` + Shift the prediction to be at the center of the pixel + (e.g. xp = x + vx * time + 0.5f). Default = True. Returns ------- @@ -55,8 +61,24 @@ static const auto DOC_Trajectory_get_y_pos = R"doc( The predicted y position (in pixels). )doc"; -static const auto DOC_Trajectory_get_pos = R"doc( - Returns the predicted (x, y) position of the trajectory. +static const auto DOC_Trajectory_get_x_index = R"doc( + Returns the predicted x position of the trajectory as an integer + (column) index. + + Parameters + ---------- + time : `float` + A zero shifted time. + + Returns + ------- + `int` + The predicted column index. + )doc"; + +static const auto DOC_Trajectory_get_y_index = R"doc( + Returns the predicted x position of the trajectory as an integer + (row) index. Parameters ---------- @@ -65,8 +87,8 @@ static const auto DOC_Trajectory_get_pos = R"doc( Returns ------- - `PixelPos` - The predicted (x, y) position (in pixels). + `int` + The predicted row index. )doc"; static const auto DOC_Trajectory_is_close = R"doc( diff --git a/src/kbmod/search/stack_search.cpp b/src/kbmod/search/stack_search.cpp index 086015d42..15fdefdb4 100644 --- a/src/kbmod/search/stack_search.cpp +++ b/src/kbmod/search/stack_search.cpp @@ -222,10 +222,7 @@ std::vector StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool e float time = psi_phi_array.read_time(i); // Query the center of the predicted location's pixel. - Point pred_pt = {trj.get_x_pos(time) + 0.5f, trj.get_y_pos(time) + 0.5f}; - Index pred_idx = pred_pt.to_index(); - PsiPhi psi_phi_val = psi_phi_array.read_psi_phi(i, pred_idx.i, pred_idx.j); - + PsiPhi psi_phi_val = psi_phi_array.read_psi_phi(i, trj.get_y_index(time), trj.get_x_index(time)); float value = (extract_psi) ? psi_phi_val.psi : psi_phi_val.phi; if (pixel_value_valid(value)) { result[i] = value; diff --git a/src/kbmod/search/stamp_creator.cpp b/src/kbmod/search/stamp_creator.cpp index 8a33b16e8..90b41b8a9 100644 --- a/src/kbmod/search/stamp_creator.cpp +++ b/src/kbmod/search/stamp_creator.cpp @@ -23,7 +23,7 @@ std::vector StampCreator::create_stamps(ImageStack& stack, const Traje if (use_all_stamps || use_index[i]) { // Calculate the trajectory position. float time = stack.get_zeroed_time(i); - Point pos{trj.x + time * trj.vx, trj.y + time * trj.vy}; + Point pos{trj.get_x_pos(time), trj.get_y_pos(time)}; RawImage& img = stack.get_single_image(i).get_science(); stamps.push_back(img.create_stamp(pos, radius, keep_no_data)); } diff --git a/tests/test_bilinear_interp.py b/tests/test_bilinear_interp.py index be18e9965..8a0788159 100644 --- a/tests/test_bilinear_interp.py +++ b/tests/test_bilinear_interp.py @@ -4,16 +4,18 @@ from kbmod.fake_data.fake_data_creator import add_fake_object, make_fake_layered_image import kbmod.search as kb +from kbmod.trajectory_utils import make_trajectory class test_bilinear_interp(unittest.TestCase): def setUp(self): self.im_count = 5 + self.trj = make_trajectory(2, 2, 0.5, 0.5) p = kb.PSF(0.05) self.images = [] for c in range(self.im_count): im = make_fake_layered_image(10, 10, 0.0, 1.0, c, p) - add_fake_object(im, 2 + c * 0.5 + 0.5, 2 + c * 0.5 + 0.5, 1, p) + add_fake_object(im, self.trj.get_x_pos(c), self.trj.get_y_pos(c), 1, p) self.images.append(im) def test_pixels(self): diff --git a/tests/test_common.py b/tests/test_common.py index 5e0665529..5011f614a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -16,12 +16,27 @@ def test_pixel_value_valid(self): def test_trajectory_predict(self): trj = make_trajectory(x=5, y=10, vx=2.0, vy=-1.0) - self.assertEqual(trj.get_x_pos(0.0), 5.0) - self.assertEqual(trj.get_y_pos(0.0), 10.0) - self.assertEqual(trj.get_x_pos(1.0), 7.0) - self.assertEqual(trj.get_y_pos(1.0), 9.0) - self.assertEqual(trj.get_x_pos(2.0), 9.0) - self.assertEqual(trj.get_y_pos(2.0), 8.0) + # With centered=false the trajectories start at the pixel edge. + self.assertEqual(trj.get_x_pos(0.0, False), 5.0) + self.assertEqual(trj.get_y_pos(0.0, False), 10.0) + self.assertEqual(trj.get_x_pos(1.0, False), 7.0) + self.assertEqual(trj.get_y_pos(1.0, False), 9.0) + self.assertEqual(trj.get_x_pos(2.0, False), 9.0) + self.assertEqual(trj.get_y_pos(2.0, False), 8.0) + + # Centering moves things by half a pixel. + self.assertEqual(trj.get_x_pos(0.0), 5.5) + self.assertEqual(trj.get_y_pos(0.0), 10.5) + self.assertEqual(trj.get_x_pos(1.0), 7.5) + self.assertEqual(trj.get_y_pos(1.0), 9.5) + self.assertEqual(trj.get_x_pos(2.0), 9.5) + self.assertEqual(trj.get_y_pos(2.0), 8.5) + + # Predicting the index gives a floored integer of the centered prediction. + self.assertEqual(trj.get_x_index(0.0), 5) + self.assertEqual(trj.get_y_index(0.0), 10) + self.assertEqual(trj.get_x_index(1.0), 7) + self.assertEqual(trj.get_y_index(1.0), 9) def test_trajectory_is_close(self): trj = make_trajectory(x=5, y=10, vx=2.0, vy=-1.0) diff --git a/tests/test_fake_data_creator.py b/tests/test_fake_data_creator.py index b58c20ea2..b6fa188a1 100644 --- a/tests/test_fake_data_creator.py +++ b/tests/test_fake_data_creator.py @@ -56,8 +56,8 @@ def test_insert_object(self): t0 = ds.stack.get_single_image(0).get_obstime() for i in range(ds.stack.img_count()): dt = ds.stack.get_single_image(i).get_obstime() - t0 - px = int(trj.x + dt * trj.vx + 0.5) - py = int(trj.y + dt * trj.vy + 0.5) + px = trj.get_x_index(dt) + py = trj.get_y_index(dt) # Check the trajectory stays in the image. self.assertGreaterEqual(px, 0) diff --git a/tests/test_masking.py b/tests/test_masking.py index c2c66374d..bdb2000b2 100644 --- a/tests/test_masking.py +++ b/tests/test_masking.py @@ -38,7 +38,7 @@ def test_apply_trajectory_mask(self): for i in range(self.img_count): time = self.stack.get_single_image(i).get_obstime() - self.stack.get_single_image(0).get_obstime() - origin_of_mask = (trj.get_x_pos(time), trj.get_y_pos(time)) + origin_of_mask = (trj.get_x_index(time), trj.get_y_index(time)) msk = masked_stack.get_single_image(i).get_mask() for x in range(self.dim_x): diff --git a/tests/test_search.py b/tests/test_search.py index 36cbeb94b..fb0d85c25 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -13,7 +13,7 @@ class test_search(unittest.TestCase): def setUp(self): # test pass thresholds - self.pixel_error = 0 + self.pixel_error = 1 self.velocity_error = 0.1 self.flux_error = 0.15 @@ -64,8 +64,8 @@ def setUp(self): ) add_fake_object( im, - self.start_x + time * self.vxel + 0.5, - self.start_y + time * self.vyel + 0.5, + self.trj.get_x_pos(time), + self.trj.get_y_pos(time), self.object_flux, self.p, ) @@ -88,8 +88,8 @@ def setUp(self): self.params.center_thresh = 0.03 self.params.peak_offset_x = 1.5 self.params.peak_offset_y = 1.5 - self.params.m01_limit = 0.6 - self.params.m10_limit = 0.6 + self.params.m01_limit = 1.0 + self.params.m10_limit = 1.0 self.params.m11_limit = 2.0 self.params.m02_limit = 35.5 self.params.m20_limit = 35.5 @@ -277,8 +277,8 @@ def test_results_off_chip(self): ) add_fake_object( im, - trj.x + time * trj.vx + 0.5, - trj.y + time * trj.vy + 0.5, + trj.get_x_index(time), + trj.get_y_index(time), self.object_flux, self.p, ) @@ -310,9 +310,8 @@ def test_sci_viz_stamps(self): self.assertEqual(sci_stamps[i].height, 5) # Compute the interpolated pixel value at the projected location. - t = times[i] - x = int(self.trj.x + self.trj.vx * t) - y = int(self.trj.y + self.trj.vy * t) + x = self.trj.get_x_index(times[i]) + y = self.trj.get_y_index(times[i]) pixVal = self.imlist[i].get_science().get_pixel(y, x) if not pixel_value_valid(pixVal): pivVal = 0.0 @@ -332,8 +331,8 @@ def test_stacked_sci(self): sum_middle = 0.0 for i in range(self.img_count): t = times[i] - x = int(self.trj.x + self.trj.vx * t) - y = int(self.trj.y + self.trj.vy * t) + x = self.trj.get_x_index(t) + y = self.trj.get_y_index(t) pixVal = self.imlist[i].get_science().get_pixel(y, x) if not pixel_value_valid(pixVal): pivVal = 0.0 @@ -364,8 +363,8 @@ def test_median_stamps_trj(self): pix_values1 = [] for i in range(self.img_count): t = times[i] - x = int(self.trj.x + self.trj.vx * t) - y = int(self.trj.y + self.trj.vy * t) + x = self.trj.get_x_index(t) + y = self.trj.get_y_index(t) pixVal = self.imlist[i].get_science().get_pixel(y, x) if pixel_value_valid(pixVal) and goodIdx[0][i] == 1: pix_values0.append(pixVal) @@ -421,8 +420,8 @@ def test_mean_stamps_trj(self): pix_count1 = 0.0 for i in range(self.img_count): t = times[i] - x = int(self.trj.x + self.trj.vx * t) - y = int(self.trj.y + self.trj.vy * t) + x = self.trj.get_x_index(t) + y = self.trj.get_y_index(t) pixVal = self.imlist[i].get_science().get_pixel(y, x) if pixel_value_valid(pixVal) and goodIdx[0][i] == 1: pix_sum0 += pixVal @@ -680,8 +679,8 @@ def test_coadd_cpu(self): pix_vals = [] for i in range(self.img_count): t = times[i] - x = int(self.trj.x + self.trj.vx * t) + x_offset - y = int(self.trj.y + self.trj.vy * t) + y_offset + x = self.trj.get_x_index(t) + x_offset + y = self.trj.get_y_index(t) + y_offset pixVal = self.imlist[i].get_science().get_pixel(y, x) if pixel_value_valid(pixVal): pix_sum += pixVal @@ -737,8 +736,8 @@ def test_coadd_gpu(self): pix_vals = [] for i in range(self.img_count): t = times[i] - x = int(self.trj.x + self.trj.vx * t) + x_offset - y = int(self.trj.y + self.trj.vy * t) + y_offset + x = self.trj.get_x_index(t) + x_offset + y = self.trj.get_y_index(t) + y_offset pixVal = self.imlist[i].get_science().get_pixel(y, x) if pixel_value_valid(pixVal): pix_sum += pixVal @@ -786,8 +785,8 @@ def test_coadd_cpu_use_inds(self): count_1 = 0.0 for i in range(self.img_count): t = times[i] - x = int(self.trj.x + self.trj.vx * t) + x_offset - y = int(self.trj.y + self.trj.vy * t) + y_offset + x = self.trj.get_x_index(t) + x_offset + y = self.trj.get_y_index(t) + y_offset pixVal = self.imlist[i].get_science().get_pixel(y, x) if pixel_value_valid(pixVal) and inds[0][i] > 0: @@ -837,8 +836,8 @@ def test_coadd_gpu_use_inds(self): count_1 = 0.0 for i in range(self.img_count): t = times[i] - x = int(self.trj.x + self.trj.vx * t) + x_offset - y = int(self.trj.y + self.trj.vy * t) + y_offset + x = self.trj.get_x_index(t) + x_offset + y = self.trj.get_y_index(t) + y_offset pixVal = self.imlist[i].get_science().get_pixel(y, x) if pixel_value_valid(pixVal) and inds[0][i] > 0: diff --git a/tests/test_trajectory_explorer.py b/tests/test_trajectory_explorer.py index 44f75066b..f3a754765 100644 --- a/tests/test_trajectory_explorer.py +++ b/tests/test_trajectory_explorer.py @@ -35,8 +35,8 @@ def setUp(self): fake_ds.insert_object(self.trj) # Remove at least observation from the trajectory. - pred_x = int(self.x0 + fake_times[10] * self.vx + 0.5) - pred_y = int(self.y0 + fake_times[10] * self.vy + 0.5) + pred_x = self.trj.get_x_index(fake_times[10]) + pred_y = self.trj.get_y_index(fake_times[10]) sci_t10 = fake_ds.stack.get_single_image(10).get_science() for dy in [-1, 0, 1]: for dx in [-1, 0, 1]: From eb22a90ea41bbed8245b522209e3943697771aad Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 18 Apr 2024 14:28:15 -0400 Subject: [PATCH 2/2] Remove the BatchFilter abstract data type --- src/kbmod/filters/base_filter.py | 40 ------------------------- src/kbmod/filters/clustering_filters.py | 9 +++--- src/kbmod/result_list.py | 19 ------------ 3 files changed, 4 insertions(+), 64 deletions(-) diff --git a/src/kbmod/filters/base_filter.py b/src/kbmod/filters/base_filter.py index ff946b914..db62c32c7 100644 --- a/src/kbmod/filters/base_filter.py +++ b/src/kbmod/filters/base_filter.py @@ -37,43 +37,3 @@ def keep_row(self, row: ResultRow): An indicator of whether to keep the row. """ pass - - -class BatchFilter(abc.ABC): - """The base class for derived filters on the ResultList - that operate on the results in a single batch. - - Batching should be used when the user needs greater control - over how the filter is run, such as using aggregate statistics - from all candidates or running batch computations on GPUs. - """ - - def __init__(self, *args, **kwargs): - pass - - @abc.abstractmethod - def get_filter_name(self): - """Get the name of the filter. - - Returns - ------- - str - The filter name. - """ - pass - - @abc.abstractmethod - def keep_indices(self, results: ResultList): - """Determine which of the ResultList's indices to keep. - - Parameters - ---------- - results: ResultList - The set of results to filter. - - Returns - ------- - list - A list of indices (int) indicating which rows to keep. - """ - pass diff --git a/src/kbmod/filters/clustering_filters.py b/src/kbmod/filters/clustering_filters.py index d271868ea..6d93f142f 100644 --- a/src/kbmod/filters/clustering_filters.py +++ b/src/kbmod/filters/clustering_filters.py @@ -1,11 +1,10 @@ import numpy as np from sklearn.cluster import DBSCAN -from kbmod.filters.base_filter import BatchFilter from kbmod.result_list import ResultList, ResultRow -class DBSCANFilter(BatchFilter): +class DBSCANFilter: """Cluster the candidates using DBSCAN and only keep a single representative trajectory from each cluster.""" @@ -17,8 +16,6 @@ def __init__(self, eps, *args, **kwargs): eps : `float` The clustering threshold. """ - super().__init__(*args, **kwargs) - self.eps = eps self.cluster_type = "" self.cluster_args = dict(eps=self.eps, min_samples=1, n_jobs=-1) @@ -272,4 +269,6 @@ def apply_clustering(result_list, cluster_params): filt = ClusterMidPosFilter(**cluster_params) else: raise ValueError(f"Unknown clustering type: {cluster_type}") - result_list.apply_batch_filter(filt) + + indices_to_keep = filt.keep_indices(result_list) + result_list.filter_results(indices_to_keep, filt.get_filter_name()) diff --git a/src/kbmod/result_list.py b/src/kbmod/result_list.py index 5b3e158c5..fdd3fa3b3 100644 --- a/src/kbmod/result_list.py +++ b/src/kbmod/result_list.py @@ -800,25 +800,6 @@ def apply_filter(self, filter_obj, num_threads=1): return self - def apply_batch_filter(self, filter_obj): - """Apply the given batch filter object to the ResultList. - - Modifies the ResultList in place. - - Parameters - ---------- - filter_obj : BatchFilter - The filtering object to use. - - Returns - ------- - self : ResultList - Returns a reference to itself to allow chaining. - """ - indices_to_keep = filter_obj.keep_indices(self) - self.filter_results(indices_to_keep, filter_obj.get_filter_name()) - return self - def get_filtered(self, label=None): """Get the results filtered at a given stage or all stages.