Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper functions for evaluation #727

Merged
merged 3 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,86 @@ def trajectory_from_dict(trj_dict):
trj.lh = float(trj_dict["lh"])
trj.obs_count = int(trj_dict["obs_count"])
return trj


def fit_trajectory_from_pixels(x_vals, y_vals, times, centered=True):
"""Fit a linear trajectory from individual pixel values. This is not a pure best-fit
because we restrict the starting pixels to be integers.

Parameters
----------
x_vals : `numpy.ndarray`
The x pixel values.
y_vals : `numpy.ndarray`
The y pixel values.
times : `numpy.ndarray`
The times of each point.
centered : `bool`
Shift the center to start on a half pixel. Setting to ``True`` matches how
KBMOD does the predictions during the search: x = vx * t + x0 + 0.5.
Default: True

Returns
-------
trj : `Trajectory`
The trajectory object that best fits the observations of this fake.
"""
num_pts = len(times)
if len(x_vals) != num_pts or len(y_vals) != num_pts:
raise ValueError(f"Mismatched number of points x={len(x_vals)}, y={len(x_vals)}, times={num_pts}.")
if num_pts < 2:
raise ValueError("At least 2 points are needed to fit a linear trajectory.")

# Make sure the times are in sorted order.
if num_pts > 1 and np.any(times[:-1] >= times[1:]):
raise ValueError("Times are not in sorted order.")
dt = times - times[0]

# Use least squares to find the independent fits for the x and y velocities.
T_matrix = np.vstack([dt, np.ones(num_pts)]).T
if centered:
vx, x0 = np.linalg.lstsq(T_matrix, x_vals - 0.5, rcond=None)[0]
vy, y0 = np.linalg.lstsq(T_matrix, y_vals - 0.5, rcond=None)[0]
else:
vx, x0 = np.linalg.lstsq(T_matrix, x_vals, rcond=None)[0]
vy, y0 = np.linalg.lstsq(T_matrix, y_vals, rcond=None)[0]

return Trajectory(x=int(np.round(x0)), y=int(np.round(y0)), vx=vx, vy=vy)


def evaluate_trajectory_mse(trj, x_vals, y_vals, zeroed_times, centered=True):
"""Evaluate the mean squared error for the trajectory's predictions.

Parameters
----------
trj : `Trajectory`
The trajectory object to evaluate.
x_vals : `numpy.ndarray`
The observed x pixel values.
y_vals : `numpy.ndarray`
The observed y pixel values.
zeroed_times : `numpy.ndarray`
The times of each observed point aligned with the start time of the trajectory.
centered : `bool`
Shift the center to start on a half pixel. Setting to ``True`` matches how
KBMOD does the predictions during the search: x = vx * t + x0 + 0.5.
Default: True

Returns
-------
mse : `float`
The mean squared error.
"""
num_pts = len(zeroed_times)
if len(x_vals) != num_pts or len(y_vals) != num_pts:
raise ValueError(f"Mismatched number of points x={len(x_vals)}, y={len(x_vals)}, times={num_pts}.")
if num_pts == 0:
raise ValueError("At least one point is needed to compute the error.")

# Compute the predicted x and y values.
pred_x = np.vectorize(trj.get_x_pos)(zeroed_times, centered=centered)
pred_y = np.vectorize(trj.get_y_pos)(zeroed_times, centered=centered)

# Compute the errors.
sq_err = (x_vals - pred_x) ** 2 + (y_vals - pred_y) ** 2
return np.mean(sq_err)
63 changes: 60 additions & 3 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
import warnings
from pathlib import Path

from astropy.coordinates import SkyCoord, EarthLocation
from astropy.io import fits
from astropy.time import Time
from astropy.utils.exceptions import AstropyWarning
from astropy.wcs.utils import skycoord_to_pixel
from astropy.time import Time
from astropy.coordinates import SkyCoord, EarthLocation
import astropy.units as u

import numpy as np
from tqdm import tqdm

from kbmod import is_interactive
from kbmod.configuration import SearchConfiguration
from kbmod.reprojection_utils import invert_correct_parallax
from kbmod.search import ImageStack, LayeredImage, PSF, RawImage, Logging
from kbmod.util_functions import get_matched_obstimes
from kbmod.wcs_utils import (
append_wcs_to_hdu_header,
calc_ecliptic_angle,
Expand All @@ -21,7 +25,6 @@
wcs_from_dict,
wcs_to_dict,
)
from kbmod.reprojection_utils import invert_correct_parallax


_DEFAULT_WORKUNIT_TQDM_BAR = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]"
Expand Down Expand Up @@ -206,6 +209,60 @@ def get_wcs(self, img_num):

return per_img

def get_pixel_coordinates(self, ra, dec, times=None):
"""Get the pixel coordinates for pairs of (RA, dec) coordinates. Uses the global
WCS if one exists. Otherwise uses the per-image WCS. If times is provided, uses those values
to choose the per-image WCS.

Parameters
----------
ra : `numpy.ndarray`
The right ascension coordinates in degrees.
dec : `numpy.ndarray`
The declination coordinates in degrees.
times : `numpy.ndarray` or `None`, optional
The times to match.

Returns
-------
x_pos, y_pos: `numpy.ndarray`
Arrays of the X and Y pixel positions respectively.
"""
num_pts = len(ra)
if num_pts != len(dec):
raise ValueError(f"Mismatched array sizes RA={len(ra)} and dec={len(dec)}.")
if times is not None and len(times) != num_pts:
raise ValueError(f"Mismatched array sizes RA={len(ra)} and times={len(times)}.")

if self.wcs is not None:
# If we have a single global WCS, we can use it for all the conversions. No time matching needed.
x_pos, y_pos = self.wcs.world_to_pixel(SkyCoord(ra=ra * u.degree, dec=dec * u.degree))
else:
if times is None:
if len(self._obstimes) == num_pts:
inds = np.arange(num_pts)
else:
raise ValueError("No time information for a WorkUnit without a gloabl WCS.")
jeremykubica marked this conversation as resolved.
Show resolved Hide resolved
elif self._obstimes is not None:
inds = get_matched_obstimes(self._obstimes, times, threshold=0.02)
else:
raise ValueError("No times provided for images in WorkUnit.")

# TODO: Determine if there is a way to vectorize.
x_pos = np.zeros(num_pts)
y_pos = np.zeros(num_pts)
for i, index in enumerate(inds):
if index == -1:
raise ValueError(f"Unmatched time {times[i]}.")
current_wcs = self._per_image_wcs[index]
curr_x, curr_y = current_wcs.world_to_pixel(
SkyCoord(ra=ra[i] * u.degree, dec=dec[i] * u.degree)
)
x_pos[i] = curr_x
y_pos[i] = curr_y

return x_pos, y_pos

def compute_ecliptic_angle(self):
"""Return the ecliptic angle (in radians in pixel space) derived from the
images and WCS.
Expand Down
45 changes: 45 additions & 0 deletions tests/test_trajectory_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import unittest

from astropy.wcs import WCS
Expand Down Expand Up @@ -87,6 +88,50 @@ def test_trajectory_from_dict(self):
self.assertEqual(trj.lh, 6.0)
self.assertEqual(trj.obs_count, 7)

def test_fit_trajectory_from_pixels(self):
x_vals = np.array([5.0, 7.0, 9.0, 11.0])
y_vals = np.array([4.0, 3.0, 2.0, 1.0])
times = np.array([1.0, 2.0, 3.0, 4.0])

trj = fit_trajectory_from_pixels(x_vals, y_vals, times, centered=False)
self.assertAlmostEqual(trj.x, 5)
self.assertAlmostEqual(trj.y, 4)
self.assertAlmostEqual(trj.vx, 2.0)
self.assertAlmostEqual(trj.vy, -1.0)

# If the pixel values are centered, we need account for the 0.5 pixel shift.
x_vals = np.array([5.5, 7.5, 9.5, 11.5])
y_vals = np.array([4.5, 3.5, 2.5, 1.5])
times = np.array([1.0, 2.0, 3.0, 4.0])

trj = fit_trajectory_from_pixels(x_vals, y_vals, times, centered=True)
self.assertAlmostEqual(trj.x, 5)
self.assertAlmostEqual(trj.y, 4)
self.assertAlmostEqual(trj.vx, 2.0)
self.assertAlmostEqual(trj.vy, -1.0)

# We can't fit trajectories from a single point or mismatched array lengths.
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0], [1.0], [1.0])
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0, 2.0], [1.0])
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0], [1.0, 2.0])

def test_evaluate_trajectory_mse(self):
trj = Trajectory(x=5, y=4, vx=2.0, vy=-1.0)
x_vals = np.array([5.5, 7.5, 9.7, 11.5])
y_vals = np.array([4.5, 3.4, 2.5, 1.5])
times = np.array([0.0, 1.0, 2.0, 3.0])

mse = evaluate_trajectory_mse(trj, x_vals, y_vals, times)
self.assertAlmostEqual(mse, (0.01 + 0.04) / 4.0)

mse = evaluate_trajectory_mse(trj, [5.0], [4.0], [0.0], centered=False)
self.assertAlmostEqual(mse, 0.0)

mse = evaluate_trajectory_mse(trj, [5.5], [4.1], [0.0], centered=False)
self.assertAlmostEqual(mse, 0.25 + 0.01)

self.assertRaises(ValueError, evaluate_trajectory_mse, trj, [], [], [])


if __name__ == "__main__":
unittest.main()
60 changes: 60 additions & 0 deletions tests/test_work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,66 @@ def test_get_unique_obstimes_and_indices(self):
assert len(indices) == 4
assert indices[3] == [3, 4]

def test_get_pixel_coordinates_global(self):
simple_wcs = make_fake_wcs(200.5, -7.5, 500, 700, 0.01)
work = WorkUnit(
im_stack=self.im_stack,
config=self.config,
wcs=simple_wcs,
)

# Compute the pixel locations of the SkyCoords.
ra = np.array([200.5, 200.55, 200.6])
dec = np.array([-7.5, -7.55, -7.60])
expected_x = np.array([249, 254, 259])
expected_y = np.array([349, 344, 339])

x_pos, y_pos = work.get_pixel_coordinates(ra, dec)
np.testing.assert_allclose(x_pos, expected_x, atol=0.2)
np.testing.assert_allclose(y_pos, expected_y, atol=0.2)

# We see an error if the arrays are the wrong length.
self.assertRaises(ValueError, work.get_pixel_coordinates, ra, np.array([-7.7888, -7.79015]))

def test_get_pixel_coordinates_per_image(self):
per_wcs = [make_fake_wcs(200.5 + 0.5 * i, -7.5, 500, 700, 0.01) for i in range(self.num_images)]
obstimes = [float(i) for i in range(self.num_images)]
work = WorkUnit(
im_stack=self.im_stack,
config=self.config,
per_image_wcs=per_wcs,
obstimes=obstimes,
)

# Compute the pixel locations of the SkyCoords.
ra = np.array([200.5 + 0.5 * i for i in range(self.num_images)])
dec = np.array([-7.5 + 0.05 * i for i in range(self.num_images)])

expected_x = np.full(self.num_images, 249)
expected_y = np.array([349 + 5 * i for i in range(self.num_images)])

x_pos, y_pos = work.get_pixel_coordinates(ra, dec)
np.testing.assert_allclose(x_pos, expected_x, atol=0.2)
np.testing.assert_allclose(y_pos, expected_y, atol=0.2)

# Test that we can query only a subset of the images.
x_pos, y_pos = work.get_pixel_coordinates(
np.array([201.0, 202.0]), # RA
np.array([-7.45, -7.35]), # dec
np.array([1.0, 3.0]), # time
)
np.testing.assert_allclose(x_pos, [249, 249], atol=0.2)
np.testing.assert_allclose(y_pos, [354, 364], atol=0.2)

# We see an error if a time is nowhere near any of the images.
self.assertRaises(
ValueError,
work.get_pixel_coordinates,
np.array([201.0, 202.0]), # RA
np.array([-7.45, -7.35]), # dec
np.array([1.0, 300.0]), # time
)


if __name__ == "__main__":
unittest.main()
Loading