Skip to content

Commit

Permalink
TrajectoryGenerators take a WorkUnit for extra information
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Sep 17, 2024
1 parent e2d523e commit c87c1ef
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
13 changes: 4 additions & 9 deletions src/kbmod/run_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def do_gpu_search(self, config, stack, trj_generator):
keep = self.load_and_filter_results(search, config)
return keep

def run_search(self, config, stack, trj_generator=None, computed_ecliptic=None):
def run_search(self, config, stack, trj_generator=None):
"""This function serves as the highest-level python interface for starting
a KBMOD search given an ImageStack and SearchConfiguration.
Expand All @@ -199,11 +199,7 @@ def run_search(self, config, stack, trj_generator=None, computed_ecliptic=None):
The stack before the masks have been applied. Modified in-place.
trj_generator : `TrajectoryGenerator`, optional
The object to generate the candidate trajectories for each pixel.
If None uses the default KBMODv1 grid search
computed_ecliptic : `float`, optional
The computed ecliptic angle in the data from a WCS (if present).
Uses ``None`` if the information needed to compute the angle is not
available.
If None uses the default EclipticCenteredSearch
Returns
-------
Expand All @@ -222,7 +218,7 @@ def run_search(self, config, stack, trj_generator=None, computed_ecliptic=None):

# Perform the actual search.
if trj_generator is None:
trj_generator = create_trajectory_generator(config, computed_ecliptic_angle=computed_ecliptic)
trj_generator = create_trajectory_generator(config, work_unit=None)
keep = self.do_gpu_search(config, stack, trj_generator)

if config["do_stamp_filter"]:
Expand Down Expand Up @@ -291,8 +287,7 @@ def run_search_from_work_unit(self, work):
keep : `Results`
The results.
"""
# If there is a WCS compute the ecliptic angle from it.
computed_ecliptic = work.compute_ecliptic_angle()
trj_generator = create_trajectory_generator(work.config, work_unit=work)

# Run the search.
return self.run_search(work.config, work.im_stack)
50 changes: 29 additions & 21 deletions src/kbmod/trajectory_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from kbmod.search import Trajectory


def create_trajectory_generator(config, **kwargs):
def create_trajectory_generator(config, work_unit=None, **kwargs):
"""Create a TrajectoryGenerator object given a dictionary
of configuration options. The generator class is specified by
the 'name' entry, which must exist and match the class name of one
Expand All @@ -20,6 +20,9 @@ def create_trajectory_generator(config, **kwargs):
----------
config : `dict` or `SearchConfiguration`
The dictionary of generator parameters.
work_unit : `WorkUnit`, optional
A WorkUnit to provide additional information about the data that
can be used to derive parameters that depend on the input.
Returns
-------
Expand Down Expand Up @@ -51,7 +54,7 @@ def create_trajectory_generator(config, **kwargs):
params.update(kwargs)
logger.debug(str(params))

return TrajectoryGenerator.generators[name](**params)
return TrajectoryGenerator.generators[name](**params, work_unit=work_unit)


class TrajectoryGenerator(abc.ABC):
Expand All @@ -65,7 +68,7 @@ class TrajectoryGenerator(abc.ABC):

generators = {} # A mapping of class name to class object for subclasses.

def __init__(self, **kwargs):
def __init__(self, work_unit=None, **kwargs):
pass

def __init_subclass__(cls, **kwargs):
Expand Down Expand Up @@ -301,7 +304,7 @@ def generate(self, *args, **kwargs):
class KBMODV1SearchConfig(KBMODV1Search):
"""Search a grid defined by velocities and angles in the format of the legacy configuration file."""

def __init__(self, v_arr, ang_arr, average_angle=None, computed_ecliptic_angle=None, **kwargs):
def __init__(self, v_arr, ang_arr, average_angle=None, work_unit=None, **kwargs):
"""Create a class KBMODV1SearchConfig.
Parameters
Expand All @@ -314,20 +317,20 @@ def __init__(self, v_arr, ang_arr, average_angle=None, computed_ecliptic_angle=N
(in radians), and the number of angles to try.
average_angle : `float`, optional
The central angle to search around. Should align with the ecliptic in most cases.
computed_ecliptic_angle : `float`, optional
An override for the computed ecliptic from a WCS (in the units defined in
``angle_units``). This parameter is ignored if ``force_ecliptic`` is given.
work_unit : `WorkUnit`, optional
A WorkUnit to provide additional information about the data that
can be used to derive parameters that depend on the input.
"""
if len(v_arr) != 3:
raise ValueError("KBMODV1SearchConfig requires v_arr to be length 3")
if len(ang_arr) != 3:
raise ValueError("KBMODV1SearchConfig requires ang_arr to be length 3")
if average_angle is None:
if computed_ecliptic_angle is None:
if work_unit is None:
raise ValueError(
"KBMODV1SearchConfig requires a valid average_angle or computed_ecliptic_angle."
"KBMODV1SearchConfig requires a valid average_angle or a WorkUnit with a WCS."
)
average_angle = computed_ecliptic_angle
average_angle = work_unit.compute_ecliptic_angle()

ang_min = average_angle - ang_arr[0]
ang_max = average_angle + ang_arr[1]
Expand Down Expand Up @@ -359,7 +362,7 @@ def __init__(
angles=[0.0, 0.0, 0],
angle_units="radians",
given_ecliptic=None,
computed_ecliptic_angle=None,
work_unit=None,
**kwargs,
):
"""Create a class EclipticCenteredSearch.
Expand All @@ -378,16 +381,23 @@ def __init__(
given_ecliptic : `float`, optional
An override for the ecliptic as given in the config (in the units defined in
``angle_units``). This angle takes precedence over ``computed_ecliptic``.
computed_ecliptic_angle : `float`, optional
An override for the computed ecliptic from a WCS (in the units defined in
``angle_units``). This parameter is ignored if ``force_ecliptic`` is given.
work_unit : `WorkUnit`, optional
A WorkUnit to provide additional information about the data that
can be used to derive parameters that depend on the input.
"""
super().__init__(**kwargs)

if given_ecliptic is not None:
ecliptic_angle = given_ecliptic
elif computed_ecliptic_angle is not None:
ecliptic_angle = computed_ecliptic_angle
if angle_units[:3] == "deg":
ecliptic_angle = given_ecliptic * (math.pi / 180.0)
elif angle_units[:3] == "rad":
ecliptic_angle = given_ecliptic
else:
raise ValueError(f"Unknown angular units {angle_units}")
elif work_unit is not None:
# compute_ecliptic_angle() always produces radians.
ecliptic_angle = work_unit.compute_ecliptic_angle()
print(f"Using WU = {ecliptic_angle}")
else:
logger = logging.getLogger(__name__)
logger.warning("No ecliptic angle provided. Using 0.0.")
Expand All @@ -405,10 +415,8 @@ def __init__(
self.angles = angles
self.ecliptic_angle = ecliptic_angle
if angle_units[:3] == "deg":
deg_to_rad = math.pi / 180.0
self.ecliptic_angle = deg_to_rad * self.ecliptic_angle
self.angles[0] = deg_to_rad * self.angles[0]
self.angles[1] = deg_to_rad * self.angles[1]
self.angles[0] = (math.pi / 180.0) * self.angles[0]
self.angles[1] = (math.pi / 180.0) * self.angles[1]
elif angle_units[:3] != "rad":
raise ValueError(f"Unknown angular units {angle_units}")

Expand Down
35 changes: 28 additions & 7 deletions tests/test_trajectory_generator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import numpy as np
import unittest

from astropy.wcs import WCS

from kbmod.configuration import SearchConfiguration
from kbmod.fake_data.fake_data_creator import FakeDataSet
from kbmod.trajectory_generator import (
KBMODV1Search,
KBMODV1SearchConfig,
Expand All @@ -11,6 +14,7 @@
VelocityGridSearch,
create_trajectory_generator,
)
from kbmod.work_unit import WorkUnit


class test_trajectory_generator(unittest.TestCase):
Expand Down Expand Up @@ -97,7 +101,7 @@ def test_EclipticCenteredSearch(self):

def test_KBMODV1SearchConfig(self):
# Note that KBMOD v1's search will never include the upper bound of angle or velocity.
gen = KBMODV1SearchConfig([0.0, 3.0, 3], [0.25, 0.25, 2], 0.0)
gen = KBMODV1SearchConfig([0.0, 3.0, 3], [0.25, 0.25, 2], average_angle=0.0)
expected_x = [0.0, 0.9689, 1.9378, 0.0, 1.0, 2.0]
expected_y = [0.0, -0.247, -0.4948, 0.0, 0.0, 0.0]

Expand Down Expand Up @@ -154,22 +158,39 @@ def test_create_trajectory_generator(self):
self.assertEqual(gen2.vx, 1)
self.assertEqual(gen2.vy, 2)

# Create a fake work unit with one image and a WCS with a non-zero ecliptic angle.
fake_wcs = WCS(naxis=2)
fake_wcs.wcs.crpix = [0.0, 0.0]
fake_wcs.wcs.cdelt = np.array([-0.1, 0.1])
fake_wcs.wcs.crval = [0, -90]
fake_wcs.wcs.ctype = ["RA---TAN-SIP", "DEC--TAN-SIP"]
fake_wcs.wcs.crota = np.array([0.0, 0.0])

fake_data = FakeDataSet(10, 10, [0.0])
base_config = SearchConfiguration()
fake_work = WorkUnit(im_stack=fake_data.stack, config=base_config, wcs=fake_wcs)
fake_ecliptic = fake_work.compute_ecliptic_angle()
self.assertGreater(fake_ecliptic, 1.0)

# Test we can create a trajectory generator with optional keyword parameters.
config3 = {
"name": "EclipticCenteredSearch",
"angles": [0.0, 45.0, 2],
"velocities": [0.0, 1.0, 2],
"angle_units": "degrees",
"force_ecliptic": None,
"given_ecliptic": None,
}
gen3 = create_trajectory_generator(config3, computed_ecliptic_angle=45.0)

# Without a given ecliptic, we use the WCS.
gen3 = create_trajectory_generator(config3, work_unit=fake_work)
self.assertTrue(type(gen3) is EclipticCenteredSearch)
self.assertAlmostEqual(gen3.ecliptic_angle, np.pi / 4.0)
self.assertEqual(gen3.min_ang, np.pi / 4.0)
self.assertEqual(gen3.max_ang, np.pi / 2.0)
self.assertAlmostEqual(gen3.ecliptic_angle, fake_ecliptic)
self.assertAlmostEqual(gen3.min_ang, fake_ecliptic)
self.assertAlmostEqual(gen3.max_ang, fake_ecliptic + np.pi / 4.0)

# The given_ecliptic has priority over the fake WCS.
config3["given_ecliptic"] = 0.0
gen4 = create_trajectory_generator(config3, computed_ecliptic_angle=45.0)
gen4 = create_trajectory_generator(config3, work_unit=fake_work)
self.assertAlmostEqual(gen4.ecliptic_angle, 0.0)
self.assertEqual(gen4.min_ang, 0.0)
self.assertEqual(gen4.max_ang, np.pi / 4.0)
Expand Down

0 comments on commit c87c1ef

Please sign in to comment.