Skip to content

Commit

Permalink
stamp center filter c++ port (#407)
Browse files Browse the repository at this point in the history
* porting over first pass

* merge

* fix bug in array size check

* remove commented out code

* black changes

* fix benchmarking script

* code reorganization
  • Loading branch information
maxwest-uw authored Dec 20, 2023
1 parent 66165d5 commit 5ba2324
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 31 deletions.
21 changes: 15 additions & 6 deletions benchmarks/bench_filter_stamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@

from kbmod.filters.stamp_filters import *
from kbmod.result_list import ResultRow
from kbmod.search import ImageStack, PSF, RawImage, StackSearch, StampParameters, StampType, Trajectory
from kbmod.search import (
ImageStack,
PSF,
RawImage,
StackSearch,
StampParameters,
StampType,
Trajectory,
StampCreator,
)


def setup_coadd_stamp(params):
Expand All @@ -30,7 +39,7 @@ def setup_coadd_stamp(params):
p = PSF(1.0)
psf_dim = p.get_dim()
psf_rad = p.get_radius()
for i in range(psf_dim):
for i in range(1, psf_dim):
for j in range(psf_dim):
stamp.set_pixel(
(params.radius - 1) - psf_rad + i, # x is one pixel off center
Expand All @@ -45,19 +54,19 @@ def run_search_benchmark(params):
stamp = setup_coadd_stamp(params)

# Create an empty search stack.
im_stack = ImageStack([])
search = StackSearch(im_stack)
# im_stack = ImageStack([])
sc = StampCreator()

# Do the timing runs.
tmr = timeit.Timer(stmt="search.filter_stamp(stamp, params)", globals=locals())
tmr = timeit.Timer(stmt="sc.filter_stamp(stamp, params)", globals=locals())
res_time = np.mean(tmr.repeat(repeat=10, number=20))
return res_time


def run_row_benchmark(params, create_filter=""):
stamp = setup_coadd_stamp(params)
row = ResultRow(Trajectory(), 10)
row.stamp = np.array(stamp.get_all_pixels())
row.stamp = stamp.image

filt = eval(create_filter)

Expand Down
28 changes: 4 additions & 24 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def keep_row(self, row: ResultRow):
return False

# Find the peak in the image.
stamp = row.stamp.reshape([self.width, self.width])
stamp = row.stamp
peak_pos = RawImage(stamp).find_peak(True)
return (
abs(peak_pos.i - self.stamp_radius) < self.x_thresh
Expand Down Expand Up @@ -179,7 +179,7 @@ def keep_row(self, row: ResultRow):
return False

# Find the peack in the image.
stamp = row.stamp.reshape([self.width, self.width])
stamp = row.stamp
moments = RawImage(stamp).find_central_moments()
return (
(abs(moments.m01) < self.m01_thresh)
Expand Down Expand Up @@ -235,25 +235,5 @@ def keep_row(self, row: ResultRow):
bool
An indicator of whether to keep the row.
"""
# Filter rows without a valid stamp.
if not self._check_row_valid(row):
return False

# Find the value of the center pixel.
stamp = row.stamp.flatten()
center_index = self.width * self.stamp_radius + self.stamp_radius
center_val = stamp[center_index]

# Find the total flux in the image and check for other local_maxima
flux_sum = 0.0
for i in range(self.width * self.width):
pix_val = stamp[i]
if pix_val != KB_NO_DATA:
flux_sum += pix_val
if i != center_index and self.local_max and (pix_val >= center_val):
return False

# Check the flux percentage.
if flux_sum == 0.0:
return False
return center_val / flux_sum >= self.flux_thresh
image = RawImage(row.stamp)
return image.center_is_local_max(self.flux_thresh, self.local_max)
20 changes: 20 additions & 0 deletions src/kbmod/search/pydocs/raw_image_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,26 @@ static const auto DOC_RawImage_find_central_moments = R"doc(
Image moments.
)doc";

static const auto DOC_RawImage_center_is_local_max = R"doc(
A filter on whether the center of the stamp is a local
maxima and the percentage of the stamp's total flux in this
pixel.
Parameters
----------
local_max : ``bool``
Require the central pixel to be a local maximum.
flux_thresh : ``float``
The fraction of the stamp's total flux that needs to be in
the center pixel [0.0, 1.0].
Returns
-------
keep_row : `bool`
Whether or not the stamp passes the check.
)doc";


static const auto DOC_RawImage_create_stamp = R"doc(
Create an image stamp around a given region.
Expand Down
26 changes: 25 additions & 1 deletion src/kbmod/search/raw_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,28 @@ ImageMoments RawImage::find_central_moments() const {
return res;
}

bool RawImage::center_is_local_max(double flux_thresh, bool local_max) const {
const int num_pixels = width * height;
int c_x = width / 2;
int c_y = height / 2;
int c_ind = c_y * width + c_x;

auto pixels = image.reshaped();
double center_val = pixels[c_ind];

// Find the sum of the zero-shifted (non-NO_DATA) pixels.
double sum = 0.0;
for (int p = 0; p < num_pixels; ++p) {
float pix_val = pixels[p];
if (p != c_ind && local_max && pix_val >= center_val) {
return false;
}
sum += (pix_val != NO_DATA) ? pix_val : 0.0;
}
if (sum == 0.0) return false;
return center_val / sum >= flux_thresh;
}

void RawImage::load_time_from_file(fitsfile* fptr) {
int mjd_status = 0;

Expand Down Expand Up @@ -603,7 +625,9 @@ static void raw_image_bindings(py::module& m) {
.def("compute_bounds", &rie::compute_bounds, pydocs::DOC_RawImage_compute_bounds)
.def("find_peak", &rie::find_peak, pydocs::DOC_RawImage_find_peak)
.def("find_central_moments", &rie::find_central_moments,
pydocs::DOC_RawImage_find_central_moments)
pydocs::DOC_RawImage_find_central_moments)
.def("center_is_local_max", &rie::center_is_local_max,
pydocs::DOC_RawImage_center_is_local_max)
.def("create_stamp", &rie::create_stamp, pydocs::DOC_RawImage_create_stamp)
.def("interpolate", &rie::interpolate, pydocs::DOC_RawImage_interpolate)
.def("interpolated_add", &rie::interpolated_add, pydocs::DOC_RawImage_interpolated_add)
Expand Down
2 changes: 2 additions & 0 deletions src/kbmod/search/raw_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class RawImage {
// Elements with NO_DATA are treated as zero.
ImageMoments find_central_moments() const;

bool center_is_local_max(double flux_thresh, bool local_max) const;

// Load the image data from a specific layer of a FITS file.
// Overwrites the current image data.
void from_fits(const std::string& file_path, int layer_num);
Expand Down

0 comments on commit 5ba2324

Please sign in to comment.