diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 6d7facc33..2c2fc68c3 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -94,9 +94,8 @@ def reproject_work_unit( The WCS to reproject all the images into. frame : `str` The WCS frame of reference to use when reprojecting. - Can either be 'original' or 'ebd' to specify whether to - use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs - respectively. + Can either be 'original' or 'ebd' to specify which WCS to access + from the WorkUnit. parallelize : `bool` If True, use multiprocessing to reproject the images in parallel. Default is True. @@ -175,9 +174,8 @@ def _reproject_work_unit( The WCS to reproject all the images into. frame : `str` The WCS frame of reference to use when reprojecting. - Can either be 'original' or 'ebd' to specify whether to - use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs - respectively. + Can either be 'original' or 'ebd' to specify which WCS to access + from the WorkUnit. write_output : `bool` Whether or not to write the reprojection results out as a sharded `WorkUnit`. directory : `str` @@ -195,6 +193,14 @@ def _reproject_work_unit( images = work_unit.im_stack.get_images() unique_obstimes, unique_obstime_indices = work_unit.get_unique_obstimes_and_indices() + # Create a list of the correct WCS. We do this extraction once and reuse for all images. + if frame == "original": + wcs_list = work_unit.get_constituent_meta("original_wcs") + elif frame == "ebd": + wcs_list = work_unit.get_constituent_meta("ebd_wcs") + else: + raise ValueError("Invalid projection frame provided.") + stack = ImageStack() for obstime_index, o_i in tqdm( enumerate(zip(unique_obstimes, unique_obstime_indices)), @@ -214,13 +220,7 @@ def _reproject_work_unit( variance = image.get_variance() mask = image.get_mask() - if frame == "original": - original_wcs = work_unit.get_wcs(index) - elif frame == "ebd": - original_wcs = work_unit._per_image_ebd_wcs[index] - else: - raise ValueError("Invalid projection frame provided.") - + original_wcs = wcs_list[index] if original_wcs is None: raise ValueError(f"No WCS provided for index {index}") @@ -295,11 +295,11 @@ def _reproject_work_unit( im_stack=stack, config=work_unit.config, wcs=common_wcs, - constituent_images=work_unit.constituent_images, + constituent_images=work_unit.get_constituent_meta("data_loc"), per_image_wcs=work_unit._per_image_wcs, - per_image_ebd_wcs=work_unit._per_image_ebd_wcs, + per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), per_image_indices=unique_obstime_indices, - geocentric_distances=work_unit.geocentric_distances, + geocentric_distances=work_unit.get_constituent_meta("geocentric_distance"), reprojected=True, ) @@ -328,9 +328,8 @@ def _reproject_work_unit_in_parallel( The WCS to reproject all the images into. frame : `str` The WCS frame of reference to use when reprojecting. - Can either be 'original' or 'ebd' to specify whether to - use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs - respectively. + Can either be 'original' or 'ebd' to specify which WCS to access + from the WorkUnit. max_parallel_processes : `int` The maximum number of parallel processes to use when reprojecting. Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in @@ -452,11 +451,11 @@ def _reproject_work_unit_in_parallel( im_stack=stack, config=work_unit.config, wcs=common_wcs, - constituent_images=work_unit.constituent_images, + constituent_images=work_unit.get_constituent_meta("data_loc"), per_image_wcs=work_unit._per_image_wcs, - per_image_ebd_wcs=work_unit._per_image_ebd_wcs, + per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), per_image_indices=unique_obstimes_indices, - geocentric_distances=work_unit.geocentric_distances, + geocentric_distances=work_unit.get_constituent_meta("geocentric_distances"), reprojected=True, ) @@ -492,9 +491,8 @@ def reproject_lazy_work_unit( shards). frame : `str` The WCS frame of reference to use when reprojecting. - Can either be 'original' or 'ebd' to specify whether to - use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs - respectively. + Can either be 'original' or 'ebd' to specify which WCS to access + from the WorkUnit. max_parallel_processes : `int` The maximum number of parallel processes to use when reprojecting. Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in @@ -572,9 +570,8 @@ def _validate_original_wcs(work_unit, indices, frame="original"): The indices to be validated in work_unit. frame : `str` The WCS frame of reference to use when reprojecting. - Can either be 'original' or 'ebd' to specify whether to - use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs - respectively. + Can either be 'original' or 'ebd' to specify which WCS to access + from the WorkUnit. Returns ------- @@ -590,7 +587,7 @@ def _validate_original_wcs(work_unit, indices, frame="original"): if frame == "original": original_wcs = [work_unit.get_wcs(i) for i in indices] elif frame == "ebd": - original_wcs = [work_unit._per_image_ebd_wcs[i] for i in indices] + original_wcs = [work_unit.get_constituent_meta("ebd_wcs")[i] for i in indices] else: raise ValueError("Invalid projection frame provided.") diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index 251a51ac1..f0320624c 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -4,6 +4,7 @@ from astropy.coordinates import SkyCoord, EarthLocation from astropy.io import fits +from astropy.table import Table from astropy.time import Time from astropy.utils.exceptions import AstropyWarning from astropy.wcs.utils import skycoord_to_pixel @@ -22,8 +23,6 @@ calc_ecliptic_angle, extract_wcs_from_hdu_header, wcs_fits_equal, - wcs_from_dict, - wcs_to_dict, ) @@ -40,6 +39,45 @@ class WorkUnit: Attributes ---------- + im_stack : `kbmod.search.ImageStack` + The image data for the KBMOD run. + config : `kbmod.configuration.SearchConfiguration` + The configuration for the KBMOD run. + n_constituents : `int` + The number of original images making up the data in this WorkUnit. This might be + different from the number of images stored in memory if the WorkUnit has been + reprojected. + org_img_meta : `astropy.table.Table` + The meta data for each constituent image. Includes columns: + * data_loc - the original location of the image + * ebd_wcs - Used to reproject the images into EBD space. + * geocentric_distance - The best fit geocentric distances used when creating + the per image EBD WCS. + * original_wcs - The original WCS of the image. + wcs : `astropy.wcs.WCS` + A global WCS for all images in the WorkUnit. Only exists + if all images have been projected to same pixel space. + per_image_wcs : `list` + A list with one WCS for each image in the WorkUnit. Used for when + the images have *not* been standardized to the same pixel space. + heliocentric_distance : `float` + The heliocentric distance that was used when creating the `per_image_ebd_wcs`. + reprojected : `bool` + Whether or not the WorkUnit image data has been reprojected. + per_image_indices : `list` of `list` + A list of lists containing the indicies of `constituent_images` at each layer + of the `ImageStack`. Used for finding corresponding original images when we + stitch images together during reprojection. + lazy : `bool` + Whether or not to load the image data for the `WorkUnit`. + file_paths : `list[str]` + The paths for the shard files, only created if the `WorkUnit` is loaded + in lazy mode. + obstimes : `list[float]` + The MJD obstimes of the images. + + Parameters + ---------- im_stack : `kbmod.search.ImageStack` The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` @@ -98,23 +136,29 @@ def __init__( self.file_paths = file_paths self._obstimes = obstimes - # Handle WCS input. If both the global and per-image WCS are provided, - # ensure they are consistent. - self.wcs = wcs + # Determine the number of constituent images. If we are given a list of constituent_images, + # use that. Otherwise use the size of the image stack. if constituent_images is None: - n_constituents = im_stack.img_count() - self.constituent_images = [None] * n_constituents + self.n_constituents = im_stack.img_count() else: - n_constituents = len(constituent_images) - self.constituent_images = constituent_images + self.n_constituents = len(constituent_images) + # Track the meta data for each constituent image in the WorkUnit. For the original + # WCS, we track the per-image WCS if it is provided and otherwise the global WCS. + self.org_img_meta = Table() + self.add_org_img_meta_data("data_loc", constituent_images) + self.add_org_img_meta_data("original_wcs", per_image_wcs, default=wcs) + + # Handle WCS input. If both the global and per-image WCS are provided, + # ensure they are consistent. + self.wcs = wcs if per_image_wcs is None: - self._per_image_wcs = [None] * n_constituents + self._per_image_wcs = [None] * self.n_constituents if self.wcs is None and per_image_ebd_wcs is None: warnings.warn("No WCS provided.", Warning) else: - if len(per_image_wcs) != n_constituents: - raise ValueError(f"Incorrect number of WCS provided. Expected {n_constituents}") + if len(per_image_wcs) != self.n_constituents: + raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_constituents}") self._per_image_wcs = per_image_wcs # Check if all the per-image WCS are None. This can happen during a load. @@ -127,24 +171,18 @@ def __init__( self.wcs = self._per_image_wcs[0] self._per_image_wcs = [None] * im_stack.img_count() - # TODO: Refactor all of this code to make it cleaner - - if per_image_ebd_wcs is None: - self._per_image_ebd_wcs = [None] * n_constituents - else: - if len(per_image_ebd_wcs) != n_constituents: - raise ValueError(f"Incorrect number of EBD WCS provided. Expected {n_constituents}") - self._per_image_ebd_wcs = per_image_ebd_wcs - - if geocentric_distances is None: - self.geocentric_distances = [None] * n_constituents - else: - self.geocentric_distances = geocentric_distances - - self.heliocentric_distance = heliocentric_distance + # Add the meta data needed for reprojection, including: the reprojected WCS, the geocentric + # distances, and each images indices in the original constituent images. self.reprojected = reprojected + self.heliocentric_distance = heliocentric_distance + self.add_org_img_meta_data("geocentric_distance", geocentric_distances) + self.add_org_img_meta_data("ebd_wcs", per_image_ebd_wcs) + + # If we have mosaicked images, each image in the stack could link back + # to more than one constituents image. Build a mapping of image stack index + # to needed original image indices. if per_image_indices is None: - self._per_image_indices = [[i] for i in range(len(self.constituent_images))] + self._per_image_indices = [[i] for i in range(self.n_constituents)] else: self._per_image_indices = per_image_indices @@ -152,6 +190,33 @@ def __len__(self): """Returns the size of the WorkUnit in number of images.""" return self.im_stack.img_count() + def get_constituent_meta(self, column): + """Get the meta data values of a given column for all the constituent images.""" + return list(self.org_img_meta[column].data) + + def add_org_img_meta_data(self, column, data, default=None): + """Add a column of meta data for the constituent images. Adds a column of all + default values if data is None and the column does not already exist. + + Parameters + ---------- + column : `str` + The name of the meta data column. + data : list-like + The data for each constituent. If None then uses None for + each column. + """ + if data is None: + if column not in self.org_img_meta.colnames: + self.org_img_meta[column] = [default] * self.n_constituents + elif len(data) == self.n_constituents: + self.org_img_meta[column] = data + else: + raise ValueError( + f"Data mismatch size for WorkUnit metadata {column}. " + f"Expected {self.n_constituents} but found {len(data)}." + ) + def has_common_wcs(self): """Returns whether the WorkUnit has a common WCS for all images.""" return self.wcs is not None @@ -686,17 +751,16 @@ def metadata_to_primary_header(self, include_wcs=True): hdul = fits.HDUList() pri = fits.PrimaryHDU() pri.header["NUMIMG"] = self.get_num_images() + pri.header["NCON"] = self.n_constituents pri.header["REPRJCTD"] = self.reprojected pri.header["HELIO"] = self.heliocentric_distance - for i in range(len(self.constituent_images)): - pri.header[f"GEO_{i}"] = self.geocentric_distances[i] + for i in range(self.n_constituents): + pri.header[f"GEO_{i}"] = self.org_img_meta["geocentric_distance"][i] # If the global WCS exists, append the corresponding keys. if self.wcs is not None: append_wcs_to_hdu_header(self.wcs, pri.header) - pri.header["NCON"] = len(self.constituent_images) - hdul.append(pri) meta_hdu = fits.BinTableHDU() @@ -713,17 +777,17 @@ def metadata_to_primary_header(self, include_wcs=True): return hdul def append_all_wcs(self, hdul): - """Append the `_per_image_wcs` and - `_per_image_ebd_wcs` elements to a header. + """Append all the original WCS and EBD WCS to a header. Parameters ---------- hdul : `astropy.io.fits.HDUList` The HDU list. """ - n_constituents = len(self.constituent_images) - for i in range(n_constituents): - img_location = self.constituent_images[i] + all_ebd_wcs = self.get_constituent_meta("ebd_wcs") + + for i in range(self.n_constituents): + img_location = self.org_img_meta["data_loc"][i] orig_wcs = self._per_image_wcs[i] wcs_hdu = fits.TableHDU() @@ -732,9 +796,8 @@ def append_all_wcs(self, hdul): wcs_hdu.header["ILOC"] = img_location hdul.append(wcs_hdu) - im_ebd_wcs = self._per_image_ebd_wcs[i] ebd_hdu = fits.TableHDU() - append_wcs_to_hdu_header(im_ebd_wcs, ebd_hdu.header) + append_wcs_to_hdu_header(all_ebd_wcs[i], ebd_hdu.header) ebd_hdu.name = f"EBD_{i}" hdul.append(ebd_hdu) @@ -797,7 +860,7 @@ def image_positions_to_original_icrs( position_ebd_coords = radec_coords helio_dist = self.heliocentric_distance - geo_dists = [self.geocentric_distances[i] for i in image_indices] + geo_dists = [self.org_img_meta["geocentric_distance"][i] for i in image_indices] all_times = self.get_all_obstimes() obstimes = [all_times[i] for i in image_indices] @@ -824,7 +887,7 @@ def image_positions_to_original_icrs( coord = inverted_coords[i] pos = [] for j in inds: - con_image = self.constituent_images[j] + con_image = self.org_img_meta["data_loc"][j] con_wcs = self._per_image_wcs[j] height, width = con_wcs.array_shape x, y = skycoord_to_pixel(coord, con_wcs) diff --git a/tests/test_reprojection.py b/tests/test_reprojection.py index 6120a6f36..1135c7b50 100644 --- a/tests/test_reprojection.py +++ b/tests/test_reprojection.py @@ -75,7 +75,11 @@ def test_reproject(self): assert reprojected_wunit.wcs != None assert reprojected_wunit.im_stack.get_width() == 60 assert reprojected_wunit.im_stack.get_height() == 50 - assert reprojected_wunit.geocentric_distances == self.test_wunit.geocentric_distances + + test_dists = self.test_wunit.get_constituent_meta("geocentric_distance") + reproject_dists = reprojected_wunit.get_constituent_meta("geocentric_distance") + assert test_dists == reproject_dists + images = reprojected_wunit.im_stack.get_images() # will be 3 as opposed to the four in the original `WorkUnit`, diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index b093e0153..abd61fa23 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -160,7 +160,13 @@ def test_save_and_load_fits(self): # Write out the existing WorkUnit with a different per-image wcs for all the entries. # work = WorkUnit(self.im_stack, self.config, None, self.diff_wcs) - work = WorkUnit(im_stack=self.im_stack, config=self.config, wcs=None, per_image_wcs=self.diff_wcs) + work = WorkUnit( + im_stack=self.im_stack, + config=self.config, + wcs=None, + per_image_wcs=self.diff_wcs, + constituent_images=self.constituent_images, + ) work.to_fits(file_path) self.assertTrue(Path(file_path).is_file()) @@ -208,6 +214,10 @@ def test_save_and_load_fits(self): self.assertEqual(work2.config["im_filepath"], "Here") self.assertEqual(work2.config["num_obs"], self.num_images) + # Check that we correctly retrieved the provenance information via “data_loc” + for index, value in enumerate(work2.org_img_meta["data_loc"]): + self.assertEqual(value, self.constituent_images[index]) + # We throw an error if we try to overwrite a file with overwrite=False self.assertRaises(FileExistsError, work.to_fits, file_path)