Skip to content

Commit

Permalink
Merge pull request #532 from dirac-institute/prediction_consistency
Browse files Browse the repository at this point in the history
Make the trajectory prediction logic more centralized
  • Loading branch information
jeremykubica authored Apr 16, 2024
2 parents 08548b7 + d402de6 commit ae4b2e4
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 57 deletions.
4 changes: 2 additions & 2 deletions src/kbmod/fake_data/fake_data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def insert_object(self, trj):

for i in range(self.num_times):
dt = self.times[i] - t0
px = trj.x + dt * trj.vx + 0.5
py = trj.y + dt * trj.vy + 0.5
px = trj.get_x_pos(dt)
py = trj.get_y_pos(dt)

# Get the image for the timestep, add the object, and
# re-set the image. This last step needs to be done
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def mask_trajectory(trj, stack, r):
for i in range(stack.img_count()):
img = stack.get_single_image(i)
time = img.get_obstime() - stack.get_single_image(0).get_obstime()
origin_of_mask = (int(trj.get_x_pos(time) + 0.5), int(trj.get_y_pos(time) + 0.5))
origin_of_mask = (trj.get_x_index(time), trj.get_y_index(time))

for dy in range(-r, r + 1):
for dx in range(-r, r + 1):
Expand Down
22 changes: 17 additions & 5 deletions src/kbmod/search/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,17 @@ struct Trajectory {
// Whether the trajectory is valid. Used for on-GPU filtering.
bool valid = true;

// Get pixel positions from a zero-shifted time.
float get_x_pos(float time) const { return x + time * vx; }
float get_y_pos(float time) const { return y + time * vy; }
// Get pixel positions from a zero-shifted time. Centered indicates whether
// the prediction starts from the center of the pixel (which it does in the search)
inline float get_x_pos(float time, bool centered = true) const {
return centered ? (x + time * vx + 0.5f) : (x + time * vx);
}
inline float get_y_pos(float time, bool centered = true) const {
return centered ? (y + time * vy + 0.5f) : (y + time * vy);
}

inline int get_x_index(float time) const { return (int)floor(get_x_pos(time, true)); }
inline int get_y_index(float time) const { return (int)floor(get_y_pos(time, true)); }

// A helper function to test if two trajectories are close in pixel space.
bool is_close(Trajectory &trj_b, float pos_thresh, float vel_thresh) {
Expand Down Expand Up @@ -176,8 +184,12 @@ static void trajectory_bindings(py::module &m) {
.def_readwrite("y", &tj::y)
.def_readwrite("obs_count", &tj::obs_count)
.def_readwrite("valid", &tj::valid)
.def("get_x_pos", &tj::get_x_pos, pydocs::DOC_Trajectory_get_x_pos)
.def("get_y_pos", &tj::get_y_pos, pydocs::DOC_Trajectory_get_y_pos)
.def("get_x_pos", &tj::get_x_pos, py::arg("time"), py::arg("centered") = true,
pydocs::DOC_Trajectory_get_x_pos)
.def("get_y_pos", &tj::get_y_pos, py::arg("time"), py::arg("centered") = true,
pydocs::DOC_Trajectory_get_y_pos)
.def("get_x_index", &tj::get_x_index, pydocs::DOC_Trajectory_get_x_index)
.def("get_y_index", &tj::get_y_index, pydocs::DOC_Trajectory_get_y_index)
.def("is_close", &tj::is_close, pydocs::DOC_Trajectory_is_close)
.def("__repr__", [](const tj &t) { return "Trajectory(" + t.to_string() + ")"; })
.def("__str__", &tj::to_string)
Expand Down
12 changes: 8 additions & 4 deletions src/kbmod/search/kernels/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ namespace search {

__host__ __device__ bool device_pixel_valid(float value) { return isfinite(value); }

__host__ __device__ int predict_index(float pos0, float vel0, float time) {
return (int)(floor(pos0 + vel0 * time + 0.5f));
}

__host__ __device__ PsiPhi read_encoded_psi_phi(PsiPhiArrayMeta &params, void *psi_phi_vect, int time,
int row, int col) {
// Bounds checking.
Expand Down Expand Up @@ -148,8 +152,8 @@ extern "C" __device__ __host__ void evaluateTrajectory(PsiPhiArrayMeta psi_phi_m
for (int i = 0; i < psi_phi_meta.num_times; ++i) {
// Predict the trajectory's position.
float curr_time = image_times[i];
int current_x = candidate->x + int(candidate->vx * curr_time + 0.5);
int current_y = candidate->y + int(candidate->vy * curr_time + 0.5);
int current_x = predict_index(candidate->x, candidate->vx, curr_time);
int current_y = predict_index(candidate->y, candidate->vy, curr_time);

// Get the Psi and Phi pixel values. Skip invalid values, such as those marked NaN or NO_DATA.
PsiPhi pixel_vals = read_encoded_psi_phi(psi_phi_meta, psi_phi_vect, i, current_y, current_x);
Expand Down Expand Up @@ -361,8 +365,8 @@ __global__ void deviceGetCoaddStamp(int num_images, int width, int height, float

// Predict the trajectory's position.
float curr_time = image_times[t];
int current_x = int(trj.x + trj.vx * curr_time);
int current_y = int(trj.y + trj.vy * curr_time);
int current_x = predict_index(trj.x, trj.vx, curr_time);
int current_y = predict_index(trj.y, trj.vy, curr_time);

// Get the stamp and add it to the list of values.
int img_x = current_x - params.radius + stamp_x;
Expand Down
30 changes: 26 additions & 4 deletions src/kbmod/search/pydocs/common_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ static const auto DOC_Trajectory_get_x_pos = R"doc(
----------
time : `float`
A zero shifted time.
centered : `bool`
Shift the prediction to be at the center of the pixel
(e.g. xp = x + vx * time + 0.5f). Default = True.
Returns
-------
Expand All @@ -48,15 +51,34 @@ static const auto DOC_Trajectory_get_y_pos = R"doc(
----------
time : `float`
A zero shifted time.
centered : `bool`
Shift the prediction to be at the center of the pixel
(e.g. xp = x + vx * time + 0.5f). Default = True.
Returns
-------
`float`
The predicted y position (in pixels).
)doc";

static const auto DOC_Trajectory_get_pos = R"doc(
Returns the predicted (x, y) position of the trajectory.
static const auto DOC_Trajectory_get_x_index = R"doc(
Returns the predicted x position of the trajectory as an integer
(column) index.
Parameters
----------
time : `float`
A zero shifted time.
Returns
-------
`int`
The predicted column index.
)doc";

static const auto DOC_Trajectory_get_y_index = R"doc(
Returns the predicted x position of the trajectory as an integer
(row) index.
Parameters
----------
Expand All @@ -65,8 +87,8 @@ static const auto DOC_Trajectory_get_pos = R"doc(
Returns
-------
`PixelPos`
The predicted (x, y) position (in pixels).
`int`
The predicted row index.
)doc";

static const auto DOC_Trajectory_is_close = R"doc(
Expand Down
5 changes: 1 addition & 4 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool e
float time = psi_phi_array.read_time(i);

// Query the center of the predicted location's pixel.
Point pred_pt = {trj.get_x_pos(time) + 0.5f, trj.get_y_pos(time) + 0.5f};
Index pred_idx = pred_pt.to_index();
PsiPhi psi_phi_val = psi_phi_array.read_psi_phi(i, pred_idx.i, pred_idx.j);

PsiPhi psi_phi_val = psi_phi_array.read_psi_phi(i, trj.get_y_index(time), trj.get_x_index(time));
float value = (extract_psi) ? psi_phi_val.psi : psi_phi_val.phi;
if (pixel_value_valid(value)) {
result[i] = value;
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/search/stamp_creator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ std::vector<RawImage> StampCreator::create_stamps(ImageStack& stack, const Traje
if (use_all_stamps || use_index[i]) {
// Calculate the trajectory position.
float time = stack.get_zeroed_time(i);
Point pos{trj.x + time * trj.vx, trj.y + time * trj.vy};
Point pos{trj.get_x_pos(time), trj.get_y_pos(time)};
RawImage& img = stack.get_single_image(i).get_science();
stamps.push_back(img.create_stamp(pos, radius, keep_no_data));
}
Expand Down
4 changes: 3 additions & 1 deletion tests/test_bilinear_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@

from kbmod.fake_data.fake_data_creator import add_fake_object, make_fake_layered_image
import kbmod.search as kb
from kbmod.trajectory_utils import make_trajectory


class test_bilinear_interp(unittest.TestCase):
def setUp(self):
self.im_count = 5
self.trj = make_trajectory(2, 2, 0.5, 0.5)
p = kb.PSF(0.05)
self.images = []
for c in range(self.im_count):
im = make_fake_layered_image(10, 10, 0.0, 1.0, c, p)
add_fake_object(im, 2 + c * 0.5 + 0.5, 2 + c * 0.5 + 0.5, 1, p)
add_fake_object(im, self.trj.get_x_pos(c), self.trj.get_y_pos(c), 1, p)
self.images.append(im)

def test_pixels(self):
Expand Down
27 changes: 21 additions & 6 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,27 @@ def test_pixel_value_valid(self):

def test_trajectory_predict(self):
trj = make_trajectory(x=5, y=10, vx=2.0, vy=-1.0)
self.assertEqual(trj.get_x_pos(0.0), 5.0)
self.assertEqual(trj.get_y_pos(0.0), 10.0)
self.assertEqual(trj.get_x_pos(1.0), 7.0)
self.assertEqual(trj.get_y_pos(1.0), 9.0)
self.assertEqual(trj.get_x_pos(2.0), 9.0)
self.assertEqual(trj.get_y_pos(2.0), 8.0)
# With centered=false the trajectories start at the pixel edge.
self.assertEqual(trj.get_x_pos(0.0, False), 5.0)
self.assertEqual(trj.get_y_pos(0.0, False), 10.0)
self.assertEqual(trj.get_x_pos(1.0, False), 7.0)
self.assertEqual(trj.get_y_pos(1.0, False), 9.0)
self.assertEqual(trj.get_x_pos(2.0, False), 9.0)
self.assertEqual(trj.get_y_pos(2.0, False), 8.0)

# Centering moves things by half a pixel.
self.assertEqual(trj.get_x_pos(0.0), 5.5)
self.assertEqual(trj.get_y_pos(0.0), 10.5)
self.assertEqual(trj.get_x_pos(1.0), 7.5)
self.assertEqual(trj.get_y_pos(1.0), 9.5)
self.assertEqual(trj.get_x_pos(2.0), 9.5)
self.assertEqual(trj.get_y_pos(2.0), 8.5)

# Predicting the index gives a floored integer of the centered prediction.
self.assertEqual(trj.get_x_index(0.0), 5)
self.assertEqual(trj.get_y_index(0.0), 10)
self.assertEqual(trj.get_x_index(1.0), 7)
self.assertEqual(trj.get_y_index(1.0), 9)

def test_trajectory_is_close(self):
trj = make_trajectory(x=5, y=10, vx=2.0, vy=-1.0)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fake_data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_insert_object(self):
t0 = ds.stack.get_single_image(0).get_obstime()
for i in range(ds.stack.img_count()):
dt = ds.stack.get_single_image(i).get_obstime() - t0
px = int(trj.x + dt * trj.vx + 0.5)
py = int(trj.y + dt * trj.vy + 0.5)
px = trj.get_x_index(dt)
py = trj.get_y_index(dt)

# Check the trajectory stays in the image.
self.assertGreaterEqual(px, 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_apply_trajectory_mask(self):
for i in range(self.img_count):
time = self.stack.get_single_image(i).get_obstime() - self.stack.get_single_image(0).get_obstime()

origin_of_mask = (trj.get_x_pos(time), trj.get_y_pos(time))
origin_of_mask = (trj.get_x_index(time), trj.get_y_index(time))
msk = masked_stack.get_single_image(i).get_mask()

for x in range(self.dim_x):
Expand Down
47 changes: 23 additions & 24 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class test_search(unittest.TestCase):
def setUp(self):
# test pass thresholds
self.pixel_error = 0
self.pixel_error = 1
self.velocity_error = 0.1
self.flux_error = 0.15

Expand Down Expand Up @@ -65,8 +65,8 @@ def setUp(self):
)
add_fake_object(
im,
self.start_x + time * self.vxel + 0.5,
self.start_y + time * self.vyel + 0.5,
self.trj.get_x_pos(time),
self.trj.get_y_pos(time),
self.object_flux,
self.p,
)
Expand All @@ -89,8 +89,8 @@ def setUp(self):
self.params.center_thresh = 0.03
self.params.peak_offset_x = 1.5
self.params.peak_offset_y = 1.5
self.params.m01_limit = 0.6
self.params.m10_limit = 0.6
self.params.m01_limit = 1.0
self.params.m10_limit = 1.0
self.params.m11_limit = 2.0
self.params.m02_limit = 35.5
self.params.m20_limit = 35.5
Expand Down Expand Up @@ -278,8 +278,8 @@ def test_results_off_chip(self):
)
add_fake_object(
im,
trj.x + time * trj.vx + 0.5,
trj.y + time * trj.vy + 0.5,
trj.get_x_index(time),
trj.get_y_index(time),
self.object_flux,
self.p,
)
Expand Down Expand Up @@ -311,9 +311,8 @@ def test_sci_viz_stamps(self):
self.assertEqual(sci_stamps[i].height, 5)

# Compute the interpolated pixel value at the projected location.
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(times[i])
y = self.trj.get_y_index(times[i])
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if not pixel_value_valid(pixVal):
pivVal = 0.0
Expand All @@ -333,8 +332,8 @@ def test_stacked_sci(self):
sum_middle = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(t)
y = self.trj.get_y_index(t)
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if not pixel_value_valid(pixVal):
pivVal = 0.0
Expand Down Expand Up @@ -365,8 +364,8 @@ def test_median_stamps_trj(self):
pix_values1 = []
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(t)
y = self.trj.get_y_index(t)
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal) and goodIdx[0][i] == 1:
pix_values0.append(pixVal)
Expand Down Expand Up @@ -422,8 +421,8 @@ def test_mean_stamps_trj(self):
pix_count1 = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(t)
y = self.trj.get_y_index(t)
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal) and goodIdx[0][i] == 1:
pix_sum0 += pixVal
Expand Down Expand Up @@ -681,8 +680,8 @@ def test_coadd_cpu(self):
pix_vals = []
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal):
pix_sum += pixVal
Expand Down Expand Up @@ -738,8 +737,8 @@ def test_coadd_gpu(self):
pix_vals = []
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal):
pix_sum += pixVal
Expand Down Expand Up @@ -787,8 +786,8 @@ def test_coadd_cpu_use_inds(self):
count_1 = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)

if pixel_value_valid(pixVal) and inds[0][i] > 0:
Expand Down Expand Up @@ -838,8 +837,8 @@ def test_coadd_gpu_use_inds(self):
count_1 = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)

if pixel_value_valid(pixVal) and inds[0][i] > 0:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trajectory_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def setUp(self):
fake_ds.insert_object(self.trj)

# Remove at least observation from the trajectory.
pred_x = int(self.x0 + fake_times[10] * self.vx + 0.5)
pred_y = int(self.y0 + fake_times[10] * self.vy + 0.5)
pred_x = self.trj.get_x_index(fake_times[10])
pred_y = self.trj.get_y_index(fake_times[10])
sci_t10 = fake_ds.stack.get_single_image(10).get_science()
for dy in [-1, 0, 1]:
for dx in [-1, 0, 1]:
Expand Down

0 comments on commit ae4b2e4

Please sign in to comment.