Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make the trajectory prediction logic more centralized #532

Merged
merged 3 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/kbmod/fake_data/fake_data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def insert_object(self, trj):

for i in range(self.num_times):
dt = self.times[i] - t0
px = trj.x + dt * trj.vx + 0.5
py = trj.y + dt * trj.vy + 0.5
px = trj.get_x_pos(dt)
py = trj.get_y_pos(dt)

# Get the image for the timestep, add the object, and
# re-set the image. This last step needs to be done
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def mask_trajectory(trj, stack, r):
for i in range(stack.img_count()):
img = stack.get_single_image(i)
time = img.get_obstime() - stack.get_single_image(0).get_obstime()
origin_of_mask = (int(trj.get_x_pos(time) + 0.5), int(trj.get_y_pos(time) + 0.5))
origin_of_mask = (trj.get_x_index(time), trj.get_y_index(time))

for dy in range(-r, r + 1):
for dx in range(-r, r + 1):
Expand Down
22 changes: 17 additions & 5 deletions src/kbmod/search/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,17 @@ struct Trajectory {
// Whether the trajectory is valid. Used for on-GPU filtering.
bool valid = true;

// Get pixel positions from a zero-shifted time.
float get_x_pos(float time) const { return x + time * vx; }
float get_y_pos(float time) const { return y + time * vy; }
// Get pixel positions from a zero-shifted time. Centered indicates whether
// the prediction starts from the center of the pixel (which it does in the search)
inline float get_x_pos(float time, bool centered = true) const {
return centered ? (x + time * vx + 0.5f) : (x + time * vx);
}
inline float get_y_pos(float time, bool centered = true) const {
return centered ? (y + time * vy + 0.5f) : (y + time * vy);
}

inline int get_x_index(float time) const { return (int)floor(get_x_pos(time, true)); }
inline int get_y_index(float time) const { return (int)floor(get_y_pos(time, true)); }
Comment on lines +56 to +66
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if this deserves a get_pos that returns an Index or an Point and then snaps them to the respective grids. (For some reason this feels like it was done at one point but there are no traces of this in the code at all?)

On the GPU we have a separate predict_index anyhow and we should use our own inventions for consistency at least.


// A helper function to test if two trajectories are close in pixel space.
bool is_close(Trajectory &trj_b, float pos_thresh, float vel_thresh) {
Expand Down Expand Up @@ -176,8 +184,12 @@ static void trajectory_bindings(py::module &m) {
.def_readwrite("y", &tj::y)
.def_readwrite("obs_count", &tj::obs_count)
.def_readwrite("valid", &tj::valid)
.def("get_x_pos", &tj::get_x_pos, pydocs::DOC_Trajectory_get_x_pos)
.def("get_y_pos", &tj::get_y_pos, pydocs::DOC_Trajectory_get_y_pos)
.def("get_x_pos", &tj::get_x_pos, py::arg("time"), py::arg("centered") = true,
pydocs::DOC_Trajectory_get_x_pos)
.def("get_y_pos", &tj::get_y_pos, py::arg("time"), py::arg("centered") = true,
pydocs::DOC_Trajectory_get_y_pos)
.def("get_x_index", &tj::get_x_index, pydocs::DOC_Trajectory_get_x_index)
.def("get_y_index", &tj::get_y_index, pydocs::DOC_Trajectory_get_y_index)
.def("is_close", &tj::is_close, pydocs::DOC_Trajectory_is_close)
.def("__repr__", [](const tj &t) { return "Trajectory(" + t.to_string() + ")"; })
.def("__str__", &tj::to_string)
Expand Down
12 changes: 8 additions & 4 deletions src/kbmod/search/kernels/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ namespace search {

__host__ __device__ bool device_pixel_valid(float value) { return isfinite(value); }

__host__ __device__ int predict_index(float pos0, float vel0, float time) {
return (int)(floor(pos0 + vel0 * time + 0.5f));
}

__host__ __device__ PsiPhi read_encoded_psi_phi(PsiPhiArrayMeta &params, void *psi_phi_vect, int time,
int row, int col) {
// Bounds checking.
Expand Down Expand Up @@ -148,8 +152,8 @@ extern "C" __device__ __host__ void evaluateTrajectory(PsiPhiArrayMeta psi_phi_m
for (int i = 0; i < psi_phi_meta.num_times; ++i) {
// Predict the trajectory's position.
float curr_time = image_times[i];
int current_x = candidate->x + int(candidate->vx * curr_time + 0.5);
int current_y = candidate->y + int(candidate->vy * curr_time + 0.5);
int current_x = predict_index(candidate->x, candidate->vx, curr_time);
int current_y = predict_index(candidate->y, candidate->vy, curr_time);

// Get the Psi and Phi pixel values. Skip invalid values, such as those marked NaN or NO_DATA.
PsiPhi pixel_vals = read_encoded_psi_phi(psi_phi_meta, psi_phi_vect, i, current_y, current_x);
Expand Down Expand Up @@ -361,8 +365,8 @@ __global__ void deviceGetCoaddStamp(int num_images, int width, int height, float

// Predict the trajectory's position.
float curr_time = image_times[t];
int current_x = int(trj.x + trj.vx * curr_time);
int current_y = int(trj.y + trj.vy * curr_time);
int current_x = predict_index(trj.x, trj.vx, curr_time);
int current_y = predict_index(trj.y, trj.vy, curr_time);

// Get the stamp and add it to the list of values.
int img_x = current_x - params.radius + stamp_x;
Expand Down
30 changes: 26 additions & 4 deletions src/kbmod/search/pydocs/common_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ static const auto DOC_Trajectory_get_x_pos = R"doc(
----------
time : `float`
A zero shifted time.
centered : `bool`
Shift the prediction to be at the center of the pixel
(e.g. xp = x + vx * time + 0.5f). Default = True.

Returns
-------
Expand All @@ -48,15 +51,34 @@ static const auto DOC_Trajectory_get_y_pos = R"doc(
----------
time : `float`
A zero shifted time.
centered : `bool`
Shift the prediction to be at the center of the pixel
(e.g. xp = x + vx * time + 0.5f). Default = True.

Returns
-------
`float`
The predicted y position (in pixels).
)doc";

static const auto DOC_Trajectory_get_pos = R"doc(
Returns the predicted (x, y) position of the trajectory.
static const auto DOC_Trajectory_get_x_index = R"doc(
Returns the predicted x position of the trajectory as an integer
(column) index.

Parameters
----------
time : `float`
A zero shifted time.

Returns
-------
`int`
The predicted column index.
)doc";

static const auto DOC_Trajectory_get_y_index = R"doc(
Returns the predicted x position of the trajectory as an integer
(row) index.

Parameters
----------
Expand All @@ -65,8 +87,8 @@ static const auto DOC_Trajectory_get_pos = R"doc(

Returns
-------
`PixelPos`
The predicted (x, y) position (in pixels).
`int`
The predicted row index.
)doc";

static const auto DOC_Trajectory_is_close = R"doc(
Expand Down
5 changes: 1 addition & 4 deletions src/kbmod/search/stack_search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ std::vector<float> StackSearch::extract_psi_or_phi_curve(Trajectory& trj, bool e
float time = psi_phi_array.read_time(i);

// Query the center of the predicted location's pixel.
Point pred_pt = {trj.get_x_pos(time) + 0.5f, trj.get_y_pos(time) + 0.5f};
Index pred_idx = pred_pt.to_index();
PsiPhi psi_phi_val = psi_phi_array.read_psi_phi(i, pred_idx.i, pred_idx.j);

PsiPhi psi_phi_val = psi_phi_array.read_psi_phi(i, trj.get_y_index(time), trj.get_x_index(time));
float value = (extract_psi) ? psi_phi_val.psi : psi_phi_val.phi;
if (pixel_value_valid(value)) {
result[i] = value;
Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/search/stamp_creator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ std::vector<RawImage> StampCreator::create_stamps(ImageStack& stack, const Traje
if (use_all_stamps || use_index[i]) {
// Calculate the trajectory position.
float time = stack.get_zeroed_time(i);
Point pos{trj.x + time * trj.vx, trj.y + time * trj.vy};
Point pos{trj.get_x_pos(time), trj.get_y_pos(time)};
RawImage& img = stack.get_single_image(i).get_science();
stamps.push_back(img.create_stamp(pos, radius, keep_no_data));
}
Expand Down
4 changes: 3 additions & 1 deletion tests/test_bilinear_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@

from kbmod.fake_data.fake_data_creator import add_fake_object, make_fake_layered_image
import kbmod.search as kb
from kbmod.trajectory_utils import make_trajectory


class test_bilinear_interp(unittest.TestCase):
def setUp(self):
self.im_count = 5
self.trj = make_trajectory(2, 2, 0.5, 0.5)
p = kb.PSF(0.05)
self.images = []
for c in range(self.im_count):
im = make_fake_layered_image(10, 10, 0.0, 1.0, c, p)
add_fake_object(im, 2 + c * 0.5 + 0.5, 2 + c * 0.5 + 0.5, 1, p)
add_fake_object(im, self.trj.get_x_pos(c), self.trj.get_y_pos(c), 1, p)
self.images.append(im)

def test_pixels(self):
Expand Down
27 changes: 21 additions & 6 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,27 @@ def test_pixel_value_valid(self):

def test_trajectory_predict(self):
trj = make_trajectory(x=5, y=10, vx=2.0, vy=-1.0)
self.assertEqual(trj.get_x_pos(0.0), 5.0)
self.assertEqual(trj.get_y_pos(0.0), 10.0)
self.assertEqual(trj.get_x_pos(1.0), 7.0)
self.assertEqual(trj.get_y_pos(1.0), 9.0)
self.assertEqual(trj.get_x_pos(2.0), 9.0)
self.assertEqual(trj.get_y_pos(2.0), 8.0)
# With centered=false the trajectories start at the pixel edge.
self.assertEqual(trj.get_x_pos(0.0, False), 5.0)
self.assertEqual(trj.get_y_pos(0.0, False), 10.0)
self.assertEqual(trj.get_x_pos(1.0, False), 7.0)
self.assertEqual(trj.get_y_pos(1.0, False), 9.0)
self.assertEqual(trj.get_x_pos(2.0, False), 9.0)
self.assertEqual(trj.get_y_pos(2.0, False), 8.0)

# Centering moves things by half a pixel.
self.assertEqual(trj.get_x_pos(0.0), 5.5)
self.assertEqual(trj.get_y_pos(0.0), 10.5)
self.assertEqual(trj.get_x_pos(1.0), 7.5)
self.assertEqual(trj.get_y_pos(1.0), 9.5)
self.assertEqual(trj.get_x_pos(2.0), 9.5)
self.assertEqual(trj.get_y_pos(2.0), 8.5)

# Predicting the index gives a floored integer of the centered prediction.
self.assertEqual(trj.get_x_index(0.0), 5)
self.assertEqual(trj.get_y_index(0.0), 10)
self.assertEqual(trj.get_x_index(1.0), 7)
self.assertEqual(trj.get_y_index(1.0), 9)

def test_trajectory_is_close(self):
trj = make_trajectory(x=5, y=10, vx=2.0, vy=-1.0)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fake_data_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_insert_object(self):
t0 = ds.stack.get_single_image(0).get_obstime()
for i in range(ds.stack.img_count()):
dt = ds.stack.get_single_image(i).get_obstime() - t0
px = int(trj.x + dt * trj.vx + 0.5)
py = int(trj.y + dt * trj.vy + 0.5)
px = trj.get_x_index(dt)
py = trj.get_y_index(dt)

# Check the trajectory stays in the image.
self.assertGreaterEqual(px, 0)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_apply_trajectory_mask(self):
for i in range(self.img_count):
time = self.stack.get_single_image(i).get_obstime() - self.stack.get_single_image(0).get_obstime()

origin_of_mask = (trj.get_x_pos(time), trj.get_y_pos(time))
origin_of_mask = (trj.get_x_index(time), trj.get_y_index(time))
msk = masked_stack.get_single_image(i).get_mask()

for x in range(self.dim_x):
Expand Down
47 changes: 23 additions & 24 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class test_search(unittest.TestCase):
def setUp(self):
# test pass thresholds
self.pixel_error = 0
self.pixel_error = 1
self.velocity_error = 0.1
self.flux_error = 0.15

Expand Down Expand Up @@ -65,8 +65,8 @@ def setUp(self):
)
add_fake_object(
im,
self.start_x + time * self.vxel + 0.5,
self.start_y + time * self.vyel + 0.5,
self.trj.get_x_pos(time),
self.trj.get_y_pos(time),
self.object_flux,
self.p,
)
Expand All @@ -89,8 +89,8 @@ def setUp(self):
self.params.center_thresh = 0.03
self.params.peak_offset_x = 1.5
self.params.peak_offset_y = 1.5
self.params.m01_limit = 0.6
self.params.m10_limit = 0.6
self.params.m01_limit = 1.0
self.params.m10_limit = 1.0
self.params.m11_limit = 2.0
self.params.m02_limit = 35.5
self.params.m20_limit = 35.5
Expand Down Expand Up @@ -278,8 +278,8 @@ def test_results_off_chip(self):
)
add_fake_object(
im,
trj.x + time * trj.vx + 0.5,
trj.y + time * trj.vy + 0.5,
trj.get_x_index(time),
trj.get_y_index(time),
self.object_flux,
self.p,
)
Expand Down Expand Up @@ -311,9 +311,8 @@ def test_sci_viz_stamps(self):
self.assertEqual(sci_stamps[i].height, 5)

# Compute the interpolated pixel value at the projected location.
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(times[i])
y = self.trj.get_y_index(times[i])
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if not pixel_value_valid(pixVal):
pivVal = 0.0
Expand All @@ -333,8 +332,8 @@ def test_stacked_sci(self):
sum_middle = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(t)
y = self.trj.get_y_index(t)
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if not pixel_value_valid(pixVal):
pivVal = 0.0
Expand Down Expand Up @@ -365,8 +364,8 @@ def test_median_stamps_trj(self):
pix_values1 = []
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(t)
y = self.trj.get_y_index(t)
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal) and goodIdx[0][i] == 1:
pix_values0.append(pixVal)
Expand Down Expand Up @@ -422,8 +421,8 @@ def test_mean_stamps_trj(self):
pix_count1 = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t)
y = int(self.trj.y + self.trj.vy * t)
x = self.trj.get_x_index(t)
y = self.trj.get_y_index(t)
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal) and goodIdx[0][i] == 1:
pix_sum0 += pixVal
Expand Down Expand Up @@ -681,8 +680,8 @@ def test_coadd_cpu(self):
pix_vals = []
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal):
pix_sum += pixVal
Expand Down Expand Up @@ -738,8 +737,8 @@ def test_coadd_gpu(self):
pix_vals = []
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)
if pixel_value_valid(pixVal):
pix_sum += pixVal
Expand Down Expand Up @@ -787,8 +786,8 @@ def test_coadd_cpu_use_inds(self):
count_1 = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)

if pixel_value_valid(pixVal) and inds[0][i] > 0:
Expand Down Expand Up @@ -838,8 +837,8 @@ def test_coadd_gpu_use_inds(self):
count_1 = 0.0
for i in range(self.img_count):
t = times[i]
x = int(self.trj.x + self.trj.vx * t) + x_offset
y = int(self.trj.y + self.trj.vy * t) + y_offset
x = self.trj.get_x_index(t) + x_offset
y = self.trj.get_y_index(t) + y_offset
pixVal = self.imlist[i].get_science().get_pixel(y, x)

if pixel_value_valid(pixVal) and inds[0][i] > 0:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trajectory_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def setUp(self):
fake_ds.insert_object(self.trj)

# Remove at least observation from the trajectory.
pred_x = int(self.x0 + fake_times[10] * self.vx + 0.5)
pred_y = int(self.y0 + fake_times[10] * self.vy + 0.5)
pred_x = self.trj.get_x_index(fake_times[10])
pred_y = self.trj.get_y_index(fake_times[10])
sci_t10 = fake_ds.stack.get_single_image(10).get_science()
for dy in [-1, 0, 1]:
for dx in [-1, 0, 1]:
Expand Down
Loading