From 236e2256aa502a8421bd1d1c3d2771e424d9de17 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 15 Nov 2024 09:42:20 -0500 Subject: [PATCH] Bug fixes --- src/kbmod/image_collection.py | 7 +- src/kbmod/reprojection.py | 29 ++- src/kbmod/work_unit.py | 461 ++++++++++++++++++---------------- tests/test_reprojection.py | 3 - tests/test_work_unit.py | 140 +++++++---- 5 files changed, 352 insertions(+), 288 deletions(-) diff --git a/src/kbmod/image_collection.py b/src/kbmod/image_collection.py index dce44354d..4e75155f5 100644 --- a/src/kbmod/image_collection.py +++ b/src/kbmod/image_collection.py @@ -888,7 +888,10 @@ def toWorkUnit(self, search_config=None, **kwargs): for std in self.get_standardizers(**kwargs): for img in std["std"].toLayeredImage(): layeredImages.append(img) + imgstack = ImageStack(layeredImages) + img_metadata = Table() + if None not in self.wcs: - return WorkUnit(imgstack, search_config, per_image_wcs=list(self.wcs)) - return WorkUnit(imgstack, search_config) + img_metadata["per_image_wcs"] = list(self.wcs) + return WorkUnit(imgstack, search_config, org_image_meta=img_metadata) diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 2c2fc68c3..49e4a0102 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -118,6 +118,9 @@ def reproject_work_unit( A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case where `write_output` is set to True. """ + if work_unit.reprojected: + raise ValueError("Unable to reproject a reprojected WorkUnit.") + show_progress = is_interactive() if show_progress is None else show_progress if (work_unit.lazy or write_output) and (directory is None or filename is None): raise ValueError("can't write output to sharded fits without directory and filename provided.") @@ -195,7 +198,7 @@ def _reproject_work_unit( # Create a list of the correct WCS. We do this extraction once and reuse for all images. if frame == "original": - wcs_list = work_unit.get_constituent_meta("original_wcs") + wcs_list = work_unit.get_constituent_meta("per_image_wcs") elif frame == "ebd": wcs_list = work_unit.get_constituent_meta("ebd_wcs") else: @@ -283,24 +286,25 @@ def _reproject_work_unit( stack.append_image(new_layered_image, force_move=True) if write_output: + # Create a copy of the WorkUnit to write the global metadata. + # We preserve the metgadata for the consituent images. new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstime_indices new_work_unit.wcs = common_wcs new_work_unit.reprojected = True - hdul = new_work_unit.metadata_to_primary_header() + hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename)) else: + # Create a new WorkUnit with the new ImageStack and global WCS. + # We preserve the metgadata for the consituent images. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, - constituent_images=work_unit.get_constituent_meta("data_loc"), - per_image_wcs=work_unit._per_image_wcs, - per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), per_image_indices=unique_obstime_indices, - geocentric_distances=work_unit.get_constituent_meta("geocentric_distance"), reprojected=True, + org_image_meta=work_unit.org_img_meta, ) return new_wunit @@ -425,7 +429,7 @@ def _reproject_work_unit_in_parallel( new_work_unit.wcs = common_wcs new_work_unit.reprojected = True - hdul = new_work_unit.metadata_to_primary_header() + hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename)) else: stack = ImageStack([]) @@ -446,17 +450,15 @@ def _reproject_work_unit_in_parallel( # sort by the time_stamp stack.sort_by_time() - # Add the imageStack to a new WorkUnit and return it. + # Add the imageStack to a new WorkUnit and return it. We preserve the metgadata + # for the consituent images. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, - constituent_images=work_unit.get_constituent_meta("data_loc"), - per_image_wcs=work_unit._per_image_wcs, - per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), per_image_indices=unique_obstimes_indices, - geocentric_distances=work_unit.get_constituent_meta("geocentric_distances"), reprojected=True, + org_image_meta=work_unit.org_img_meta, ) return new_wunit @@ -549,6 +551,7 @@ def reproject_lazy_work_unit( if not result.result(): raise RuntimeError("one or more jobs failed.") + # We use new metadata for the new images and the same metadata for the original images. new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstimes_indices new_work_unit.wcs = common_wcs @@ -591,6 +594,8 @@ def _validate_original_wcs(work_unit, indices, frame="original"): else: raise ValueError("Invalid projection frame provided.") + if len(original_wcs) == 0: + raise ValueError(f"No WCS found for frame {frame}") if np.any(original_wcs) is None: # find indices where the wcs is None bad_indices = np.where(original_wcs == None) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index f0320624c..05bf542da 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -7,6 +7,7 @@ from astropy.table import Table from astropy.time import Time from astropy.utils.exceptions import AstropyWarning +from astropy.wcs import WCS from astropy.wcs.utils import skycoord_to_pixel import astropy.units as u @@ -21,7 +22,9 @@ from kbmod.wcs_utils import ( append_wcs_to_hdu_header, calc_ecliptic_angle, + deserialize_wcs, extract_wcs_from_hdu_header, + serialize_wcs, wcs_fits_equal, ) @@ -43,6 +46,9 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. + n_images : `int` + The number of images. This may differ from the length of the ImageStack due + to lazy loading. n_constituents : `int` The number of original images making up the data in this WorkUnit. This might be different from the number of images stored in memory if the WorkUnit has been @@ -53,13 +59,10 @@ class WorkUnit: * ebd_wcs - Used to reproject the images into EBD space. * geocentric_distance - The best fit geocentric distances used when creating the per image EBD WCS. - * original_wcs - The original WCS of the image. + * original_wcs - The original per-image WCS of the image. wcs : `astropy.wcs.WCS` A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. - per_image_wcs : `list` - A list with one WCS for each image in the WorkUnit. Used for when - the images have *not* been standardized to the same pixel space. heliocentric_distance : `float` The heliocentric distance that was used when creating the `per_image_ebd_wcs`. reprojected : `bool` @@ -82,53 +85,43 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. - wcs : `astropy.wcs.WCS` + num_images : `int`, optional + The number of images. This may differ from the length of the ImageStack due + to lazy loading. + wcs : `astropy.wcs.WCS`, optional A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. - constituent_images: `list` - A list of strings with the original locations of images used - to construct the WorkUnit. This is necessary to maintain as metadata - because after reprojection we may stitch multiple images into one. - per_image_wcs : `list` - A list with one WCS for each image in the WorkUnit. Used for when - the images have *not* been standardized to the same pixel space. - per_image_ebd_wcs : `list` - A list with one WCS for each image in the WorkUnit. Used to reproject the images - into EBD space. - heliocentric_distance : `float` - The heliocentric distance that was used when creating the `per_image_ebd_wcs`. - geocentric_distances : `list` - The best fit geocentric distances used when creating the `per_image_ebd_wcs`. - reprojected : `bool` + reprojected : `bool`, optional Whether or not the WorkUnit image data has been reprojected. - per_image_indices : `list` of `list` + per_image_indices : `list` of `list`, optional A list of lists containing the indicies of `constituent_images` at each layer of the `ImageStack`. Used for finding corresponding original images when we stitch images together during reprojection. - lazy : `bool` + heliocentric_distance : `float`, optional + The heliocentric distance that was used when creating the `per_image_ebd_wcs`. + lazy : `bool`, optional Whether or not to load the image data for the `WorkUnit`. - file_paths : `list[str]` + file_paths : `list[str]`, optional The paths for the shard files, only created if the `WorkUnit` is loaded in lazy mode. obstimes : `list[float]` The MJD obstimes of the images. + org_image_meta : `dict` or `astropy.table.Table`, optional + A table of per-image data for the constituent images. """ def __init__( self, - im_stack=None, - config=None, + im_stack, + config, wcs=None, - constituent_images=None, - per_image_wcs=None, - per_image_ebd_wcs=None, - heliocentric_distance=None, - geocentric_distances=None, reprojected=False, per_image_indices=None, + heliocentric_distance=None, lazy=False, file_paths=None, obstimes=None, + org_image_meta=None, ): self.im_stack = im_stack self.config = config @@ -136,47 +129,28 @@ def __init__( self.file_paths = file_paths self._obstimes = obstimes - # Determine the number of constituent images. If we are given a list of constituent_images, - # use that. Otherwise use the size of the image stack. - if constituent_images is None: + # Determine the number of constituent images. If we are given metadata for the + # of constituent_images, use that. Otherwise use the size of the image stack. + if org_image_meta is None: self.n_constituents = im_stack.img_count() else: - self.n_constituents = len(constituent_images) + self.n_constituents = len(org_image_meta) - # Track the meta data for each constituent image in the WorkUnit. For the original - # WCS, we track the per-image WCS if it is provided and otherwise the global WCS. - self.org_img_meta = Table() - self.add_org_img_meta_data("data_loc", constituent_images) - self.add_org_img_meta_data("original_wcs", per_image_wcs, default=wcs) + # Track the metadata for each constituent image in the WorkUnit. If no constituent + # data is provided, this will create an empty array the same size as the original. + self.org_img_meta = create_image_metadata(self.n_constituents, data=org_image_meta) # Handle WCS input. If both the global and per-image WCS are provided, # ensure they are consistent. self.wcs = wcs - if per_image_wcs is None: - self._per_image_wcs = [None] * self.n_constituents - if self.wcs is None and per_image_ebd_wcs is None: - warnings.warn("No WCS provided.", Warning) - else: - if len(per_image_wcs) != self.n_constituents: - raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_constituents}") - self._per_image_wcs = per_image_wcs - - # Check if all the per-image WCS are None. This can happen during a load. - all_none = self.per_image_wcs_all_match(None) - if self.wcs is None and all_none: - warnings.warn("No WCS provided.", Warning) - - # See if we can compress the per-image WCS into a global one. - if self.wcs is None and not all_none and self.per_image_wcs_all_match(self._per_image_wcs[0]): - self.wcs = self._per_image_wcs[0] - self._per_image_wcs = [None] * im_stack.img_count() - - # Add the meta data needed for reprojection, including: the reprojected WCS, the geocentric - # distances, and each images indices in the original constituent images. + if np.all(self.org_img_meta["per_image_wcs"] == None): + self.org_img_meta["per_image_wcs"] = np.full(self.n_constituents, self.wcs) + if np.any(self.org_img_meta["per_image_wcs"] == None): + warnings.warn("At least one image was does not have a WCS.", Warning) + + # Set the global metadata for reprojection. self.reprojected = reprojected self.heliocentric_distance = heliocentric_distance - self.add_org_img_meta_data("geocentric_distance", geocentric_distances) - self.add_org_img_meta_data("ebd_wcs", per_image_ebd_wcs) # If we have mosaicked images, each image in the stack could link back # to more than one constituents image. Build a mapping of image stack index @@ -186,59 +160,36 @@ def __init__( else: self._per_image_indices = per_image_indices + # Run some basic validity checks. + if self.reprojected and self.wcs is None: + raise ValueError("Global WCS required for reprojected data.") + for inds in self._per_image_indices: + if np.max(inds) >= self.n_constituents: + raise ValueError( + f"Found pointer to constituents image {np.max(inds)} of {self.n_constituents}" + ) + def __len__(self): """Returns the size of the WorkUnit in number of images.""" return self.im_stack.img_count() - def get_constituent_meta(self, column): - """Get the meta data values of a given column for all the constituent images.""" - return list(self.org_img_meta[column].data) + def get_num_images(self): + return len(self._per_image_indices) - def add_org_img_meta_data(self, column, data, default=None): - """Add a column of meta data for the constituent images. Adds a column of all - default values if data is None and the column does not already exist. + def get_constituent_meta(self, column): + """Get the meta data values of a given column for all the constituent images. Parameters ---------- column : `str` - The name of the meta data column. - data : list-like - The data for each constituent. If None then uses None for - each column. - """ - if data is None: - if column not in self.org_img_meta.colnames: - self.org_img_meta[column] = [default] * self.n_constituents - elif len(data) == self.n_constituents: - self.org_img_meta[column] = data - else: - raise ValueError( - f"Data mismatch size for WorkUnit metadata {column}. " - f"Expected {self.n_constituents} but found {len(data)}." - ) - - def has_common_wcs(self): - """Returns whether the WorkUnit has a common WCS for all images.""" - return self.wcs is not None - - def per_image_wcs_all_match(self, target=None): - """Check if all the per-image WCS are the same as a given target value. - - Parameters - ---------- - target : `astropy.wcs.WCS`, optional - The WCS to which to compare the per-image WCS. If None, checks that - all of the per-image WCS are None. + The column name to fetch. Returns ------- - result : `bool` - A Boolean indicating that all the per-images WCS match the target. + data : `list` + A list of the meta-data for each constituent image. """ - for current in self._per_image_wcs: - if not wcs_fits_equal(current, target): - return False - return True + return list(self.org_img_meta[column].data) def get_wcs(self, img_num): """Return the WCS for the a given image. Alway prioritizes @@ -253,26 +204,12 @@ def get_wcs(self, img_num): ------- wcs : `astropy.wcs.WCS` The image's WCS if one exists. Otherwise None. - - Raises - ------ - IndexError if an invalid index is given. """ - if img_num < 0 or img_num >= len(self._per_image_wcs): - raise IndexError(f"Invalid image number {img_num}") - - # Extract the per-image WCS if one exists. - if self._per_image_wcs is not None and img_num < len(self._per_image_wcs): - per_img = self._per_image_wcs[img_num] - else: - per_img = None - if self.wcs is not None: - if per_img is not None and not wcs_fits_equal(self.wcs, per_img): - warnings.warn("Both a global and per-image WCS given. Using global WCS.", Warning) return self.wcs - - return per_img + else: + # If there is no common WCS, use the original per-image one. + return self.org_img_meta["per_image_wcs"][img_num] def get_pixel_coordinates(self, ra, dec, times=None): """Get the pixel coordinates for pairs of (RA, dec) coordinates. Uses the global @@ -319,7 +256,7 @@ def get_pixel_coordinates(self, ra, dec, times=None): for i, index in enumerate(inds): if index == -1: raise ValueError(f"Unmatched time {times[i]}.") - current_wcs = self._per_image_wcs[index] + current_wcs = self.org_img_meta["per_image_wcs"][index] curr_x, curr_y = current_wcs.world_to_pixel( SkyCoord(ra=ra[i] * u.degree, dec=dec[i] * u.degree) ) @@ -363,9 +300,6 @@ def get_unique_obstimes_and_indices(self): unique_indices = [list(np.where(all_obstimes == time)[0]) for time in unique_obstimes] return unique_obstimes, unique_indices - def get_num_images(self): - return len(self._per_image_indices) - @classmethod def from_fits(cls, filename, show_progress=None): """Create a WorkUnit from a single FITS file. @@ -403,11 +337,22 @@ def from_fits(cls, filename, show_progress=None): if num_layers < 5: raise ValueError(f"WorkUnit file has too few extensions {len(hdul)}.") - # TODO - Read in provenance metadata from extension #1. - # Read in the search parameters from the 'kbmod_config' extension. config = SearchConfiguration.from_hdu(hdul["kbmod_config"]) + # Read the size and order information from the primary header. + num_images = hdul[0].header["NUMIMG"] + n_constituents = hdul[0].header["NCON"] if "NCON" in hdul[0].header else num_images + logger.info(f"Loading {num_images} images.") + + # Read in the per-image metadata for the constituent images. + if "IMG_META" in hdul: + logger.debug("Reading original image metadata from IMG_META.") + hdu_meta = hdu_to_image_metadata_table(hdul["IMG_META"]) + else: + hdu_meta = None + org_image_meta = create_image_metadata(n_constituents, data=hdu_meta) + # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension # since the primary header does not have an image. @@ -415,23 +360,18 @@ def from_fits(cls, filename, show_progress=None): warnings.simplefilter("ignore", AstropyWarning) global_wcs = extract_wcs_from_hdu_header(hdul[0].header) - # Read the size and order information from the primary header. - num_images = hdul[0].header["NUMIMG"] - n_constituents = hdul[0].header["NCON"] - expected_num_images = (4 * num_images) + (2 * n_constituents) + 3 - if len(hdul) != expected_num_images: - raise ValueError(f"WorkUnit wrong number of extensions. Expected " f"{expected_num_images}.") - logger.info(f"Loading {num_images} images and {expected_num_images} total layers.") - # Misc. reprojection metadata reprojected = hdul[0].header["REPRJCTD"] heliocentric_distance = hdul[0].header["HELIO"] - geocentric_distances = [] - for i in range(num_images): - geocentric_distances.append(hdul[0].header[f"GEO_{i}"]) - per_image_indices = [] + # If there is geocentric distances in the header information + # (legacy approach), in read those. + for i in range(n_constituents): + if f"GEO_{i}" in hdul[0].header: + org_image_meta["geocentric_distance"][i] = hdul[0].header[f"GEO_{i}"] + # Read in all the image files. + per_image_indices = [] for i in tqdm( range(num_images), bar_format=_DEFAULT_WORKUNIT_TQDM_BAR, @@ -452,37 +392,39 @@ def from_fits(cls, filename, show_progress=None): # force_move destroys img object, but avoids a copy. im_stack.append_image(img, force_move=True) + # Read the mapping of current image to constituent image from the header info. + # TODO: Serialize this into its own table. n_indices = sci_hdu.header["NIND"] sub_indices = [] for j in range(n_indices): sub_indices.append(sci_hdu.header[f"IND_{j}"]) per_image_indices.append(sub_indices) - per_image_wcs = [] - per_image_ebd_wcs = [] - constituent_images = [] + # Extract the per-image data from header information if needed. This happens + # when the WorkUnit was saved before metadata tables were saved as layers and + # all the information is in header values. for i in tqdm( range(n_constituents), bar_format=_DEFAULT_WORKUNIT_TQDM_BAR, desc="Loading WCS", disable=not show_progress, ): - # Extract the per-image WCS if one exists. - per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header)) - per_image_ebd_wcs.append(extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header)) - constituent_images.append(hdul[f"WCS_{i}"].header["ILOC"]) + if f"WCS_{i}" in hdul: + wcs_header = hdul[f"WCS_{i}"].header + org_image_meta["per_image_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) + if "ILOC" in wcs_header: + org_image_meta["data_loc"][i] = wcs_header["ILOC"] + if f"EBD_{i}" in hdul: + org_image_meta["ebd_wcs"][i] = extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header) result = WorkUnit( im_stack=im_stack, config=config, wcs=global_wcs, - constituent_images=constituent_images, - per_image_wcs=per_image_wcs, - per_image_ebd_wcs=per_image_ebd_wcs, heliocentric_distance=heliocentric_distance, - geocentric_distances=geocentric_distances, reprojected=reprojected, per_image_indices=per_image_indices, + org_image_meta=org_image_meta, ) return result @@ -508,8 +450,10 @@ def to_fits(self, filename, overwrite=False): if Path(filename).is_file() and not overwrite: raise FileExistsError(f"WorkUnit file {filename} already exists.") - hdul = self.metadata_to_primary_header(include_wcs=False) + # Create an HDU list with the metadata layers, including all the WCS info. + hdul = self.metadata_to_hdul() + # Create each image layer. for i in range(self.im_stack.img_count()): layered = self.im_stack.get_single_image(i) obstime = layered.get_obstime() @@ -538,8 +482,6 @@ def to_fits(self, filename, overwrite=False): psf_hdu.name = f"PSF_{i}" hdul.append(psf_hdu) - self.append_all_wcs(hdul) - hdul.writeto(filename, overwrite=overwrite) def to_sharded_fits(self, filename, directory, overwrite=False): @@ -553,7 +495,6 @@ def to_sharded_fits(self, filename, directory, overwrite=False): image index infront of the given filename, e.g. "0_filename.fits". - Primary File: 0 - Primary header with overall metadata 1 or "metadata" - The data provenance metadata @@ -615,7 +556,8 @@ def to_sharded_fits(self, filename, directory, overwrite=False): sub_hdul.append(psf_hdu) sub_hdul.writeto(os.path.join(directory, f"{i}_{filename}")) - hdul = self.metadata_to_primary_header(include_wcs=True) + # Create a primary file with all of the metadata, including all the WCS info. + hdul = self.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename), overwrite=overwrite) @classmethod @@ -660,6 +602,19 @@ def from_sharded_fits(cls, filename, directory, lazy=False): with fits.open(os.path.join(directory, filename)) as primary: config = SearchConfiguration.from_hdu(primary["kbmod_config"]) + # Read the size and order information from the primary header. + num_images = primary[0].header["NUMIMG"] + n_constituents = primary[0].header["NCON"] if "NCON" in primary[0].header else num_images + logger.info(f"Loading {num_images} images.") + + # Read in the per-image metadata for the constituent images. + if "IMG_META" in primary: + logger.debug("Reading original image metadata from IMG_META.") + hdu_meta = hdu_to_image_metadata_table(primary["IMG_META"]) + else: + hdu_meta = None + org_image_meta = create_image_metadata(n_constituents, data=hdu_meta) + # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension # since the primary header does not have an image. @@ -667,26 +622,25 @@ def from_sharded_fits(cls, filename, directory, lazy=False): warnings.simplefilter("ignore", AstropyWarning) global_wcs = extract_wcs_from_hdu_header(primary[0].header) - # Read the size and order information from the primary header. - num_images = primary[0].header["NUMIMG"] - n_constituents = primary[0].header["NCON"] - expected_num_images = (4 * num_images) + (2 * n_constituents) + 3 - # Misc. reprojection metadata reprojected = primary[0].header["REPRJCTD"] heliocentric_distance = primary[0].header["HELIO"] - geocentric_distances = [] for i in range(n_constituents): - geocentric_distances.append(primary[0].header[f"GEO_{i}"]) + if f"GEO_{i}" in primary[0].header: + org_image_meta["geocentric_distance"][i] = primary[0].header[f"GEO_{i}"] - per_image_wcs = [] - per_image_ebd_wcs = [] - constituent_images = [] + # Extract the per-image data from header information if needed. + # This happens with when the WorkUnit was saved before metadata tables were + # saved as layers. for i in range(n_constituents): - # Extract the per-image WCS if one exists. - per_image_wcs.append(extract_wcs_from_hdu_header(primary[f"WCS_{i}"].header)) - per_image_ebd_wcs.append(extract_wcs_from_hdu_header(primary[f"EBD_{i}"].header)) - constituent_images.append(primary[f"WCS_{i}"].header["ILOC"]) + if f"WCS_{i}" in primary: + wcs_header = primary[f"WCS_{i}"].header + org_image_meta["per_image_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) + if "ILOC" in wcs_header: + org_image_meta["data_loc"][i] = wcs_header["ILOC"] + if f"EBD_{i}" in primary: + org_image_meta["ebd_wcs"][i] = extract_wcs_from_hdu_header(primary[f"EBD_{i}"].header) + per_image_indices = [] file_paths = [] obstimes = [] @@ -707,6 +661,7 @@ def from_sharded_fits(cls, filename, directory, lazy=False): else: file_paths.append(shard_path) + # Load the mapping of current image to constituent image. n_indices = sci_hdu.header["NIND"] sub_indices = [] for j in range(n_indices): @@ -718,29 +673,19 @@ def from_sharded_fits(cls, filename, directory, lazy=False): im_stack=im_stack, config=config, wcs=global_wcs, - constituent_images=constituent_images, - per_image_wcs=per_image_wcs, - per_image_ebd_wcs=per_image_ebd_wcs, - heliocentric_distance=heliocentric_distance, - geocentric_distances=geocentric_distances, reprojected=reprojected, - per_image_indices=per_image_indices, lazy=lazy, + heliocentric_distance=heliocentric_distance, + per_image_indices=per_image_indices, file_paths=file_paths, obstimes=obstimes, + org_image_meta=org_image_meta, ) return result - def metadata_to_primary_header(self, include_wcs=True): + def metadata_to_hdul(self): """Creates the metadata fits headers. - Parameters - ---------- - include_wcs : `bool` - whether or not to append all the per image wcses - to the header (optional for the serial `to_fits` - case so that we can maintain the same indexing - as before). Returns ------- hdul : `astropy.io.fits.HDUList` @@ -754,53 +699,22 @@ def metadata_to_primary_header(self, include_wcs=True): pri.header["NCON"] = self.n_constituents pri.header["REPRJCTD"] = self.reprojected pri.header["HELIO"] = self.heliocentric_distance - for i in range(self.n_constituents): - pri.header[f"GEO_{i}"] = self.org_img_meta["geocentric_distance"][i] - # If the global WCS exists, append the corresponding keys. + # If the global WCS exists, append the corresponding keys to the primary header. if self.wcs is not None: append_wcs_to_hdu_header(self.wcs, pri.header) - hdul.append(pri) - meta_hdu = fits.BinTableHDU() - meta_hdu.name = "metadata" - hdul.append(meta_hdu) - + # Add the configuration layer. config_hdu = self.config.to_hdu() config_hdu.name = "kbmod_config" hdul.append(config_hdu) - if include_wcs: - self.append_all_wcs(hdul) + # Save the additional metadata table into HDUs + hdul.append(image_metadata_table_to_hdu(self.org_img_meta, "IMG_META")) return hdul - def append_all_wcs(self, hdul): - """Append all the original WCS and EBD WCS to a header. - - Parameters - ---------- - hdul : `astropy.io.fits.HDUList` - The HDU list. - """ - all_ebd_wcs = self.get_constituent_meta("ebd_wcs") - - for i in range(self.n_constituents): - img_location = self.org_img_meta["data_loc"][i] - - orig_wcs = self._per_image_wcs[i] - wcs_hdu = fits.TableHDU() - append_wcs_to_hdu_header(orig_wcs, wcs_hdu.header) - wcs_hdu.name = f"WCS_{i}" - wcs_hdu.header["ILOC"] = img_location - hdul.append(wcs_hdu) - - ebd_hdu = fits.TableHDU() - append_wcs_to_hdu_header(all_ebd_wcs[i], ebd_hdu.header) - ebd_hdu.name = f"EBD_{i}" - hdul.append(ebd_hdu) - def image_positions_to_original_icrs( self, image_indices, positions, input_format="xy", output_format="xy", filter_in_frame=True ): @@ -829,6 +743,7 @@ def image_positions_to_original_icrs( Whether or not to filter the output based on whether they fit within the original `constituent_image` frame. If `True`, only results that fall within the bounds of the original WCS will be returned. + Returns ------- positions : `list` of `astropy.coordinates.SkyCoord`s or `tuple`s @@ -888,7 +803,7 @@ def image_positions_to_original_icrs( pos = [] for j in inds: con_image = self.org_img_meta["data_loc"][j] - con_wcs = self._per_image_wcs[j] + con_wcs = self.org_img_meta["per_image_wcs"][j] height, width = con_wcs.array_shape x, y = skycoord_to_pixel(coord, con_wcs) x, y = float(x), float(y) @@ -985,3 +900,115 @@ def raw_image_to_hdu(img, obstime, wcs=None): hdu.header["MJD"] = obstime return hdu + + +# ------------------------------------------------------------------ +# --- Utility functions for the metadata table --------------------- +# ------------------------------------------------------------------ + + +def create_image_metadata(n_images, data=None): + """Create an empty img_meta table, filling in default values + for any unspecified columns. + + Parameters + ---------- + n_images : `int` + The number of images to include. + data : `astropy.table.Table` + An existing table from which to fill in some of the data. + + Returns + ------- + img_meta : `astropy.table.Table` + The empty table of org_img_meta. + """ + if n_images <= 0: + raise ValueError("Invalid metadata size: {n_images}") + img_meta = Table() + + # Fill in the defaults. + for colname in ["data_loc", "ebd_wcs", "geocentric_distance", "per_image_wcs"]: + img_meta[colname] = np.full(n_images, None) + + # Fill in any values from the given table. This overwrites the defaults. + if data is not None: + if len(data) != n_images: + raise ValueError(f"Metadata size mismatch. Expected {n_images}. Found {len(data)}") + for colname in data.colnames: + img_meta[colname] = data[colname] + + return img_meta + + +def image_metadata_table_to_hdu(data, layer_name=None): + """Create a HDU layer from an astropy table with custom + encodings for some columns (such as WCS). + + Parameters + ---------- + data : `astropy.table.Table` + The table of the data to save. + layer_name : `str`, optional + The name of the layer in which to save the table. + """ + num_rows = len(data) + if num_rows == 0: + # No data to encode. Just use the current table. + save_table = data + else: + # Create a new table to save with the correct column + # values/names for the serialized information. + save_table = Table() + for colname in data.colnames: + col_data = data[colname].value + + if np.all(col_data == None): + # The entire column is filled with Nones (probably from a default value). + save_table[f"_EMPTY_{colname}"] = np.full(num_rows, "None", dtype=str) + elif isinstance(col_data[0], WCS): + # Serialize WCS objects and use a custom tag so we can unserialize them. + values = np.array([serialize_wcs(entry) for entry in data[colname]], dtype=str) + save_table[f"_WCSSTR_{colname}"] = values + else: + save_table[colname] = data[colname] + + # Format the metadata as a single HDU + meta_hdu = fits.TableHDU(save_table) + if layer_name is not None: + meta_hdu.name = layer_name + return meta_hdu + + +def hdu_to_image_metadata_table(hdu): + """Load a HDU layer with custom encodings for some columns (such as WCS) + to an astropy table. + + Parameters + ---------- + hdu : `astropy.io.fits.TableHDU` + The HDUList for the fits file. + + Returns + ------- + data : `astropy.table.Table` + The table of loaded data. + """ + if hdu is None: + # Nothing to decode. Return an empty table. + return Table() + + data = Table(hdu.data) + all_cols = set(data.colnames) + + # Check if there are any columns we need to decode. If so: decode them, add a new column, + # and delete the old column. + for colname in all_cols: + if colname.startswith("_WCSSTR_"): + data[colname[8:]] = np.array([deserialize_wcs(entry) for entry in data[colname]]) + data.remove_column(colname) + elif colname.startswith("_EMPTY_"): + data[colname[7:]] = np.array([None for _ in data[colname]]) + data.remove_column(colname) + + return data diff --git a/tests/test_reprojection.py b/tests/test_reprojection.py index 1135c7b50..63c13defd 100644 --- a/tests/test_reprojection.py +++ b/tests/test_reprojection.py @@ -19,9 +19,6 @@ def setUp(self): self.test_wunit = WorkUnit.from_fits(self.data_path) self.common_wcs = self.test_wunit.get_wcs(0) - # self.tmp_dir = os.path(tempfile.TemporaryDirectory()) - # self.test_wunit.to_sharded_fits("test_wunit.fits", self.tmp_dir) - def test_reproject(self): # test exception conditions self.assertRaises( diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index abd61fa23..8e2a86c2f 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -4,6 +4,7 @@ from astropy.coordinates import EarthLocation, SkyCoord from astropy.time import Time import numpy as np +import numpy.testing as npt import os from pathlib import Path import tempfile @@ -15,7 +16,13 @@ import kbmod.search as kb from kbmod.reprojection_utils import fit_barycentric_wcs from kbmod.wcs_utils import make_fake_wcs, wcs_fits_equal -from kbmod.work_unit import raw_image_to_hdu, WorkUnit +from kbmod.work_unit import ( + create_image_metadata, + hdu_to_image_metadata_table, + image_metadata_table_to_hdu, + raw_image_to_hdu, + WorkUnit, +) import numpy.testing as npt @@ -96,18 +103,25 @@ def setUp(self): "five.fits", ] + self.org_image_meta = Table( + { + "data_loc": np.array(self.constituent_images), + "ebd_wcs": np.array([self.per_image_ebd_wcs] * self.num_images), + "geocentric_distance": np.array([self.geo_dist] * self.num_images), + "per_image_wcs": np.array(self.per_image_wcs), + } + ) + def test_create(self): # Test the creation of a WorkUnit with no WCS. Should throw a warning. with warnings.catch_warnings(record=True) as wrn: warnings.simplefilter("always") work = WorkUnit(self.im_stack, self.config) - self.assertTrue("No WCS provided." in str(wrn[-1].message)) self.assertIsNotNone(work) self.assertEqual(work.im_stack.img_count(), 5) self.assertEqual(work.config["im_filepath"], "Here") self.assertEqual(work.config["num_obs"], 5) - self.assertFalse(work.has_common_wcs()) self.assertIsNone(work.wcs) self.assertEqual(len(work), self.num_images) for i in range(self.num_images): @@ -116,7 +130,6 @@ def test_create(self): # Create with a global WCS work2 = WorkUnit(self.im_stack, self.config, self.wcs) self.assertEqual(work2.im_stack.img_count(), 5) - self.assertTrue(work2.has_common_wcs()) self.assertIsNotNone(work2.wcs) for i in range(self.num_images): self.assertIsNotNone(work2.get_wcs(i)) @@ -129,26 +142,60 @@ def test_create(self): self.im_stack, self.config, self.wcs, - [f"img_{i}" for i in range(self.im_stack.img_count())], [self.wcs, self.wcs, self.wcs], ) - # Create with per-image WCS that can be compressed to a global WCS. - per_image_wcs = [self.wcs] * self.num_images - work3 = WorkUnit(self.im_stack, self.config, per_image_wcs=per_image_wcs) - self.assertIsNotNone(work3.wcs) - self.assertTrue(work3.has_common_wcs()) - for i in range(self.num_images): - self.assertIsNotNone(work3.get_wcs(i)) - self.assertTrue(wcs_fits_equal(self.wcs, work3.get_wcs(i))) - - # Create with per-image WCS that cannot be compressed to a global WCS. - work3 = WorkUnit(self.im_stack, self.config, per_image_wcs=self.diff_wcs) - self.assertIsNone(work3.wcs) - self.assertFalse(work3.has_common_wcs()) - for i in range(self.num_images): - self.assertIsNotNone(work3.get_wcs(i)) - self.assertTrue(wcs_fits_equal(work3.get_wcs(i), self.diff_wcs[i])) + def test_metadata_helpers(self): + """Test that we can roundtrip an astropy table of metadata (including) WCS + into a BinTableHDU. + """ + metadata_dict = { + "col1": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), # Floats + "uri": np.array(["a", "bc", "def", "ghij", "other_strings"]), # Strings + "wcs": np.array(self.per_image_wcs), # WCSes + "none_col": np.array([None] * self.num_images), # Empty column + "Other": np.arange(5), # ints + } + metadata_table = Table(metadata_dict) + + # Convert to an HDU + hdu = image_metadata_table_to_hdu(metadata_table) + self.assertIsNotNone(hdu) + + # Convert it back. + md_table2 = hdu_to_image_metadata_table(hdu) + self.assertEqual(len(md_table2.colnames), 5) + npt.assert_array_equal(metadata_dict["col1"], md_table2["col1"]) + npt.assert_array_equal(metadata_dict["uri"], md_table2["uri"]) + npt.assert_array_equal(metadata_dict["Other"], md_table2["Other"]) + self.assertTrue(np.all(md_table2["none_col"] == None)) + for i in range(len(md_table2)): + self.assertTrue(isinstance(md_table2["wcs"][i], WCS)) + + def test_create_image_meta(self): + # Empty constituent image data. + org_img_meta = create_image_meta(n_images=3, data=None) + self.assertEqual(len(org_img_meta), 3) + self.assertTrue("data_loc" in org_img_meta.colnames) + self.assertTrue("ebd_wcs" in org_img_meta.colnames) + self.assertTrue("geocentric_distance" in org_img_meta.colnames) + self.assertTrue("per_image_wcs" in org_img_meta.colnames) + + # We can create from a dictionary. In this case we ignore n_images + data_dict = { + "uri": ["file1", "file2", "file3"], + "geocentric_distance": [1.0, 2.0, 3.0], + } + org_img_meta2 = create_image_meta(n_images=5, data=data_dict) + self.assertEqual(len(org_img_meta2), 3) + self.assertTrue("data_loc" in org_img_meta2.colnames) + self.assertTrue("ebd_wcs" in org_img_meta2.colnames) + self.assertTrue("geocentric_distance" in org_img_meta2.colnames) + self.assertTrue("per_image_wcs" in org_img_meta2.colnames) + self.assertTrue("uri" in org_img_meta2.colnames) + + # We need either data or a positive number of images. + self.assertRaises(ValueError, create_image_meta, None, -1) def test_save_and_load_fits(self): with tempfile.TemporaryDirectory() as dir_name: @@ -159,13 +206,19 @@ def test_save_and_load_fits(self): self.assertRaises(ValueError, WorkUnit.from_fits, file_path) # Write out the existing WorkUnit with a different per-image wcs for all the entries. - # work = WorkUnit(self.im_stack, self.config, None, self.diff_wcs) + # work = WorkUnit(self.im_stack, self.config, None, self.diff_wcs). + # Include extra per-image metadata. + extra_meta = { + "data_loc": np.array(self.constituent_images), + "int_index": np.arange(self.num_images), + "uri": np.array([f"file_loc_{i}" for i in range(self.num_images)]), + } work = WorkUnit( im_stack=self.im_stack, config=self.config, wcs=None, per_image_wcs=self.diff_wcs, - constituent_images=self.constituent_images, + org_image_meta=extra_meta, ) work.to_fits(file_path) self.assertTrue(Path(file_path).is_file()) @@ -174,7 +227,6 @@ def test_save_and_load_fits(self): work2 = WorkUnit.from_fits(file_path) self.assertEqual(work2.im_stack.img_count(), self.num_images) self.assertIsNone(work2.wcs) - self.assertFalse(work2.has_common_wcs()) for i in range(self.num_images): li = work2.im_stack.get_single_image(i) self.assertEqual(li.get_width(), self.width) @@ -214,9 +266,10 @@ def test_save_and_load_fits(self): self.assertEqual(work2.config["im_filepath"], "Here") self.assertEqual(work2.config["num_obs"], self.num_images) - # Check that we correctly retrieved the provenance information via “data_loc” - for index, value in enumerate(work2.org_img_meta["data_loc"]): - self.assertEqual(value, self.constituent_images[index]) + # Check that we retrieved the extra metadata that we added. + npt.assert_array_equal(work2.get_constituent_meta("uri"), extra_meta["uri"]) + npt.assert_array_equal(work2.get_constituent_meta("int_index"), extra_meta["int_index"]) + npt.assert_array_equal(work2.get_constituent_meta("data_loc"), self.constituent_images) # We throw an error if we try to overwrite a file with overwrite=False self.assertRaises(FileExistsError, work.to_fits, file_path) @@ -242,7 +295,6 @@ def test_save_and_load_fits_shard(self): work2 = WorkUnit.from_sharded_fits(filename="test_workunit.fits", directory=dir_name) self.assertEqual(work2.im_stack.img_count(), self.num_images) self.assertIsNone(work2.wcs) - self.assertFalse(work2.has_common_wcs()) for i in range(self.num_images): li = work2.im_stack.get_single_image(i) self.assertEqual(li.get_width(), self.width) @@ -298,7 +350,6 @@ def test_save_and_load_fits_shard_lazy(self): work2 = WorkUnit.from_sharded_fits(filename="test_workunit.fits", directory=dir_name, lazy=True) self.assertEqual(len(work2.file_paths), self.num_images) self.assertIsNone(work2.wcs) - self.assertFalse(work2.has_common_wcs()) # Check that we read in the configuration values correctly. self.assertEqual(work2.config["im_filepath"], "Here") @@ -321,7 +372,6 @@ def test_save_and_load_fits_global_wcs(self): # Read in the file and check that the values agree. work2 = WorkUnit.from_fits(file_path) self.assertIsNotNone(work2.wcs) - self.assertTrue(work2.has_common_wcs()) self.assertTrue(wcs_fits_equal(work2.wcs, self.wcs)) for i in range(self.num_images): self.assertIsNotNone(work2.get_wcs(i)) @@ -341,12 +391,9 @@ def test_image_positions_to_original_icrs_invalid_format(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) # Incorrect format for 'xy' @@ -381,12 +428,9 @@ def test_image_positions_to_original_icrs_basic_inputs(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) res = work.image_positions_to_original_icrs( @@ -430,12 +474,9 @@ def test_image_positions_to_original_icrs_filtering(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) res = work.image_positions_to_original_icrs( @@ -458,16 +499,13 @@ def test_image_positions_to_original_icrs_mosaicking(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) new_wcs = make_fake_wcs(190.0, -7.7888, 500, 700) - work._per_image_wcs[-1] = new_wcs + work.org_img_meta["per_image_wcs"][-1] = new_wcs work._per_image_indices[3] = [3, 4] res = work.image_positions_to_original_icrs( @@ -494,9 +532,6 @@ def test_image_positions_to_original_icrs_mosaicking(self): npt.assert_almost_equal(res[3][0].separation(self.expected_radec_positions[3]).deg, 0.0, decimal=5) assert res[3][1] == "five.fits" - # work._per_image_wcs[4] = work._per_image_wcs[3] - # work._per_image_ebd_wcs[4] = work._per_image_ebd_wcs[3] - res = work.image_positions_to_original_icrs( self.indices, self.pixel_positions, @@ -517,11 +552,8 @@ def test_get_unique_obstimes_and_indices(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, + org_image_meta=self.org_image_meta, ) times = work.get_all_obstimes() times[-1] = times[-2]