diff --git a/src/kbmod/reprojection_utils.py b/src/kbmod/reprojection_utils.py index 0f23bfdf8..a1c79b856 100644 --- a/src/kbmod/reprojection_utils.py +++ b/src/kbmod/reprojection_utils.py @@ -28,8 +28,8 @@ def correct_parallax( use_bounds=False, ): """Calculate the parallax corrected postions for a given object at a given - time, observation location on Earth, and user defined distance from th - e Sun. + time, observation location on Earth, and user defined distance from the Sun. + By default, this function will use the geometric solution for objects beyond 1au. If the distance is less than 1au, the function will use the scipy minimizer to find the best geocentric distance. @@ -66,7 +66,6 @@ def correct_parallax( ICRS, and the best fit geocentric distance (float). """ - if use_minimizer or heliocentric_distance < 1.02: return correct_parallax_with_minimizer( coord, obstime, point_on_earth, heliocentric_distance, geocentric_distance, method, use_bounds diff --git a/src/kbmod/results.py b/src/kbmod/results.py index e35b17307..1931bf56d 100644 --- a/src/kbmod/results.py +++ b/src/kbmod/results.py @@ -12,6 +12,7 @@ from kbmod.trajectory_utils import trajectory_from_np_object from kbmod.search import Trajectory +from kbmod.wcs_utils import deserialize_wcs, serialize_wcs logger = logging.getLogger(__name__) @@ -28,6 +29,9 @@ class Results: ---------- table : `astropy.table.Table` The stored results data. + wcs : `astropy.wcs.WCS` + A global WCS for all the results. This is optional and primarily used when saving + the results to a file so as to preserve the WCS for future analysis. track_filtered : `bool` Whether to track (save) the filtered trajectories. This will use more memory and is recommended only for analysis. @@ -53,16 +57,21 @@ class Results: ] _required_col_names = set([rq_col[0] for rq_col in required_cols]) - def __init__(self, data=None, track_filtered=False): + def __init__(self, data=None, track_filtered=False, wcs=None): """Create a ResultTable class. Parameters ---------- data : `dict`, `astropy.table.Table` + The data for the results table. track_filtered : `bool` Whether to track (save) the filtered trajectories. This will use more memory and is recommended only for analysis. + wcs : `astropy.wcs.WCS`, optional + A gloabl WCS for the results. """ + self.wcs = wcs + # Set up information to track which row is filtered at which round. self.track_filtered = track_filtered self.filtered = {} @@ -179,7 +188,14 @@ def read_table(cls, filename, track_filtered=False): if not Path(filename).is_file(): raise FileNotFoundError(f"File {filename} not found.") data = Table.read(filename) - return Results(data, track_filtered=track_filtered) + + # Check if we have stored a global WCS. + if "wcs" in data.meta: + wcs = deserialize_wcs(data.meta["wcs"]) + else: + wcs = None + + return Results(data, track_filtered=track_filtered, wcs=wcs) def remove_column(self, colname): """Remove a column from the results table. @@ -600,7 +616,7 @@ def revert_filter(self, label=None, add_column=None): return self - def write_table(self, filename, overwrite=True, cols_to_drop=[]): + def write_table(self, filename, overwrite=True, cols_to_drop=()): """Write the unfiltered results to a single (ecsv) file. Parameter @@ -609,26 +625,28 @@ def write_table(self, filename, overwrite=True, cols_to_drop=[]): The name of the result file. overwrite : `bool` Overwrite the file if it already exists. [default: True] - cols_to_drop : `list` - A list of columns to drop (to save space). [default: []] + cols_to_drop : `tuple` + A tuple of columns to drop (to save space). [default: ()] """ logger.info(f"Saving results to {filename}") - if len(cols_to_drop) > 0: - # Make a copy so we can modify the table - write_table = self.table.copy() + # Make a copy so we can modify the table + write_table = self.table.copy() - for col in cols_to_drop: - if col in write_table.colnames: - if col in self._required_col_names: - logger.debug(f"Unable to drop required column {col} for write.") - else: - write_table.remove_column(col) + # Drop the columns we need to drop. + for col in cols_to_drop: + if col in write_table.colnames: + if col in self._required_col_names: + logger.debug(f"Unable to drop required column {col} for write.") + else: + write_table.remove_column(col) - # Write out the table. - write_table.write(filename, overwrite=overwrite) - else: - self.table.write(filename, overwrite=overwrite) + # Add global meta data that we can retrieve. + if self.wcs is not None: + write_table.meta["wcs"] = serialize_wcs(self.wcs) + + # Write out the table. + write_table.write(filename, overwrite=overwrite) def write_trajectory_file(self, filename, overwrite=True): """Save the trajectories to a numpy file. @@ -693,6 +711,8 @@ def write_column(self, colname, filename): logger.info(f"Writing {colname} column data to {filename}") if colname not in self.table.colnames: raise KeyError(f"Column {colname} missing from data.") + + # Save the column. data = np.array(self.table[colname]) np.save(filename, data, allow_pickle=False) @@ -722,6 +742,7 @@ def load_column(self, filename, colname): raise ValueError( f"Error loading {filename}: expected {len(self.table)} entries, but found {len(data)}." ) + self.table[colname] = data def write_filtered_stats(self, filename): diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index 8c243d637..1ffa40a60 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -12,6 +12,7 @@ from .results import Results from .trajectory_generator import create_trajectory_generator +from .wcs_utils import wcs_to_dict from .work_unit import WorkUnit @@ -186,7 +187,7 @@ def do_gpu_search(self, config, stack, trj_generator): keep = self.load_and_filter_results(search, config) return keep - def run_search(self, config, stack, trj_generator=None): + def run_search(self, config, stack, trj_generator=None, wcs=None): """This function serves as the highest-level python interface for starting a KBMOD search given an ImageStack and SearchConfiguration. @@ -199,6 +200,8 @@ def run_search(self, config, stack, trj_generator=None): trj_generator : `TrajectoryGenerator`, optional The object to generate the candidate trajectories for each pixel. If None uses the default EclipticCenteredSearch + wcs : `astropy.wcs.WCS`, optional + A global WCS for all images in the search. Returns ------- @@ -247,6 +250,9 @@ def run_search(self, config, stack, trj_generator=None): if config["save_all_stamps"]: append_all_stamps(keep, stack, config["stamp_radius"]) + # Append the WCS information if it is provided. This will be saved with the results. + keep.table.wcs = wcs + logger.info(f"Found {len(keep)} potential trajectories.") # Save the results in as an ecsv file and/or a legacy text file. @@ -280,4 +286,4 @@ def run_search_from_work_unit(self, work): trj_generator = create_trajectory_generator(work.config, work_unit=work) # Run the search. - return self.run_search(work.config, work.im_stack, trj_generator=trj_generator) + return self.run_search(work.config, work.im_stack, trj_generator=trj_generator, wcs=work.wcs) diff --git a/src/kbmod/wcs_utils.py b/src/kbmod/wcs_utils.py index 60656e16f..890e6aaae 100644 --- a/src/kbmod/wcs_utils.py +++ b/src/kbmod/wcs_utils.py @@ -3,6 +3,7 @@ import astropy.coordinates import astropy.units import astropy.wcs +import json import numpy @@ -436,14 +437,52 @@ def wcs_to_dict(wcs): result : `dict` A dictionary containing the WCS header information. """ - result = {} - if wcs is not None: - wcs_header = wcs.to_header() - for key in wcs_header: - result[key] = wcs_header[key] + result = dict(wcs.to_header(relax=True)) + if wcs.pixel_shape is not None: + header["NAXIS1"], header["NAXIS2"] = wcs.pixel_shape + elif wcs.array_shape is not None: + header["NAXIS2"], header["NAXIS1"] = wcs.array_shape return result +def serialize_wcs(wcs): + """Convert a WCS into a JSON string. + + Parameters + ---------- + wcs : `astropy.wcs.WCS` + The WCS to convert. + + Returns + ------- + wcs_str : `str` + The serialized WCS. + """ + # Since AstroPy's WCS does not output NAXIS, we need to manually add those. + header = wcs.to_header(relax=True) + header["NAXIS1"], header["NAXIS2"] = wcs.pixel_shape + return json.dumps(dict(header)) + + +def deserialize_wcs(wcs_str): + """Convert a JSON string into a WCS object. + + Parameters + ---------- + wcs_str : `str` + The serialized WCS. + + Returns + ------- + wcs : `astropy.wcs.WCS` + The resulting WCS. + """ + wcs_dict = json.loads(wcs_str) + wcs = astropy.wcs.WCS(wcs_dict) + wcs.pixel_shape = (wcs_dict["NAXIS1"], wcs_dict["NAXIS2"]) + return wcs + + def make_fake_wcs_info(center_ra, center_dec, height, width, deg_per_pixel=None): """Create a fake WCS dictionary given basic information. This is not a realistic WCS in terms of astronomy, but can provide a place holder for many tests. @@ -548,4 +587,10 @@ def wcs_fits_equal(wcs_a, wcs_b): if header_a[key] != header_b[key]: return False + # Check that we correctly kept the shape of the matrix. + if wcs_a.array_shape != wcs_b.array_shape: + return False + if wcs_a.pixel_shape != wcs_b.pixel_shape: + return False + return True diff --git a/tests/test_results.py b/tests/test_results.py index c66948a9a..7af6cf712 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -10,6 +10,7 @@ from kbmod.results import Results from kbmod.search import Trajectory +from kbmod.wcs_utils import make_fake_wcs, wcs_fits_equal class test_results(unittest.TestCase): @@ -388,6 +389,10 @@ def test_to_from_table_file(self): table.table["other"] = [i for i in range(max_save)] self.assertEqual(len(table), max_save) + # Create a fake WCS to use for serialization tests. + fake_wcs = make_fake_wcs(25.0, -7.5, 800, 600, deg_per_pixel=0.01) + table.wcs = fake_wcs + # Test read/write to file. with tempfile.TemporaryDirectory() as dir_name: file_path = os.path.join(dir_name, "results.ecsv") @@ -402,6 +407,11 @@ def test_to_from_table_file(self): for col in ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count", "other"]: self.assertTrue(np.allclose(table[col], table2[col])) + # Check that we reloaded the WCS's, including the correct shape. + self.assertIsNotNone(table2.wcs) + self.assertTrue(wcs_fits_equal(table2.wcs, fake_wcs)) + self.assertEqual(table2.wcs.pixel_shape, fake_wcs.pixel_shape) + # Cannot overwrite with it set to False with self.assertRaises(OSError): table.write_table(file_path, overwrite=False, cols_to_drop=["other"]) diff --git a/tests/test_wcs_utils.py b/tests/test_wcs_utils.py index 224920023..043d4e596 100644 --- a/tests/test_wcs_utils.py +++ b/tests/test_wcs_utils.py @@ -57,6 +57,16 @@ def test_wcs_to_dict(self): self.assertTrue(key in new_dict) self.assertAlmostEqual(new_dict[key], self.header_dict[key]) + def test_serialization(self): + self.wcs.pixel_shape = (200, 250) + wcs_str = serialize_wcs(self.wcs) + self.assertTrue(isinstance(wcs_str, str)) + + wcs2 = deserialize_wcs(wcs_str) + self.assertTrue(isinstance(wcs2, WCS)) + self.assertEqual(self.wcs.pixel_shape, wcs2.pixel_shape) + self.assertTrue(wcs_fits_equal(self.wcs, wcs2)) + def test_append_wcs_to_hdu_header(self): for use_dictionary in [True, False]: if use_dictionary: