Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stamp center filter c++ port #407

Merged
merged 8 commits into from
Dec 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions benchmarks/bench_filter_stamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@

from kbmod.filters.stamp_filters import *
from kbmod.result_list import ResultRow
from kbmod.search import ImageStack, PSF, RawImage, StackSearch, StampParameters, StampType, Trajectory
from kbmod.search import (
ImageStack,
PSF,
RawImage,
StackSearch,
StampParameters,
StampType,
Trajectory,
StampCreator,
)


def setup_coadd_stamp(params):
Expand All @@ -30,7 +39,7 @@ def setup_coadd_stamp(params):
p = PSF(1.0)
psf_dim = p.get_dim()
psf_rad = p.get_radius()
for i in range(psf_dim):
for i in range(1, psf_dim):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is changing. The calculation below moves the center pixel one off, but shouldn't the convolution with the PSF still cover the full PSF?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before the range of the outer for loop would include -1 when i is 0 (due to the one pixel offset), giving an out of bound error. I just changed the range to only include dimensions that fall within the PSF.

for j in range(psf_dim):
stamp.set_pixel(
(params.radius - 1) - psf_rad + i, # x is one pixel off center
Expand All @@ -45,19 +54,19 @@ def run_search_benchmark(params):
stamp = setup_coadd_stamp(params)

# Create an empty search stack.
im_stack = ImageStack([])
search = StackSearch(im_stack)
# im_stack = ImageStack([])
sc = StampCreator()

# Do the timing runs.
tmr = timeit.Timer(stmt="search.filter_stamp(stamp, params)", globals=locals())
tmr = timeit.Timer(stmt="sc.filter_stamp(stamp, params)", globals=locals())
res_time = np.mean(tmr.repeat(repeat=10, number=20))
return res_time


def run_row_benchmark(params, create_filter=""):
stamp = setup_coadd_stamp(params)
row = ResultRow(Trajectory(), 10)
row.stamp = np.array(stamp.get_all_pixels())
row.stamp = stamp.image

filt = eval(create_filter)

Expand Down
28 changes: 4 additions & 24 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def keep_row(self, row: ResultRow):
return False

# Find the peak in the image.
stamp = row.stamp.reshape([self.width, self.width])
stamp = row.stamp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change because stamp is already in the correct shape? Or is there some other reason?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep!

peak_pos = RawImage(stamp).find_peak(True)
return (
abs(peak_pos.i - self.stamp_radius) < self.x_thresh
Expand Down Expand Up @@ -179,7 +179,7 @@ def keep_row(self, row: ResultRow):
return False

# Find the peack in the image.
stamp = row.stamp.reshape([self.width, self.width])
stamp = row.stamp
moments = RawImage(stamp).find_central_moments()
return (
(abs(moments.m01) < self.m01_thresh)
Expand Down Expand Up @@ -235,25 +235,5 @@ def keep_row(self, row: ResultRow):
bool
An indicator of whether to keep the row.
"""
# Filter rows without a valid stamp.
if not self._check_row_valid(row):
return False

# Find the value of the center pixel.
stamp = row.stamp.flatten()
center_index = self.width * self.stamp_radius + self.stamp_radius
center_val = stamp[center_index]

# Find the total flux in the image and check for other local_maxima
flux_sum = 0.0
for i in range(self.width * self.width):
pix_val = stamp[i]
if pix_val != KB_NO_DATA:
flux_sum += pix_val
if i != center_index and self.local_max and (pix_val >= center_val):
return False

# Check the flux percentage.
if flux_sum == 0.0:
return False
return center_val / flux_sum >= self.flux_thresh
image = RawImage(row.stamp)
return image.center_is_local_max(self.flux_thresh, self.local_max)
20 changes: 20 additions & 0 deletions src/kbmod/search/pydocs/raw_image_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,26 @@ static const auto DOC_RawImage_find_central_moments = R"doc(
Image moments.
)doc";

static const auto DOC_RawImage_center_is_local_max = R"doc(
A filter on whether the center of the stamp is a local
maxima and the percentage of the stamp's total flux in this
pixel.

Parameters
----------
local_max : ``bool``
Require the central pixel to be a local maximum.
flux_thresh : ``float``
The fraction of the stamp's total flux that needs to be in
the center pixel [0.0, 1.0].

Returns
-------
keep_row : `bool`
Whether or not the stamp passes the check.
)doc";


static const auto DOC_RawImage_create_stamp = R"doc(
Create an image stamp around a given region.

Expand Down
26 changes: 25 additions & 1 deletion src/kbmod/search/raw_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,28 @@ ImageMoments RawImage::find_central_moments() const {
return res;
}

bool RawImage::center_is_local_max(double flux_thresh, bool local_max) const {
const int num_pixels = width * height;
int c_x = width / 2;
int c_y = height / 2;
int c_ind = c_y * width + c_x;

auto pixels = image.reshaped();
double center_val = pixels[c_ind];

// Find the sum of the zero-shifted (non-NO_DATA) pixels.
double sum = 0.0;
for (int p = 0; p < num_pixels; ++p) {
float pix_val = pixels[p];
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
if (p != c_ind && local_max && pix_val >= center_val) {
return false;
}
sum += (pix_val != NO_DATA) ? pix_val : 0.0;
}
if (sum == 0.0) return false;
return center_val / sum >= flux_thresh;
}

void RawImage::load_time_from_file(fitsfile* fptr) {
int mjd_status = 0;

Expand Down Expand Up @@ -603,7 +625,9 @@ static void raw_image_bindings(py::module& m) {
.def("compute_bounds", &rie::compute_bounds, pydocs::DOC_RawImage_compute_bounds)
.def("find_peak", &rie::find_peak, pydocs::DOC_RawImage_find_peak)
.def("find_central_moments", &rie::find_central_moments,
pydocs::DOC_RawImage_find_central_moments)
pydocs::DOC_RawImage_find_central_moments)
.def("center_is_local_max", &rie::center_is_local_max,
pydocs::DOC_RawImage_center_is_local_max)
.def("create_stamp", &rie::create_stamp, pydocs::DOC_RawImage_create_stamp)
.def("interpolate", &rie::interpolate, pydocs::DOC_RawImage_interpolate)
.def("interpolated_add", &rie::interpolated_add, pydocs::DOC_RawImage_interpolated_add)
Expand Down
2 changes: 2 additions & 0 deletions src/kbmod/search/raw_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class RawImage {
// Elements with NO_DATA are treated as zero.
ImageMoments find_central_moments() const;

bool center_is_local_max(double flux_thresh, bool local_max) const;

// Load the image data from a specific layer of a FITS file.
// Overwrites the current image data.
void from_fits(const std::string& file_path, int layer_num);
Expand Down