diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 68c54b917..b7f2cd65e 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -187,7 +187,7 @@ def do_gpu_search(self, config, stack, trj_generator): keep = self.load_and_filter_results(search, config) return keep - def run_search(self, config, stack, trj_generator=None, computed_ecliptic=None): + def run_search(self, config, stack, trj_generator=None): """This function serves as the highest-level python interface for starting a KBMOD search given an ImageStack and SearchConfiguration. @@ -199,11 +199,7 @@ def run_search(self, config, stack, trj_generator=None, computed_ecliptic=None): The stack before the masks have been applied. Modified in-place. trj_generator : `TrajectoryGenerator`, optional The object to generate the candidate trajectories for each pixel. - If None uses the default KBMODv1 grid search - computed_ecliptic : `float`, optional - The computed ecliptic angle in the data from a WCS (if present). - Uses ``None`` if the information needed to compute the angle is not - available. + If None uses the default EclipticCenteredSearch Returns ------- @@ -222,7 +218,7 @@ def run_search(self, config, stack, trj_generator=None, computed_ecliptic=None): # Perform the actual search. if trj_generator is None: - trj_generator = create_trajectory_generator(config, computed_ecliptic_angle=computed_ecliptic) + trj_generator = create_trajectory_generator(config, work_unit=None) keep = self.do_gpu_search(config, stack, trj_generator) if config["do_stamp_filter"]: @@ -291,8 +287,7 @@ def run_search_from_work_unit(self, work): keep : `Results` The results. """ - # If there is a WCS compute the ecliptic angle from it. - computed_ecliptic = work.compute_ecliptic_angle() + trj_generator = create_trajectory_generator(work.config, work_unit=work) # Run the search. return self.run_search(work.config, work.im_stack) diff --git a/src/kbmod/trajectory_generator.py b/src/kbmod/trajectory_generator.py index 388198f57..c949b2e92 100644 --- a/src/kbmod/trajectory_generator.py +++ b/src/kbmod/trajectory_generator.py @@ -10,7 +10,7 @@ from kbmod.search import Trajectory -def create_trajectory_generator(config, **kwargs): +def create_trajectory_generator(config, work_unit=None, **kwargs): """Create a TrajectoryGenerator object given a dictionary of configuration options. The generator class is specified by the 'name' entry, which must exist and match the class name of one @@ -20,6 +20,9 @@ def create_trajectory_generator(config, **kwargs): ---------- config : `dict` or `SearchConfiguration` The dictionary of generator parameters. + work_unit : `WorkUnit`, optional + A WorkUnit to provide additional information about the data that + can be used to derive parameters that depend on the input. Returns ------- @@ -51,7 +54,7 @@ def create_trajectory_generator(config, **kwargs): params.update(kwargs) logger.debug(str(params)) - return TrajectoryGenerator.generators[name](**params) + return TrajectoryGenerator.generators[name](**params, work_unit=work_unit) class TrajectoryGenerator(abc.ABC): @@ -65,7 +68,7 @@ class TrajectoryGenerator(abc.ABC): generators = {} # A mapping of class name to class object for subclasses. - def __init__(self, **kwargs): + def __init__(self, work_unit=None, **kwargs): pass def __init_subclass__(cls, **kwargs): @@ -301,7 +304,7 @@ def generate(self, *args, **kwargs): class KBMODV1SearchConfig(KBMODV1Search): """Search a grid defined by velocities and angles in the format of the legacy configuration file.""" - def __init__(self, v_arr, ang_arr, average_angle=None, computed_ecliptic_angle=None, **kwargs): + def __init__(self, v_arr, ang_arr, average_angle=None, work_unit=None, **kwargs): """Create a class KBMODV1SearchConfig. Parameters @@ -314,20 +317,20 @@ def __init__(self, v_arr, ang_arr, average_angle=None, computed_ecliptic_angle=N (in radians), and the number of angles to try. average_angle : `float`, optional The central angle to search around. Should align with the ecliptic in most cases. - computed_ecliptic_angle : `float`, optional - An override for the computed ecliptic from a WCS (in the units defined in - ``angle_units``). This parameter is ignored if ``force_ecliptic`` is given. + work_unit : `WorkUnit`, optional + A WorkUnit to provide additional information about the data that + can be used to derive parameters that depend on the input. """ if len(v_arr) != 3: raise ValueError("KBMODV1SearchConfig requires v_arr to be length 3") if len(ang_arr) != 3: raise ValueError("KBMODV1SearchConfig requires ang_arr to be length 3") if average_angle is None: - if computed_ecliptic_angle is None: + if work_unit is None: raise ValueError( - "KBMODV1SearchConfig requires a valid average_angle or computed_ecliptic_angle." + "KBMODV1SearchConfig requires a valid average_angle or a WorkUnit with a WCS." ) - average_angle = computed_ecliptic_angle + average_angle = work_unit.compute_ecliptic_angle() ang_min = average_angle - ang_arr[0] ang_max = average_angle + ang_arr[1] @@ -359,7 +362,7 @@ def __init__( angles=[0.0, 0.0, 0], angle_units="radians", given_ecliptic=None, - computed_ecliptic_angle=None, + work_unit=None, **kwargs, ): """Create a class EclipticCenteredSearch. @@ -378,16 +381,23 @@ def __init__( given_ecliptic : `float`, optional An override for the ecliptic as given in the config (in the units defined in ``angle_units``). This angle takes precedence over ``computed_ecliptic``. - computed_ecliptic_angle : `float`, optional - An override for the computed ecliptic from a WCS (in the units defined in - ``angle_units``). This parameter is ignored if ``force_ecliptic`` is given. + work_unit : `WorkUnit`, optional + A WorkUnit to provide additional information about the data that + can be used to derive parameters that depend on the input. """ super().__init__(**kwargs) if given_ecliptic is not None: - ecliptic_angle = given_ecliptic - elif computed_ecliptic_angle is not None: - ecliptic_angle = computed_ecliptic_angle + if angle_units[:3] == "deg": + ecliptic_angle = given_ecliptic * (math.pi / 180.0) + elif angle_units[:3] == "rad": + ecliptic_angle = given_ecliptic + else: + raise ValueError(f"Unknown angular units {angle_units}") + elif work_unit is not None: + # compute_ecliptic_angle() always produces radians. + ecliptic_angle = work_unit.compute_ecliptic_angle() + print(f"Using WU = {ecliptic_angle}") else: logger = logging.getLogger(__name__) logger.warning("No ecliptic angle provided. Using 0.0.") @@ -405,10 +415,8 @@ def __init__( self.angles = angles self.ecliptic_angle = ecliptic_angle if angle_units[:3] == "deg": - deg_to_rad = math.pi / 180.0 - self.ecliptic_angle = deg_to_rad * self.ecliptic_angle - self.angles[0] = deg_to_rad * self.angles[0] - self.angles[1] = deg_to_rad * self.angles[1] + self.angles[0] = (math.pi / 180.0) * self.angles[0] + self.angles[1] = (math.pi / 180.0) * self.angles[1] elif angle_units[:3] != "rad": raise ValueError(f"Unknown angular units {angle_units}") diff --git a/tests/test_trajectory_generator.py b/tests/test_trajectory_generator.py index 34b16d6c7..b20e13815 100644 --- a/tests/test_trajectory_generator.py +++ b/tests/test_trajectory_generator.py @@ -1,7 +1,10 @@ import numpy as np import unittest +from astropy.wcs import WCS + from kbmod.configuration import SearchConfiguration +from kbmod.fake_data.fake_data_creator import FakeDataSet from kbmod.trajectory_generator import ( KBMODV1Search, KBMODV1SearchConfig, @@ -11,6 +14,7 @@ VelocityGridSearch, create_trajectory_generator, ) +from kbmod.work_unit import WorkUnit class test_trajectory_generator(unittest.TestCase): @@ -97,7 +101,7 @@ def test_EclipticCenteredSearch(self): def test_KBMODV1SearchConfig(self): # Note that KBMOD v1's search will never include the upper bound of angle or velocity. - gen = KBMODV1SearchConfig([0.0, 3.0, 3], [0.25, 0.25, 2], 0.0) + gen = KBMODV1SearchConfig([0.0, 3.0, 3], [0.25, 0.25, 2], average_angle=0.0) expected_x = [0.0, 0.9689, 1.9378, 0.0, 1.0, 2.0] expected_y = [0.0, -0.247, -0.4948, 0.0, 0.0, 0.0] @@ -154,22 +158,39 @@ def test_create_trajectory_generator(self): self.assertEqual(gen2.vx, 1) self.assertEqual(gen2.vy, 2) + # Create a fake work unit with one image and a WCS with a non-zero ecliptic angle. + fake_wcs = WCS(naxis=2) + fake_wcs.wcs.crpix = [0.0, 0.0] + fake_wcs.wcs.cdelt = np.array([-0.1, 0.1]) + fake_wcs.wcs.crval = [0, -90] + fake_wcs.wcs.ctype = ["RA---TAN-SIP", "DEC--TAN-SIP"] + fake_wcs.wcs.crota = np.array([0.0, 0.0]) + + fake_data = FakeDataSet(10, 10, [0.0]) + base_config = SearchConfiguration() + fake_work = WorkUnit(im_stack=fake_data.stack, config=base_config, wcs=fake_wcs) + fake_ecliptic = fake_work.compute_ecliptic_angle() + self.assertGreater(fake_ecliptic, 1.0) + # Test we can create a trajectory generator with optional keyword parameters. config3 = { "name": "EclipticCenteredSearch", "angles": [0.0, 45.0, 2], "velocities": [0.0, 1.0, 2], "angle_units": "degrees", - "force_ecliptic": None, + "given_ecliptic": None, } - gen3 = create_trajectory_generator(config3, computed_ecliptic_angle=45.0) + + # Without a given ecliptic, we use the WCS. + gen3 = create_trajectory_generator(config3, work_unit=fake_work) self.assertTrue(type(gen3) is EclipticCenteredSearch) - self.assertAlmostEqual(gen3.ecliptic_angle, np.pi / 4.0) - self.assertEqual(gen3.min_ang, np.pi / 4.0) - self.assertEqual(gen3.max_ang, np.pi / 2.0) + self.assertAlmostEqual(gen3.ecliptic_angle, fake_ecliptic) + self.assertAlmostEqual(gen3.min_ang, fake_ecliptic) + self.assertAlmostEqual(gen3.max_ang, fake_ecliptic + np.pi / 4.0) + # The given_ecliptic has priority over the fake WCS. config3["given_ecliptic"] = 0.0 - gen4 = create_trajectory_generator(config3, computed_ecliptic_angle=45.0) + gen4 = create_trajectory_generator(config3, work_unit=fake_work) self.assertAlmostEqual(gen4.ecliptic_angle, 0.0) self.assertEqual(gen4.min_ang, 0.0) self.assertEqual(gen4.max_ang, np.pi / 4.0)