From e5e6e701bad0a198ddaa8c0584b86eccb4d0fff0 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 15 Nov 2024 09:42:35 -0500 Subject: [PATCH] Revert "Bug fixes" This reverts commit 236e2256aa502a8421bd1d1c3d2771e424d9de17. --- src/kbmod/image_collection.py | 7 +- src/kbmod/reprojection.py | 29 +-- src/kbmod/work_unit.py | 461 ++++++++++++++++------------------ tests/test_reprojection.py | 3 + tests/test_work_unit.py | 140 ++++------- 5 files changed, 288 insertions(+), 352 deletions(-) diff --git a/src/kbmod/image_collection.py b/src/kbmod/image_collection.py index 4e75155f5..dce44354d 100644 --- a/src/kbmod/image_collection.py +++ b/src/kbmod/image_collection.py @@ -888,10 +888,7 @@ def toWorkUnit(self, search_config=None, **kwargs): for std in self.get_standardizers(**kwargs): for img in std["std"].toLayeredImage(): layeredImages.append(img) - imgstack = ImageStack(layeredImages) - img_metadata = Table() - if None not in self.wcs: - img_metadata["per_image_wcs"] = list(self.wcs) - return WorkUnit(imgstack, search_config, org_image_meta=img_metadata) + return WorkUnit(imgstack, search_config, per_image_wcs=list(self.wcs)) + return WorkUnit(imgstack, search_config) diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 49e4a0102..2c2fc68c3 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -118,9 +118,6 @@ def reproject_work_unit( A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case where `write_output` is set to True. """ - if work_unit.reprojected: - raise ValueError("Unable to reproject a reprojected WorkUnit.") - show_progress = is_interactive() if show_progress is None else show_progress if (work_unit.lazy or write_output) and (directory is None or filename is None): raise ValueError("can't write output to sharded fits without directory and filename provided.") @@ -198,7 +195,7 @@ def _reproject_work_unit( # Create a list of the correct WCS. We do this extraction once and reuse for all images. if frame == "original": - wcs_list = work_unit.get_constituent_meta("per_image_wcs") + wcs_list = work_unit.get_constituent_meta("original_wcs") elif frame == "ebd": wcs_list = work_unit.get_constituent_meta("ebd_wcs") else: @@ -286,25 +283,24 @@ def _reproject_work_unit( stack.append_image(new_layered_image, force_move=True) if write_output: - # Create a copy of the WorkUnit to write the global metadata. - # We preserve the metgadata for the consituent images. new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstime_indices new_work_unit.wcs = common_wcs new_work_unit.reprojected = True - hdul = new_work_unit.metadata_to_hdul() + hdul = new_work_unit.metadata_to_primary_header() hdul.writeto(os.path.join(directory, filename)) else: - # Create a new WorkUnit with the new ImageStack and global WCS. - # We preserve the metgadata for the consituent images. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, + constituent_images=work_unit.get_constituent_meta("data_loc"), + per_image_wcs=work_unit._per_image_wcs, + per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), per_image_indices=unique_obstime_indices, + geocentric_distances=work_unit.get_constituent_meta("geocentric_distance"), reprojected=True, - org_image_meta=work_unit.org_img_meta, ) return new_wunit @@ -429,7 +425,7 @@ def _reproject_work_unit_in_parallel( new_work_unit.wcs = common_wcs new_work_unit.reprojected = True - hdul = new_work_unit.metadata_to_hdul() + hdul = new_work_unit.metadata_to_primary_header() hdul.writeto(os.path.join(directory, filename)) else: stack = ImageStack([]) @@ -450,15 +446,17 @@ def _reproject_work_unit_in_parallel( # sort by the time_stamp stack.sort_by_time() - # Add the imageStack to a new WorkUnit and return it. We preserve the metgadata - # for the consituent images. + # Add the imageStack to a new WorkUnit and return it. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, + constituent_images=work_unit.get_constituent_meta("data_loc"), + per_image_wcs=work_unit._per_image_wcs, + per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), per_image_indices=unique_obstimes_indices, + geocentric_distances=work_unit.get_constituent_meta("geocentric_distances"), reprojected=True, - org_image_meta=work_unit.org_img_meta, ) return new_wunit @@ -551,7 +549,6 @@ def reproject_lazy_work_unit( if not result.result(): raise RuntimeError("one or more jobs failed.") - # We use new metadata for the new images and the same metadata for the original images. new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstimes_indices new_work_unit.wcs = common_wcs @@ -594,8 +591,6 @@ def _validate_original_wcs(work_unit, indices, frame="original"): else: raise ValueError("Invalid projection frame provided.") - if len(original_wcs) == 0: - raise ValueError(f"No WCS found for frame {frame}") if np.any(original_wcs) is None: # find indices where the wcs is None bad_indices = np.where(original_wcs == None) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index 05bf542da..f0320624c 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -7,7 +7,6 @@ from astropy.table import Table from astropy.time import Time from astropy.utils.exceptions import AstropyWarning -from astropy.wcs import WCS from astropy.wcs.utils import skycoord_to_pixel import astropy.units as u @@ -22,9 +21,7 @@ from kbmod.wcs_utils import ( append_wcs_to_hdu_header, calc_ecliptic_angle, - deserialize_wcs, extract_wcs_from_hdu_header, - serialize_wcs, wcs_fits_equal, ) @@ -46,9 +43,6 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. - n_images : `int` - The number of images. This may differ from the length of the ImageStack due - to lazy loading. n_constituents : `int` The number of original images making up the data in this WorkUnit. This might be different from the number of images stored in memory if the WorkUnit has been @@ -59,10 +53,13 @@ class WorkUnit: * ebd_wcs - Used to reproject the images into EBD space. * geocentric_distance - The best fit geocentric distances used when creating the per image EBD WCS. - * original_wcs - The original per-image WCS of the image. + * original_wcs - The original WCS of the image. wcs : `astropy.wcs.WCS` A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. + per_image_wcs : `list` + A list with one WCS for each image in the WorkUnit. Used for when + the images have *not* been standardized to the same pixel space. heliocentric_distance : `float` The heliocentric distance that was used when creating the `per_image_ebd_wcs`. reprojected : `bool` @@ -85,43 +82,53 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. - num_images : `int`, optional - The number of images. This may differ from the length of the ImageStack due - to lazy loading. - wcs : `astropy.wcs.WCS`, optional + wcs : `astropy.wcs.WCS` A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. - reprojected : `bool`, optional + constituent_images: `list` + A list of strings with the original locations of images used + to construct the WorkUnit. This is necessary to maintain as metadata + because after reprojection we may stitch multiple images into one. + per_image_wcs : `list` + A list with one WCS for each image in the WorkUnit. Used for when + the images have *not* been standardized to the same pixel space. + per_image_ebd_wcs : `list` + A list with one WCS for each image in the WorkUnit. Used to reproject the images + into EBD space. + heliocentric_distance : `float` + The heliocentric distance that was used when creating the `per_image_ebd_wcs`. + geocentric_distances : `list` + The best fit geocentric distances used when creating the `per_image_ebd_wcs`. + reprojected : `bool` Whether or not the WorkUnit image data has been reprojected. - per_image_indices : `list` of `list`, optional + per_image_indices : `list` of `list` A list of lists containing the indicies of `constituent_images` at each layer of the `ImageStack`. Used for finding corresponding original images when we stitch images together during reprojection. - heliocentric_distance : `float`, optional - The heliocentric distance that was used when creating the `per_image_ebd_wcs`. - lazy : `bool`, optional + lazy : `bool` Whether or not to load the image data for the `WorkUnit`. - file_paths : `list[str]`, optional + file_paths : `list[str]` The paths for the shard files, only created if the `WorkUnit` is loaded in lazy mode. obstimes : `list[float]` The MJD obstimes of the images. - org_image_meta : `dict` or `astropy.table.Table`, optional - A table of per-image data for the constituent images. """ def __init__( self, - im_stack, - config, + im_stack=None, + config=None, wcs=None, + constituent_images=None, + per_image_wcs=None, + per_image_ebd_wcs=None, + heliocentric_distance=None, + geocentric_distances=None, reprojected=False, per_image_indices=None, - heliocentric_distance=None, lazy=False, file_paths=None, obstimes=None, - org_image_meta=None, ): self.im_stack = im_stack self.config = config @@ -129,28 +136,47 @@ def __init__( self.file_paths = file_paths self._obstimes = obstimes - # Determine the number of constituent images. If we are given metadata for the - # of constituent_images, use that. Otherwise use the size of the image stack. - if org_image_meta is None: + # Determine the number of constituent images. If we are given a list of constituent_images, + # use that. Otherwise use the size of the image stack. + if constituent_images is None: self.n_constituents = im_stack.img_count() else: - self.n_constituents = len(org_image_meta) + self.n_constituents = len(constituent_images) - # Track the metadata for each constituent image in the WorkUnit. If no constituent - # data is provided, this will create an empty array the same size as the original. - self.org_img_meta = create_image_metadata(self.n_constituents, data=org_image_meta) + # Track the meta data for each constituent image in the WorkUnit. For the original + # WCS, we track the per-image WCS if it is provided and otherwise the global WCS. + self.org_img_meta = Table() + self.add_org_img_meta_data("data_loc", constituent_images) + self.add_org_img_meta_data("original_wcs", per_image_wcs, default=wcs) # Handle WCS input. If both the global and per-image WCS are provided, # ensure they are consistent. self.wcs = wcs - if np.all(self.org_img_meta["per_image_wcs"] == None): - self.org_img_meta["per_image_wcs"] = np.full(self.n_constituents, self.wcs) - if np.any(self.org_img_meta["per_image_wcs"] == None): - warnings.warn("At least one image was does not have a WCS.", Warning) - - # Set the global metadata for reprojection. + if per_image_wcs is None: + self._per_image_wcs = [None] * self.n_constituents + if self.wcs is None and per_image_ebd_wcs is None: + warnings.warn("No WCS provided.", Warning) + else: + if len(per_image_wcs) != self.n_constituents: + raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_constituents}") + self._per_image_wcs = per_image_wcs + + # Check if all the per-image WCS are None. This can happen during a load. + all_none = self.per_image_wcs_all_match(None) + if self.wcs is None and all_none: + warnings.warn("No WCS provided.", Warning) + + # See if we can compress the per-image WCS into a global one. + if self.wcs is None and not all_none and self.per_image_wcs_all_match(self._per_image_wcs[0]): + self.wcs = self._per_image_wcs[0] + self._per_image_wcs = [None] * im_stack.img_count() + + # Add the meta data needed for reprojection, including: the reprojected WCS, the geocentric + # distances, and each images indices in the original constituent images. self.reprojected = reprojected self.heliocentric_distance = heliocentric_distance + self.add_org_img_meta_data("geocentric_distance", geocentric_distances) + self.add_org_img_meta_data("ebd_wcs", per_image_ebd_wcs) # If we have mosaicked images, each image in the stack could link back # to more than one constituents image. Build a mapping of image stack index @@ -160,36 +186,59 @@ def __init__( else: self._per_image_indices = per_image_indices - # Run some basic validity checks. - if self.reprojected and self.wcs is None: - raise ValueError("Global WCS required for reprojected data.") - for inds in self._per_image_indices: - if np.max(inds) >= self.n_constituents: - raise ValueError( - f"Found pointer to constituents image {np.max(inds)} of {self.n_constituents}" - ) - def __len__(self): """Returns the size of the WorkUnit in number of images.""" return self.im_stack.img_count() - def get_num_images(self): - return len(self._per_image_indices) - def get_constituent_meta(self, column): - """Get the meta data values of a given column for all the constituent images. + """Get the meta data values of a given column for all the constituent images.""" + return list(self.org_img_meta[column].data) + + def add_org_img_meta_data(self, column, data, default=None): + """Add a column of meta data for the constituent images. Adds a column of all + default values if data is None and the column does not already exist. Parameters ---------- column : `str` - The column name to fetch. + The name of the meta data column. + data : list-like + The data for each constituent. If None then uses None for + each column. + """ + if data is None: + if column not in self.org_img_meta.colnames: + self.org_img_meta[column] = [default] * self.n_constituents + elif len(data) == self.n_constituents: + self.org_img_meta[column] = data + else: + raise ValueError( + f"Data mismatch size for WorkUnit metadata {column}. " + f"Expected {self.n_constituents} but found {len(data)}." + ) + + def has_common_wcs(self): + """Returns whether the WorkUnit has a common WCS for all images.""" + return self.wcs is not None + + def per_image_wcs_all_match(self, target=None): + """Check if all the per-image WCS are the same as a given target value. + + Parameters + ---------- + target : `astropy.wcs.WCS`, optional + The WCS to which to compare the per-image WCS. If None, checks that + all of the per-image WCS are None. Returns ------- - data : `list` - A list of the meta-data for each constituent image. + result : `bool` + A Boolean indicating that all the per-images WCS match the target. """ - return list(self.org_img_meta[column].data) + for current in self._per_image_wcs: + if not wcs_fits_equal(current, target): + return False + return True def get_wcs(self, img_num): """Return the WCS for the a given image. Alway prioritizes @@ -204,12 +253,26 @@ def get_wcs(self, img_num): ------- wcs : `astropy.wcs.WCS` The image's WCS if one exists. Otherwise None. + + Raises + ------ + IndexError if an invalid index is given. """ + if img_num < 0 or img_num >= len(self._per_image_wcs): + raise IndexError(f"Invalid image number {img_num}") + + # Extract the per-image WCS if one exists. + if self._per_image_wcs is not None and img_num < len(self._per_image_wcs): + per_img = self._per_image_wcs[img_num] + else: + per_img = None + if self.wcs is not None: + if per_img is not None and not wcs_fits_equal(self.wcs, per_img): + warnings.warn("Both a global and per-image WCS given. Using global WCS.", Warning) return self.wcs - else: - # If there is no common WCS, use the original per-image one. - return self.org_img_meta["per_image_wcs"][img_num] + + return per_img def get_pixel_coordinates(self, ra, dec, times=None): """Get the pixel coordinates for pairs of (RA, dec) coordinates. Uses the global @@ -256,7 +319,7 @@ def get_pixel_coordinates(self, ra, dec, times=None): for i, index in enumerate(inds): if index == -1: raise ValueError(f"Unmatched time {times[i]}.") - current_wcs = self.org_img_meta["per_image_wcs"][index] + current_wcs = self._per_image_wcs[index] curr_x, curr_y = current_wcs.world_to_pixel( SkyCoord(ra=ra[i] * u.degree, dec=dec[i] * u.degree) ) @@ -300,6 +363,9 @@ def get_unique_obstimes_and_indices(self): unique_indices = [list(np.where(all_obstimes == time)[0]) for time in unique_obstimes] return unique_obstimes, unique_indices + def get_num_images(self): + return len(self._per_image_indices) + @classmethod def from_fits(cls, filename, show_progress=None): """Create a WorkUnit from a single FITS file. @@ -337,22 +403,11 @@ def from_fits(cls, filename, show_progress=None): if num_layers < 5: raise ValueError(f"WorkUnit file has too few extensions {len(hdul)}.") + # TODO - Read in provenance metadata from extension #1. + # Read in the search parameters from the 'kbmod_config' extension. config = SearchConfiguration.from_hdu(hdul["kbmod_config"]) - # Read the size and order information from the primary header. - num_images = hdul[0].header["NUMIMG"] - n_constituents = hdul[0].header["NCON"] if "NCON" in hdul[0].header else num_images - logger.info(f"Loading {num_images} images.") - - # Read in the per-image metadata for the constituent images. - if "IMG_META" in hdul: - logger.debug("Reading original image metadata from IMG_META.") - hdu_meta = hdu_to_image_metadata_table(hdul["IMG_META"]) - else: - hdu_meta = None - org_image_meta = create_image_metadata(n_constituents, data=hdu_meta) - # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension # since the primary header does not have an image. @@ -360,18 +415,23 @@ def from_fits(cls, filename, show_progress=None): warnings.simplefilter("ignore", AstropyWarning) global_wcs = extract_wcs_from_hdu_header(hdul[0].header) + # Read the size and order information from the primary header. + num_images = hdul[0].header["NUMIMG"] + n_constituents = hdul[0].header["NCON"] + expected_num_images = (4 * num_images) + (2 * n_constituents) + 3 + if len(hdul) != expected_num_images: + raise ValueError(f"WorkUnit wrong number of extensions. Expected " f"{expected_num_images}.") + logger.info(f"Loading {num_images} images and {expected_num_images} total layers.") + # Misc. reprojection metadata reprojected = hdul[0].header["REPRJCTD"] heliocentric_distance = hdul[0].header["HELIO"] + geocentric_distances = [] + for i in range(num_images): + geocentric_distances.append(hdul[0].header[f"GEO_{i}"]) - # If there is geocentric distances in the header information - # (legacy approach), in read those. - for i in range(n_constituents): - if f"GEO_{i}" in hdul[0].header: - org_image_meta["geocentric_distance"][i] = hdul[0].header[f"GEO_{i}"] - - # Read in all the image files. per_image_indices = [] + # Read in all the image files. for i in tqdm( range(num_images), bar_format=_DEFAULT_WORKUNIT_TQDM_BAR, @@ -392,39 +452,37 @@ def from_fits(cls, filename, show_progress=None): # force_move destroys img object, but avoids a copy. im_stack.append_image(img, force_move=True) - # Read the mapping of current image to constituent image from the header info. - # TODO: Serialize this into its own table. n_indices = sci_hdu.header["NIND"] sub_indices = [] for j in range(n_indices): sub_indices.append(sci_hdu.header[f"IND_{j}"]) per_image_indices.append(sub_indices) - # Extract the per-image data from header information if needed. This happens - # when the WorkUnit was saved before metadata tables were saved as layers and - # all the information is in header values. + per_image_wcs = [] + per_image_ebd_wcs = [] + constituent_images = [] for i in tqdm( range(n_constituents), bar_format=_DEFAULT_WORKUNIT_TQDM_BAR, desc="Loading WCS", disable=not show_progress, ): - if f"WCS_{i}" in hdul: - wcs_header = hdul[f"WCS_{i}"].header - org_image_meta["per_image_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) - if "ILOC" in wcs_header: - org_image_meta["data_loc"][i] = wcs_header["ILOC"] - if f"EBD_{i}" in hdul: - org_image_meta["ebd_wcs"][i] = extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header) + # Extract the per-image WCS if one exists. + per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header)) + per_image_ebd_wcs.append(extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header)) + constituent_images.append(hdul[f"WCS_{i}"].header["ILOC"]) result = WorkUnit( im_stack=im_stack, config=config, wcs=global_wcs, + constituent_images=constituent_images, + per_image_wcs=per_image_wcs, + per_image_ebd_wcs=per_image_ebd_wcs, heliocentric_distance=heliocentric_distance, + geocentric_distances=geocentric_distances, reprojected=reprojected, per_image_indices=per_image_indices, - org_image_meta=org_image_meta, ) return result @@ -450,10 +508,8 @@ def to_fits(self, filename, overwrite=False): if Path(filename).is_file() and not overwrite: raise FileExistsError(f"WorkUnit file {filename} already exists.") - # Create an HDU list with the metadata layers, including all the WCS info. - hdul = self.metadata_to_hdul() + hdul = self.metadata_to_primary_header(include_wcs=False) - # Create each image layer. for i in range(self.im_stack.img_count()): layered = self.im_stack.get_single_image(i) obstime = layered.get_obstime() @@ -482,6 +538,8 @@ def to_fits(self, filename, overwrite=False): psf_hdu.name = f"PSF_{i}" hdul.append(psf_hdu) + self.append_all_wcs(hdul) + hdul.writeto(filename, overwrite=overwrite) def to_sharded_fits(self, filename, directory, overwrite=False): @@ -495,6 +553,7 @@ def to_sharded_fits(self, filename, directory, overwrite=False): image index infront of the given filename, e.g. "0_filename.fits". + Primary File: 0 - Primary header with overall metadata 1 or "metadata" - The data provenance metadata @@ -556,8 +615,7 @@ def to_sharded_fits(self, filename, directory, overwrite=False): sub_hdul.append(psf_hdu) sub_hdul.writeto(os.path.join(directory, f"{i}_{filename}")) - # Create a primary file with all of the metadata, including all the WCS info. - hdul = self.metadata_to_hdul() + hdul = self.metadata_to_primary_header(include_wcs=True) hdul.writeto(os.path.join(directory, filename), overwrite=overwrite) @classmethod @@ -602,19 +660,6 @@ def from_sharded_fits(cls, filename, directory, lazy=False): with fits.open(os.path.join(directory, filename)) as primary: config = SearchConfiguration.from_hdu(primary["kbmod_config"]) - # Read the size and order information from the primary header. - num_images = primary[0].header["NUMIMG"] - n_constituents = primary[0].header["NCON"] if "NCON" in primary[0].header else num_images - logger.info(f"Loading {num_images} images.") - - # Read in the per-image metadata for the constituent images. - if "IMG_META" in primary: - logger.debug("Reading original image metadata from IMG_META.") - hdu_meta = hdu_to_image_metadata_table(primary["IMG_META"]) - else: - hdu_meta = None - org_image_meta = create_image_metadata(n_constituents, data=hdu_meta) - # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension # since the primary header does not have an image. @@ -622,25 +667,26 @@ def from_sharded_fits(cls, filename, directory, lazy=False): warnings.simplefilter("ignore", AstropyWarning) global_wcs = extract_wcs_from_hdu_header(primary[0].header) + # Read the size and order information from the primary header. + num_images = primary[0].header["NUMIMG"] + n_constituents = primary[0].header["NCON"] + expected_num_images = (4 * num_images) + (2 * n_constituents) + 3 + # Misc. reprojection metadata reprojected = primary[0].header["REPRJCTD"] heliocentric_distance = primary[0].header["HELIO"] + geocentric_distances = [] for i in range(n_constituents): - if f"GEO_{i}" in primary[0].header: - org_image_meta["geocentric_distance"][i] = primary[0].header[f"GEO_{i}"] + geocentric_distances.append(primary[0].header[f"GEO_{i}"]) - # Extract the per-image data from header information if needed. - # This happens with when the WorkUnit was saved before metadata tables were - # saved as layers. + per_image_wcs = [] + per_image_ebd_wcs = [] + constituent_images = [] for i in range(n_constituents): - if f"WCS_{i}" in primary: - wcs_header = primary[f"WCS_{i}"].header - org_image_meta["per_image_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) - if "ILOC" in wcs_header: - org_image_meta["data_loc"][i] = wcs_header["ILOC"] - if f"EBD_{i}" in primary: - org_image_meta["ebd_wcs"][i] = extract_wcs_from_hdu_header(primary[f"EBD_{i}"].header) - + # Extract the per-image WCS if one exists. + per_image_wcs.append(extract_wcs_from_hdu_header(primary[f"WCS_{i}"].header)) + per_image_ebd_wcs.append(extract_wcs_from_hdu_header(primary[f"EBD_{i}"].header)) + constituent_images.append(primary[f"WCS_{i}"].header["ILOC"]) per_image_indices = [] file_paths = [] obstimes = [] @@ -661,7 +707,6 @@ def from_sharded_fits(cls, filename, directory, lazy=False): else: file_paths.append(shard_path) - # Load the mapping of current image to constituent image. n_indices = sci_hdu.header["NIND"] sub_indices = [] for j in range(n_indices): @@ -673,19 +718,29 @@ def from_sharded_fits(cls, filename, directory, lazy=False): im_stack=im_stack, config=config, wcs=global_wcs, - reprojected=reprojected, - lazy=lazy, + constituent_images=constituent_images, + per_image_wcs=per_image_wcs, + per_image_ebd_wcs=per_image_ebd_wcs, heliocentric_distance=heliocentric_distance, + geocentric_distances=geocentric_distances, + reprojected=reprojected, per_image_indices=per_image_indices, + lazy=lazy, file_paths=file_paths, obstimes=obstimes, - org_image_meta=org_image_meta, ) return result - def metadata_to_hdul(self): + def metadata_to_primary_header(self, include_wcs=True): """Creates the metadata fits headers. + Parameters + ---------- + include_wcs : `bool` + whether or not to append all the per image wcses + to the header (optional for the serial `to_fits` + case so that we can maintain the same indexing + as before). Returns ------- hdul : `astropy.io.fits.HDUList` @@ -699,22 +754,53 @@ def metadata_to_hdul(self): pri.header["NCON"] = self.n_constituents pri.header["REPRJCTD"] = self.reprojected pri.header["HELIO"] = self.heliocentric_distance + for i in range(self.n_constituents): + pri.header[f"GEO_{i}"] = self.org_img_meta["geocentric_distance"][i] - # If the global WCS exists, append the corresponding keys to the primary header. + # If the global WCS exists, append the corresponding keys. if self.wcs is not None: append_wcs_to_hdu_header(self.wcs, pri.header) + hdul.append(pri) - # Add the configuration layer. + meta_hdu = fits.BinTableHDU() + meta_hdu.name = "metadata" + hdul.append(meta_hdu) + config_hdu = self.config.to_hdu() config_hdu.name = "kbmod_config" hdul.append(config_hdu) - # Save the additional metadata table into HDUs - hdul.append(image_metadata_table_to_hdu(self.org_img_meta, "IMG_META")) + if include_wcs: + self.append_all_wcs(hdul) return hdul + def append_all_wcs(self, hdul): + """Append all the original WCS and EBD WCS to a header. + + Parameters + ---------- + hdul : `astropy.io.fits.HDUList` + The HDU list. + """ + all_ebd_wcs = self.get_constituent_meta("ebd_wcs") + + for i in range(self.n_constituents): + img_location = self.org_img_meta["data_loc"][i] + + orig_wcs = self._per_image_wcs[i] + wcs_hdu = fits.TableHDU() + append_wcs_to_hdu_header(orig_wcs, wcs_hdu.header) + wcs_hdu.name = f"WCS_{i}" + wcs_hdu.header["ILOC"] = img_location + hdul.append(wcs_hdu) + + ebd_hdu = fits.TableHDU() + append_wcs_to_hdu_header(all_ebd_wcs[i], ebd_hdu.header) + ebd_hdu.name = f"EBD_{i}" + hdul.append(ebd_hdu) + def image_positions_to_original_icrs( self, image_indices, positions, input_format="xy", output_format="xy", filter_in_frame=True ): @@ -743,7 +829,6 @@ def image_positions_to_original_icrs( Whether or not to filter the output based on whether they fit within the original `constituent_image` frame. If `True`, only results that fall within the bounds of the original WCS will be returned. - Returns ------- positions : `list` of `astropy.coordinates.SkyCoord`s or `tuple`s @@ -803,7 +888,7 @@ def image_positions_to_original_icrs( pos = [] for j in inds: con_image = self.org_img_meta["data_loc"][j] - con_wcs = self.org_img_meta["per_image_wcs"][j] + con_wcs = self._per_image_wcs[j] height, width = con_wcs.array_shape x, y = skycoord_to_pixel(coord, con_wcs) x, y = float(x), float(y) @@ -900,115 +985,3 @@ def raw_image_to_hdu(img, obstime, wcs=None): hdu.header["MJD"] = obstime return hdu - - -# ------------------------------------------------------------------ -# --- Utility functions for the metadata table --------------------- -# ------------------------------------------------------------------ - - -def create_image_metadata(n_images, data=None): - """Create an empty img_meta table, filling in default values - for any unspecified columns. - - Parameters - ---------- - n_images : `int` - The number of images to include. - data : `astropy.table.Table` - An existing table from which to fill in some of the data. - - Returns - ------- - img_meta : `astropy.table.Table` - The empty table of org_img_meta. - """ - if n_images <= 0: - raise ValueError("Invalid metadata size: {n_images}") - img_meta = Table() - - # Fill in the defaults. - for colname in ["data_loc", "ebd_wcs", "geocentric_distance", "per_image_wcs"]: - img_meta[colname] = np.full(n_images, None) - - # Fill in any values from the given table. This overwrites the defaults. - if data is not None: - if len(data) != n_images: - raise ValueError(f"Metadata size mismatch. Expected {n_images}. Found {len(data)}") - for colname in data.colnames: - img_meta[colname] = data[colname] - - return img_meta - - -def image_metadata_table_to_hdu(data, layer_name=None): - """Create a HDU layer from an astropy table with custom - encodings for some columns (such as WCS). - - Parameters - ---------- - data : `astropy.table.Table` - The table of the data to save. - layer_name : `str`, optional - The name of the layer in which to save the table. - """ - num_rows = len(data) - if num_rows == 0: - # No data to encode. Just use the current table. - save_table = data - else: - # Create a new table to save with the correct column - # values/names for the serialized information. - save_table = Table() - for colname in data.colnames: - col_data = data[colname].value - - if np.all(col_data == None): - # The entire column is filled with Nones (probably from a default value). - save_table[f"_EMPTY_{colname}"] = np.full(num_rows, "None", dtype=str) - elif isinstance(col_data[0], WCS): - # Serialize WCS objects and use a custom tag so we can unserialize them. - values = np.array([serialize_wcs(entry) for entry in data[colname]], dtype=str) - save_table[f"_WCSSTR_{colname}"] = values - else: - save_table[colname] = data[colname] - - # Format the metadata as a single HDU - meta_hdu = fits.TableHDU(save_table) - if layer_name is not None: - meta_hdu.name = layer_name - return meta_hdu - - -def hdu_to_image_metadata_table(hdu): - """Load a HDU layer with custom encodings for some columns (such as WCS) - to an astropy table. - - Parameters - ---------- - hdu : `astropy.io.fits.TableHDU` - The HDUList for the fits file. - - Returns - ------- - data : `astropy.table.Table` - The table of loaded data. - """ - if hdu is None: - # Nothing to decode. Return an empty table. - return Table() - - data = Table(hdu.data) - all_cols = set(data.colnames) - - # Check if there are any columns we need to decode. If so: decode them, add a new column, - # and delete the old column. - for colname in all_cols: - if colname.startswith("_WCSSTR_"): - data[colname[8:]] = np.array([deserialize_wcs(entry) for entry in data[colname]]) - data.remove_column(colname) - elif colname.startswith("_EMPTY_"): - data[colname[7:]] = np.array([None for _ in data[colname]]) - data.remove_column(colname) - - return data diff --git a/tests/test_reprojection.py b/tests/test_reprojection.py index 63c13defd..1135c7b50 100644 --- a/tests/test_reprojection.py +++ b/tests/test_reprojection.py @@ -19,6 +19,9 @@ def setUp(self): self.test_wunit = WorkUnit.from_fits(self.data_path) self.common_wcs = self.test_wunit.get_wcs(0) + # self.tmp_dir = os.path(tempfile.TemporaryDirectory()) + # self.test_wunit.to_sharded_fits("test_wunit.fits", self.tmp_dir) + def test_reproject(self): # test exception conditions self.assertRaises( diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index 8e2a86c2f..abd61fa23 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -4,7 +4,6 @@ from astropy.coordinates import EarthLocation, SkyCoord from astropy.time import Time import numpy as np -import numpy.testing as npt import os from pathlib import Path import tempfile @@ -16,13 +15,7 @@ import kbmod.search as kb from kbmod.reprojection_utils import fit_barycentric_wcs from kbmod.wcs_utils import make_fake_wcs, wcs_fits_equal -from kbmod.work_unit import ( - create_image_metadata, - hdu_to_image_metadata_table, - image_metadata_table_to_hdu, - raw_image_to_hdu, - WorkUnit, -) +from kbmod.work_unit import raw_image_to_hdu, WorkUnit import numpy.testing as npt @@ -103,25 +96,18 @@ def setUp(self): "five.fits", ] - self.org_image_meta = Table( - { - "data_loc": np.array(self.constituent_images), - "ebd_wcs": np.array([self.per_image_ebd_wcs] * self.num_images), - "geocentric_distance": np.array([self.geo_dist] * self.num_images), - "per_image_wcs": np.array(self.per_image_wcs), - } - ) - def test_create(self): # Test the creation of a WorkUnit with no WCS. Should throw a warning. with warnings.catch_warnings(record=True) as wrn: warnings.simplefilter("always") work = WorkUnit(self.im_stack, self.config) + self.assertTrue("No WCS provided." in str(wrn[-1].message)) self.assertIsNotNone(work) self.assertEqual(work.im_stack.img_count(), 5) self.assertEqual(work.config["im_filepath"], "Here") self.assertEqual(work.config["num_obs"], 5) + self.assertFalse(work.has_common_wcs()) self.assertIsNone(work.wcs) self.assertEqual(len(work), self.num_images) for i in range(self.num_images): @@ -130,6 +116,7 @@ def test_create(self): # Create with a global WCS work2 = WorkUnit(self.im_stack, self.config, self.wcs) self.assertEqual(work2.im_stack.img_count(), 5) + self.assertTrue(work2.has_common_wcs()) self.assertIsNotNone(work2.wcs) for i in range(self.num_images): self.assertIsNotNone(work2.get_wcs(i)) @@ -142,60 +129,26 @@ def test_create(self): self.im_stack, self.config, self.wcs, + [f"img_{i}" for i in range(self.im_stack.img_count())], [self.wcs, self.wcs, self.wcs], ) - def test_metadata_helpers(self): - """Test that we can roundtrip an astropy table of metadata (including) WCS - into a BinTableHDU. - """ - metadata_dict = { - "col1": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), # Floats - "uri": np.array(["a", "bc", "def", "ghij", "other_strings"]), # Strings - "wcs": np.array(self.per_image_wcs), # WCSes - "none_col": np.array([None] * self.num_images), # Empty column - "Other": np.arange(5), # ints - } - metadata_table = Table(metadata_dict) - - # Convert to an HDU - hdu = image_metadata_table_to_hdu(metadata_table) - self.assertIsNotNone(hdu) - - # Convert it back. - md_table2 = hdu_to_image_metadata_table(hdu) - self.assertEqual(len(md_table2.colnames), 5) - npt.assert_array_equal(metadata_dict["col1"], md_table2["col1"]) - npt.assert_array_equal(metadata_dict["uri"], md_table2["uri"]) - npt.assert_array_equal(metadata_dict["Other"], md_table2["Other"]) - self.assertTrue(np.all(md_table2["none_col"] == None)) - for i in range(len(md_table2)): - self.assertTrue(isinstance(md_table2["wcs"][i], WCS)) - - def test_create_image_meta(self): - # Empty constituent image data. - org_img_meta = create_image_meta(n_images=3, data=None) - self.assertEqual(len(org_img_meta), 3) - self.assertTrue("data_loc" in org_img_meta.colnames) - self.assertTrue("ebd_wcs" in org_img_meta.colnames) - self.assertTrue("geocentric_distance" in org_img_meta.colnames) - self.assertTrue("per_image_wcs" in org_img_meta.colnames) - - # We can create from a dictionary. In this case we ignore n_images - data_dict = { - "uri": ["file1", "file2", "file3"], - "geocentric_distance": [1.0, 2.0, 3.0], - } - org_img_meta2 = create_image_meta(n_images=5, data=data_dict) - self.assertEqual(len(org_img_meta2), 3) - self.assertTrue("data_loc" in org_img_meta2.colnames) - self.assertTrue("ebd_wcs" in org_img_meta2.colnames) - self.assertTrue("geocentric_distance" in org_img_meta2.colnames) - self.assertTrue("per_image_wcs" in org_img_meta2.colnames) - self.assertTrue("uri" in org_img_meta2.colnames) - - # We need either data or a positive number of images. - self.assertRaises(ValueError, create_image_meta, None, -1) + # Create with per-image WCS that can be compressed to a global WCS. + per_image_wcs = [self.wcs] * self.num_images + work3 = WorkUnit(self.im_stack, self.config, per_image_wcs=per_image_wcs) + self.assertIsNotNone(work3.wcs) + self.assertTrue(work3.has_common_wcs()) + for i in range(self.num_images): + self.assertIsNotNone(work3.get_wcs(i)) + self.assertTrue(wcs_fits_equal(self.wcs, work3.get_wcs(i))) + + # Create with per-image WCS that cannot be compressed to a global WCS. + work3 = WorkUnit(self.im_stack, self.config, per_image_wcs=self.diff_wcs) + self.assertIsNone(work3.wcs) + self.assertFalse(work3.has_common_wcs()) + for i in range(self.num_images): + self.assertIsNotNone(work3.get_wcs(i)) + self.assertTrue(wcs_fits_equal(work3.get_wcs(i), self.diff_wcs[i])) def test_save_and_load_fits(self): with tempfile.TemporaryDirectory() as dir_name: @@ -206,19 +159,13 @@ def test_save_and_load_fits(self): self.assertRaises(ValueError, WorkUnit.from_fits, file_path) # Write out the existing WorkUnit with a different per-image wcs for all the entries. - # work = WorkUnit(self.im_stack, self.config, None, self.diff_wcs). - # Include extra per-image metadata. - extra_meta = { - "data_loc": np.array(self.constituent_images), - "int_index": np.arange(self.num_images), - "uri": np.array([f"file_loc_{i}" for i in range(self.num_images)]), - } + # work = WorkUnit(self.im_stack, self.config, None, self.diff_wcs) work = WorkUnit( im_stack=self.im_stack, config=self.config, wcs=None, per_image_wcs=self.diff_wcs, - org_image_meta=extra_meta, + constituent_images=self.constituent_images, ) work.to_fits(file_path) self.assertTrue(Path(file_path).is_file()) @@ -227,6 +174,7 @@ def test_save_and_load_fits(self): work2 = WorkUnit.from_fits(file_path) self.assertEqual(work2.im_stack.img_count(), self.num_images) self.assertIsNone(work2.wcs) + self.assertFalse(work2.has_common_wcs()) for i in range(self.num_images): li = work2.im_stack.get_single_image(i) self.assertEqual(li.get_width(), self.width) @@ -266,10 +214,9 @@ def test_save_and_load_fits(self): self.assertEqual(work2.config["im_filepath"], "Here") self.assertEqual(work2.config["num_obs"], self.num_images) - # Check that we retrieved the extra metadata that we added. - npt.assert_array_equal(work2.get_constituent_meta("uri"), extra_meta["uri"]) - npt.assert_array_equal(work2.get_constituent_meta("int_index"), extra_meta["int_index"]) - npt.assert_array_equal(work2.get_constituent_meta("data_loc"), self.constituent_images) + # Check that we correctly retrieved the provenance information via “data_loc” + for index, value in enumerate(work2.org_img_meta["data_loc"]): + self.assertEqual(value, self.constituent_images[index]) # We throw an error if we try to overwrite a file with overwrite=False self.assertRaises(FileExistsError, work.to_fits, file_path) @@ -295,6 +242,7 @@ def test_save_and_load_fits_shard(self): work2 = WorkUnit.from_sharded_fits(filename="test_workunit.fits", directory=dir_name) self.assertEqual(work2.im_stack.img_count(), self.num_images) self.assertIsNone(work2.wcs) + self.assertFalse(work2.has_common_wcs()) for i in range(self.num_images): li = work2.im_stack.get_single_image(i) self.assertEqual(li.get_width(), self.width) @@ -350,6 +298,7 @@ def test_save_and_load_fits_shard_lazy(self): work2 = WorkUnit.from_sharded_fits(filename="test_workunit.fits", directory=dir_name, lazy=True) self.assertEqual(len(work2.file_paths), self.num_images) self.assertIsNone(work2.wcs) + self.assertFalse(work2.has_common_wcs()) # Check that we read in the configuration values correctly. self.assertEqual(work2.config["im_filepath"], "Here") @@ -372,6 +321,7 @@ def test_save_and_load_fits_global_wcs(self): # Read in the file and check that the values agree. work2 = WorkUnit.from_fits(file_path) self.assertIsNotNone(work2.wcs) + self.assertTrue(work2.has_common_wcs()) self.assertTrue(wcs_fits_equal(work2.wcs, self.wcs)) for i in range(self.num_images): self.assertIsNotNone(work2.get_wcs(i)) @@ -391,9 +341,12 @@ def test_image_positions_to_original_icrs_invalid_format(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, + per_image_wcs=self.per_image_wcs, + per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, + geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, + constituent_images=self.constituent_images, reprojected=True, - org_image_meta=self.org_image_meta, ) # Incorrect format for 'xy' @@ -428,9 +381,12 @@ def test_image_positions_to_original_icrs_basic_inputs(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, + per_image_wcs=self.per_image_wcs, + per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, + geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, + constituent_images=self.constituent_images, reprojected=True, - org_image_meta=self.org_image_meta, ) res = work.image_positions_to_original_icrs( @@ -474,9 +430,12 @@ def test_image_positions_to_original_icrs_filtering(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, + per_image_wcs=self.per_image_wcs, + per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, + geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, + constituent_images=self.constituent_images, reprojected=True, - org_image_meta=self.org_image_meta, ) res = work.image_positions_to_original_icrs( @@ -499,13 +458,16 @@ def test_image_positions_to_original_icrs_mosaicking(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, + per_image_wcs=self.per_image_wcs, + per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, + geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, + constituent_images=self.constituent_images, reprojected=True, - org_image_meta=self.org_image_meta, ) new_wcs = make_fake_wcs(190.0, -7.7888, 500, 700) - work.org_img_meta["per_image_wcs"][-1] = new_wcs + work._per_image_wcs[-1] = new_wcs work._per_image_indices[3] = [3, 4] res = work.image_positions_to_original_icrs( @@ -532,6 +494,9 @@ def test_image_positions_to_original_icrs_mosaicking(self): npt.assert_almost_equal(res[3][0].separation(self.expected_radec_positions[3]).deg, 0.0, decimal=5) assert res[3][1] == "five.fits" + # work._per_image_wcs[4] = work._per_image_wcs[3] + # work._per_image_ebd_wcs[4] = work._per_image_ebd_wcs[3] + res = work.image_positions_to_original_icrs( self.indices, self.pixel_positions, @@ -552,8 +517,11 @@ def test_get_unique_obstimes_and_indices(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, + per_image_wcs=self.per_image_wcs, + per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, + geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - org_image_meta=self.org_image_meta, + constituent_images=self.constituent_images, ) times = work.get_all_obstimes() times[-1] = times[-2]