Skip to content

Commit

Permalink
Add helper functions for fitting/evaluating a trajectory from points
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Oct 25, 2024
1 parent 7c7ab64 commit ab71c59
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 0 deletions.
83 changes: 83 additions & 0 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,86 @@ def trajectory_from_dict(trj_dict):
trj.lh = float(trj_dict["lh"])
trj.obs_count = int(trj_dict["obs_count"])
return trj


def fit_trajectory_from_pixels(x_vals, y_vals, times, centered=True):
"""Fit a linear trajectory from individual pixel values. This is not a pure best-fit
because we restrict the starting pixels to be integers.
Parameters
----------
x_vals : `numpy.ndarray`
The x pixel values.
y_vals : `numpy.ndarray`
The y pixel values.
times : `numpy.ndarray`
The times of each point.
centered : `bool`
Shift the center to start on a half pixel. Setting to ``True`` matches how
KBMOD does the predictions during the search: x = vx * t + x0 + 0.5.
Default: True
Returns
-------
trj : `Trajectory`
The trajectory object that best fits the observations of this fake.
"""
num_pts = len(times)
if len(x_vals) != num_pts or len(y_vals) != num_pts:
raise ValueError(f"Mismatched number of points x={len(x_vals)}, y={len(x_vals)}, times={num_pts}.")
if num_pts < 2:
raise ValueError("At least 2 points are needed to fit a linear trajectory.")

# Make sure the times are in sorted order.
if num_pts > 1 and np.any(times[:-1] >= times[1:]):
raise ValueError("Times are not in sorted order.")
dt = times - times[0]

# Use least squares to find the independent fits for the x and y velocities.
T_matrix = np.vstack([dt, np.ones(num_pts)]).T
if centered:
vx, x0 = np.linalg.lstsq(T_matrix, x_vals - 0.5, rcond=None)[0]
vy, y0 = np.linalg.lstsq(T_matrix, y_vals - 0.5, rcond=None)[0]
else:
vx, x0 = np.linalg.lstsq(T_matrix, x_vals, rcond=None)[0]
vy, y0 = np.linalg.lstsq(T_matrix, y_vals, rcond=None)[0]

return Trajectory(x=int(np.round(x0)), y=int(np.round(y0)), vx=vx, vy=vy)


def evaluate_trajectory_mse(trj, x_vals, y_vals, zeroed_times, centered=True):
"""Evaluate the mean squared error for the trajectory's predictions.
Parameters
----------
trj : `Trajectory`
The trajectory object to evaluate.
x_vals : `numpy.ndarray`
The observed x pixel values.
y_vals : `numpy.ndarray`
The observed y pixel values.
zeroed_times : `numpy.ndarray`
The times of each observed point aligned with the start time of the trajectory.
centered : `bool`
Shift the center to start on a half pixel. Setting to ``True`` matches how
KBMOD does the predictions during the search: x = vx * t + x0 + 0.5.
Default: True
Returns
-------
mse : `float`
The mean squared error.
"""
num_pts = len(zeroed_times)
if len(x_vals) != num_pts or len(y_vals) != num_pts:
raise ValueError(f"Mismatched number of points x={len(x_vals)}, y={len(x_vals)}, times={num_pts}.")
if num_pts == 0:
raise ValueError("At least one point is needed to compute the error.")

# Compute the predicted x and y values.
pred_x = np.vectorize(trj.get_x_pos)(zeroed_times, centered=centered)
pred_y = np.vectorize(trj.get_y_pos)(zeroed_times, centered=centered)

# Compute the errors.
sq_err = (x_vals - pred_x) ** 2 + (y_vals - pred_y) ** 2
return np.mean(sq_err)
45 changes: 45 additions & 0 deletions tests/test_trajectory_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import unittest

from astropy.wcs import WCS
Expand Down Expand Up @@ -87,6 +88,50 @@ def test_trajectory_from_dict(self):
self.assertEqual(trj.lh, 6.0)
self.assertEqual(trj.obs_count, 7)

def test_fit_trajectory_from_pixels(self):
x_vals = np.array([5.0, 7.0, 9.0, 11.0])
y_vals = np.array([4.0, 3.0, 2.0, 1.0])
times = np.array([1.0, 2.0, 3.0, 4.0])

trj = fit_trajectory_from_pixels(x_vals, y_vals, times, centered=False)
self.assertAlmostEqual(trj.x, 5)
self.assertAlmostEqual(trj.y, 4)
self.assertAlmostEqual(trj.vx, 2.0)
self.assertAlmostEqual(trj.vy, -1.0)

# If the pixel values are centered, we need account for the 0.5 pixel shift.
x_vals = np.array([5.5, 7.5, 9.5, 11.5])
y_vals = np.array([4.5, 3.5, 2.5, 1.5])
times = np.array([1.0, 2.0, 3.0, 4.0])

trj = fit_trajectory_from_pixels(x_vals, y_vals, times, centered=True)
self.assertAlmostEqual(trj.x, 5)
self.assertAlmostEqual(trj.y, 4)
self.assertAlmostEqual(trj.vx, 2.0)
self.assertAlmostEqual(trj.vy, -1.0)

# We can't fit trajectories from a single point or mismatched array lengths.
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0], [1.0], [1.0])
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0, 2.0], [1.0])
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0], [1.0, 2.0])

def test_evaluate_trajectory_mse(self):
trj = Trajectory(x=5, y=4, vx=2.0, vy=-1.0)
x_vals = np.array([5.5, 7.5, 9.7, 11.5])
y_vals = np.array([4.5, 3.4, 2.5, 1.5])
times = np.array([0.0, 1.0, 2.0, 3.0])

mse = evaluate_trajectory_mse(trj, x_vals, y_vals, times)
self.assertAlmostEqual(mse, (0.01 + 0.04) / 4.0)

mse = evaluate_trajectory_mse(trj, [5.0], [4.0], [0.0], centered=False)
self.assertAlmostEqual(mse, 0.0)

mse = evaluate_trajectory_mse(trj, [5.5], [4.1], [0.0], centered=False)
self.assertAlmostEqual(mse, 0.25 + 0.01)

self.assertRaises(ValueError, evaluate_trajectory_mse, trj, [], [], [])


if __name__ == "__main__":
unittest.main()

0 comments on commit ab71c59

Please sign in to comment.