Skip to content

Commit

Permalink
Merge pull request #493 from dirac-institute/gpu_nan
Browse files Browse the repository at this point in the history
Add support for NaNs in kernel functions
  • Loading branch information
jeremykubica authored Mar 1, 2024
2 parents 44c94aa + be591f3 commit 6581790
Show file tree
Hide file tree
Showing 13 changed files with 182 additions and 68 deletions.
2 changes: 2 additions & 0 deletions src/kbmod/search/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ PYBIND11_MODULE(search, m) {
search::stamp_parameters_bindings(m);
search::psi_phi_array_binding(m);
search::debug_timer_binding(m);
// Helper function from common.h
m.def("pixel_value_valid", &search::pixel_value_valid);
// Functions from raw_image.cpp
m.def("create_median_image", &search::create_median_image);
m.def("create_summed_image", &search::create_summed_image);
Expand Down
5 changes: 5 additions & 0 deletions src/kbmod/search/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ constexpr float NO_DATA = -9999.0;

enum StampType { STAMP_SUM = 0, STAMP_MEAN, STAMP_MEDIAN };

// A helper function to check that a pixel value is valid.
inline bool pixel_value_valid(float value) {
return ((value != NO_DATA) && !std::isnan(value));
}

/*
* Data structure to represent an objects trajectory
* through a stack of images
Expand Down
11 changes: 7 additions & 4 deletions src/kbmod/search/image_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@

namespace search {

// This is defined in kernels.cu.
__host__ __device__ bool device_pixel_valid(float value);

/*
* Device kernel that convolves the provided image with the psf
*/
Expand All @@ -30,13 +33,13 @@ __global__ void convolve_psf(int width, int height, float *source_img, float *re
float sum = 0.0;
float psf_portion = 0.0;
float center = source_img[y * width + x];
if (center != NO_DATA) {
if (device_pixel_valid(center)) {
for (int j = -psf_radius; j <= psf_radius; j++) {
// #pragma unroll
for (int i = -psf_radius; i <= psf_radius; i++) {
if ((x + i >= 0) && (x + i < width) && (y + j >= 0) && (y + j < height)) {
float current_pix = source_img[(y + j) * width + (x + i)];
if (current_pix != NO_DATA) {
if (device_pixel_valid(current_pix)) {
float current_psf = psf[(j + psf_radius) * psf_dim + (i + psf_radius)];
psf_portion += current_psf;
sum += current_pix * current_psf;
Expand All @@ -47,8 +50,8 @@ __global__ void convolve_psf(int width, int height, float *source_img, float *re

result_img[y * width + x] = (sum * psf_sum) / psf_portion;
} else {
// Leave masked pixel alone (these could be replaced here with zero)
result_img[y * width + x] = NO_DATA; // 0.0
// Leave masked and NaN pixels alone (these could be replaced here with zero)
result_img[y * width + x] = center; // 0.0
}
}

Expand Down
8 changes: 5 additions & 3 deletions src/kbmod/search/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

namespace search {

__host__ __device__ bool device_pixel_valid(float value) { return ((value != NO_DATA) && !isnan(value)); }

extern "C" void device_allocate_psi_phi_arrays(PsiPhiArray *data) {
if (data == nullptr) {
throw std::runtime_error("No data given.");
Expand Down Expand Up @@ -179,9 +181,9 @@ extern "C" __device__ __host__ void evaluateTrajectory(PsiPhiArrayMeta psi_phi_m
int current_x = candidate->x + int(candidate->vx * curr_time + 0.5);
int current_y = candidate->y + int(candidate->vy * curr_time + 0.5);

// Get the Psi and Phi pixel values.
// Get the Psi and Phi pixel values. Skip invalid values, such as those marked NaN or NO_DATA.
PsiPhi pixel_vals = read_encoded_psi_phi(psi_phi_meta, psi_phi_vect, i, current_y, current_x);
if (pixel_vals.psi != NO_DATA && pixel_vals.phi != NO_DATA) {
if (device_pixel_valid(pixel_vals.psi) && device_pixel_valid(pixel_vals.phi)) {
psi_sum += pixel_vals.psi;
phi_sum += pixel_vals.phi;
psi_array[num_seen] = pixel_vals.psi;
Expand Down Expand Up @@ -414,7 +416,7 @@ __global__ void deviceGetCoaddStamp(int num_images, int width, int height, float
int img_y = current_y - params.radius + stamp_y;
if ((img_x >= 0) && (img_x < width) && (img_y >= 0) && (img_y < height)) {
int pixel_index = width * height * t + img_y * width + img_x;
if (image_vect[pixel_index] != NO_DATA) {
if (device_pixel_valid(image_vect[pixel_index])) {
values[num_values] = image_vect[pixel_index];
++num_values;
}
Expand Down
8 changes: 5 additions & 3 deletions src/kbmod/search/layered_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,10 @@ void LayeredImage::subtract_template(RawImage& sub_template) {
float* sci_pixels = science.data();
float* tem_pixels = sub_template.data();
for (unsigned i = 0; i < num_pixels; ++i) {
if ((sci_pixels[i] != NO_DATA) && (tem_pixels[i] != NO_DATA)) {
if (pixel_value_valid(sci_pixels[i]) && pixel_value_valid(tem_pixels[i])) {
sci_pixels[i] -= tem_pixels[i];
} else {
sci_pixels[i] = NO_DATA;
}
}
}
Expand Down Expand Up @@ -154,7 +156,7 @@ RawImage LayeredImage::generate_psi_image() {
const int num_pixels = get_npixels();
for (int p = 0; p < num_pixels; ++p) {
float var_pix = var_array[p];
if (var_pix != NO_DATA && var_pix != 0.0 && sci_array[p] != NO_DATA) {
if (pixel_value_valid(var_pix) && var_pix != 0.0 && pixel_value_valid(sci_array[p])) {
result_arr[p] = sci_array[p] / var_pix;
} else {
result_arr[p] = NO_DATA;
Expand All @@ -176,7 +178,7 @@ RawImage LayeredImage::generate_phi_image() {
const int num_pixels = get_npixels();
for (int p = 0; p < num_pixels; ++p) {
float var_pix = var_array[p];
if (var_pix != NO_DATA && var_pix != 0.0) {
if (pixel_value_valid(var_pix) && var_pix != 0.0) {
result_arr[p] = 1.0 / var_pix;
} else {
result_arr[p] = NO_DATA;
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/search/psi_phi_array_ds.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct PsiPhi {

// Helper utility functions.
inline float encode_uint_scalar(float value, float min_val, float max_val, float scale) {
return (value == NO_DATA) ? 0 : (std::max(std::min(value, max_val), min_val) - min_val) / scale + 1.0;
return !pixel_value_valid(value) ? 0 : (std::max(std::min(value, max_val), min_val) - min_val) / scale + 1.0;
}

inline float decode_uint_scalar(float value, float min_val, float scale) {
Expand Down
50 changes: 26 additions & 24 deletions src/kbmod/search/raw_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ std::array<float, 2> RawImage::compute_bounds() const {
float max_val = -FLT_MAX;

for (auto elem : image.reshaped())
if (elem != NO_DATA) {
if (pixel_value_valid(elem)) {
min_val = std::min(min_val, elem);
max_val = std::max(max_val, elem);
}
Expand All @@ -173,9 +173,9 @@ void RawImage::convolve_cpu(PSF& psf) {

for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
// Pixels with NO_DATA remain NO_DATA.
if (image(y, x) == NO_DATA) {
result(y, x) = NO_DATA;
// Pixels with invalid data (e.g. NO_DATA or NaN) do not change.
if (!pixel_value_valid(image(y, x))) {
result(y, x) = image(y, x);
continue;
}

Expand All @@ -186,7 +186,7 @@ void RawImage::convolve_cpu(PSF& psf) {
if ((x + i >= 0) && (x + i < width) && (y + j >= 0) && (y + j < height)) {
float current_pixel = image(y + j, x + i);
// note that convention for index access is flipped for PSF
if (current_pixel != NO_DATA) {
if (pixel_value_valid(current_pixel)) {
float current_psf = psf.get_value(i + psf_rad, j + psf_rad);
psf_portion += current_psf;
sum += current_pixel * current_psf;
Expand Down Expand Up @@ -240,22 +240,23 @@ Index RawImage::find_peak(bool furthest_from_center) const {

// Initialize the variables for tracking the peak's location.
Index result = {0, 0};
float max_val = image(0, 0);
float max_val = NO_DATA;
float dist2 = c_x * c_x + c_y * c_y;

// Search each pixel for the peak.
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
if (image(y, x) > max_val) {
max_val = image(y, x);
float pix_val = image(y, x);
if (pixel_value_valid(pix_val) && (pix_val > max_val)) {
max_val = pix_val;
result.i = y;
result.j = x;
dist2 = (c_x - x) * (c_x - x) + (c_y - y) * (c_y - y);
} else if (image(y, x) == max_val) {
} else if (pixel_value_valid(pix_val) && (pix_val == max_val)) {
int new_dist2 = (c_x - x) * (c_x - x) + (c_y - y) * (c_y - y);
if ((furthest_from_center && (new_dist2 > dist2)) ||
(!furthest_from_center && (new_dist2 < dist2))) {
max_val = image(y, x);
max_val = pix_val;
result.i = y;
result.j = x;
dist2 = new_dist2;
Expand All @@ -282,21 +283,21 @@ ImageMoments RawImage::find_central_moments() const {
// Find the min (non-NO_DATA) value to subtract off.
float min_val = FLT_MAX;
for (int p = 0; p < num_pixels; ++p) {
min_val = ((pixels[p] != NO_DATA) && (pixels[p] < min_val)) ? pixels[p] : min_val;
min_val = (pixel_value_valid(pixels[p]) && (pixels[p] < min_val)) ? pixels[p] : min_val;
}

// Find the sum of the zero-shifted (non-NO_DATA) pixels.
double sum = 0.0;
for (int p = 0; p < num_pixels; ++p) {
sum += (pixels[p] != NO_DATA) ? (pixels[p] - min_val) : 0.0;
sum += pixel_value_valid(pixels[p]) ? (pixels[p] - min_val) : 0.0;
}
if (sum == 0.0) return res;

// Compute the rest of the moments.
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
int ind = y * width + x;
float pix_val = (pixels[ind] != NO_DATA) ? (pixels[ind] - min_val) / sum : 0.0;
float pix_val = pixel_value_valid(pixels[ind]) ? (pixels[ind] - min_val) / sum : 0.0;

res.m00 += pix_val;
res.m10 += (x - c_x) * pix_val;
Expand Down Expand Up @@ -326,7 +327,7 @@ bool RawImage::center_is_local_max(double flux_thresh, bool local_max) const {
if (p != c_ind && local_max && pix_val >= center_val) {
return false;
}
sum += (pix_val != NO_DATA) ? pix_val : 0.0;
sum += pixel_value_valid(pixels[p]) ? pix_val : 0.0;
}
if (sum == 0.0) return false;
return center_val / sum >= flux_thresh;
Expand Down Expand Up @@ -354,11 +355,8 @@ RawImage create_median_image(const std::vector<RawImage>& images) {
int num_unmasked = 0;
for (auto& img : images) {
// Only used the unmasked array.
// we have a get_pixel and pixel_has_data, but we don't use them
// here in the original code, so I go to get_image()() too...
if ((img.get_image()(y, x) != NO_DATA) &&
(!std::isnan(img.get_image()(y, x)))) { // why are we suddenly checking nans?!
pix_array[num_unmasked] = img.get_image()(y, x);
if (img.pixel_has_data({y, x})) {
pix_array[num_unmasked] = img.get_pixel({y, x});
num_unmasked += 1;
}
}
Expand Down Expand Up @@ -395,7 +393,13 @@ RawImage create_summed_image(const std::vector<RawImage>& images) {

Image result = Image::Zero(height, width);
for (auto& img : images) {
result += (img.get_image().array() == NO_DATA).select(0, img.get_image());
for (int y = 0; y < height; ++y) {
for (int x = 0; x < width; ++x) {
if (img.pixel_has_data({y, x})) {
result(y, x) += img.get_pixel({y, x});
}
}
}
}
return RawImage(result);
}
Expand All @@ -416,11 +420,9 @@ RawImage create_mean_image(const std::vector<RawImage>& images) {
float sum = 0.0;
float count = 0.0;
for (auto& img : images) {
// we have a get_pixel and pixel_has_data, but we don't use them
// here in the original code, so I go to get_image()() too...
if ((img.get_image()(y, x) != NO_DATA) && (!std::isnan(img.get_image()(y, x)))) {
if (img.pixel_has_data({y, x})) {
count += 1.0;
sum += img.get_image()(y, x);
sum += img.get_pixel({y, x});
}
}

Expand Down
6 changes: 4 additions & 2 deletions src/kbmod/search/raw_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ class RawImage {

inline float get_pixel(const Index& idx) const { return contains(idx) ? image(idx.i, idx.j) : NO_DATA; }

inline bool pixel_has_data(const Index& idx) const { return get_pixel(idx) != NO_DATA ? true : false; }
inline bool pixel_has_data(const Index& idx) const {
return pixel_value_valid(get_pixel(idx)) ? true : false;
}

inline void set_pixel(const Index& idx, float value) {
if (!contains(idx)) throw std::runtime_error("Index out of bounds!");
Expand Down Expand Up @@ -104,7 +106,7 @@ class RawImage {
// Find the basic image moments in order to test if stamps have a gaussian shape.
// It computes the moments on the "normalized" image where the minimum
// value has been shifted to zero and the sum of all elements is 1.0.
// Elements with NO_DATA are treated as zero.
// Elements with NO_DATA, NaN, etc. are treated as zero.
ImageMoments find_central_moments() const;

bool center_is_local_max(double flux_thresh, bool local_max) const;
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool e
PsiPhi psi_phi_val = psi_phi_array.read_psi_phi(i, pred_idx.i, pred_idx.j);

float value = (extract_psi) ? psi_phi_val.psi : psi_phi_val.phi;
if (value != NO_DATA) {
if (pixel_value_valid(value)) {
result[i] = value;
}
}
Expand Down
49 changes: 33 additions & 16 deletions tests/test_layered_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import math
import numpy as np
import os
import tempfile
import unittest
Expand Down Expand Up @@ -309,10 +311,14 @@ def test_psi_and_phi_image(self):
sci.set_pixel(y, x, float(x))
var.set_pixel(y, x, float(y + 1))

# Mask a single pixel and set another to variance of zero.
# Mask a single pixel, set another to variance of zero,
# and mark two as NaN.
sci.set_pixel(3, 1, KB_NO_DATA)
var.set_pixel(3, 1, KB_NO_DATA)
var.set_pixel(3, 2, 0.0)
var.set_pixel(3, 0, np.nan)
sci.set_pixel(3, 3, math.nan)
sci.set_pixel(3, 4, np.nan)

# Generate and check psi and phi images.
psi = img.generate_psi_image()
Expand All @@ -325,36 +331,47 @@ def test_psi_and_phi_image(self):

for y in range(5):
for x in range(6):
has_data = y != 3 or x == 0 or x > 2
self.assertEqual(psi.pixel_has_data(y, x), has_data)
self.assertEqual(phi.pixel_has_data(y, x), has_data)
if has_data:
self.assertAlmostEqual(psi.get_pixel(y, x), x / (y + 1))
self.assertAlmostEqual(phi.get_pixel(y, x), 1.0 / (y + 1))
psi_has_data = y != 3 or x > 4
self.assertEqual(psi.pixel_has_data(y, x), psi_has_data)
if psi_has_data:
self.assertAlmostEqual(psi.get_pixel(y, x), x / (y + 1), delta=1e-5)
else:
self.assertEqual(psi.get_pixel(y, x), KB_NO_DATA)

phi_has_data = y != 3 or x > 2
self.assertEqual(phi.pixel_has_data(y, x), phi_has_data)
if phi_has_data:
self.assertAlmostEqual(phi.get_pixel(y, x), 1.0 / (y + 1), delta=1e-5)
else:
self.assertEqual(phi.get_pixel(y, x), KB_NO_DATA)

def test_subtract_template(self):
sci = self.image.get_science()
sci.set_pixel(7, 10, KB_NO_DATA)
sci.set_pixel(21, 10, KB_NO_DATA)
sci.set_pixel(7, 11, KB_NO_DATA)
sci.set_pixel(7, 12, math.nan)
sci.set_pixel(7, 13, np.nan)
old_sci = RawImage(sci.image.copy()) # Make a copy.

template = RawImage(self.image.get_width(), self.image.get_height())
template.set_all(0.0)
for h in range(sci.height):
template.set_pixel(h, 10, 0.01 * h)
for w in range(4, sci.width):
template.set_pixel(h, w, 0.01 * h)
self.image.sub_template(template)

for x in range(sci.width):
for y in range(sci.height):
val1 = old_sci.get_pixel(y, x)
val2 = sci.get_pixel(y, x)
if x == 10 and y != 7 and y != 21:
self.assertAlmostEqual(val2, val1 - 0.01 * y, delta=1e-6)
else:
for y in range(sci.height):
for x in range(sci.width):
if y == 7 and (x >= 10 and x <= 13):
self.assertFalse(sci.pixel_has_data(y, x))
elif x < 4:
val1 = old_sci.get_pixel(y, x)
val2 = sci.get_pixel(y, x)
self.assertEqual(val1, val2)
else:
val1 = old_sci.get_pixel(y, x) - 0.01 * y
val2 = sci.get_pixel(y, x)
self.assertAlmostEqual(val1, val2, delta=1e-5)

def test_read_write_files(self):
with tempfile.TemporaryDirectory() as dir_name:
Expand Down
4 changes: 4 additions & 0 deletions tests/test_psi_phi_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def encode_uint_scalar(self):
# NO_DATA always encodes to 0.0.
self.assertAlmostEqual(encode_uint_scalar(KB_NO_DATA, 0.0, 10.0, 0.1), 0.0)

# NAN always encodes to 0.0.
self.assertAlmostEqual(encode_uint_scalar(math.nan, 0.0, 10.0, 0.1), 0.0)
self.assertAlmostEqual(encode_uint_scalar(np.nan, 0.0, 10.0, 0.1), 0.0)

# Test clipping.
self.assertAlmostEqual(encode_uint_scalar(11.0, 0.0, 10.0, 0.1), 100.0)
self.assertAlmostEqual(encode_uint_scalar(-100.0, 0.0, 10.0, 0.1), 1.0)
Expand Down
Loading

0 comments on commit 6581790

Please sign in to comment.