From acdf3511acdad97ee1f06ea989fe78857f3ab9bb Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 13 Nov 2024 14:37:43 -0500 Subject: [PATCH 01/15] Allow None in WCS serialization --- src/kbmod/wcs_utils.py | 14 ++++++++++---- tests/test_wcs_utils.py | 7 +++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/kbmod/wcs_utils.py b/src/kbmod/wcs_utils.py index 890e6aaae..1a84509ad 100644 --- a/src/kbmod/wcs_utils.py +++ b/src/kbmod/wcs_utils.py @@ -450,14 +450,17 @@ def serialize_wcs(wcs): Parameters ---------- - wcs : `astropy.wcs.WCS` + wcs : `astropy.wcs.WCS` or None The WCS to convert. Returns ------- wcs_str : `str` - The serialized WCS. + The serialized WCS. Returns an empty string if wcs is None. """ + if wcs is None: + return "" + # Since AstroPy's WCS does not output NAXIS, we need to manually add those. header = wcs.to_header(relax=True) header["NAXIS1"], header["NAXIS2"] = wcs.pixel_shape @@ -474,9 +477,12 @@ def deserialize_wcs(wcs_str): Returns ------- - wcs : `astropy.wcs.WCS` - The resulting WCS. + wcs : `astropy.wcs.WCS` or None + The resulting WCS or None if no data is provided. """ + if wcs_str == "" or wcs_str.lower() == "none": + return None + wcs_dict = json.loads(wcs_str) wcs = astropy.wcs.WCS(wcs_dict) wcs.pixel_shape = (wcs_dict["NAXIS1"], wcs_dict["NAXIS2"]) diff --git a/tests/test_wcs_utils.py b/tests/test_wcs_utils.py index 043d4e596..4e5791bb2 100644 --- a/tests/test_wcs_utils.py +++ b/tests/test_wcs_utils.py @@ -67,6 +67,13 @@ def test_serialization(self): self.assertEqual(self.wcs.pixel_shape, wcs2.pixel_shape) self.assertTrue(wcs_fits_equal(self.wcs, wcs2)) + # Test that we can serialize and deserialize None. + none_str = serialize_wcs(None) + self.assertEqual(none_str, "") + self.assertIsNone(deserialize_wcs("")) + self.assertIsNone(deserialize_wcs("none")) + self.assertIsNone(deserialize_wcs("None")) + def test_append_wcs_to_hdu_header(self): for use_dictionary in [True, False]: if use_dictionary: From cf97bf79f6786c2d7e4b31d794776a3d3e8e954d Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:19:48 -0500 Subject: [PATCH 02/15] v1 of saving meta data to a HDU --- src/kbmod/work_unit.py | 137 ++++++++++++++++++++++++++++++++++++++-- tests/test_work_unit.py | 34 +++++++++- 2 files changed, 164 insertions(+), 7 deletions(-) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index f0320624c..b56597ab2 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -7,6 +7,7 @@ from astropy.table import Table from astropy.time import Time from astropy.utils.exceptions import AstropyWarning +from astropy.wcs import WCS from astropy.wcs.utils import skycoord_to_pixel import astropy.units as u @@ -21,7 +22,9 @@ from kbmod.wcs_utils import ( append_wcs_to_hdu_header, calc_ecliptic_angle, + deserialize_wcs, extract_wcs_from_hdu_header, + serialize_wcs, wcs_fits_equal, ) @@ -112,6 +115,8 @@ class WorkUnit: in lazy mode. obstimes : `list[float]` The MJD obstimes of the images. + per_image_meta : `dict` or `astropy.table.Table` + A table of additional per-image metadata. """ def __init__( @@ -129,6 +134,7 @@ def __init__( lazy=False, file_paths=None, obstimes=None, + per_image_meta=None, ): self.im_stack = im_stack self.config = config @@ -145,7 +151,14 @@ def __init__( # Track the meta data for each constituent image in the WorkUnit. For the original # WCS, we track the per-image WCS if it is provided and otherwise the global WCS. - self.org_img_meta = Table() + if per_image_meta is None: + self.org_img_meta = Table() + elif isinstance(per_image_meta, Table): + self.org_img_meta = per_image_meta.copy() + elif isinstance(per_image_meta, dict): + self.org_img_meta = Table(per_image_meta) + else: + raise ValueError(f"Invalid type for per_image_meta: {type(per_image_meta)}") self.add_org_img_meta_data("data_loc", constituent_images) self.add_org_img_meta_data("original_wcs", per_image_wcs, default=wcs) @@ -191,7 +204,18 @@ def __len__(self): return self.im_stack.img_count() def get_constituent_meta(self, column): - """Get the meta data values of a given column for all the constituent images.""" + """Get the meta data values of a given column for all the constituent images. + + Parameters + ---------- + column : `str` + The column name to fetch. + + Returns + ------- + data : `list` + A list of the meta-data for each constituent image. + """ return list(self.org_img_meta[column].data) def add_org_img_meta_data(self, column, data, default=None): @@ -205,12 +229,15 @@ def add_org_img_meta_data(self, column, data, default=None): data : list-like The data for each constituent. If None then uses None for each column. + dtype : type + The numpy dtype to use when storing this data. If None + """ if data is None: if column not in self.org_img_meta.colnames: - self.org_img_meta[column] = [default] * self.n_constituents + self.org_img_meta[column] = np.array([default] * self.n_constituents) elif len(data) == self.n_constituents: - self.org_img_meta[column] = data + self.org_img_meta[column] = np.array(data) else: raise ValueError( f"Data mismatch size for WorkUnit metadata {column}. " @@ -403,11 +430,15 @@ def from_fits(cls, filename, show_progress=None): if num_layers < 5: raise ValueError(f"WorkUnit file has too few extensions {len(hdul)}.") - # TODO - Read in provenance metadata from extension #1. - # Read in the search parameters from the 'kbmod_config' extension. config = SearchConfiguration.from_hdu(hdul["kbmod_config"]) + # Read in the per-image metadata. + if "IMG_META" in hdul: + per_image_meta = hdu_to_metadata_table(hdul["IMG_META"]) + else: + per_image_meta = None + # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension # since the primary header does not have an image. @@ -483,6 +514,7 @@ def from_fits(cls, filename, show_progress=None): geocentric_distances=geocentric_distances, reprojected=reprojected, per_image_indices=per_image_indices, + per_image_meta=per_image_meta, ) return result @@ -538,6 +570,9 @@ def to_fits(self, filename, overwrite=False): psf_hdu.name = f"PSF_{i}" hdul.append(psf_hdu) + # Format additional metadata. + meta_hdu = metadata_table_to_hdu(self.org_img_meta, "IMG_META") + hdul.append(meta_hdu) self.append_all_wcs(hdul) hdul.writeto(filename, overwrite=overwrite) @@ -616,6 +651,11 @@ def to_sharded_fits(self, filename, directory, overwrite=False): sub_hdul.writeto(os.path.join(directory, f"{i}_{filename}")) hdul = self.metadata_to_primary_header(include_wcs=True) + + # Format additional metadata as a single HDU + meta_hdu = metadata_table_to_hdu(self.org_img_meta, "IMG_META") + hdul.append(meta_hdu) + hdul.writeto(os.path.join(directory, filename), overwrite=overwrite) @classmethod @@ -660,6 +700,12 @@ def from_sharded_fits(cls, filename, directory, lazy=False): with fits.open(os.path.join(directory, filename)) as primary: config = SearchConfiguration.from_hdu(primary["kbmod_config"]) + # Read in the per-image metadata. + if "IMG_META" in primary: + per_image_meta = hdu_to_metadata_table(primary["IMG_META"]) + else: + per_image_meta = None + # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension # since the primary header does not have an image. @@ -728,6 +774,7 @@ def from_sharded_fits(cls, filename, directory, lazy=False): lazy=lazy, file_paths=file_paths, obstimes=obstimes, + per_image_meta=per_image_meta, ) return result @@ -985,3 +1032,81 @@ def raw_image_to_hdu(img, obstime, wcs=None): hdu.header["MJD"] = obstime return hdu + + +# ------------------------------------------------------------------ +# --- Utility functions for the metadata table --------------------- +# ------------------------------------------------------------------ + + +def metadata_table_to_hdu(data, layer_name=None): + """Create a HDU layer from an astropy table with custom + encodings for some columns (such as WCS). + + Parameters + ---------- + data : `astropy.table.Table` + The table of the data to save. + layer_name : `str`, optional + The name of the layer in which to save the table. + """ + num_rows = len(data) + if num_rows == 0: + # No data to encode. Just use the current table. + meta_hdu = fits.BinTableHDU(data) + else: + # Create a new table to save with the correct column + # values/names for the serialized information. + save_table = Table() + for colname in data.colnames: + col_data = data[colname].value + + if np.all(col_data == None): + # The entire column is filled with Nones (probably from a default value). + save_table[f"_EMPTY_{colname}"] = np.full(num_rows, "None") + elif isinstance(col_data[0], WCS): + # Serialize WCS objects and use a custom tag so we can unserialize them. + values = np.array([serialize_wcs(entry) for entry in data[colname]]) + save_table[f"_WCSSTR_{colname}"] = values + else: + save_table[colname] = data[colname] + + # Format the metadata as a single HDU + meta_hdu = fits.BinTableHDU(save_table) + if layer_name is not None: + meta_hdu.name = layer_name + return meta_hdu + + +def hdu_to_metadata_table(hdu): + """Load a HDU layer with custom encodings for some columns (such as WCS) + to an astropy table. + + Parameters + ---------- + hdu : `astropy.io.fits.BinTableHDU` + The HDUList for the fits file. + + Returns + ------- + data : `astropy.table.Table` + The table of loaded data. + """ + if hdu is None: + # Nothing to decode. Return an empty table. + return Table() + + data = Table(hdu.data) + all_cols = set(data.colnames) + + # Check if there are any columns we need to decode. If so: decode them, add a new column, + # and delete the old column. + for colname in all_cols: + if colname.startswith("_WCSSTR_"): + data[colname[8:]] = np.array([deserialize_wcs(entry) for entry in data[colname]]) + data.remove_column(colname) + elif colname.startswith("_EMPTY_"): + data[colname[7:]] = np.array([None for _ in data[colname]]) + data.remove_column(colname) + + return data diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index abd61fa23..283aa2ff0 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -15,7 +15,12 @@ import kbmod.search as kb from kbmod.reprojection_utils import fit_barycentric_wcs from kbmod.wcs_utils import make_fake_wcs, wcs_fits_equal -from kbmod.work_unit import raw_image_to_hdu, WorkUnit +from kbmod.work_unit import ( + hdu_to_metadata_table, + metadata_table_to_hdu, + raw_image_to_hdu, + WorkUnit, +) import numpy.testing as npt @@ -150,6 +155,33 @@ def test_create(self): self.assertIsNotNone(work3.get_wcs(i)) self.assertTrue(wcs_fits_equal(work3.get_wcs(i), self.diff_wcs[i])) + def test_metadata_helpers(self): + """Test that we can roundtrip an astropy table of metadata (including) WCS + into a BinTableHDU. + """ + metadata_dict = { + "col1": np.array([1.0, 2.0, 3.0, 4.0, 5.0]), # Floats + "uri": np.array(["a", "bc", "def", "ghij", "other_strings"]), # Strings + "wcs": np.array(self.per_image_wcs), # WCSes + "none_col": np.array([None] * self.num_images), # Empty column + "Other": np.arange(5), # ints + } + metadata_table = Table(metadata_dict) + + # Convert to an HDU + hdu = metadata_table_to_hdu(metadata_table) + self.assertIsNotNone(hdu) + + # Convert it back. + md_table2 = hdu_to_metadata_table(hdu) + self.assertEqual(len(md_table2.colnames), 5) + npt.assert_array_equal(metadata_dict["col1"], md_table2["col1"]) + npt.assert_array_equal(metadata_dict["uri"], md_table2["uri"]) + npt.assert_array_equal(metadata_dict["Other"], md_table2["Other"]) + self.assertTrue(np.all(md_table2["none_col"] == None)) + for i in range(len(md_table2)): + self.assertTrue(isinstance(md_table2["wcs"][i], WCS)) + def test_save_and_load_fits(self): with tempfile.TemporaryDirectory() as dir_name: file_path = os.path.join(dir_name, "test_workunit.fits") From 43a6250c342cf7360c3aa1e14ebdc0375d7197d3 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:31:20 -0500 Subject: [PATCH 03/15] Save all metadata in tables part 1 Tests do not pass yet --- src/kbmod/reprojection.py | 64 +++-- src/kbmod/work_unit.py | 477 ++++++++++++++++---------------------- tests/test_work_unit.py | 84 +++---- 3 files changed, 281 insertions(+), 344 deletions(-) diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 2c2fc68c3..431080b36 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -19,6 +19,31 @@ _DEFAULT_TQDM_BAR = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]" +def create_new_image_metadata(unique_obstime_indices, common_wcs): + """Create a table of the metadata for the new reprojected images. + + Parameters + ---------- + unique_obstime_indices : `numpy.ndarray` + An array of lists (or arrays) indicating from which original images + the new images were created. + common_wcs : `astropy.wcs.WCS` + The new WCS for the images. + + Returns + ------- + metadata : `astropy.table.Table` + A table of metadata for the new images. + """ + metadata = Table( + { + "per_image_indices": np.array(unique_obstime_indices), + "wcs": np.full(len(unique_obstime_indices), common_wcs), + } + ) + return metadata + + def reproject_image(image, original_wcs, common_wcs): """Given an ndarray representing image data (either science or variance, when used with `reproject_work_unit`), as well as a common wcs, return the reprojected @@ -118,6 +143,9 @@ def reproject_work_unit( A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case where `write_output` is set to True. """ + if work_unit.reprojected: + raise ValueError("Unable to reproject a reprojected WorkUnit.") + show_progress = is_interactive() if show_progress is None else show_progress if (work_unit.lazy or write_output) and (directory is None or filename is None): raise ValueError("can't write output to sharded fits without directory and filename provided.") @@ -282,25 +310,25 @@ def _reproject_work_unit( ) stack.append_image(new_layered_image, force_move=True) + # Determine the metadata for the new reprojected images. + new_image_meta = create_new_image_metadata(unique_obstime_indices, common_wcs) + if write_output: new_work_unit = copy(work_unit) - new_work_unit._per_image_indices = unique_obstime_indices - new_work_unit.wcs = common_wcs + new_work_unit.img_meta = new_image_meta new_work_unit.reprojected = True + new_work_unit.wcs = common_wcs - hdul = new_work_unit.metadata_to_primary_header() + hdul = new_work_unit.metadata_to_primary_hdul() hdul.writeto(os.path.join(directory, filename)) else: new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, - constituent_images=work_unit.get_constituent_meta("data_loc"), - per_image_wcs=work_unit._per_image_wcs, - per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), - per_image_indices=unique_obstime_indices, - geocentric_distances=work_unit.get_constituent_meta("geocentric_distance"), reprojected=True, + image_meta=new_image_meta, + org_image_meta=work_unit.org_img_meta, ) return new_wunit @@ -415,17 +443,20 @@ def _reproject_work_unit_in_parallel( # when all the multiprocessing has finished, convert the returned numpy arrays to RawImages. concurrent.futures.wait(future_reprojections, return_when=concurrent.futures.ALL_COMPLETED) + # Determine the metadata for the new reprojected images. + new_image_meta = create_new_image_metadata(unique_obstime_indices, common_wcs) + if write_output: for result in future_reprojections: if not result.result(): raise RuntimeError("one or more jobs failed.") new_work_unit = copy(work_unit) - new_work_unit._per_image_indices = unique_obstimes_indices - new_work_unit.wcs = common_wcs + new_work_unit.img_meta = new_image_meta new_work_unit.reprojected = True + new_work_unit.wcs = common_wcs - hdul = new_work_unit.metadata_to_primary_header() + hdul = new_work_unit.metadata_to_primary_hdul() hdul.writeto(os.path.join(directory, filename)) else: stack = ImageStack([]) @@ -451,12 +482,9 @@ def _reproject_work_unit_in_parallel( im_stack=stack, config=work_unit.config, wcs=common_wcs, - constituent_images=work_unit.get_constituent_meta("data_loc"), - per_image_wcs=work_unit._per_image_wcs, - per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"), - per_image_indices=unique_obstimes_indices, - geocentric_distances=work_unit.get_constituent_meta("geocentric_distances"), reprojected=True, + image_meta=new_image_meta, + org_image_meta=work_unit.org_img_meta, ) return new_wunit @@ -550,9 +578,9 @@ def reproject_lazy_work_unit( raise RuntimeError("one or more jobs failed.") new_work_unit = copy(work_unit) - new_work_unit._per_image_indices = unique_obstimes_indices - new_work_unit.wcs = common_wcs + new_work_unit.img_meta = create_new_image_metadata(unique_obstime_indices, common_wcs) new_work_unit.reprojected = True + new_work_unit.wcs = common_wcs hdul = new_work_unit.metadata_to_primary_header() hdul.writeto(os.path.join(directory, filename)) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index b56597ab2..6364c5d10 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -46,10 +46,19 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. + n_images : `int` + The number of current images. n_constituents : `int` The number of original images making up the data in this WorkUnit. This might be different from the number of images stored in memory if the WorkUnit has been reprojected. + img_meta : `astropy.table.Table` + The meta data for each of the current images. These might differ from the + constituent images if the WorkUnit has been filtered or reprojected. + * wcs - The WCS of the image. This is set even if all image share a global WCS. + * per_image_indices - A lists containing the indicies of `constituent_images` + for each current image. Used for finding corresponding original images when we + stitch images together during reprojection. org_img_meta : `astropy.table.Table` The meta data for each constituent image. Includes columns: * data_loc - the original location of the image @@ -60,9 +69,6 @@ class WorkUnit: wcs : `astropy.wcs.WCS` A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. - per_image_wcs : `list` - A list with one WCS for each image in the WorkUnit. Used for when - the images have *not* been standardized to the same pixel space. heliocentric_distance : `float` The heliocentric distance that was used when creating the `per_image_ebd_wcs`. reprojected : `bool` @@ -88,26 +94,13 @@ class WorkUnit: wcs : `astropy.wcs.WCS` A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. - constituent_images: `list` - A list of strings with the original locations of images used - to construct the WorkUnit. This is necessary to maintain as metadata - because after reprojection we may stitch multiple images into one. per_image_wcs : `list` A list with one WCS for each image in the WorkUnit. Used for when the images have *not* been standardized to the same pixel space. - per_image_ebd_wcs : `list` - A list with one WCS for each image in the WorkUnit. Used to reproject the images - into EBD space. - heliocentric_distance : `float` - The heliocentric distance that was used when creating the `per_image_ebd_wcs`. - geocentric_distances : `list` - The best fit geocentric distances used when creating the `per_image_ebd_wcs`. reprojected : `bool` Whether or not the WorkUnit image data has been reprojected. - per_image_indices : `list` of `list` - A list of lists containing the indicies of `constituent_images` at each layer - of the `ImageStack`. Used for finding corresponding original images when we - stitch images together during reprojection. + heliocentric_distance : `float` + The heliocentric distance that was used when creating the `per_image_ebd_wcs`. lazy : `bool` Whether or not to load the image data for the `WorkUnit`. file_paths : `list[str]` @@ -115,8 +108,10 @@ class WorkUnit: in lazy mode. obstimes : `list[float]` The MJD obstimes of the images. - per_image_meta : `dict` or `astropy.table.Table` - A table of additional per-image metadata. + img_meta : `dict` or `astropy.table.Table`, optional + The meta data for each of the current images. + org_image_meta : `dict` or `astropy.table.Table`, optional + A table of per-image data for the constituent images. """ def __init__( @@ -124,17 +119,14 @@ def __init__( im_stack=None, config=None, wcs=None, - constituent_images=None, per_image_wcs=None, - per_image_ebd_wcs=None, - heliocentric_distance=None, - geocentric_distances=None, reprojected=False, - per_image_indices=None, + heliocentric_distance=None, lazy=False, file_paths=None, obstimes=None, - per_image_meta=None, + image_meta=None, + org_image_meta=None, ): self.im_stack = im_stack self.config = config @@ -142,130 +134,123 @@ def __init__( self.file_paths = file_paths self._obstimes = obstimes - # Determine the number of constituent images. If we are given a list of constituent_images, - # use that. Otherwise use the size of the image stack. - if constituent_images is None: - self.n_constituents = im_stack.img_count() - else: - self.n_constituents = len(constituent_images) - - # Track the meta data for each constituent image in the WorkUnit. For the original - # WCS, we track the per-image WCS if it is provided and otherwise the global WCS. - if per_image_meta is None: - self.org_img_meta = Table() - elif isinstance(per_image_meta, Table): - self.org_img_meta = per_image_meta.copy() - elif isinstance(per_image_meta, dict): - self.org_img_meta = Table(per_image_meta) - else: - raise ValueError(f"Invalid type for per_image_meta: {type(per_image_meta)}") - self.add_org_img_meta_data("data_loc", constituent_images) - self.add_org_img_meta_data("original_wcs", per_image_wcs, default=wcs) + # Track the metadata for each of the current images. + self.img_meta = WorkUnit.create_meta( + constituent=False, + data=image_meta, + n_images=im_stack.img_count(), + ) + + # Base the number of current images on the metadata because in a lazy load, + # the ImageStack might be empty. + self.n_images = len(self.img_meta) + + # Track the metadata for each constituent image in the WorkUnit. If no constituent + # data is provided, this will create an empty array the same size as the original. + no_org_img_meta_given = org_image_meta is None + self.org_img_meta = WorkUnit.create_meta( + constituent=True, + data=org_image_meta, + n_images=self.n_images, + ) + self.n_constituents = len(self.org_img_meta) - # Handle WCS input. If both the global and per-image WCS are provided, - # ensure they are consistent. + # Handle WCS input. self.wcs = wcs - if per_image_wcs is None: - self._per_image_wcs = [None] * self.n_constituents - if self.wcs is None and per_image_ebd_wcs is None: - warnings.warn("No WCS provided.", Warning) - else: - if len(per_image_wcs) != self.n_constituents: - raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_constituents}") - self._per_image_wcs = per_image_wcs - - # Check if all the per-image WCS are None. This can happen during a load. - all_none = self.per_image_wcs_all_match(None) - if self.wcs is None and all_none: - warnings.warn("No WCS provided.", Warning) - - # See if we can compress the per-image WCS into a global one. - if self.wcs is None and not all_none and self.per_image_wcs_all_match(self._per_image_wcs[0]): - self.wcs = self._per_image_wcs[0] - self._per_image_wcs = [None] * im_stack.img_count() - - # Add the meta data needed for reprojection, including: the reprojected WCS, the geocentric - # distances, and each images indices in the original constituent images. + if per_image_wcs is not None: + # If we are given explicit per-image WCS, use those. Overwrite the values + # the current image metadata. + if len(per_image_wcs) != self.n_images: + raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_images}") + self.img_meta["wcs"] = np.array(per_image_wcs) + elif "wcs" not in self.img_meta.colnames or np.all(self.img_meta["wcs"] == None): + # If we have no per-image WCS already, use the global one (which might be None). + self.img_meta["wcs"] = np.array([self.wcs] * self.n_images) + + # If no constituent data was provided, then save the current image's WCS as the original. + # This is needed to ensure that we always have a correct original WCS. + if no_org_img_meta_given: + for i in range(self.n_images): + self.org_img_meta["original_wcs"][i] = self.img_meta["wcs"][i] + + # If both the global and per-image WCS are provided, ensure they are consistent. + if self.wcs is not None and not np.all(self.img_meta["wcs"].value == None): + for idx in range(im_stack.img_count()): + if not wcs_fits_equal(self.wcs, self.img_meta["wcs"][idx]): + raise ValueError(f"Inconsistent WCS at index {idx}.") + if self.wcs is None and np.any(self.img_meta["wcs"].value == None): + logger.warning("No WCS provided for at least one image.") + + # Set the global metadata for reprojection. self.reprojected = reprojected self.heliocentric_distance = heliocentric_distance - self.add_org_img_meta_data("geocentric_distance", geocentric_distances) - self.add_org_img_meta_data("ebd_wcs", per_image_ebd_wcs) - - # If we have mosaicked images, each image in the stack could link back - # to more than one constituents image. Build a mapping of image stack index - # to needed original image indices. - if per_image_indices is None: - self._per_image_indices = [[i] for i in range(self.n_constituents)] - else: - self._per_image_indices = per_image_indices def __len__(self): """Returns the size of the WorkUnit in number of images.""" return self.im_stack.img_count() - def get_constituent_meta(self, column): - """Get the meta data values of a given column for all the constituent images. + @staticmethod + def create_meta(constituent=False, n_images=None, data=None): + """Create an img_meta table, filling in default values + for any unspecified columns. Parameters ---------- - column : `str` - The column name to fetch. + constituent : `bool` + Indicates the type of table. True indicates a table of constituent (original) + images. False indicates a table of current images. + data : `dict`, `astropy.table.Table`, or None + The data from which to seed the table. + n_images : `int`, optional + The number of images to include. Only use when no data is + provided in order to fill in defaults. Returns ------- - data : `list` - A list of the meta-data for each constituent image. - """ - return list(self.org_img_meta[column].data) - - def add_org_img_meta_data(self, column, data, default=None): - """Add a column of meta data for the constituent images. Adds a column of all - default values if data is None and the column does not already exist. - - Parameters - ---------- - column : `str` - The name of the meta data column. - data : list-like - The data for each constituent. If None then uses None for - each column. - dtype : type - The numpy dtype to use when storing this data. If None - + img_meta : `astropy.table.Table` + The empty table of org_img_meta. """ if data is None: - if column not in self.org_img_meta.colnames: - self.org_img_meta[column] = np.array([default] * self.n_constituents) - elif len(data) == self.n_constituents: - self.org_img_meta[column] = np.array(data) + if n_images is None or n_images <= 0: + raise ValueError("If no data provided 'n_images' must be >= 1. Is {n_images}") + + # Add a place holder column of the correct size. + data = Table({"_index": np.arange(n_images)}) + elif isinstance(data, dict): + data = Table(data) + elif isinstance(data, Table): + data = data.copy() else: - raise ValueError( - f"Data mismatch size for WorkUnit metadata {column}. " - f"Expected {self.n_constituents} but found {len(data)}." - ) - - def has_common_wcs(self): - """Returns whether the WorkUnit has a common WCS for all images.""" - return self.wcs is not None + raise TypeError("Unsupported type for data table.") + n_images = len(data) + + if constituent: + # Fill in the defaults for the original/constituent images. + for colname in ["data_loc", "ebd_wcs", "geocentric_distance", "original_wcs"]: + if colname not in data.colnames: + data[colname] = np.full(n_images, None) + else: + # Fill in the defaults for the current image. + if "per_image_indices" not in data.colnames: + data["per_image_indices"] = [[i] for i in range(n_images)] + if "wcs" not in data.colnames: + data["wcs"] = np.full(n_images, None) + return data - def per_image_wcs_all_match(self, target=None): - """Check if all the per-image WCS are the same as a given target value. + def get_constituent_meta(self, column): + """Get the meta data values of a given column for all the constituent images. Parameters ---------- - target : `astropy.wcs.WCS`, optional - The WCS to which to compare the per-image WCS. If None, checks that - all of the per-image WCS are None. + column : `str` + The column name to fetch. Returns ------- - result : `bool` - A Boolean indicating that all the per-images WCS match the target. + data : `list` + A list of the meta-data for each constituent image. """ - for current in self._per_image_wcs: - if not wcs_fits_equal(current, target): - return False - return True + return list(self.org_img_meta[column].data) def get_wcs(self, img_num): """Return the WCS for the a given image. Alway prioritizes @@ -280,26 +265,11 @@ def get_wcs(self, img_num): ------- wcs : `astropy.wcs.WCS` The image's WCS if one exists. Otherwise None. - - Raises - ------ - IndexError if an invalid index is given. """ - if img_num < 0 or img_num >= len(self._per_image_wcs): - raise IndexError(f"Invalid image number {img_num}") - - # Extract the per-image WCS if one exists. - if self._per_image_wcs is not None and img_num < len(self._per_image_wcs): - per_img = self._per_image_wcs[img_num] - else: - per_img = None - - if self.wcs is not None: - if per_img is not None and not wcs_fits_equal(self.wcs, per_img): - warnings.warn("Both a global and per-image WCS given. Using global WCS.", Warning) + if self.wcs: return self.wcs - - return per_img + else: + return self.img_meta["wcs"][img_num] def get_pixel_coordinates(self, ra, dec, times=None): """Get the pixel coordinates for pairs of (RA, dec) coordinates. Uses the global @@ -346,7 +316,7 @@ def get_pixel_coordinates(self, ra, dec, times=None): for i, index in enumerate(inds): if index == -1: raise ValueError(f"Unmatched time {times[i]}.") - current_wcs = self._per_image_wcs[index] + current_wcs = self.img_meta["wcs"][index] curr_x, curr_y = current_wcs.world_to_pixel( SkyCoord(ra=ra[i] * u.degree, dec=dec[i] * u.degree) ) @@ -390,9 +360,6 @@ def get_unique_obstimes_and_indices(self): unique_indices = [list(np.where(all_obstimes == time)[0]) for time in unique_obstimes] return unique_obstimes, unique_indices - def get_num_images(self): - return len(self._per_image_indices) - @classmethod def from_fits(cls, filename, show_progress=None): """Create a WorkUnit from a single FITS file. @@ -433,11 +400,23 @@ def from_fits(cls, filename, show_progress=None): # Read in the search parameters from the 'kbmod_config' extension. config = SearchConfiguration.from_hdu(hdul["kbmod_config"]) - # Read in the per-image metadata. + # Read the size and order information from the primary header. + num_images = hdul[0].header["NUMIMG"] + n_constituents = hdul[0].header["NCON"] if "NCON" in hdul[0].header else num_images + logger.info(f"Loading {num_images} images.") + + # Read in the per-image metadata for the current images and the constituent images. if "IMG_META" in hdul: + logger.debug("Reading image metadata from IMG_META.") per_image_meta = hdu_to_metadata_table(hdul["IMG_META"]) else: - per_image_meta = None + per_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=num_images) + + if "ORG_META" in hdul: + logger.debug("Reading original image metadata from ORG_META.") + org_image_meta = hdu_to_metadata_table(hdul["ORG_META"]) + else: + org_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=n_constituents) # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension @@ -446,22 +425,16 @@ def from_fits(cls, filename, show_progress=None): warnings.simplefilter("ignore", AstropyWarning) global_wcs = extract_wcs_from_hdu_header(hdul[0].header) - # Read the size and order information from the primary header. - num_images = hdul[0].header["NUMIMG"] - n_constituents = hdul[0].header["NCON"] - expected_num_images = (4 * num_images) + (2 * n_constituents) + 3 - if len(hdul) != expected_num_images: - raise ValueError(f"WorkUnit wrong number of extensions. Expected " f"{expected_num_images}.") - logger.info(f"Loading {num_images} images and {expected_num_images} total layers.") - # Misc. reprojection metadata reprojected = hdul[0].header["REPRJCTD"] heliocentric_distance = hdul[0].header["HELIO"] - geocentric_distances = [] - for i in range(num_images): - geocentric_distances.append(hdul[0].header[f"GEO_{i}"]) + if np.all(org_image_meta["geocentric_distance"] == None): + # If the metadata table does not have the geocentric_distance, try + # loading it from the primary header's GEO_i fields (legacy approach). + for i in range(n_constituents): + if f"GEO_{i}" in hdul[0].header: + org_image_meta["geocentric_distance"][i] = hdul[0].header[f"GEO_{i}"] - per_image_indices = [] # Read in all the image files. for i in tqdm( range(num_images), @@ -483,38 +456,38 @@ def from_fits(cls, filename, show_progress=None): # force_move destroys img object, but avoids a copy. im_stack.append_image(img, force_move=True) - n_indices = sci_hdu.header["NIND"] - sub_indices = [] - for j in range(n_indices): - sub_indices.append(sci_hdu.header[f"IND_{j}"]) - per_image_indices.append(sub_indices) - - per_image_wcs = [] - per_image_ebd_wcs = [] - constituent_images = [] + # Check if we need to load the map of current images to constituent images + # from the (legacy) headers. + if "NIND" in sci_hdu.header: + n_indices = sci_hdu.header["NIND"] + sub_indices = [] + for j in range(n_indices): + sub_indices.append(sci_hdu.header[f"IND_{j}"]) + per_image_meta["per_image_indices"][i] = sub_indices + + # Extract the per-image data from header information if needed. This happens + # when the WorkUnit was saved before metadata tables were saved as layers and + # all the information is in header values. for i in tqdm( range(n_constituents), bar_format=_DEFAULT_WORKUNIT_TQDM_BAR, desc="Loading WCS", disable=not show_progress, ): - # Extract the per-image WCS if one exists. - per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header)) - per_image_ebd_wcs.append(extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header)) - constituent_images.append(hdul[f"WCS_{i}"].header["ILOC"]) + if f"WCS_{i}" in hdul: + org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header) + org_image_meta["data_loc"][i] = hdul[f"WCS_{i}"].header["ILOC"] + if f"EBD_{i}" in hdul: + org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header) result = WorkUnit( im_stack=im_stack, config=config, wcs=global_wcs, - constituent_images=constituent_images, - per_image_wcs=per_image_wcs, - per_image_ebd_wcs=per_image_ebd_wcs, heliocentric_distance=heliocentric_distance, - geocentric_distances=geocentric_distances, reprojected=reprojected, - per_image_indices=per_image_indices, - per_image_meta=per_image_meta, + image_meta=per_image_meta, + org_image_meta=org_image_meta, ) return result @@ -540,20 +513,17 @@ def to_fits(self, filename, overwrite=False): if Path(filename).is_file() and not overwrite: raise FileExistsError(f"WorkUnit file {filename} already exists.") - hdul = self.metadata_to_primary_header(include_wcs=False) + # Create an HDU list with the metadata layers. + hdul = self.metadata_to_hdul(include_wcs=False) + # Create each image layer. for i in range(self.im_stack.img_count()): layered = self.im_stack.get_single_image(i) obstime = layered.get_obstime() - c_indices = self._per_image_indices[i] - n_indices = len(c_indices) img_wcs = self.get_wcs(i) sci_hdu = raw_image_to_hdu(layered.get_science(), obstime, img_wcs) sci_hdu.name = f"SCI_{i}" - sci_hdu.header["NIND"] = n_indices - for j in range(n_indices): - sci_hdu.header[f"IND_{j}"] = c_indices[j] hdul.append(sci_hdu) var_hdu = raw_image_to_hdu(layered.get_variance(), obstime) @@ -570,11 +540,6 @@ def to_fits(self, filename, overwrite=False): psf_hdu.name = f"PSF_{i}" hdul.append(psf_hdu) - # Format additional metadata. - meta_hdu = metadata_table_to_hdu(self.org_img_meta, "IMG_META") - hdul.append(meta_hdu) - self.append_all_wcs(hdul) - hdul.writeto(filename, overwrite=overwrite) def to_sharded_fits(self, filename, directory, overwrite=False): @@ -623,16 +588,11 @@ def to_sharded_fits(self, filename, directory, overwrite=False): for i in range(self.im_stack.img_count()): layered = self.im_stack.get_single_image(i) obstime = layered.get_obstime() - c_indices = self._per_image_indices[i] - n_indices = len(c_indices) sub_hdul = fits.HDUList() img_wcs = self.get_wcs(i) sci_hdu = raw_image_to_hdu(layered.get_science(), obstime, img_wcs) sci_hdu.name = f"SCI_{i}" - sci_hdu.header["NIND"] = n_indices - for j in range(n_indices): - sci_hdu.header[f"IND_{j}"] = c_indices[j] sub_hdul.append(sci_hdu) var_hdu = raw_image_to_hdu(layered.get_variance(), obstime) @@ -650,12 +610,8 @@ def to_sharded_fits(self, filename, directory, overwrite=False): sub_hdul.append(psf_hdu) sub_hdul.writeto(os.path.join(directory, f"{i}_{filename}")) - hdul = self.metadata_to_primary_header(include_wcs=True) - - # Format additional metadata as a single HDU - meta_hdu = metadata_table_to_hdu(self.org_img_meta, "IMG_META") - hdul.append(meta_hdu) - + # Create a primary file with all of the metadata + hdul = self.metadata_to_hdul(include_wcs=True) hdul.writeto(os.path.join(directory, filename), overwrite=overwrite) @classmethod @@ -700,11 +656,23 @@ def from_sharded_fits(cls, filename, directory, lazy=False): with fits.open(os.path.join(directory, filename)) as primary: config = SearchConfiguration.from_hdu(primary["kbmod_config"]) - # Read in the per-image metadata. + # Read the size and order information from the primary header. + num_images = primary[0].header["NUMIMG"] + n_constituents = primary[0].header["NCON"] if "NCON" in primary[0].header else num_images + logger.info(f"Loading {num_images} images.") + + # Read in the per-image metadata for the current images and the constituent images. if "IMG_META" in primary: + logger.debug("Reading image metadata from IMG_META.") per_image_meta = hdu_to_metadata_table(primary["IMG_META"]) else: - per_image_meta = None + per_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=num_images) + + if "ORG_META" in primary: + logger.debug("Reading original image metadata from ORG_META.") + org_image_meta = hdu_to_metadata_table(primary["ORG_META"]) + else: + org_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=n_constituents) # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension @@ -713,27 +681,27 @@ def from_sharded_fits(cls, filename, directory, lazy=False): warnings.simplefilter("ignore", AstropyWarning) global_wcs = extract_wcs_from_hdu_header(primary[0].header) - # Read the size and order information from the primary header. - num_images = primary[0].header["NUMIMG"] - n_constituents = primary[0].header["NCON"] - expected_num_images = (4 * num_images) + (2 * n_constituents) + 3 - # Misc. reprojection metadata reprojected = primary[0].header["REPRJCTD"] heliocentric_distance = primary[0].header["HELIO"] - geocentric_distances = [] for i in range(n_constituents): - geocentric_distances.append(primary[0].header[f"GEO_{i}"]) + if f"GEO_{i}" in primary[0].header: + org_image_meta["geocentric_distance"][i] = primary[0].header[f"GEO_{i}"] - per_image_wcs = [] - per_image_ebd_wcs = [] - constituent_images = [] + # Extract the per-image data from header innformation if needed. + # This happens with when the WorkUnit was saved before metadata tables were + # saved as layers.that we will fill in from the headers. for i in range(n_constituents): - # Extract the per-image WCS if one exists. - per_image_wcs.append(extract_wcs_from_hdu_header(primary[f"WCS_{i}"].header)) - per_image_ebd_wcs.append(extract_wcs_from_hdu_header(primary[f"EBD_{i}"].header)) - constituent_images.append(primary[f"WCS_{i}"].header["ILOC"]) - per_image_indices = [] + if f"WCS_{i}" in primary: + org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header( + primary[f"WCS_{i}"].header + ) + org_image_meta["data_loc"][i] = primary[f"WCS_{i}"].header["ILOC"] + if f"EBD_{i}" in primary: + org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header( + primary[f"EBD_{i}"].header + ) + file_paths = [] obstimes = [] for i in range(num_images): @@ -753,32 +721,28 @@ def from_sharded_fits(cls, filename, directory, lazy=False): else: file_paths.append(shard_path) - n_indices = sci_hdu.header["NIND"] - sub_indices = [] - for j in range(n_indices): - sub_indices.append(sci_hdu.header[f"IND_{j}"]) - per_image_indices.append(sub_indices) + if "NIND" in sci_hdu.header: + n_indices = sci_hdu.header["NIND"] + sub_indices = [] + for j in range(n_indices): + sub_indices.append(sci_hdu.header[f"IND_{j}"]) + per_image_meta["per_image_indices"][i] = sub_indices file_paths = None if not lazy else file_paths result = WorkUnit( im_stack=im_stack, config=config, wcs=global_wcs, - constituent_images=constituent_images, - per_image_wcs=per_image_wcs, - per_image_ebd_wcs=per_image_ebd_wcs, - heliocentric_distance=heliocentric_distance, - geocentric_distances=geocentric_distances, reprojected=reprojected, - per_image_indices=per_image_indices, - lazy=lazy, + heliocentric_distance=heliocentric_distance, file_paths=file_paths, - obstimes=obstimes, - per_image_meta=per_image_meta, + lazy=lazy, + image_meta=per_image_meta, + org_image_meta=org_image_meta, ) return result - def metadata_to_primary_header(self, include_wcs=True): + def metadata_to_hdul(self, include_wcs=True): """Creates the metadata fits headers. Parameters @@ -797,57 +761,27 @@ def metadata_to_primary_header(self, include_wcs=True): # the metadata (empty), and the configuration. hdul = fits.HDUList() pri = fits.PrimaryHDU() - pri.header["NUMIMG"] = self.get_num_images() + pri.header["NUMIMG"] = self.n_images pri.header["NCON"] = self.n_constituents pri.header["REPRJCTD"] = self.reprojected pri.header["HELIO"] = self.heliocentric_distance - for i in range(self.n_constituents): - pri.header[f"GEO_{i}"] = self.org_img_meta["geocentric_distance"][i] # If the global WCS exists, append the corresponding keys. if self.wcs is not None: append_wcs_to_hdu_header(self.wcs, pri.header) - hdul.append(pri) - meta_hdu = fits.BinTableHDU() - meta_hdu.name = "metadata" - hdul.append(meta_hdu) - + # Add the configuration layer. config_hdu = self.config.to_hdu() config_hdu.name = "kbmod_config" hdul.append(config_hdu) - if include_wcs: - self.append_all_wcs(hdul) + # Save the additional metadata tables into HDUs + hdul.append(metadata_table_to_hdu(self.img_meta, "IMG_META")) + hdul.append(metadata_table_to_hdu(self.org_img_meta, "ORG_META")) return hdul - def append_all_wcs(self, hdul): - """Append all the original WCS and EBD WCS to a header. - - Parameters - ---------- - hdul : `astropy.io.fits.HDUList` - The HDU list. - """ - all_ebd_wcs = self.get_constituent_meta("ebd_wcs") - - for i in range(self.n_constituents): - img_location = self.org_img_meta["data_loc"][i] - - orig_wcs = self._per_image_wcs[i] - wcs_hdu = fits.TableHDU() - append_wcs_to_hdu_header(orig_wcs, wcs_hdu.header) - wcs_hdu.name = f"WCS_{i}" - wcs_hdu.header["ILOC"] = img_location - hdul.append(wcs_hdu) - - ebd_hdu = fits.TableHDU() - append_wcs_to_hdu_header(all_ebd_wcs[i], ebd_hdu.header) - ebd_hdu.name = f"EBD_{i}" - hdul.append(ebd_hdu) - def image_positions_to_original_icrs( self, image_indices, positions, input_format="xy", output_format="xy", filter_in_frame=True ): @@ -876,6 +810,7 @@ def image_positions_to_original_icrs( Whether or not to filter the output based on whether they fit within the original `constituent_image` frame. If `True`, only results that fall within the bounds of the original WCS will be returned. + Returns ------- positions : `list` of `astropy.coordinates.SkyCoord`s or `tuple`s @@ -930,12 +865,12 @@ def image_positions_to_original_icrs( positions = [] for i in image_indices: - inds = self._per_image_indices[i] + inds = self.img_meta["per_image_indices"][i] coord = inverted_coords[i] pos = [] for j in inds: con_image = self.org_img_meta["data_loc"][j] - con_wcs = self._per_image_wcs[j] + con_wcs = self.org_img_meta["original_wcs"][j] height, width = con_wcs.array_shape x, y = skycoord_to_pixel(coord, con_wcs) x, y = float(x), float(y) @@ -1053,7 +988,7 @@ def metadata_table_to_hdu(data, layer_name=None): num_rows = len(data) if num_rows == 0: # No data to encode. Just use the current table. - meta_hdu = fits.BinTableHDU(data) + save_table = data else: # Create a new table to save with the correct column # values/names for the serialized information. diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index 283aa2ff0..a60abcaaa 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -4,6 +4,7 @@ from astropy.coordinates import EarthLocation, SkyCoord from astropy.time import Time import numpy as np +import numpy.testing as npt import os from pathlib import Path import tempfile @@ -101,18 +102,25 @@ def setUp(self): "five.fits", ] + self.org_image_meta = Table( + { + "data_loc": np.array(self.constituent_images), + "ebd_wcs": np.array([self.per_image_ebd_wcs] * self.num_images), + "geocentric_distance": np.array([self.geo_dist] * self.num_images), + "original_wcs": np.array(self.per_image_wcs), + } + ) + def test_create(self): # Test the creation of a WorkUnit with no WCS. Should throw a warning. with warnings.catch_warnings(record=True) as wrn: warnings.simplefilter("always") work = WorkUnit(self.im_stack, self.config) - self.assertTrue("No WCS provided." in str(wrn[-1].message)) self.assertIsNotNone(work) self.assertEqual(work.im_stack.img_count(), 5) self.assertEqual(work.config["im_filepath"], "Here") self.assertEqual(work.config["num_obs"], 5) - self.assertFalse(work.has_common_wcs()) self.assertIsNone(work.wcs) self.assertEqual(len(work), self.num_images) for i in range(self.num_images): @@ -121,7 +129,6 @@ def test_create(self): # Create with a global WCS work2 = WorkUnit(self.im_stack, self.config, self.wcs) self.assertEqual(work2.im_stack.img_count(), 5) - self.assertTrue(work2.has_common_wcs()) self.assertIsNotNone(work2.wcs) for i in range(self.num_images): self.assertIsNotNone(work2.get_wcs(i)) @@ -134,27 +141,9 @@ def test_create(self): self.im_stack, self.config, self.wcs, - [f"img_{i}" for i in range(self.im_stack.img_count())], [self.wcs, self.wcs, self.wcs], ) - # Create with per-image WCS that can be compressed to a global WCS. - per_image_wcs = [self.wcs] * self.num_images - work3 = WorkUnit(self.im_stack, self.config, per_image_wcs=per_image_wcs) - self.assertIsNotNone(work3.wcs) - self.assertTrue(work3.has_common_wcs()) - for i in range(self.num_images): - self.assertIsNotNone(work3.get_wcs(i)) - self.assertTrue(wcs_fits_equal(self.wcs, work3.get_wcs(i))) - - # Create with per-image WCS that cannot be compressed to a global WCS. - work3 = WorkUnit(self.im_stack, self.config, per_image_wcs=self.diff_wcs) - self.assertIsNone(work3.wcs) - self.assertFalse(work3.has_common_wcs()) - for i in range(self.num_images): - self.assertIsNotNone(work3.get_wcs(i)) - self.assertTrue(wcs_fits_equal(work3.get_wcs(i), self.diff_wcs[i])) - def test_metadata_helpers(self): """Test that we can roundtrip an astropy table of metadata (including) WCS into a BinTableHDU. @@ -191,13 +180,19 @@ def test_save_and_load_fits(self): self.assertRaises(ValueError, WorkUnit.from_fits, file_path) # Write out the existing WorkUnit with a different per-image wcs for all the entries. - # work = WorkUnit(self.im_stack, self.config, None, self.diff_wcs) + # work = WorkUnit(self.im_stack, self.config, None, self.diff_wcs). + # Include extra per-image metadata. + extra_meta = { + "data_loc": np.array(self.constituent_images), + "int_index": np.arange(self.num_images), + "uri": np.array([f"file_loc_{i}" for i in range(self.num_images)]), + } work = WorkUnit( im_stack=self.im_stack, config=self.config, wcs=None, per_image_wcs=self.diff_wcs, - constituent_images=self.constituent_images, + org_image_meta=extra_meta, ) work.to_fits(file_path) self.assertTrue(Path(file_path).is_file()) @@ -206,7 +201,6 @@ def test_save_and_load_fits(self): work2 = WorkUnit.from_fits(file_path) self.assertEqual(work2.im_stack.img_count(), self.num_images) self.assertIsNone(work2.wcs) - self.assertFalse(work2.has_common_wcs()) for i in range(self.num_images): li = work2.im_stack.get_single_image(i) self.assertEqual(li.get_width(), self.width) @@ -246,9 +240,10 @@ def test_save_and_load_fits(self): self.assertEqual(work2.config["im_filepath"], "Here") self.assertEqual(work2.config["num_obs"], self.num_images) - # Check that we correctly retrieved the provenance information via “data_loc” - for index, value in enumerate(work2.org_img_meta["data_loc"]): - self.assertEqual(value, self.constituent_images[index]) + # Check that we retrieved the extra metadata that we added. + npt.assert_array_equal(work2.get_constituent_meta("uri"), extra_meta["uri"]) + npt.assert_array_equal(work2.get_constituent_meta("int_index"), extra_meta["int_index"]) + npt.assert_array_equal(work2.get_constituent_meta("data_loc"), self.constituent_images) # We throw an error if we try to overwrite a file with overwrite=False self.assertRaises(FileExistsError, work.to_fits, file_path) @@ -274,7 +269,6 @@ def test_save_and_load_fits_shard(self): work2 = WorkUnit.from_sharded_fits(filename="test_workunit.fits", directory=dir_name) self.assertEqual(work2.im_stack.img_count(), self.num_images) self.assertIsNone(work2.wcs) - self.assertFalse(work2.has_common_wcs()) for i in range(self.num_images): li = work2.im_stack.get_single_image(i) self.assertEqual(li.get_width(), self.width) @@ -330,7 +324,6 @@ def test_save_and_load_fits_shard_lazy(self): work2 = WorkUnit.from_sharded_fits(filename="test_workunit.fits", directory=dir_name, lazy=True) self.assertEqual(len(work2.file_paths), self.num_images) self.assertIsNone(work2.wcs) - self.assertFalse(work2.has_common_wcs()) # Check that we read in the configuration values correctly. self.assertEqual(work2.config["im_filepath"], "Here") @@ -353,7 +346,6 @@ def test_save_and_load_fits_global_wcs(self): # Read in the file and check that the values agree. work2 = WorkUnit.from_fits(file_path) self.assertIsNotNone(work2.wcs) - self.assertTrue(work2.has_common_wcs()) self.assertTrue(wcs_fits_equal(work2.wcs, self.wcs)) for i in range(self.num_images): self.assertIsNotNone(work2.get_wcs(i)) @@ -373,12 +365,9 @@ def test_image_positions_to_original_icrs_invalid_format(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) # Incorrect format for 'xy' @@ -413,12 +402,9 @@ def test_image_positions_to_original_icrs_basic_inputs(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) res = work.image_positions_to_original_icrs( @@ -462,12 +448,9 @@ def test_image_positions_to_original_icrs_filtering(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) res = work.image_positions_to_original_icrs( @@ -490,17 +473,14 @@ def test_image_positions_to_original_icrs_mosaicking(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, reprojected=True, + org_image_meta=self.org_image_meta, ) new_wcs = make_fake_wcs(190.0, -7.7888, 500, 700) - work._per_image_wcs[-1] = new_wcs - work._per_image_indices[3] = [3, 4] + work.org_img_meta["original_wcs"][-1] = new_wcs + work.img_meta["per_image_indices"] = [[0], [1], [2], [3, 4], [4]] res = work.image_positions_to_original_icrs( self.indices, @@ -526,9 +506,6 @@ def test_image_positions_to_original_icrs_mosaicking(self): npt.assert_almost_equal(res[3][0].separation(self.expected_radec_positions[3]).deg, 0.0, decimal=5) assert res[3][1] == "five.fits" - # work._per_image_wcs[4] = work._per_image_wcs[3] - # work._per_image_ebd_wcs[4] = work._per_image_ebd_wcs[3] - res = work.image_positions_to_original_icrs( self.indices, self.pixel_positions, @@ -549,11 +526,8 @@ def test_get_unique_obstimes_and_indices(self): im_stack=self.im_stack, config=self.config, wcs=self.per_image_ebd_wcs, - per_image_wcs=self.per_image_wcs, - per_image_ebd_wcs=[self.per_image_ebd_wcs] * self.num_images, - geocentric_distances=[self.geo_dist] * self.num_images, heliocentric_distance=41.0, - constituent_images=self.constituent_images, + org_image_meta=self.org_image_meta, ) times = work.get_all_obstimes() times[-1] = times[-2] From 41e0006d55d127c70f02c31e2eb3117b5adf8610 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 14 Nov 2024 22:52:51 -0500 Subject: [PATCH 04/15] Revert some of the changes --- src/kbmod/reprojection.py | 46 ++----- src/kbmod/work_unit.py | 253 +++++++++++++++++++------------------ tests/test_reprojection.py | 3 - tests/test_work_unit.py | 28 +++- 4 files changed, 168 insertions(+), 162 deletions(-) diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 431080b36..1d60e7230 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -19,31 +19,6 @@ _DEFAULT_TQDM_BAR = "{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}]" -def create_new_image_metadata(unique_obstime_indices, common_wcs): - """Create a table of the metadata for the new reprojected images. - - Parameters - ---------- - unique_obstime_indices : `numpy.ndarray` - An array of lists (or arrays) indicating from which original images - the new images were created. - common_wcs : `astropy.wcs.WCS` - The new WCS for the images. - - Returns - ------- - metadata : `astropy.table.Table` - A table of metadata for the new images. - """ - metadata = Table( - { - "per_image_indices": np.array(unique_obstime_indices), - "wcs": np.full(len(unique_obstime_indices), common_wcs), - } - ) - return metadata - - def reproject_image(image, original_wcs, common_wcs): """Given an ndarray representing image data (either science or variance, when used with `reproject_work_unit`), as well as a common wcs, return the reprojected @@ -310,12 +285,9 @@ def _reproject_work_unit( ) stack.append_image(new_layered_image, force_move=True) - # Determine the metadata for the new reprojected images. - new_image_meta = create_new_image_metadata(unique_obstime_indices, common_wcs) - if write_output: new_work_unit = copy(work_unit) - new_work_unit.img_meta = new_image_meta + new_work_unit._per_image_indices = unique_obstime_indices new_work_unit.reprojected = True new_work_unit.wcs = common_wcs @@ -326,8 +298,9 @@ def _reproject_work_unit( im_stack=stack, config=work_unit.config, wcs=common_wcs, + per_image_wcs=work_unit._per_image_wcs, + per_image_indices=unique_obstime_indices, reprojected=True, - image_meta=new_image_meta, org_image_meta=work_unit.org_img_meta, ) @@ -443,16 +416,13 @@ def _reproject_work_unit_in_parallel( # when all the multiprocessing has finished, convert the returned numpy arrays to RawImages. concurrent.futures.wait(future_reprojections, return_when=concurrent.futures.ALL_COMPLETED) - # Determine the metadata for the new reprojected images. - new_image_meta = create_new_image_metadata(unique_obstime_indices, common_wcs) - if write_output: for result in future_reprojections: if not result.result(): raise RuntimeError("one or more jobs failed.") new_work_unit = copy(work_unit) - new_work_unit.img_meta = new_image_meta + new_work_unit._per_image_indices = unique_obstime_indices new_work_unit.reprojected = True new_work_unit.wcs = common_wcs @@ -482,8 +452,9 @@ def _reproject_work_unit_in_parallel( im_stack=stack, config=work_unit.config, wcs=common_wcs, + per_image_wcs=work_unit._per_image_wcs, + per_image_indices=unique_obstime_indices, reprojected=True, - image_meta=new_image_meta, org_image_meta=work_unit.org_img_meta, ) @@ -577,8 +548,9 @@ def reproject_lazy_work_unit( if not result.result(): raise RuntimeError("one or more jobs failed.") + # We use new metadata for the new images and the same metadata for the original images. new_work_unit = copy(work_unit) - new_work_unit.img_meta = create_new_image_metadata(unique_obstime_indices, common_wcs) + new_work_unit._per_image_indices = unique_obstime_indices new_work_unit.reprojected = True new_work_unit.wcs = common_wcs @@ -619,6 +591,8 @@ def _validate_original_wcs(work_unit, indices, frame="original"): else: raise ValueError("Invalid projection frame provided.") + if len(original_wcs) == 0: + raise ValueError(f"No WCS found for frame {frame}") if np.any(original_wcs) is None: # find indices where the wcs is None bad_indices = np.where(original_wcs == None) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index 6364c5d10..fcf7d891b 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -52,13 +52,6 @@ class WorkUnit: The number of original images making up the data in this WorkUnit. This might be different from the number of images stored in memory if the WorkUnit has been reprojected. - img_meta : `astropy.table.Table` - The meta data for each of the current images. These might differ from the - constituent images if the WorkUnit has been filtered or reprojected. - * wcs - The WCS of the image. This is set even if all image share a global WCS. - * per_image_indices - A lists containing the indicies of `constituent_images` - for each current image. Used for finding corresponding original images when we - stitch images together during reprojection. org_img_meta : `astropy.table.Table` The meta data for each constituent image. Includes columns: * data_loc - the original location of the image @@ -99,6 +92,10 @@ class WorkUnit: the images have *not* been standardized to the same pixel space. reprojected : `bool` Whether or not the WorkUnit image data has been reprojected. + per_image_indices : `list` of `list` + A list of lists containing the indicies of `constituent_images` at each layer + of the `ImageStack`. Used for finding corresponding original images when we + stitch images together during reprojection. heliocentric_distance : `float` The heliocentric distance that was used when creating the `per_image_ebd_wcs`. lazy : `bool` @@ -108,8 +105,6 @@ class WorkUnit: in lazy mode. obstimes : `list[float]` The MJD obstimes of the images. - img_meta : `dict` or `astropy.table.Table`, optional - The meta data for each of the current images. org_image_meta : `dict` or `astropy.table.Table`, optional A table of per-image data for the constituent images. """ @@ -121,11 +116,11 @@ def __init__( wcs=None, per_image_wcs=None, reprojected=False, + per_image_indices=None, heliocentric_distance=None, lazy=False, file_paths=None, obstimes=None, - image_meta=None, org_image_meta=None, ): self.im_stack = im_stack @@ -134,71 +129,63 @@ def __init__( self.file_paths = file_paths self._obstimes = obstimes - # Track the metadata for each of the current images. - self.img_meta = WorkUnit.create_meta( - constituent=False, - data=image_meta, - n_images=im_stack.img_count(), - ) - - # Base the number of current images on the metadata because in a lazy load, - # the ImageStack might be empty. - self.n_images = len(self.img_meta) + # Try to infer the number of images. This is a bit complex because of lazy loading. + # We test multiple sources in a predefined order. + for test_item in [im_stack, per_image_wcs, per_image_indices, obstimes]: + if test_item is not None and len(test_item) > 0: + self.n_images = len(test_item) + break # Track the metadata for each constituent image in the WorkUnit. If no constituent # data is provided, this will create an empty array the same size as the original. no_org_img_meta_given = org_image_meta is None - self.org_img_meta = WorkUnit.create_meta( - constituent=True, + self.org_img_meta = WorkUnit.create_image_meta( data=org_image_meta, n_images=self.n_images, ) self.n_constituents = len(self.org_img_meta) - # Handle WCS input. + # Handle WCS input. If both the global and per-image WCS are provided, + # ensure they are consistent. self.wcs = wcs - if per_image_wcs is not None: - # If we are given explicit per-image WCS, use those. Overwrite the values - # the current image metadata. + if per_image_wcs is None: + self._per_image_wcs = [self.wcs for _ in range(self.n_images)] + else: if len(per_image_wcs) != self.n_images: raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_images}") - self.img_meta["wcs"] = np.array(per_image_wcs) - elif "wcs" not in self.img_meta.colnames or np.all(self.img_meta["wcs"] == None): - # If we have no per-image WCS already, use the global one (which might be None). - self.img_meta["wcs"] = np.array([self.wcs] * self.n_images) + self._per_image_wcs = per_image_wcs - # If no constituent data was provided, then save the current image's WCS as the original. + if np.any(self._per_image_wcs == None): + warnings.warn("At least one image without a WCS.", Warning) + + # If no constituent data was provided, we save the current image's WCS as the original. # This is needed to ensure that we always have a correct original WCS. if no_org_img_meta_given: - for i in range(self.n_images): - self.org_img_meta["original_wcs"][i] = self.img_meta["wcs"][i] - - # If both the global and per-image WCS are provided, ensure they are consistent. - if self.wcs is not None and not np.all(self.img_meta["wcs"].value == None): - for idx in range(im_stack.img_count()): - if not wcs_fits_equal(self.wcs, self.img_meta["wcs"][idx]): - raise ValueError(f"Inconsistent WCS at index {idx}.") - if self.wcs is None and np.any(self.img_meta["wcs"].value == None): - logger.warning("No WCS provided for at least one image.") + self.org_img_meta["original_wcs"] = self._per_image_wcs # Set the global metadata for reprojection. self.reprojected = reprojected self.heliocentric_distance = heliocentric_distance + # If we have mosaicked images, each image in the stack could link back + # to more than one constituents image. Build a mapping of image stack index + # to needed original image indices. + if per_image_indices is None: + self._per_image_indices = [[i] for i in range(self.n_constituents)] + else: + self._per_image_indices = per_image_indices + def __len__(self): """Returns the size of the WorkUnit in number of images.""" return self.im_stack.img_count() @staticmethod - def create_meta(constituent=False, n_images=None, data=None): + def create_image_meta(data=None, n_images=None): """Create an img_meta table, filling in default values for any unspecified columns. Parameters ---------- - constituent : `bool` - Indicates the type of table. True indicates a table of constituent (original) - images. False indicates a table of current images. data : `dict`, `astropy.table.Table`, or None The data from which to seed the table. n_images : `int`, optional @@ -224,17 +211,10 @@ def create_meta(constituent=False, n_images=None, data=None): raise TypeError("Unsupported type for data table.") n_images = len(data) - if constituent: - # Fill in the defaults for the original/constituent images. - for colname in ["data_loc", "ebd_wcs", "geocentric_distance", "original_wcs"]: - if colname not in data.colnames: - data[colname] = np.full(n_images, None) - else: - # Fill in the defaults for the current image. - if "per_image_indices" not in data.colnames: - data["per_image_indices"] = [[i] for i in range(n_images)] - if "wcs" not in data.colnames: - data["wcs"] = np.full(n_images, None) + # Fill in the defaults for the original/constituent images. + for colname in ["data_loc", "ebd_wcs", "geocentric_distance", "original_wcs"]: + if colname not in data.colnames: + data[colname] = np.full(n_images, None) return data def get_constituent_meta(self, column): @@ -252,6 +232,29 @@ def get_constituent_meta(self, column): """ return list(self.org_img_meta[column].data) + def per_image_wcs_all_match(self, target=None): + """Check if all the per-image WCS are the same as a given target value. + + Parameters + ---------- + target : `astropy.wcs.WCS`, optional + The WCS to which to compare the per-image WCS. If None, checks that + all of the per-image WCS are None. + + Returns + ------- + result : `bool` + A Boolean indicating that all the per-images WCS match the target. + """ + for current in self._per_image_wcs: + if not wcs_fits_equal(current, target): + return False + return True + + def has_common_wcs(self): + """Returns whether the WorkUnit has a common WCS for all images.""" + return self.wcs is not None + def get_wcs(self, img_num): """Return the WCS for the a given image. Alway prioritizes a global WCS if one exits. @@ -266,10 +269,10 @@ def get_wcs(self, img_num): wcs : `astropy.wcs.WCS` The image's WCS if one exists. Otherwise None. """ - if self.wcs: + if self.wcs is not None: return self.wcs else: - return self.img_meta["wcs"][img_num] + return self._per_image_wcs[img_num] def get_pixel_coordinates(self, ra, dec, times=None): """Get the pixel coordinates for pairs of (RA, dec) coordinates. Uses the global @@ -316,7 +319,7 @@ def get_pixel_coordinates(self, ra, dec, times=None): for i, index in enumerate(inds): if index == -1: raise ValueError(f"Unmatched time {times[i]}.") - current_wcs = self.img_meta["wcs"][index] + current_wcs = self._per_image_wcs[index] curr_x, curr_y = current_wcs.world_to_pixel( SkyCoord(ra=ra[i] * u.degree, dec=dec[i] * u.degree) ) @@ -405,18 +408,12 @@ def from_fits(cls, filename, show_progress=None): n_constituents = hdul[0].header["NCON"] if "NCON" in hdul[0].header else num_images logger.info(f"Loading {num_images} images.") - # Read in the per-image metadata for the current images and the constituent images. + # Read in the per-image metadata for the constituent images. if "IMG_META" in hdul: - logger.debug("Reading image metadata from IMG_META.") - per_image_meta = hdu_to_metadata_table(hdul["IMG_META"]) + logger.debug("Reading original image metadata from IMG_META.") + org_image_meta = hdu_to_metadata_table(hdul["IMG_META"]) else: - per_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=num_images) - - if "ORG_META" in hdul: - logger.debug("Reading original image metadata from ORG_META.") - org_image_meta = hdu_to_metadata_table(hdul["ORG_META"]) - else: - org_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=n_constituents) + org_image_meta = WorkUnit.create_image_meta(data=None, n_images=n_constituents) # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension @@ -436,6 +433,8 @@ def from_fits(cls, filename, show_progress=None): org_image_meta["geocentric_distance"][i] = hdul[0].header[f"GEO_{i}"] # Read in all the image files. + per_image_indices = [] + per_image_wcs = [] for i in tqdm( range(num_images), bar_format=_DEFAULT_WORKUNIT_TQDM_BAR, @@ -456,14 +455,13 @@ def from_fits(cls, filename, show_progress=None): # force_move destroys img object, but avoids a copy. im_stack.append_image(img, force_move=True) - # Check if we need to load the map of current images to constituent images - # from the (legacy) headers. - if "NIND" in sci_hdu.header: - n_indices = sci_hdu.header["NIND"] - sub_indices = [] - for j in range(n_indices): - sub_indices.append(sci_hdu.header[f"IND_{j}"]) - per_image_meta["per_image_indices"][i] = sub_indices + # Read the mapping of current image to constituent image from the header info. + # TODO: Serialize this. + n_indices = sci_hdu.header["NIND"] + sub_indices = [] + for j in range(n_indices): + sub_indices.append(sci_hdu.header[f"IND_{j}"]) + per_image_indices.append(sub_indices) # Extract the per-image data from header information if needed. This happens # when the WorkUnit was saved before metadata tables were saved as layers and @@ -474,19 +472,23 @@ def from_fits(cls, filename, show_progress=None): desc="Loading WCS", disable=not show_progress, ): + per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header)) if f"WCS_{i}" in hdul: - org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header) - org_image_meta["data_loc"][i] = hdul[f"WCS_{i}"].header["ILOC"] + wcs_header = hdul[f"WCS_{i}"].header + org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) + if "ILOC" in wcs_header: + org_image_meta["data_loc"][i] = wcs_header["ILOC"] if f"EBD_{i}" in hdul: - org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header) + org_image_meta["ebd_wcs"][i] = extract_wcs_from_hdu_header(hdul[f"EBD_{i}"].header) result = WorkUnit( im_stack=im_stack, config=config, wcs=global_wcs, + per_image_wcs=per_image_wcs, heliocentric_distance=heliocentric_distance, reprojected=reprojected, - image_meta=per_image_meta, + per_image_indices=per_image_indices, org_image_meta=org_image_meta, ) return result @@ -513,17 +515,22 @@ def to_fits(self, filename, overwrite=False): if Path(filename).is_file() and not overwrite: raise FileExistsError(f"WorkUnit file {filename} already exists.") - # Create an HDU list with the metadata layers. - hdul = self.metadata_to_hdul(include_wcs=False) + # Create an HDU list with the metadata layers, including all the WCS info. + hdul = self.metadata_to_hdul() # Create each image layer. for i in range(self.im_stack.img_count()): layered = self.im_stack.get_single_image(i) obstime = layered.get_obstime() + c_indices = self._per_image_indices[i] + n_indices = len(c_indices) img_wcs = self.get_wcs(i) sci_hdu = raw_image_to_hdu(layered.get_science(), obstime, img_wcs) sci_hdu.name = f"SCI_{i}" + sci_hdu.header["NIND"] = n_indices + for j in range(n_indices): + sci_hdu.header[f"IND_{j}"] = c_indices[j] hdul.append(sci_hdu) var_hdu = raw_image_to_hdu(layered.get_variance(), obstime) @@ -553,7 +560,6 @@ def to_sharded_fits(self, filename, directory, overwrite=False): image index infront of the given filename, e.g. "0_filename.fits". - Primary File: 0 - Primary header with overall metadata 1 or "metadata" - The data provenance metadata @@ -588,11 +594,16 @@ def to_sharded_fits(self, filename, directory, overwrite=False): for i in range(self.im_stack.img_count()): layered = self.im_stack.get_single_image(i) obstime = layered.get_obstime() + c_indices = self._per_image_indices[i] + n_indices = len(c_indices) sub_hdul = fits.HDUList() img_wcs = self.get_wcs(i) sci_hdu = raw_image_to_hdu(layered.get_science(), obstime, img_wcs) sci_hdu.name = f"SCI_{i}" + sci_hdu.header["NIND"] = n_indices + for j in range(n_indices): + sci_hdu.header[f"IND_{j}"] = c_indices[j] sub_hdul.append(sci_hdu) var_hdu = raw_image_to_hdu(layered.get_variance(), obstime) @@ -610,8 +621,8 @@ def to_sharded_fits(self, filename, directory, overwrite=False): sub_hdul.append(psf_hdu) sub_hdul.writeto(os.path.join(directory, f"{i}_{filename}")) - # Create a primary file with all of the metadata - hdul = self.metadata_to_hdul(include_wcs=True) + # Create a primary file with all of the metadata, including all the WCS info. + hdul = self.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename), overwrite=overwrite) @classmethod @@ -661,18 +672,12 @@ def from_sharded_fits(cls, filename, directory, lazy=False): n_constituents = primary[0].header["NCON"] if "NCON" in primary[0].header else num_images logger.info(f"Loading {num_images} images.") - # Read in the per-image metadata for the current images and the constituent images. + # Read in the per-image metadata for the constituent images. if "IMG_META" in primary: - logger.debug("Reading image metadata from IMG_META.") - per_image_meta = hdu_to_metadata_table(primary["IMG_META"]) + logger.debug("Reading original image metadata from IMG_META.") + org_image_meta = hdu_to_metadata_table(primary["IMG_META"]) else: - per_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=num_images) - - if "ORG_META" in primary: - logger.debug("Reading original image metadata from ORG_META.") - org_image_meta = hdu_to_metadata_table(primary["ORG_META"]) - else: - org_image_meta = WorkUnit.create_meta(constituent=True, data=None, n_images=n_constituents) + org_image_meta = WorkUnit.create_image_meta(data=None, n_images=n_constituents) # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension @@ -688,20 +693,23 @@ def from_sharded_fits(cls, filename, directory, lazy=False): if f"GEO_{i}" in primary[0].header: org_image_meta["geocentric_distance"][i] = primary[0].header[f"GEO_{i}"] - # Extract the per-image data from header innformation if needed. + # Extract the per-image data from header information if needed. # This happens with when the WorkUnit was saved before metadata tables were - # saved as layers.that we will fill in from the headers. + # saved as layers. + per_image_wcs = [] for i in range(n_constituents): + per_image_wcs.append(extract_wcs_from_hdu_header(primary[f"WCS_{i}"].header)) if f"WCS_{i}" in primary: - org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header( - primary[f"WCS_{i}"].header - ) - org_image_meta["data_loc"][i] = primary[f"WCS_{i}"].header["ILOC"] + wcs_header = primary[f"WCS_{i}"].header + org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) + if "ILOC" in wcs_header: + org_image_meta["data_loc"][i] = wcs_header["ILOC"] if f"EBD_{i}" in primary: org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header( primary[f"EBD_{i}"].header ) + per_image_indices = [] file_paths = [] obstimes = [] for i in range(num_images): @@ -721,37 +729,32 @@ def from_sharded_fits(cls, filename, directory, lazy=False): else: file_paths.append(shard_path) - if "NIND" in sci_hdu.header: - n_indices = sci_hdu.header["NIND"] - sub_indices = [] - for j in range(n_indices): - sub_indices.append(sci_hdu.header[f"IND_{j}"]) - per_image_meta["per_image_indices"][i] = sub_indices + # Load the mapping of current image to constituent image. + n_indices = sci_hdu.header["NIND"] + sub_indices = [] + for j in range(n_indices): + sub_indices.append(sci_hdu.header[f"IND_{j}"]) + per_image_indices.append(sub_indices) file_paths = None if not lazy else file_paths result = WorkUnit( im_stack=im_stack, config=config, wcs=global_wcs, + per_image_wcs=per_image_wcs, reprojected=reprojected, + lazy=lazy, heliocentric_distance=heliocentric_distance, + per_image_indices=per_image_indices, file_paths=file_paths, - lazy=lazy, - image_meta=per_image_meta, + obstimes=obstimes, org_image_meta=org_image_meta, ) return result - def metadata_to_hdul(self, include_wcs=True): + def metadata_to_hdul(self): """Creates the metadata fits headers. - Parameters - ---------- - include_wcs : `bool` - whether or not to append all the per image wcses - to the header (optional for the serial `to_fits` - case so that we can maintain the same indexing - as before). Returns ------- hdul : `astropy.io.fits.HDUList` @@ -776,9 +779,15 @@ def metadata_to_hdul(self, include_wcs=True): config_hdu.name = "kbmod_config" hdul.append(config_hdu) - # Save the additional metadata tables into HDUs - hdul.append(metadata_table_to_hdu(self.img_meta, "IMG_META")) - hdul.append(metadata_table_to_hdu(self.org_img_meta, "ORG_META")) + # Save the additional metadata table into HDUs + hdul.append(metadata_table_to_hdu(self.org_img_meta, "IMG_META")) + + # Save the WCS layers. + for idx, wcs in enumerate(self._per_image_wcs): + wcs_hdu = fits.TableHDU() + append_wcs_to_hdu_header(wcs, wcs_hdu.header) + wcs_hdu.name = f"WCS_{idx}" + hdul.append(wcs_hdu) return hdul @@ -865,7 +874,7 @@ def image_positions_to_original_icrs( positions = [] for i in image_indices: - inds = self.img_meta["per_image_indices"][i] + inds = self._per_image_indices[i] coord = inverted_coords[i] pos = [] for j in inds: diff --git a/tests/test_reprojection.py b/tests/test_reprojection.py index 1135c7b50..63c13defd 100644 --- a/tests/test_reprojection.py +++ b/tests/test_reprojection.py @@ -19,9 +19,6 @@ def setUp(self): self.test_wunit = WorkUnit.from_fits(self.data_path) self.common_wcs = self.test_wunit.get_wcs(0) - # self.tmp_dir = os.path(tempfile.TemporaryDirectory()) - # self.test_wunit.to_sharded_fits("test_wunit.fits", self.tmp_dir) - def test_reproject(self): # test exception conditions self.assertRaises( diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index a60abcaaa..dcb71dc90 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -171,6 +171,32 @@ def test_metadata_helpers(self): for i in range(len(md_table2)): self.assertTrue(isinstance(md_table2["wcs"][i], WCS)) + def test_create_image_meta(self): + # Empty constituent image data. + org_img_meta = WorkUnit.create_image_meta(n_images=3, data=None) + self.assertEqual(len(org_img_meta), 3) + self.assertTrue("data_loc" in org_img_meta.colnames) + self.assertTrue("ebd_wcs" in org_img_meta.colnames) + self.assertTrue("geocentric_distance" in org_img_meta.colnames) + self.assertTrue("original_wcs" in org_img_meta.colnames) + + # We can create from a dictionary. In this case we ignore n_images + data_dict = { + "uri": ["file1", "file2", "file3"], + "geocentric_distance": [1.0, 2.0, 3.0], + } + org_img_meta2 = WorkUnit.create_image_meta(n_images=5, data=data_dict) + self.assertEqual(len(org_img_meta2), 3) + self.assertTrue("data_loc" in org_img_meta2.colnames) + self.assertTrue("ebd_wcs" in org_img_meta2.colnames) + self.assertTrue("geocentric_distance" in org_img_meta2.colnames) + self.assertTrue("original_wcs" in org_img_meta2.colnames) + self.assertTrue("uri" in org_img_meta2.colnames) + + # We need either data or a positive number of images. + self.assertRaises(ValueError, WorkUnit.create_image_meta, None, None) + self.assertRaises(ValueError, WorkUnit.create_image_meta, None, -1) + def test_save_and_load_fits(self): with tempfile.TemporaryDirectory() as dir_name: file_path = os.path.join(dir_name, "test_workunit.fits") @@ -480,7 +506,7 @@ def test_image_positions_to_original_icrs_mosaicking(self): new_wcs = make_fake_wcs(190.0, -7.7888, 500, 700) work.org_img_meta["original_wcs"][-1] = new_wcs - work.img_meta["per_image_indices"] = [[0], [1], [2], [3, 4], [4]] + work._per_image_indices[3] = [3, 4] res = work.image_positions_to_original_icrs( self.indices, From 1c01a2023ced1a42881700e61a11722cc3ca54ae Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 14 Nov 2024 22:57:16 -0500 Subject: [PATCH 05/15] Update reprojection.py --- src/kbmod/reprojection.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 1d60e7230..3cfdcc8e2 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -423,8 +423,8 @@ def _reproject_work_unit_in_parallel( new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstime_indices - new_work_unit.reprojected = True new_work_unit.wcs = common_wcs + new_work_unit.reprojected = True hdul = new_work_unit.metadata_to_primary_hdul() hdul.writeto(os.path.join(directory, filename)) @@ -453,7 +453,7 @@ def _reproject_work_unit_in_parallel( config=work_unit.config, wcs=common_wcs, per_image_wcs=work_unit._per_image_wcs, - per_image_indices=unique_obstime_indices, + per_image_indices=unique_obstimes_indices, reprojected=True, org_image_meta=work_unit.org_img_meta, ) @@ -550,9 +550,9 @@ def reproject_lazy_work_unit( # We use new metadata for the new images and the same metadata for the original images. new_work_unit = copy(work_unit) - new_work_unit._per_image_indices = unique_obstime_indices - new_work_unit.reprojected = True + new_work_unit._per_image_indices = unique_obstimes_indices new_work_unit.wcs = common_wcs + new_work_unit.reprojected = True hdul = new_work_unit.metadata_to_primary_header() hdul.writeto(os.path.join(directory, filename)) From 5d446fed0e289097ec489694bd6c494d9e5fa40e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Thu, 14 Nov 2024 22:58:56 -0500 Subject: [PATCH 06/15] Update reprojection.py --- src/kbmod/reprojection.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 3cfdcc8e2..2dfd224bd 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -288,8 +288,8 @@ def _reproject_work_unit( if write_output: new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstime_indices - new_work_unit.reprojected = True new_work_unit.wcs = common_wcs + new_work_unit.reprojected = True hdul = new_work_unit.metadata_to_primary_hdul() hdul.writeto(os.path.join(directory, filename)) @@ -422,7 +422,7 @@ def _reproject_work_unit_in_parallel( raise RuntimeError("one or more jobs failed.") new_work_unit = copy(work_unit) - new_work_unit._per_image_indices = unique_obstime_indices + new_work_unit._per_image_indices = unique_obstimes_indices new_work_unit.wcs = common_wcs new_work_unit.reprojected = True From 1000ba507ead9f76d5ce2923f376f505988bf4fe Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 15 Nov 2024 09:43:53 -0500 Subject: [PATCH 07/15] bugfixes --- src/kbmod/image_collection.py | 7 +- src/kbmod/reprojection.py | 15 ++- src/kbmod/work_unit.py | 240 ++++++++++++++-------------------- tests/test_work_unit.py | 24 ++-- 4 files changed, 125 insertions(+), 161 deletions(-) diff --git a/src/kbmod/image_collection.py b/src/kbmod/image_collection.py index dce44354d..4e75155f5 100644 --- a/src/kbmod/image_collection.py +++ b/src/kbmod/image_collection.py @@ -888,7 +888,10 @@ def toWorkUnit(self, search_config=None, **kwargs): for std in self.get_standardizers(**kwargs): for img in std["std"].toLayeredImage(): layeredImages.append(img) + imgstack = ImageStack(layeredImages) + img_metadata = Table() + if None not in self.wcs: - return WorkUnit(imgstack, search_config, per_image_wcs=list(self.wcs)) - return WorkUnit(imgstack, search_config) + img_metadata["per_image_wcs"] = list(self.wcs) + return WorkUnit(imgstack, search_config, org_image_meta=img_metadata) diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 2dfd224bd..49e4a0102 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -198,7 +198,7 @@ def _reproject_work_unit( # Create a list of the correct WCS. We do this extraction once and reuse for all images. if frame == "original": - wcs_list = work_unit.get_constituent_meta("original_wcs") + wcs_list = work_unit.get_constituent_meta("per_image_wcs") elif frame == "ebd": wcs_list = work_unit.get_constituent_meta("ebd_wcs") else: @@ -286,19 +286,22 @@ def _reproject_work_unit( stack.append_image(new_layered_image, force_move=True) if write_output: + # Create a copy of the WorkUnit to write the global metadata. + # We preserve the metgadata for the consituent images. new_work_unit = copy(work_unit) new_work_unit._per_image_indices = unique_obstime_indices new_work_unit.wcs = common_wcs new_work_unit.reprojected = True - hdul = new_work_unit.metadata_to_primary_hdul() + hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename)) else: + # Create a new WorkUnit with the new ImageStack and global WCS. + # We preserve the metgadata for the consituent images. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, - per_image_wcs=work_unit._per_image_wcs, per_image_indices=unique_obstime_indices, reprojected=True, org_image_meta=work_unit.org_img_meta, @@ -426,7 +429,7 @@ def _reproject_work_unit_in_parallel( new_work_unit.wcs = common_wcs new_work_unit.reprojected = True - hdul = new_work_unit.metadata_to_primary_hdul() + hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename)) else: stack = ImageStack([]) @@ -447,12 +450,12 @@ def _reproject_work_unit_in_parallel( # sort by the time_stamp stack.sort_by_time() - # Add the imageStack to a new WorkUnit and return it. + # Add the imageStack to a new WorkUnit and return it. We preserve the metgadata + # for the consituent images. new_wunit = WorkUnit( im_stack=stack, config=work_unit.config, wcs=common_wcs, - per_image_wcs=work_unit._per_image_wcs, per_image_indices=unique_obstimes_indices, reprojected=True, org_image_meta=work_unit.org_img_meta, diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index fcf7d891b..05bf542da 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -47,7 +47,8 @@ class WorkUnit: config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. n_images : `int` - The number of current images. + The number of images. This may differ from the length of the ImageStack due + to lazy loading. n_constituents : `int` The number of original images making up the data in this WorkUnit. This might be different from the number of images stored in memory if the WorkUnit has been @@ -58,7 +59,7 @@ class WorkUnit: * ebd_wcs - Used to reproject the images into EBD space. * geocentric_distance - The best fit geocentric distances used when creating the per image EBD WCS. - * original_wcs - The original WCS of the image. + * original_wcs - The original per-image WCS of the image. wcs : `astropy.wcs.WCS` A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. @@ -84,23 +85,23 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. - wcs : `astropy.wcs.WCS` + num_images : `int`, optional + The number of images. This may differ from the length of the ImageStack due + to lazy loading. + wcs : `astropy.wcs.WCS`, optional A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. - per_image_wcs : `list` - A list with one WCS for each image in the WorkUnit. Used for when - the images have *not* been standardized to the same pixel space. - reprojected : `bool` + reprojected : `bool`, optional Whether or not the WorkUnit image data has been reprojected. - per_image_indices : `list` of `list` + per_image_indices : `list` of `list`, optional A list of lists containing the indicies of `constituent_images` at each layer of the `ImageStack`. Used for finding corresponding original images when we stitch images together during reprojection. - heliocentric_distance : `float` + heliocentric_distance : `float`, optional The heliocentric distance that was used when creating the `per_image_ebd_wcs`. - lazy : `bool` + lazy : `bool`, optional Whether or not to load the image data for the `WorkUnit`. - file_paths : `list[str]` + file_paths : `list[str]`, optional The paths for the shard files, only created if the `WorkUnit` is loaded in lazy mode. obstimes : `list[float]` @@ -111,10 +112,9 @@ class WorkUnit: def __init__( self, - im_stack=None, - config=None, + im_stack, + config, wcs=None, - per_image_wcs=None, reprojected=False, per_image_indices=None, heliocentric_distance=None, @@ -129,39 +129,24 @@ def __init__( self.file_paths = file_paths self._obstimes = obstimes - # Try to infer the number of images. This is a bit complex because of lazy loading. - # We test multiple sources in a predefined order. - for test_item in [im_stack, per_image_wcs, per_image_indices, obstimes]: - if test_item is not None and len(test_item) > 0: - self.n_images = len(test_item) - break + # Determine the number of constituent images. If we are given metadata for the + # of constituent_images, use that. Otherwise use the size of the image stack. + if org_image_meta is None: + self.n_constituents = im_stack.img_count() + else: + self.n_constituents = len(org_image_meta) # Track the metadata for each constituent image in the WorkUnit. If no constituent # data is provided, this will create an empty array the same size as the original. - no_org_img_meta_given = org_image_meta is None - self.org_img_meta = WorkUnit.create_image_meta( - data=org_image_meta, - n_images=self.n_images, - ) - self.n_constituents = len(self.org_img_meta) + self.org_img_meta = create_image_metadata(self.n_constituents, data=org_image_meta) # Handle WCS input. If both the global and per-image WCS are provided, # ensure they are consistent. self.wcs = wcs - if per_image_wcs is None: - self._per_image_wcs = [self.wcs for _ in range(self.n_images)] - else: - if len(per_image_wcs) != self.n_images: - raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_images}") - self._per_image_wcs = per_image_wcs - - if np.any(self._per_image_wcs == None): - warnings.warn("At least one image without a WCS.", Warning) - - # If no constituent data was provided, we save the current image's WCS as the original. - # This is needed to ensure that we always have a correct original WCS. - if no_org_img_meta_given: - self.org_img_meta["original_wcs"] = self._per_image_wcs + if np.all(self.org_img_meta["per_image_wcs"] == None): + self.org_img_meta["per_image_wcs"] = np.full(self.n_constituents, self.wcs) + if np.any(self.org_img_meta["per_image_wcs"] == None): + warnings.warn("At least one image was does not have a WCS.", Warning) # Set the global metadata for reprojection. self.reprojected = reprojected @@ -175,47 +160,21 @@ def __init__( else: self._per_image_indices = per_image_indices + # Run some basic validity checks. + if self.reprojected and self.wcs is None: + raise ValueError("Global WCS required for reprojected data.") + for inds in self._per_image_indices: + if np.max(inds) >= self.n_constituents: + raise ValueError( + f"Found pointer to constituents image {np.max(inds)} of {self.n_constituents}" + ) + def __len__(self): """Returns the size of the WorkUnit in number of images.""" return self.im_stack.img_count() - @staticmethod - def create_image_meta(data=None, n_images=None): - """Create an img_meta table, filling in default values - for any unspecified columns. - - Parameters - ---------- - data : `dict`, `astropy.table.Table`, or None - The data from which to seed the table. - n_images : `int`, optional - The number of images to include. Only use when no data is - provided in order to fill in defaults. - - Returns - ------- - img_meta : `astropy.table.Table` - The empty table of org_img_meta. - """ - if data is None: - if n_images is None or n_images <= 0: - raise ValueError("If no data provided 'n_images' must be >= 1. Is {n_images}") - - # Add a place holder column of the correct size. - data = Table({"_index": np.arange(n_images)}) - elif isinstance(data, dict): - data = Table(data) - elif isinstance(data, Table): - data = data.copy() - else: - raise TypeError("Unsupported type for data table.") - n_images = len(data) - - # Fill in the defaults for the original/constituent images. - for colname in ["data_loc", "ebd_wcs", "geocentric_distance", "original_wcs"]: - if colname not in data.colnames: - data[colname] = np.full(n_images, None) - return data + def get_num_images(self): + return len(self._per_image_indices) def get_constituent_meta(self, column): """Get the meta data values of a given column for all the constituent images. @@ -232,29 +191,6 @@ def get_constituent_meta(self, column): """ return list(self.org_img_meta[column].data) - def per_image_wcs_all_match(self, target=None): - """Check if all the per-image WCS are the same as a given target value. - - Parameters - ---------- - target : `astropy.wcs.WCS`, optional - The WCS to which to compare the per-image WCS. If None, checks that - all of the per-image WCS are None. - - Returns - ------- - result : `bool` - A Boolean indicating that all the per-images WCS match the target. - """ - for current in self._per_image_wcs: - if not wcs_fits_equal(current, target): - return False - return True - - def has_common_wcs(self): - """Returns whether the WorkUnit has a common WCS for all images.""" - return self.wcs is not None - def get_wcs(self, img_num): """Return the WCS for the a given image. Alway prioritizes a global WCS if one exits. @@ -272,7 +208,8 @@ def get_wcs(self, img_num): if self.wcs is not None: return self.wcs else: - return self._per_image_wcs[img_num] + # If there is no common WCS, use the original per-image one. + return self.org_img_meta["per_image_wcs"][img_num] def get_pixel_coordinates(self, ra, dec, times=None): """Get the pixel coordinates for pairs of (RA, dec) coordinates. Uses the global @@ -319,7 +256,7 @@ def get_pixel_coordinates(self, ra, dec, times=None): for i, index in enumerate(inds): if index == -1: raise ValueError(f"Unmatched time {times[i]}.") - current_wcs = self._per_image_wcs[index] + current_wcs = self.org_img_meta["per_image_wcs"][index] curr_x, curr_y = current_wcs.world_to_pixel( SkyCoord(ra=ra[i] * u.degree, dec=dec[i] * u.degree) ) @@ -411,9 +348,10 @@ def from_fits(cls, filename, show_progress=None): # Read in the per-image metadata for the constituent images. if "IMG_META" in hdul: logger.debug("Reading original image metadata from IMG_META.") - org_image_meta = hdu_to_metadata_table(hdul["IMG_META"]) + hdu_meta = hdu_to_image_metadata_table(hdul["IMG_META"]) else: - org_image_meta = WorkUnit.create_image_meta(data=None, n_images=n_constituents) + hdu_meta = None + org_image_meta = create_image_metadata(n_constituents, data=hdu_meta) # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension @@ -425,16 +363,15 @@ def from_fits(cls, filename, show_progress=None): # Misc. reprojection metadata reprojected = hdul[0].header["REPRJCTD"] heliocentric_distance = hdul[0].header["HELIO"] - if np.all(org_image_meta["geocentric_distance"] == None): - # If the metadata table does not have the geocentric_distance, try - # loading it from the primary header's GEO_i fields (legacy approach). - for i in range(n_constituents): - if f"GEO_{i}" in hdul[0].header: - org_image_meta["geocentric_distance"][i] = hdul[0].header[f"GEO_{i}"] + + # If there is geocentric distances in the header information + # (legacy approach), in read those. + for i in range(n_constituents): + if f"GEO_{i}" in hdul[0].header: + org_image_meta["geocentric_distance"][i] = hdul[0].header[f"GEO_{i}"] # Read in all the image files. per_image_indices = [] - per_image_wcs = [] for i in tqdm( range(num_images), bar_format=_DEFAULT_WORKUNIT_TQDM_BAR, @@ -456,7 +393,7 @@ def from_fits(cls, filename, show_progress=None): im_stack.append_image(img, force_move=True) # Read the mapping of current image to constituent image from the header info. - # TODO: Serialize this. + # TODO: Serialize this into its own table. n_indices = sci_hdu.header["NIND"] sub_indices = [] for j in range(n_indices): @@ -472,10 +409,9 @@ def from_fits(cls, filename, show_progress=None): desc="Loading WCS", disable=not show_progress, ): - per_image_wcs.append(extract_wcs_from_hdu_header(hdul[f"WCS_{i}"].header)) if f"WCS_{i}" in hdul: wcs_header = hdul[f"WCS_{i}"].header - org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) + org_image_meta["per_image_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) if "ILOC" in wcs_header: org_image_meta["data_loc"][i] = wcs_header["ILOC"] if f"EBD_{i}" in hdul: @@ -485,7 +421,6 @@ def from_fits(cls, filename, show_progress=None): im_stack=im_stack, config=config, wcs=global_wcs, - per_image_wcs=per_image_wcs, heliocentric_distance=heliocentric_distance, reprojected=reprojected, per_image_indices=per_image_indices, @@ -675,9 +610,10 @@ def from_sharded_fits(cls, filename, directory, lazy=False): # Read in the per-image metadata for the constituent images. if "IMG_META" in primary: logger.debug("Reading original image metadata from IMG_META.") - org_image_meta = hdu_to_metadata_table(primary["IMG_META"]) + hdu_meta = hdu_to_image_metadata_table(primary["IMG_META"]) else: - org_image_meta = WorkUnit.create_image_meta(data=None, n_images=n_constituents) + hdu_meta = None + org_image_meta = create_image_metadata(n_constituents, data=hdu_meta) # Read in the global WCS from extension 0 if the information exists. # We filter the warning that the image dimension does not match the WCS dimension @@ -696,18 +632,14 @@ def from_sharded_fits(cls, filename, directory, lazy=False): # Extract the per-image data from header information if needed. # This happens with when the WorkUnit was saved before metadata tables were # saved as layers. - per_image_wcs = [] for i in range(n_constituents): - per_image_wcs.append(extract_wcs_from_hdu_header(primary[f"WCS_{i}"].header)) if f"WCS_{i}" in primary: wcs_header = primary[f"WCS_{i}"].header - org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) + org_image_meta["per_image_wcs"][i] = extract_wcs_from_hdu_header(wcs_header) if "ILOC" in wcs_header: org_image_meta["data_loc"][i] = wcs_header["ILOC"] if f"EBD_{i}" in primary: - org_image_meta["original_wcs"][i] = extract_wcs_from_hdu_header( - primary[f"EBD_{i}"].header - ) + org_image_meta["ebd_wcs"][i] = extract_wcs_from_hdu_header(primary[f"EBD_{i}"].header) per_image_indices = [] file_paths = [] @@ -741,7 +673,6 @@ def from_sharded_fits(cls, filename, directory, lazy=False): im_stack=im_stack, config=config, wcs=global_wcs, - per_image_wcs=per_image_wcs, reprojected=reprojected, lazy=lazy, heliocentric_distance=heliocentric_distance, @@ -764,12 +695,12 @@ def metadata_to_hdul(self): # the metadata (empty), and the configuration. hdul = fits.HDUList() pri = fits.PrimaryHDU() - pri.header["NUMIMG"] = self.n_images + pri.header["NUMIMG"] = self.get_num_images() pri.header["NCON"] = self.n_constituents pri.header["REPRJCTD"] = self.reprojected pri.header["HELIO"] = self.heliocentric_distance - # If the global WCS exists, append the corresponding keys. + # If the global WCS exists, append the corresponding keys to the primary header. if self.wcs is not None: append_wcs_to_hdu_header(self.wcs, pri.header) hdul.append(pri) @@ -780,14 +711,7 @@ def metadata_to_hdul(self): hdul.append(config_hdu) # Save the additional metadata table into HDUs - hdul.append(metadata_table_to_hdu(self.org_img_meta, "IMG_META")) - - # Save the WCS layers. - for idx, wcs in enumerate(self._per_image_wcs): - wcs_hdu = fits.TableHDU() - append_wcs_to_hdu_header(wcs, wcs_hdu.header) - wcs_hdu.name = f"WCS_{idx}" - hdul.append(wcs_hdu) + hdul.append(image_metadata_table_to_hdu(self.org_img_meta, "IMG_META")) return hdul @@ -879,7 +803,7 @@ def image_positions_to_original_icrs( pos = [] for j in inds: con_image = self.org_img_meta["data_loc"][j] - con_wcs = self.org_img_meta["original_wcs"][j] + con_wcs = self.org_img_meta["per_image_wcs"][j] height, width = con_wcs.array_shape x, y = skycoord_to_pixel(coord, con_wcs) x, y = float(x), float(y) @@ -983,7 +907,41 @@ def raw_image_to_hdu(img, obstime, wcs=None): # ------------------------------------------------------------------ -def metadata_table_to_hdu(data, layer_name=None): +def create_image_metadata(n_images, data=None): + """Create an empty img_meta table, filling in default values + for any unspecified columns. + + Parameters + ---------- + n_images : `int` + The number of images to include. + data : `astropy.table.Table` + An existing table from which to fill in some of the data. + + Returns + ------- + img_meta : `astropy.table.Table` + The empty table of org_img_meta. + """ + if n_images <= 0: + raise ValueError("Invalid metadata size: {n_images}") + img_meta = Table() + + # Fill in the defaults. + for colname in ["data_loc", "ebd_wcs", "geocentric_distance", "per_image_wcs"]: + img_meta[colname] = np.full(n_images, None) + + # Fill in any values from the given table. This overwrites the defaults. + if data is not None: + if len(data) != n_images: + raise ValueError(f"Metadata size mismatch. Expected {n_images}. Found {len(data)}") + for colname in data.colnames: + img_meta[colname] = data[colname] + + return img_meta + + +def image_metadata_table_to_hdu(data, layer_name=None): """Create a HDU layer from an astropy table with custom encodings for some columns (such as WCS). @@ -1007,28 +965,28 @@ def metadata_table_to_hdu(data, layer_name=None): if np.all(col_data == None): # The entire column is filled with Nones (probably from a default value). - save_table[f"_EMPTY_{colname}"] = np.full(num_rows, "None") + save_table[f"_EMPTY_{colname}"] = np.full(num_rows, "None", dtype=str) elif isinstance(col_data[0], WCS): # Serialize WCS objects and use a custom tag so we can unserialize them. - values = np.array([serialize_wcs(entry) for entry in data[colname]]) + values = np.array([serialize_wcs(entry) for entry in data[colname]], dtype=str) save_table[f"_WCSSTR_{colname}"] = values else: save_table[colname] = data[colname] # Format the metadata as a single HDU - meta_hdu = fits.BinTableHDU(save_table) + meta_hdu = fits.TableHDU(save_table) if layer_name is not None: meta_hdu.name = layer_name return meta_hdu -def hdu_to_metadata_table(hdu): +def hdu_to_image_metadata_table(hdu): """Load a HDU layer with custom encodings for some columns (such as WCS) to an astropy table. Parameters ---------- - hdu : `astropy.io.fits.BinTableHDU` + hdu : `astropy.io.fits.TableHDU` The HDUList for the fits file. Returns diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index dcb71dc90..8e2a86c2f 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -17,8 +17,9 @@ from kbmod.reprojection_utils import fit_barycentric_wcs from kbmod.wcs_utils import make_fake_wcs, wcs_fits_equal from kbmod.work_unit import ( - hdu_to_metadata_table, - metadata_table_to_hdu, + create_image_metadata, + hdu_to_image_metadata_table, + image_metadata_table_to_hdu, raw_image_to_hdu, WorkUnit, ) @@ -107,7 +108,7 @@ def setUp(self): "data_loc": np.array(self.constituent_images), "ebd_wcs": np.array([self.per_image_ebd_wcs] * self.num_images), "geocentric_distance": np.array([self.geo_dist] * self.num_images), - "original_wcs": np.array(self.per_image_wcs), + "per_image_wcs": np.array(self.per_image_wcs), } ) @@ -158,11 +159,11 @@ def test_metadata_helpers(self): metadata_table = Table(metadata_dict) # Convert to an HDU - hdu = metadata_table_to_hdu(metadata_table) + hdu = image_metadata_table_to_hdu(metadata_table) self.assertIsNotNone(hdu) # Convert it back. - md_table2 = hdu_to_metadata_table(hdu) + md_table2 = hdu_to_image_metadata_table(hdu) self.assertEqual(len(md_table2.colnames), 5) npt.assert_array_equal(metadata_dict["col1"], md_table2["col1"]) npt.assert_array_equal(metadata_dict["uri"], md_table2["uri"]) @@ -173,29 +174,28 @@ def test_metadata_helpers(self): def test_create_image_meta(self): # Empty constituent image data. - org_img_meta = WorkUnit.create_image_meta(n_images=3, data=None) + org_img_meta = create_image_meta(n_images=3, data=None) self.assertEqual(len(org_img_meta), 3) self.assertTrue("data_loc" in org_img_meta.colnames) self.assertTrue("ebd_wcs" in org_img_meta.colnames) self.assertTrue("geocentric_distance" in org_img_meta.colnames) - self.assertTrue("original_wcs" in org_img_meta.colnames) + self.assertTrue("per_image_wcs" in org_img_meta.colnames) # We can create from a dictionary. In this case we ignore n_images data_dict = { "uri": ["file1", "file2", "file3"], "geocentric_distance": [1.0, 2.0, 3.0], } - org_img_meta2 = WorkUnit.create_image_meta(n_images=5, data=data_dict) + org_img_meta2 = create_image_meta(n_images=5, data=data_dict) self.assertEqual(len(org_img_meta2), 3) self.assertTrue("data_loc" in org_img_meta2.colnames) self.assertTrue("ebd_wcs" in org_img_meta2.colnames) self.assertTrue("geocentric_distance" in org_img_meta2.colnames) - self.assertTrue("original_wcs" in org_img_meta2.colnames) + self.assertTrue("per_image_wcs" in org_img_meta2.colnames) self.assertTrue("uri" in org_img_meta2.colnames) # We need either data or a positive number of images. - self.assertRaises(ValueError, WorkUnit.create_image_meta, None, None) - self.assertRaises(ValueError, WorkUnit.create_image_meta, None, -1) + self.assertRaises(ValueError, create_image_meta, None, -1) def test_save_and_load_fits(self): with tempfile.TemporaryDirectory() as dir_name: @@ -505,7 +505,7 @@ def test_image_positions_to_original_icrs_mosaicking(self): ) new_wcs = make_fake_wcs(190.0, -7.7888, 500, 700) - work.org_img_meta["original_wcs"][-1] = new_wcs + work.org_img_meta["per_image_wcs"][-1] = new_wcs work._per_image_indices[3] = [3, 4] res = work.image_positions_to_original_icrs( From 6d6449d81870405daee670d08a7afef948c883ae Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 15 Nov 2024 12:41:36 -0500 Subject: [PATCH 08/15] Bug fixes --- src/kbmod/image_collection.py | 34 ++++++++++++-- src/kbmod/reprojection.py | 2 +- .../fits_standardizers/kbmodv05.py | 6 ++- .../fits_standardizers/kbmodv1.py | 6 ++- src/kbmod/work_unit.py | 39 +++++++++------ tests/test_imagecollection.py | 9 ++++ tests/test_work_unit.py | 47 +++++++++---------- 7 files changed, 96 insertions(+), 47 deletions(-) diff --git a/src/kbmod/image_collection.py b/src/kbmod/image_collection.py index 4e75155f5..e764aa0ed 100644 --- a/src/kbmod/image_collection.py +++ b/src/kbmod/image_collection.py @@ -884,14 +884,40 @@ def toWorkUnit(self, search_config=None, **kwargs): from .work_unit import WorkUnit logger.info("Building WorkUnit from ImageCollection") + + # Create a storage location for additional image metadata to save. + # Not all of this may be present in from the ImageCollection, + # in which case we will append None. + metadata_vals = { + "visit": [], + } + + # Extract data from each standardizer and each LayeredImage within + # that standardizer. layeredImages = [] for std in self.get_standardizers(**kwargs): + num_added = 0 for img in std["std"].toLayeredImage(): layeredImages.append(img) + num_added += 1 - imgstack = ImageStack(layeredImages) - img_metadata = Table() + # Get each meta data value from the standardizer so it can be + # passed to the WorkUnit. Use the same value for all images from + # this standardizer. + metadata = std["std"].standardizeMetadata() + for col in metadata_vals.keys(): + value = metadata.get(col, None) + metadata_vals[col].extend([value] * num_added) + # Append the WCS information if we have it. if None not in self.wcs: - img_metadata["per_image_wcs"] = list(self.wcs) - return WorkUnit(imgstack, search_config, org_image_meta=img_metadata) + metadata_vals["per_image_wcs"] = list(self.wcs) + + # Save the metadata as a table, but prune it if empty. + image_metadata = Table(metadata_vals) if len(metadata_vals.keys()) > 0 else None + + # Create the basic WorkUnit from the ImageStack. + imgstack = ImageStack(layeredImages) + work = WorkUnit(imgstack, search_config, org_image_meta=image_metadata) + + return work diff --git a/src/kbmod/reprojection.py b/src/kbmod/reprojection.py index 49e4a0102..ea3883584 100644 --- a/src/kbmod/reprojection.py +++ b/src/kbmod/reprojection.py @@ -557,7 +557,7 @@ def reproject_lazy_work_unit( new_work_unit.wcs = common_wcs new_work_unit.reprojected = True - hdul = new_work_unit.metadata_to_primary_header() + hdul = new_work_unit.metadata_to_hdul() hdul.writeto(os.path.join(directory, filename)) diff --git a/src/kbmod/standardizers/fits_standardizers/kbmodv05.py b/src/kbmod/standardizers/fits_standardizers/kbmodv05.py index 315dca5df..1c45ee4d5 100644 --- a/src/kbmod/standardizers/fits_standardizers/kbmodv05.py +++ b/src/kbmod/standardizers/fits_standardizers/kbmodv05.py @@ -141,14 +141,15 @@ def __init__(self, location=None, hdulist=None, config=None, **kwargs): ] def translateHeader(self): - """Returns the following metadata, read from the primary header, as a - dictionary: + """Returns at least the following metadata, read from the primary header, + as a dictionary: ======== ========== =================================================== Key Header Key Description ======== ========== =================================================== mjd DATE-AVG Decimal MJD timestamp of the middle of the exposure filter FILTER Filter band + visit EXPID Exposure ID ======== ========== =================================================== """ # this is the 1 mandatory piece of metadata we need to extract @@ -159,6 +160,7 @@ def translateHeader(self): # these are all optional things standardizedHeader["filter"] = self.primary["FILTER"] + standardizedHeader["visit"] = self.primary["EXPID"] # If no observatory information is given, default to the Deccam data # (Cerro Tololo Inter-American Observatory). diff --git a/src/kbmod/standardizers/fits_standardizers/kbmodv1.py b/src/kbmod/standardizers/fits_standardizers/kbmodv1.py index 1982b8045..c5c1db87a 100644 --- a/src/kbmod/standardizers/fits_standardizers/kbmodv1.py +++ b/src/kbmod/standardizers/fits_standardizers/kbmodv1.py @@ -143,8 +143,9 @@ def translateHeader(self): Key Header Key Description ======== ========== =================================================== mjd_mid DATE-AVG Decimal MJD timestamp of the middle of the exposure - filter FILTER Filter band - visit_id IDNUM Visit ID + FILTER FILTER Filter band + visit EXPID Exposure ID + IDNUM IDNUM Visit ID observat OBSERVAT Observatory name obs_lat OBS-LAT Observatory Latitude obs_lon OBS-LONG Observatory Longitude @@ -164,6 +165,7 @@ def translateHeader(self): # these are all optional things standardizedHeader["FILTER"] = self.primary["FILTER"] standardizedHeader["IDNUM"] = self.primary["IDNUM"] + standardizedHeader["visit"] = self.primary["EXPID"] standardizedHeader["OBSID"] = self.primary["OBSID"] standardizedHeader["DTNSANAM"] = self.primary["DTNSANAM"] standardizedHeader["AIRMASS"] = self.primary["AIRMASS"] diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index 05bf542da..4cc5c382c 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -91,6 +91,9 @@ class WorkUnit: wcs : `astropy.wcs.WCS`, optional A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. + per_image_wcs : `list` + A list with one WCS for each image in the WorkUnit. Used for when + the images have *not* been standardized to the same pixel space. reprojected : `bool`, optional Whether or not the WorkUnit image data has been reprojected. per_image_indices : `list` of `list`, optional @@ -115,6 +118,7 @@ def __init__( im_stack, config, wcs=None, + per_image_wcs=None, reprojected=False, per_image_indices=None, heliocentric_distance=None, @@ -131,18 +135,24 @@ def __init__( # Determine the number of constituent images. If we are given metadata for the # of constituent_images, use that. Otherwise use the size of the image stack. - if org_image_meta is None: - self.n_constituents = im_stack.img_count() - else: + if org_image_meta is not None: self.n_constituents = len(org_image_meta) + elif per_image_wcs is not None: + self.n_constituents = len(per_image_wcs) + else: + self.n_constituents = im_stack.img_count() # Track the metadata for each constituent image in the WorkUnit. If no constituent # data is provided, this will create an empty array the same size as the original. self.org_img_meta = create_image_metadata(self.n_constituents, data=org_image_meta) - # Handle WCS input. If both the global and per-image WCS are provided, - # ensure they are consistent. + # Handle WCS input. If per_image_wcs is provided on the command line use that. + # If no per_image_wcs values are provided, use the global one. self.wcs = wcs + if per_image_wcs is not None: + if len(per_image_wcs) != self.n_constituents: + raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_constituents}") + self.org_img_meta["per_image_wcs"] = per_image_wcs if np.all(self.org_img_meta["per_image_wcs"] == None): self.org_img_meta["per_image_wcs"] = np.full(self.n_constituents, self.wcs) if np.any(self.org_img_meta["per_image_wcs"] == None): @@ -932,7 +942,7 @@ def create_image_metadata(n_images, data=None): img_meta[colname] = np.full(n_images, None) # Fill in any values from the given table. This overwrites the defaults. - if data is not None: + if data is not None and len(data) > 0: if len(data) != n_images: raise ValueError(f"Metadata size mismatch. Expected {n_images}. Found {len(data)}") for colname in data.colnames: @@ -963,18 +973,22 @@ def image_metadata_table_to_hdu(data, layer_name=None): for colname in data.colnames: col_data = data[colname].value - if np.all(col_data == None): - # The entire column is filled with Nones (probably from a default value). - save_table[f"_EMPTY_{colname}"] = np.full(num_rows, "None", dtype=str) + if data[colname].dtype != "O": + # If this is something we know how to encode (float, int, string), just add the column. + save_table[colname] = data[colname] + elif np.all(col_data == None): + # Skip completely empty columns. + logger.debug("Skipping empty metadata column {colname}") elif isinstance(col_data[0], WCS): # Serialize WCS objects and use a custom tag so we can unserialize them. values = np.array([serialize_wcs(entry) for entry in data[colname]], dtype=str) save_table[f"_WCSSTR_{colname}"] = values else: - save_table[colname] = data[colname] + # Try converting to a string. + save_table[colname] = data[colname].data.astype(str) # Format the metadata as a single HDU - meta_hdu = fits.TableHDU(save_table) + meta_hdu = fits.BinTableHDU(save_table) if layer_name is not None: meta_hdu.name = layer_name return meta_hdu @@ -1007,8 +1021,5 @@ def hdu_to_image_metadata_table(hdu): if colname.startswith("_WCSSTR_"): data[colname[8:]] = np.array([deserialize_wcs(entry) for entry in data[colname]]) data.remove_column(colname) - elif colname.startswith("_EMPTY_"): - data[colname[7:]] = np.array([None for _ in data[colname]]) - data.remove_column(colname) return data diff --git a/tests/test_imagecollection.py b/tests/test_imagecollection.py index d320e5d0a..1ecbbc633 100644 --- a/tests/test_imagecollection.py +++ b/tests/test_imagecollection.py @@ -69,6 +69,7 @@ def test_basics(self): "obs_elev", "FILTER", "IDNUM", + "visit", "OBSID", "DTNSANAM", "AIRMASS", @@ -156,6 +157,14 @@ def test_workunit(self): data = self.fitsFactory.get_n(3, spoof_data=True) ic = ImageCollection.fromTargets(data) wu = ic.toWorkUnit(search_config=SearchConfiguration()) + self.assertEqual(len(wu), 3) + + # We can retrieve the meta data from the WorkUnit. + filter_info = wu.get_constituent_meta("visit") + self.assertEqual(len(filter_info), 3) + self.assertIsNotNone(filter_info[0]) + + # We can write the whole work unit to a file. with tempfile.TemporaryDirectory() as dir_name: wu.to_fits(f"{dir_name}/test.fits") diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index 8e2a86c2f..cd1a5935f 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -135,16 +135,6 @@ def test_create(self): self.assertIsNotNone(work2.get_wcs(i)) self.assertTrue(wcs_fits_equal(self.wcs, work2.get_wcs(i))) - # Mismatch with the number of WCS. - self.assertRaises( - ValueError, - WorkUnit, - self.im_stack, - self.config, - self.wcs, - [self.wcs, self.wcs, self.wcs], - ) - def test_metadata_helpers(self): """Test that we can roundtrip an astropy table of metadata (including) WCS into a BinTableHDU. @@ -162,31 +152,33 @@ def test_metadata_helpers(self): hdu = image_metadata_table_to_hdu(metadata_table) self.assertIsNotNone(hdu) - # Convert it back. + # Convert it back. We should have dropped the column of all None. md_table2 = hdu_to_image_metadata_table(hdu) - self.assertEqual(len(md_table2.colnames), 5) + self.assertEqual(len(md_table2.colnames), 4) npt.assert_array_equal(metadata_dict["col1"], md_table2["col1"]) npt.assert_array_equal(metadata_dict["uri"], md_table2["uri"]) npt.assert_array_equal(metadata_dict["Other"], md_table2["Other"]) - self.assertTrue(np.all(md_table2["none_col"] == None)) + self.assertFalse("none_col" in md_table2.colnames) for i in range(len(md_table2)): self.assertTrue(isinstance(md_table2["wcs"][i], WCS)) - def test_create_image_meta(self): + def test_create_image_metadata(self): # Empty constituent image data. - org_img_meta = create_image_meta(n_images=3, data=None) + org_img_meta = create_image_metadata(3, data=None) self.assertEqual(len(org_img_meta), 3) self.assertTrue("data_loc" in org_img_meta.colnames) self.assertTrue("ebd_wcs" in org_img_meta.colnames) self.assertTrue("geocentric_distance" in org_img_meta.colnames) self.assertTrue("per_image_wcs" in org_img_meta.colnames) - # We can create from a dictionary. In this case we ignore n_images - data_dict = { - "uri": ["file1", "file2", "file3"], - "geocentric_distance": [1.0, 2.0, 3.0], - } - org_img_meta2 = create_image_meta(n_images=5, data=data_dict) + # We can create from a Table. + data = Table( + { + "uri": ["file1", "file2", "file3"], + "geocentric_distance": [1.0, 2.0, 3.0], + } + ) + org_img_meta2 = create_image_metadata(3, data) self.assertEqual(len(org_img_meta2), 3) self.assertTrue("data_loc" in org_img_meta2.colnames) self.assertTrue("ebd_wcs" in org_img_meta2.colnames) @@ -194,8 +186,15 @@ def test_create_image_meta(self): self.assertTrue("per_image_wcs" in org_img_meta2.colnames) self.assertTrue("uri" in org_img_meta2.colnames) - # We need either data or a positive number of images. - self.assertRaises(ValueError, create_image_meta, None, -1) + npt.assert_array_equal(org_img_meta2["geocentric_distance"], data["geocentric_distance"]) + npt.assert_array_equal(org_img_meta2["uri"], data["uri"]) + self.assertTrue(np.all(org_img_meta2["ebd_wcs"] == None)) + self.assertTrue(np.all(org_img_meta2["per_image_wcs"] == None)) + self.assertTrue(np.all(org_img_meta2["data_loc"] == None)) + + # We need a positive number of images that matches the length of data (if provided). + self.assertRaises(ValueError, create_image_metadata, -1, None) + self.assertRaises(ValueError, create_image_metadata, 2, data) def test_save_and_load_fits(self): with tempfile.TemporaryDirectory() as dir_name: @@ -218,7 +217,7 @@ def test_save_and_load_fits(self): config=self.config, wcs=None, per_image_wcs=self.diff_wcs, - org_image_meta=extra_meta, + org_image_meta=Table(extra_meta), ) work.to_fits(file_path) self.assertTrue(Path(file_path).is_file()) From 7f045b5fbda5441316dd32ec24b1f3764e11296f Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 15 Nov 2024 13:24:33 -0500 Subject: [PATCH 09/15] Fix comments --- src/kbmod/work_unit.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index 4cc5c382c..b14a09cc4 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -46,9 +46,6 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. - n_images : `int` - The number of images. This may differ from the length of the ImageStack due - to lazy loading. n_constituents : `int` The number of original images making up the data in this WorkUnit. This might be different from the number of images stored in memory if the WorkUnit has been @@ -85,15 +82,13 @@ class WorkUnit: The image data for the KBMOD run. config : `kbmod.configuration.SearchConfiguration` The configuration for the KBMOD run. - num_images : `int`, optional - The number of images. This may differ from the length of the ImageStack due - to lazy loading. wcs : `astropy.wcs.WCS`, optional A global WCS for all images in the WorkUnit. Only exists if all images have been projected to same pixel space. - per_image_wcs : `list` + per_image_wcs : `list`, optional A list with one WCS for each image in the WorkUnit. Used for when - the images have *not* been standardized to the same pixel space. + the images have *not* been standardized to the same pixel space. If provided + this will the WCS values in org_image_meta reprojected : `bool`, optional Whether or not the WorkUnit image data has been reprojected. per_image_indices : `list` of `list`, optional @@ -109,7 +104,7 @@ class WorkUnit: in lazy mode. obstimes : `list[float]` The MJD obstimes of the images. - org_image_meta : `dict` or `astropy.table.Table`, optional + org_image_meta : `astropy.table.Table`, optional A table of per-image data for the constituent images. """ @@ -143,10 +138,10 @@ def __init__( self.n_constituents = im_stack.img_count() # Track the metadata for each constituent image in the WorkUnit. If no constituent - # data is provided, this will create an empty array the same size as the original. + # data is provided, this will create a table of default values the correct size. self.org_img_meta = create_image_metadata(self.n_constituents, data=org_image_meta) - # Handle WCS input. If per_image_wcs is provided on the command line use that. + # Handle WCS input. If per_image_wcs is provided as an argument, use that. # If no per_image_wcs values are provided, use the global one. self.wcs = wcs if per_image_wcs is not None: @@ -1008,10 +1003,6 @@ def hdu_to_image_metadata_table(hdu): data : `astropy.table.Table` The table of loaded data. """ - if hdu is None: - # Nothing to decode. Return an empty table. - return Table() - data = Table(hdu.data) all_cols = set(data.colnames) From 603d442098d5b28fa99317b6335e14390ed8be1f Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 15 Nov 2024 14:52:20 -0500 Subject: [PATCH 10/15] Address Dinos comment in the PR --- src/kbmod/image_collection.py | 27 ++++----------------------- 1 file changed, 4 insertions(+), 23 deletions(-) diff --git a/src/kbmod/image_collection.py b/src/kbmod/image_collection.py index e764aa0ed..ff2b9ded6 100644 --- a/src/kbmod/image_collection.py +++ b/src/kbmod/image_collection.py @@ -885,39 +885,20 @@ def toWorkUnit(self, search_config=None, **kwargs): logger.info("Building WorkUnit from ImageCollection") - # Create a storage location for additional image metadata to save. - # Not all of this may be present in from the ImageCollection, - # in which case we will append None. - metadata_vals = { - "visit": [], - } - # Extract data from each standardizer and each LayeredImage within # that standardizer. layeredImages = [] for std in self.get_standardizers(**kwargs): - num_added = 0 for img in std["std"].toLayeredImage(): layeredImages.append(img) - num_added += 1 - - # Get each meta data value from the standardizer so it can be - # passed to the WorkUnit. Use the same value for all images from - # this standardizer. - metadata = std["std"].standardizeMetadata() - for col in metadata_vals.keys(): - value = metadata.get(col, None) - metadata_vals[col].extend([value] * num_added) - # Append the WCS information if we have it. + # Extract all of the relevant metadata from the ImageCollection. + metadata = Table(self.toBinTableHDU().data) if None not in self.wcs: - metadata_vals["per_image_wcs"] = list(self.wcs) - - # Save the metadata as a table, but prune it if empty. - image_metadata = Table(metadata_vals) if len(metadata_vals.keys()) > 0 else None + metadata["per_image_wcs"] = list(self.wcs) # Create the basic WorkUnit from the ImageStack. imgstack = ImageStack(layeredImages) - work = WorkUnit(imgstack, search_config, org_image_meta=image_metadata) + work = WorkUnit(imgstack, search_config, org_image_meta=metadata) return work From 137f83e5abc04666190930c6e81c23ba5f158784 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Mon, 18 Nov 2024 09:15:34 -0500 Subject: [PATCH 11/15] Pass more metadata through results --- src/kbmod/run_search.py | 33 ++++++++++++++++++++++++--------- src/kbmod/work_unit.py | 25 +++++++++++++++++++------ tests/test_regression_test.py | 9 ++++++++- tests/test_work_unit.py | 7 +++++++ 4 files changed, 58 insertions(+), 16 deletions(-) diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index faf21cede..dca56dea3 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -187,7 +187,7 @@ def do_gpu_search(self, config, stack, trj_generator): keep = self.load_and_filter_results(search, config) return keep - def run_search(self, config, stack, trj_generator=None, wcs=None): + def run_search(self, config, stack, trj_generator=None, wcs=None, extra_meta=None): """This function serves as the highest-level python interface for starting a KBMOD search given an ImageStack and SearchConfiguration. @@ -202,6 +202,8 @@ def run_search(self, config, stack, trj_generator=None, wcs=None): If None uses the default EclipticCenteredSearch wcs : `astropy.wcs.WCS`, optional A global WCS for all images in the search. + extra_meta : `dict`, optional + Any additional metadata to save as part of the results file. Returns ------- @@ -256,11 +258,11 @@ def run_search(self, config, stack, trj_generator=None, wcs=None): # Create and save any additional meta data that should be saved with the results. num_img = stack.img_count() - meta = { - "num_img": num_img, - "dims": (stack.get_width(), stack.get_height()), - "mjd_mid": [stack.get_obstime(i) for i in range(num_img)], - } + + meta_to_save = extra_meta.copy() + meta_to_save["num_img"] = num_img + meta_to_save["dims"] = stack.get_width(), stack.get_height() + meta_to_save["mjd_mid"] = [stack.get_obstime(i) for i in range(num_img)] # Save the results in as an ecsv file and/or a legacy text file. if config["legacy_filename"] is not None: @@ -270,9 +272,11 @@ def run_search(self, config, stack, trj_generator=None, wcs=None): if config["result_filename"] is not None: logger.info(f"Saving results table to {config['result_filename']}") if not config["save_all_stamps"]: - keep.write_table(config["result_filename"], cols_to_drop=["all_stamps"], extra_meta=meta) + keep.write_table( + config["result_filename"], cols_to_drop=["all_stamps"], extra_meta=meta_to_save + ) else: - keep.write_table(config["result_filename"], extra_meta=meta) + keep.write_table(config["result_filename"], extra_meta=meta_to_save) full_timer.stop() return keep @@ -292,5 +296,16 @@ def run_search_from_work_unit(self, work): """ trj_generator = create_trajectory_generator(work.config, work_unit=work) + # Extract extra metadata. We do not use the full org_image_meta table from the WorkUnit + # because this can be very large and varies with the source. Instead we only save a + # few pre-defined fields to the results data. + extra_meta = work.get_constituent_meta(["visit", "filter"]) + # Run the search. - return self.run_search(work.config, work.im_stack, trj_generator=trj_generator, wcs=work.wcs) + return self.run_search( + work.config, + work.im_stack, + trj_generator=trj_generator, + wcs=work.wcs, + extra_meta=extra_meta, + ) diff --git a/src/kbmod/work_unit.py b/src/kbmod/work_unit.py index b14a09cc4..f49c22bbb 100644 --- a/src/kbmod/work_unit.py +++ b/src/kbmod/work_unit.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable import os import warnings from pathlib import Path @@ -182,19 +183,31 @@ def get_num_images(self): return len(self._per_image_indices) def get_constituent_meta(self, column): - """Get the meta data values of a given column for all the constituent images. + """Get the metadata values of a given column or a list of columns + for all the constituent images. Parameters ---------- - column : `str` - The column name to fetch. + column : `str`, or Iterable + The column name(s) to fetch. Returns ------- - data : `list` - A list of the meta-data for each constituent image. + data : `list` or `dict` + If a single string column name is provided, the function returns the + values in a list. Otherwise it returns a dictionary, mapping + each column name to its values. """ - return list(self.org_img_meta[column].data) + if isinstance(column, str): + return self.org_img_meta[column].data.tolist() + elif isinstance(column, Iterable): + results = {} + for col in column: + if col in self.org_img_meta.colnames: + results[col] = self.org_img_meta[col].data.tolist() + return results + else: + raise TypeError(f"Unsupported column type {type(column)}") def get_wcs(self, img_num): """Return the WCS for the a given image. Alway prioritizes diff --git a/tests/test_regression_test.py b/tests/test_regression_test.py index 850b27219..af4ef37e6 100644 --- a/tests/test_regression_test.py +++ b/tests/test_regression_test.py @@ -266,7 +266,10 @@ def perform_search(im_stack, res_filename, default_psf): } config = SearchConfiguration.from_dict(input_parameters) - wu = WorkUnit(im_stack=im_stack, config=config) # , wcs=fake_wcs) + # Create fake visit metadata to confirm we pass it along. + wu = WorkUnit(im_stack=im_stack, config=config) + wu.org_img_meta["visit"] = [f"img_{i}" for i in range(im_stack.img_count())] + rs = SearchRunner() rs.run_search_from_work_unit(wu) @@ -340,6 +343,10 @@ def run_full_test(): assert loaded_data.table.meta["num_img"] == num_times assert loaded_data.table.meta["dims"] == (stack.get_width(), stack.get_height()) assert np.allclose(loaded_data.table.meta["mjd_mid"], times) + assert np.array_equal( + loaded_data.table.meta["visit"], + [f"img_{i}" for i in range(stack.img_count())], + ) # Determine which trajectories we did not recover. overlap = find_unique_overlap(trjs, found, 3.0, [0.0, 2.0]) diff --git a/tests/test_work_unit.py b/tests/test_work_unit.py index cd1a5935f..c331632cb 100644 --- a/tests/test_work_unit.py +++ b/tests/test_work_unit.py @@ -270,6 +270,13 @@ def test_save_and_load_fits(self): npt.assert_array_equal(work2.get_constituent_meta("int_index"), extra_meta["int_index"]) npt.assert_array_equal(work2.get_constituent_meta("data_loc"), self.constituent_images) + # Check that we can retrieve the extra metadata in a single request. + meta2 = work2.get_constituent_meta(["uri", "int_index", "nonexistent_column"]) + self.assertEqual(len(meta2), 2) + npt.assert_array_equal(meta2["uri"], extra_meta["uri"]) + npt.assert_array_equal(meta2["int_index"], extra_meta["int_index"]) + self.assertFalse("nonexistent_column" in extra_meta) + # We throw an error if we try to overwrite a file with overwrite=False self.assertRaises(FileExistsError, work.to_fits, file_path) From 1aa7354f57f9aa2a2f28ea94b5e2c566d6a58fb1 Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Mon, 18 Nov 2024 19:55:58 -0800 Subject: [PATCH 12/15] Add a KBMOD results filter for matching "known objects" (#741) * Add filter for known objects * Modify init with known object filter * Clean up comments and tests * Lint fixes * Refactored to KnownObjsMatcher and added filters * More refactoring and renaming * Separate match vs obs_valid filters and clean up comments. * Update test names and documentaiton * Format * Revert local change * Remove blank line * Address small comments * Fix time filter unit conversion and testing * Make obs_ratio and min_obs function parameters * Remove unneeded obs_match_ratio field --- src/kbmod/filters/__init__.py | 1 + src/kbmod/filters/known_object_filters.py | 438 +++++++++++++++ tests/test_known_object_filters.py | 627 ++++++++++++++++++++++ 3 files changed, 1066 insertions(+) create mode 100644 src/kbmod/filters/known_object_filters.py create mode 100644 tests/test_known_object_filters.py diff --git a/src/kbmod/filters/__init__.py b/src/kbmod/filters/__init__.py index 79d271cd0..f9e607cae 100644 --- a/src/kbmod/filters/__init__.py +++ b/src/kbmod/filters/__init__.py @@ -1,5 +1,6 @@ from . import ( clustering_filters, + known_object_filters, sigma_g_filter, stamp_filters, ) diff --git a/src/kbmod/filters/known_object_filters.py b/src/kbmod/filters/known_object_filters.py new file mode 100644 index 000000000..6c6cdc6ec --- /dev/null +++ b/src/kbmod/filters/known_object_filters.py @@ -0,0 +1,438 @@ +import astropy.units as u +import numpy as np +from astropy.coordinates import SkyCoord, search_around_sky + +import kbmod.search as kb +from kbmod.trajectory_utils import trajectory_predict_skypos +from collections import Counter + +logger = kb.Logging.getLogger(__name__) + + +class KnownObjsMatcher: + """ + A class which ingests an astopy table of object data expected to be found in the dataset + searched by KBMOD (either real objects or inserted synthetic fakes) and provides methods for + matching to the observations in a given set of KBMOD Results. + + It allows for configuration of how the matching is done, including the maximum + separation in arcseconds between a known object and a result to be considered a match, + the maximum time separation in seconds between a known object and the observation + used in a KBMOD result. + + In addition to modifying a KBMOD `Results` table to include columns for matched known objects, + it also provides methods for filtering the results based on the matches. This includes + marking observations that matched to known objects as invalid, and filtering out results that matched to known objects by + either the minimum number of observations that matched to that known object or the proportion + of observations from the catalog for that known object that were matched to a given result. + """ + + def __init__( + self, + table, + obstimes, + matcher_name, + sep_thresh=1.0, + time_thresh_s=600.0, + mjd_col="mjd_mid", + ra_col="RA", + dec_col="DEC", + name_col="Name", + ): + """ + Parameters + ---------- + table : astropy.table.Table + A table containing our catalog of observations of known objects. + obstimes : list(float) + The MJD times of each observation within KBMOD results we want to match to + the known objects. + matcher_name : str + The name of the filter to apply to the results. This both determines + the name of the column of matched observations which may be added to + the `Results` table and how the filtering and matching phases are identified within KBMOD logs. + sep_thresh : float, optional + The maximum separation in arcseconds between a known object and a result + to be considered a match. Default is 1.0. + time_thresh_s : float, optional + The maximum time separation in seconds between a known object and the observation + used in a KBMOD result. Default is 600.0. + mjd_col : str, optional + The name of the catalog column containing the MJD of the known objects. Default is "mjd_mid". + ra_col : str, optional + The name of the catalog column containing the RA of the known objects. Default is "RA". + dec_col : str, optional + The name of the catalog column containing the DEC of the known objects. Default is "DEC". + name_col : str, optional + The name of the catalog column containing the name of the known objects. Default is "Name". + + Raises + ------ + ValueError + If the required columns are not present in the table. + + Returns + ------- + KnownObjsMatcher + A KnownObjsMatcher object. + """ + self.data = table + + # Map our required columns to any specified column names. + self.mjd_col = mjd_col + self.ra_col = ra_col + self.dec_col = dec_col + self.name_col = name_col + + # Check that the required columns are present + user_cols = set([self.mjd_col, self.ra_col, self.dec_col, self.name_col]) + invalid_cols = user_cols - set(self.data.colnames) + if invalid_cols: + raise ValueError(f"{invalid_cols} not found in KnownObjs data.") + + self.obstimes = obstimes + if len(self.obstimes) == 0: + raise ValueError("No obstimes provided") + + self.matcher_name = matcher_name + self.sep_thresh = sep_thresh * u.arcsec + self.time_thresh_s = time_thresh_s + + # Pre-filter down our data to window of temporally relevant observations to speed up matching. + time_thresh_days = self.time_thresh_s / (24 * 3600) # Convert seconds to days + start_mjd = max(0, min(self.obstimes) - time_thresh_days - 1e-6) + end_mjd = max(self.obstimes) + time_thresh_days + 1e-6 + + # Filter out known object observations outside of our time thresholds + self.data = self.data[(self.data[self.mjd_col] >= start_mjd) & (self.data[self.mjd_col] <= end_mjd)] + + def match_min_obs_col(self, min_obs): + """A colummn name for objects that matched results based on the minimum number of observations.""" + return f"recovered_{self.matcher_name}_min_obs_{min_obs}" + + def match_obs_ratio_col(self, obs_ratio): + # A column name for objects that matched results based on the proportion of observations that + # matched to the known observations for that object within the catalog. + return f"recovered_{self.matcher_name}_obs_ratio_{obs_ratio}" + + def __len__(self): + """Returns the number of observations known objects of interest in this matcher's catalog.""" + return len(self.data) + + def get_mjd(self, ko_idx): + """ + Returns the MJD of the known object at a given index. + """ + return self.data[ko_idx][self.mjd_col] + + def get_ra(self, ko_idx): + """ + Returns the RA of the known object at a given index. + """ + return self.data[ko_idx][self.ra_col] + + def get_dec(self, ko_idx): + """ + Returns the DEC of the known object at a given index. + """ + return self.data[ko_idx][self.dec_col] + + def get_name(self, ko_idx): + """ + Returns the name of the known object at a given index. + """ + return self.data[ko_idx][self.name_col] + + def to_skycoords(self): + """ + Returns a SkyCoord representation of the known objects. + """ + return SkyCoord(ra=self.data[self.ra_col], dec=self.data[self.dec_col], unit="deg") + + def match(self, result_data, wcs): + """This function takes a list of results and matches them to known objects. + + This modifies the `Results` table by adding a column with name `self.matcher_name` that provides for each result a dictionary mapping the names of known + objects (as defined by the catalog's `name_col`) to a boolean array indicating which observations + in the result matched to that known object. Note that depending on the matching parameters, a result + can match to multiple known objects from the catalog even at the same observation time. + + So for a dataset with 5 observations a result matching to 2 known objects, A and B, might have an entry in the column `self.matcher_name` like: + ```{ + "A": [True, True, False, False, False], + "B": [False, False, False, True, True], + }``` + + Parameters + ---------- + result_data: `Results` + The set of results to filter. This data gets modified directly by + the filtering. + wcs: `astropy.wcs.WCS` + The common WCS object for the stack of images. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + """ + logger.info(f"Matching known objects to {len(result_data)} results using {self.matcher_name} filter") + all_matches = [] + + # Get the RA and DEC of the known objects and the trajectories of the results for matching + known_objs_ra_dec = self.to_skycoords() + trj_list = result_data.make_trajectory_list() + + for result_idx in range(len(result_data)): + # Generate (RA, Dec) pairs for all of the valid observations for this result trajectory + valid_obstimes = self.obstimes[result_data[result_idx]["obs_valid"]] + trj_skycoords = trajectory_predict_skypos(trj_list[result_idx], wcs, valid_obstimes) + + # Because we're only matching using the subset of obstimes that were valid for this result, we + # can use this to later map back to the original index of all observations in the stack. + trj_idx_to_obs_idx = np.where(result_data[result_idx]["obs_valid"])[0] + + # Now we can compare the SkyCoords of the known objects to the SkyCoords of the result trajectories using search_around_sky + # This will return a list of indices of known objects that are within sep_thresh of a trajectory + # Note that subsequent calls by default will use the same underlying KD-Tree iin coords2.cache. + trjs_idx, known_objs_idx, _, _ = search_around_sky( + trj_skycoords, known_objs_ra_dec, self.sep_thresh + ) + + # Now we can count per-known object how many observations matched within this result + matched_known_objs = {} + for t_idx, ko_idx in zip(trjs_idx, known_objs_idx): + # The observation spatially matched but now check that the time separation is witihin our threshold + if abs(self.get_mjd(ko_idx) - valid_obstimes[t_idx]) * 24 * 3600 <= self.time_thresh_s: + # The name of the object that matched to this observation + obj_name = self.get_name(ko_idx) + if obj_name not in matched_known_objs: + # Create an array of which observations match to this object. + # Note that we need to use the length of all obstimes, not just the presently valid ones + matched_known_objs[obj_name] = np.full(len(self.obstimes), False) + # Map to the original set of all obstimes (valid or invalid) since that's what we + # want for results filtering. + obs_idx = trj_idx_to_obs_idx[t_idx] + matched_known_objs[obj_name][obs_idx] = True + all_matches.append(matched_known_objs) + + # Add matches as a result column + result_data.table[self.matcher_name] = all_matches + + logger.info(f"Matched known objects to {len(result_data)} results using {self.matcher_name} filter") + + return result_data + + def mark_matched_obs_invalid( + self, + result_data, + drop_empty_rows=True, + ): + """ + Mark observations that matched to known objects as invalid, by default dropping + results that no longer have any valid observations. + + Note that a given result can match to multiple objects, and that we expect the + `Results` table to have a column with name corresponding to `self.matcher_name` that + contains which observations were matched to each known object. + + Parameters + ---------- + result_data : `Results` + The results to filter. + drop_empty_rows : bool, optional + If True, drop rows that have no valid observations after filtering. Default is True. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + """ + # Skip filtering if there is nothing to filter. + if len(result_data) == 0 or len(self.obstimes) == 0 or len(self.data) == 0: + return result_data + + if self.matcher_name not in result_data.table.colnames: + raise ValueError( + f"Column {self.matcher_name} not found in results table. Please run match() first." + ) + + matched_known_objs = result_data.table[self.matcher_name] + new_obs_valid = result_data["obs_valid"] + for result_idx in range(len(result_data)): + # A result can match to multiple objects, so we want to logically OR + # against all matching objects with a logical OR using np.any. + # We can then use bitwise NOT and AND to mark any previously valid + # observations that matched to known objects as invalid. + new_obs_valid[result_idx] &= ~np.any( + np.array(list(matched_known_objs[result_idx].values())), axis=0 + ) + + return result_data.update_obs_valid(new_obs_valid, drop_empty_rows=drop_empty_rows) + + def match_on_min_obs( + self, + result_data, + min_obs, + ): + """ + Create a column corresponding to the known objects that were matched to a result + based on the minimum number of observations that matched to that known object. + Note that the ratio is calculated based on the total number of observations + that were within `time_sep_thresh_s` of the `obstimes` we are matching to. Observations + outside of that time range are not considered. + + Note that a given result can match to multiple objects. + + Parameters + ---------- + result_data : `Results` + The results to filter. + min_obs : int + The minimum number of observations within a KBMOD result that must match to a known + object for that result to be considered a match. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + """ + matched_objs = [] + for idx in range(len(result_data)): + matched_objs.append(set([])) + matches = result_data[self.matcher_name][idx] + for name in matches: + if np.count_nonzero(matches[name]) >= min_obs: + matched_objs[-1].add(name) + result_data.table[self.match_min_obs_col(min_obs)] = matched_objs + + return result_data + + def match_on_obs_ratio( + self, + result_data, + obs_ratio, + ): + """ + Create a column corresponding to the known objects that were matched to a result + based on the proportion of observations that matched to that known object within the catalog. + + Note that a given result can match to multiple objects. + + Parameters + ---------- + result_data : `Results` + The results to filter. + obs_ratio : float + The minimum ratio of observations within a KBMOD result that must match to the total + observations within our catalog of known objects for that result to be considered a match. + Must be within the range [0, 1]. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + + Raises + ------ + ValueError + If `obs_ratio` is not within the range [0, 1]. + """ + if obs_ratio < 0 or obs_ratio > 1: + raise ValueError("obs_ratio must be within the range [0, 1].") + + # Create a dictionary of how many observations we have for each known object + # in our catalog + known_obj_cnts = dict(Counter(self.data[self.name_col])) + matched_objs = [] + for idx in range(len(result_data)): + matched_objs.append(set([])) + matches = result_data[self.matcher_name][idx] + for name in matches: + if name not in known_obj_cnts: + raise ValueError(f"Unknown known object {name}") + + curr_obs_ratio = np.count_nonzero(matches[name]) / known_obj_cnts[name] + if curr_obs_ratio <= obs_ratio: + matched_objs[-1].add(name) + + result_data.table[self.match_obs_ratio_col(obs_ratio)] = matched_objs + + return result_data + + def get_recovered_objects(self, result_data, match_col): + """ + Get the set of objects that were recovered or missed in the results. + + For our purposes, a recovered object is one that was matched to a result based on the + matching column of choice in the results table and a missing object are objects in + the catalog that were not matched. Note that not all catalogs may be + constructed in a way where all objects could be spatially present and + recoverable in the results. + + Parameters + ---------- + result_data : `Results` + The results to filter. + match_col : str + The name of the column in the results table that contains the matched objects. + + Returns + ------- + set, set + A tuple of sets where the first set contains the names of objects that were recovered + and the second set contains the names objects that were missed + + Raises + ------ + ValueError + If the `match_col` is not present in the results table + """ + if match_col not in result_data.table.colnames: + raise ValueError(f"Column {match_col} not found in results table.") + + if len(result_data) == 0 or len(self.data) == 0: + return set(), set() + + expected_objects = set(self.data[self.name_col]) + matched_objects = set() + for idx in range(len(result_data)): + matched_objects.update(result_data[match_col][idx]) + recovered_objects = matched_objects.intersection(expected_objects) + missed_objects = expected_objects - recovered_objects + + return recovered_objects, missed_objects + + def filter_matches(self, result_data, match_col): + """ + Filter out the results table to only include results that did not match to any known objects. + + Parameters + ---------- + result_data : `Results` + The results to filter. + match_col : str + The name of the column in the results table that contains the matched objects. + + Returns + ------- + `Results` + The modified `Results` object returned for chaining. + + Raises + ------ + ValueError + If the `match_col` is not present in the results table. + """ + if match_col not in result_data.table.colnames: + raise ValueError(f"Column {match_col} not found in results table.") + + if len(result_data) == 0: + return result_data + + # Only keep results that did not match to any known objects in our column + idx_to_keep = np.array([len(x) == 0 for x in result_data[match_col]]) + # Use the name of our matching column as the filter name + result_data = result_data.filter_rows(idx_to_keep, match_col) + + return result_data diff --git a/tests/test_known_object_filters.py b/tests/test_known_object_filters.py new file mode 100644 index 000000000..a6d2e7b11 --- /dev/null +++ b/tests/test_known_object_filters.py @@ -0,0 +1,627 @@ +import random +import unittest + +import numpy as np +from astropy.table import Table + +from kbmod.fake_data.fake_data_creator import FakeDataSet, create_fake_times +from kbmod.filters.known_object_filters import KnownObjsMatcher +from kbmod.results import Results +from kbmod.search import * +from kbmod.trajectory_utils import trajectory_predict_skypos +from kbmod.wcs_utils import make_fake_wcs + + +class TestKnownObjMatcher(unittest.TestCase): + def setUp(self): + # Seed for reproducibility of random generated trajectories + self.seed = 500 + np.random.seed(self.seed) + random.seed(self.seed) + + # Set up some default parameters for our matcher + self.matcher_name = "test_matches" + self.sep_thresh = 1.0 + self.time_thresh_s = 600.0 + + # Create a fake dataset with 15 x 10 images and 25 obstimes. + num_images = 25 + self.obstimes = np.array(create_fake_times(num_images)) + ds = FakeDataSet(15, 10, self.obstimes, use_seed=True) + self.wcs = make_fake_wcs(10.0, 15.0, 15, 10) + ds.set_wcs(self.wcs) + + # Randomly generate a Trajectory for each of our 10 results + num_results = 10 + for i in range(num_results): + ds.insert_random_object(self.seed) + self.res = Results.from_trajectories(ds.trajectories, track_filtered=True) + self.assertEqual(len(ds.trajectories), num_results) + + # Generate which observations are valid observations for each result + self.obs_valid = np.full((num_results, num_images), True) + for i in range(num_results): + # For each result include a random set of 5 invalid observations + invalid_obs = np.random.choice(num_images, 5, replace=False) + self.obs_valid[i][invalid_obs] = False + self.res.update_obs_valid(self.obs_valid) + assert set(self.res.table.columns) == set( + ["x", "y", "vx", "vy", "likelihood", "flux", "obs_count", "obs_valid"] + ) + + # Use the results' trajectories to generate a set of known objects that intersect our generated results in various + # ways. + self.known_objs = Table({"Name": np.empty(0, dtype=str), "RA": [], "DEC": [], "mjd_mid": []}) + + # Have the temporal offset for near and far objects be just below and above our time threshold + time_offset_mjd_close = (self.time_thresh_s - 1) / (24.0 * 3600) + time_offset_mjd_far = (self.time_thresh_s + 1) / (24.0 * 3600) + + # Case 1: Near in space and near in time just within the range of our filters to result 1 + self.generate_known_obj_from_result( + self.known_objs, + 1, # Base off result 1 + self.obstimes, # Use all possible obstimes + "spatial_close_time_close_1", + spatial_offset=0.00001, + time_offset=time_offset_mjd_close, + ) + + # Case 2 near in space to result 3, but farther in time. + self.generate_known_obj_from_result( + self.known_objs, + 3, # Base off result 3 + self.obstimes, # Use all possible obstimes + "spatial_close_time_far_3", + spatial_offset=0.0001, + time_offset=time_offset_mjd_far, + ) + + # Case 3: A similar trajectory to result 5, but farther in space with similar timestamps. + self.generate_known_obj_from_result( + self.known_objs, + 5, # Base off result 5 + self.obstimes, # Use all possible obstimes + "spatial_far_time_close_5", + spatial_offset=5, + time_offset=time_offset_mjd_close, + ) + + # Case 4: A similar trajectory to result 7, but far off spatially and temporally + self.generate_known_obj_from_result( + self.known_objs, + 7, # Base off result 7 + self.obstimes, # Use all possible obstimes + "spatial_far_time_far_7", + spatial_offset=5, + time_offset=time_offset_mjd_far, + ) + + # Case 5: a trajectory matching result 8 but with only a few observations. + self.generate_known_obj_from_result( + self.known_objs, + 8, # Base off result 8 + self.obstimes[::10], # Samples down to every 10th observation + "sparse_8", + spatial_offset=0.0001, + time_offset=time_offset_mjd_close, + ) + + def test_known_objs_matcher_init( + self, + ): # Test that a table with no columns specified raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table(), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no Name column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"RA": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no RA column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"Name": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no DEC column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"Name": [], "RA": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with no mjd_mid column raises a ValueError + with self.assertRaises(ValueError): + KnownObjsMatcher( + Table({"Name": [], "RA": [], "DEC": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Test that a table with all columns specified does not raise an error + correct = KnownObjsMatcher( + Table({"Name": [], "RA": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + self.assertEqual(0, len(correct)) + + # Test a table where we override the names for each column + self.assertEqual( + 0, + len( + KnownObjsMatcher( + Table({"my_Name": [], "my_RA": [], "my_DEC": [], "my_mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + mjd_col="my_mjd_mid", + ra_col="my_RA", + dec_col="my_DEC", + name_col="my_Name", + ) + ), + ) + + def generate_known_obj_from_result( + self, + known_obj_table, + res_idx, + obstimes, + name, + spatial_offset=0.0001, + time_offset=0.00025, + ): + """Helper function to generate a known object based on an existing result trajectory""" + trj_skycoords = trajectory_predict_skypos( + self.res.make_trajectory_list()[res_idx], + self.wcs, + obstimes, + ) + for i in range(len(obstimes)): + known_obj_table.add_row( + { + "Name": name, + "RA": trj_skycoords[i].ra.degree + spatial_offset, + "DEC": trj_skycoords[i].dec.degree + spatial_offset, + "mjd_mid": obstimes[i] + time_offset, + } + ) + + def test_known_objs_match_empty(self): + # Here we test the filter across various empty parameters + + # Test that the filter is not applied when no known objects were provided + empty_objs = KnownObjsMatcher( + Table({"Name": np.empty(0, dtype=str), "RA": [], "DEC": [], "mjd_mid": []}), + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + self.res = empty_objs.match( + self.res, + self.wcs, + ) + # Though there were no known objects, check that the results table still has rows + self.assertEqual(10, len(self.res)) + # We should still apply the matching column to the results table even if empty + matches = self.res[empty_objs.matcher_name] + self.assertEqual(0, sum([len(m.keys()) for m in matches])) + + # Test that we can apply the filter even when there are known results + self.res = empty_objs.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(10, len(self.res)) + + # Test that the filter is not applied when there were no results. + empty_res = Results() + empty_res = empty_objs.match( + empty_res, + self.wcs, + ) + matches = empty_res[empty_objs.matcher_name] + self.assertEqual(0, sum([len(m.keys()) for m in matches])) + + empty_res = empty_objs.mark_matched_obs_invalid(empty_res, drop_empty_rows=True) + self.assertEqual(0, len(empty_res)) + + def test_match(self): + # We expect to find only the objects close in time and space to our results, + # including one object matching closely to a result across all observations + # and also a sparsely represented object with only a few observations. + expected_matches = set(["spatial_close_time_close_1", "sparse_8"]) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Generate matches for the results according to the known objects + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[self.matcher_name] + # Assert the expected result + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Check that the close known object we inserted near result 1 is dropped + # But the sparsely observed known object will not get filtered out. + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(9, len(self.res)) + + # Check that the close known object we inserted near result 1 is present + self.assertEqual(len(matches[1]), 1) + self.assertTrue("spatial_close_time_close_1" in matches[1]) + + self.assertEqual(len(matches[8]), 1) + self.assertTrue("sparse_8" in matches[8]) + + # Check that no results other than results 1 and 8 have a match + for i in range(len(self.res)): + if i != 1 and i != 8: + self.assertEqual(0, len(matches[i])) + + def test_match_excessive_spatial_filtering(self): + # Here we only filter for exact spatial matches and should return no results + self.sep_thresh = 0.0 + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[matcher.matcher_name] + self.assertEqual(0, sum([len(m.keys()) for m in matches])) + + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(10, len(self.res)) + + def test_match_spatial_filtering(self): + # Here we use a filter that only matches spatially with an unreasonably generous time filter + self.time_thresh_s += 2 + # Our expected matches now include all objects that are close in space to our results regardless + # of the time offset we generated. + expected_matches = set(["spatial_close_time_close_1", "spatial_close_time_far_3", "sparse_8"]) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Performing matching + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[matcher.matcher_name] + + # Confirm that the expected matches are present + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Check that the close known objects we inserted are removed by valid obs filtering + # while the sparse known object does not fully filter out that result. + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(8, len(self.res)) + + # Check that the close known object we inserted near result 1 is present + self.assertEqual(1, len(matches[1])) + self.assertTrue("spatial_close_time_close_1" in matches[1]) + self.assertEqual( + np.count_nonzero(self.obs_valid[1]), + np.count_nonzero(matches[1]["spatial_close_time_close_1"]), + ) + + # Check that the close known object we inserted near result 3 is present + self.assertEqual(1, len(matches[3])) + self.assertTrue("spatial_close_time_far_3" in matches[3]) + self.assertEqual( + np.count_nonzero(self.obs_valid[3]), + np.count_nonzero(matches[3]["spatial_close_time_far_3"]), + ) + + # Check that the sparse known object we inserted near result 8 is present + self.assertEqual(1, len(matches[8])) + self.assertTrue("sparse_8" in matches[8]) + self.assertGreaterEqual( + len(self.known_objs[self.known_objs["Name"] == "sparse_8"]), + np.count_nonzero(matches[8]["sparse_8"]), + ) + + # Check that no results other than results 1 and 3 are full matches + # Since these are based off of random trajectories we can't guarantee there + # won't some overlapping observations. + for i in range(len(self.res)): + if i not in [1, 3]: + for obj_name in matches[i]: + self.assertGreater( + np.count_nonzero(self.obs_valid[i]), + np.count_nonzero(matches[i][obj_name]), + ) + + def test_match_temporal_filtering(self): + # Here we use a filter that only matches temporally with an unreasonably generous spatial filter + self.sep_thresh = 100000 + expected_matches = set(["spatial_close_time_close_1", "spatial_far_time_close_5", "sparse_8"]) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + + # Generate matches + self.res = matcher.match( + self.res, + self.wcs, + ) + matches = self.res[matcher.matcher_name] + + # Confirm that the expected matches are present + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Because we have objects that match to each observation temporally, + # a generous spatial filter will filter out all valid observations. + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(0, len(self.res)) + + for i in range(len(matches)): + self.assertEqual(expected_matches, set(matches[i].keys())) + # Check that all observations were matched to the known objects + for obj_name in matches[i]: + if obj_name == "sparse_8": + # The sparse object only has a few observations to match + self.assertGreaterEqual( + len(self.known_objs[self.known_objs["Name"] == "sparse_8"]), + np.count_nonzero(matches[i]["sparse_8"]), + ) + else: + # The other objects have a full set of observations to match + self.assertEqual( + np.count_nonzero(self.obs_valid[i]), + np.count_nonzero(matches[i][obj_name]), + ) + + def test_match_all(self): + # Here we use generous temporal and spatial filters to recover all objects + self.sep_thresh = 100000 + self.time_thresh_s = 1000000 + expected_matches = set( + [ + "spatial_close_time_close_1", + "spatial_close_time_far_3", + "spatial_far_time_close_5", + "spatial_far_time_far_7", + "sparse_8", + ] + ) + # Perform the matching + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + self.sep_thresh, + self.time_thresh_s, + ) + self.res = matcher.match( + self.res, + self.wcs, + ) + + # Here we expect to recover all of our known objects. + matches = self.res[matcher.matcher_name] + obs_matches = set() + for m in matches: + obs_matches.update(m.keys()) + self.assertEqual(expected_matches, obs_matches) + + # Each result should have matched to every object + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=True) + self.assertEqual(0, len(self.res)) + + # Check that every result matches to all of expected known objects + for i in range(len(matches)): + self.assertEqual(expected_matches, set(matches[i].keys())) + # Check that all observations were matched to the known objects since + # ven the most sparse object should match to every observation with + # our time filter. + for obj_name in matches[i]: + self.assertEqual( + np.count_nonzero(self.obs_valid[i]), + np.count_nonzero(matches[i][obj_name]), + ) + + def test_match_obs_ratio_invalid(self): + # Here we test that we raise an error for observation ratios outside of the valid range + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + ) + self.res = matcher.match(self.res, self.wcs) + + # Test some inavlid ratios outside of the range [0, 1] + with self.assertRaises(ValueError): + matcher.match_on_obs_ratio(self.res, 1.1) + with self.assertRaises(ValueError): + matcher.match_on_obs_ratio(self.res, -0.1) + + def test_match_obs_ratio(self): + # Here we test considering a known object recovered based on the ratio of observations + # in the catalog that were temporally within + min_obs_ratios = [ + 0.0, + 1.0, + ] + # The expected matching objects for each min_obs_ratio parameter chosen. + expected_matches = [ + set([]), + set(["spatial_close_time_close_1", "sparse_8"]), + ] + orig_res = self.res.table.copy() + for obs_ratio, expected in zip(min_obs_ratios, expected_matches): + self.res = Results(data=orig_res.copy()) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + matcher_name=self.matcher_name, + sep_thresh=self.sep_thresh, + time_thresh_s=self.time_thresh_s, + ) + + # Perform the intial matching + self.res = matcher.match( + self.res, + self.wcs, + ) + + # Validate that we did not filter any results by obstimes + assert self.matcher_name in self.res.table.columns + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=False) + self.assertEqual(10, len(self.res)) + + # Generate the column of which objects were "recovered" + matcher.match_on_obs_ratio(self.res, obs_ratio) + match_col = f"recovered_test_matches_obs_ratio_{obs_ratio}" + assert match_col in self.res.table.columns + assert match_col == matcher.match_obs_ratio_col(obs_ratio) + + # Verify that we recovered the expected matches + recovered, missed = matcher.get_recovered_objects( + self.res, matcher.match_obs_ratio_col(obs_ratio) + ) + self.assertEqual(expected, recovered) + # The missed object are all other known objects in our catalog - the expected objects + expected_missed = set(self.known_objs["Name"]) - expected + self.assertEqual(expected_missed, missed) + + # Verify that we filter out our expected results + matcher.filter_matches(self.res, match_col) + self.assertEqual(10 - len(expected), len(self.res)) + + def test_match_min_obs(self): + # Here we test considering a known object recovered based on the ratio of observations + # in the catalog that were temporally within + min_obs_settings = [ + 100, # No objects should be recovered since our catalog objects have fewer observations + 1, + 5, # The sparse object will not have enough observations to be recovered. + ] + expected_matches = [ + set([]), + set(["spatial_close_time_close_1", "sparse_8"]), + set(["spatial_close_time_close_1"]), + ] + orig_res = self.res.table.copy() + for min_obs, expected in zip(min_obs_settings, expected_matches): + self.res = Results(data=orig_res.copy()) + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + matcher_name=self.matcher_name, + sep_thresh=self.sep_thresh, + time_thresh_s=self.time_thresh_s, + ) + # Perform the initial matching + matcher.match( + self.res, + self.wcs, + ) + # Validate that we did not filter any results + assert self.matcher_name in self.res.table.columns + self.res = matcher.mark_matched_obs_invalid(self.res, drop_empty_rows=False) + self.assertEqual(10, len(self.res)) + + # Generate the recovered object column for a minimum number of observations + matcher.match_on_min_obs(self.res, min_obs) + match_col = f"recovered_test_matches_min_obs_{min_obs}" + assert match_col in self.res.table.columns + assert match_col == matcher.match_min_obs_col(min_obs) + + # Verify that we recovered the expected matches + recovered, missed = matcher.get_recovered_objects(self.res, matcher.match_min_obs_col(min_obs)) + self.assertEqual(expected, recovered) + # The missed object are all other known objects in our catalog - the expected objects + expected_missed = set(self.known_objs["Name"]) - expected + self.assertEqual(expected_missed, missed) + + # Verify that we filter out our expected results + matcher.filter_matches(self.res, match_col) + self.assertEqual(10 - len(expected), len(self.res)) + + def test_empty_filter_matches(self): + # Test that we can filter matches with an empty Results table + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + ) + # Adds a matching column to our empty table. + empty_res = matcher.match_on_obs_ratio(Results(), 0.5) + with self.assertRaises(ValueError): + # Test an inavlid matching column + matcher.filter_matches(empty_res, "empty") + + empty_res = matcher.filter_matches(empty_res, matcher.match_obs_ratio_col(0.5)) + self.assertEqual(0, len(empty_res)) + + def test_empty_get_recovered_objects(self): + # Test that we can get recovered objects with an empty Results table + matcher = KnownObjsMatcher( + self.known_objs, + self.obstimes, + self.matcher_name, + ) + # Adds a matching column to our empty table. + empty_res = matcher.match_on_min_obs(Results(), 5) + with self.assertRaises(ValueError): + # Test an inavlid matching column + matcher.get_recovered_objects(empty_res, "empty") + + recovered, missed = matcher.get_recovered_objects(empty_res, matcher.match_min_obs_col(5)) + self.assertEqual(0, len(recovered)) + self.assertEqual(0, len(missed)) From b59fc2fe90a9783fb310ac78cb0213383330232e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:15:31 -0500 Subject: [PATCH 13/15] Add a notebook for joining results and known objects --- notebooks/join_known_objects_example.ipynb | 166 +++++++++++++++++++++ src/kbmod/filters/known_object_filters.py | 30 ++-- src/kbmod/results.py | 32 ++++ tests/test_known_object_filters.py | 32 +++- tests/test_results.py | 27 ++++ 5 files changed, 272 insertions(+), 15 deletions(-) create mode 100644 notebooks/join_known_objects_example.ipynb diff --git a/notebooks/join_known_objects_example.ipynb b/notebooks/join_known_objects_example.ipynb new file mode 100644 index 000000000..bd1170beb --- /dev/null +++ b/notebooks/join_known_objects_example.ipynb @@ -0,0 +1,166 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "a89ad362-6ed6-489f-806c-2fd94fc9356d", + "metadata": {}, + "source": [ + "# Example Known Object Labeling\n", + "\n", + "This notebook serves as an example (and usable tool) for labeling objects in the results file as corresponding to a known object. It assumes the user has run KBMOD to produce a results .ecsv and has access to a table of known results.\n", + "\n", + "This notebook uses specific files and parameters from the DEEP reprocessing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b1d47da-db6f-4530-93a2-9cb83a6af145", + "metadata": {}, + "outputs": [], + "source": [ + "from astropy.table import Table\n", + "import numpy as np\n", + "\n", + "from kbmod.filters.known_object_filters import KnownObjsMatcher\n", + "from kbmod.results import Results" + ] + }, + { + "cell_type": "markdown", + "id": "43b6aaa7-9bf3-4ed7-8659-8daea4ed4765", + "metadata": {}, + "source": [ + "We start by loading the results data and known object data. The results data is the ecsv file produced by a KBMOD run and contains information on each trajectory found. The known object table is a given file with information on the location (RA, dec) of each observation at different time steps.\n", + "\n", + "We also extract the required metadata (global WCS and a list of all observation times) from the results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6ce8dcad-fa91-4267-b2cf-7c787ac0d665", + "metadata": {}, + "outputs": [], + "source": [ + "# These two files are specific to the UW DEEP reprocessing runs and should be replaced\n", + "# by the user's files of interest.\n", + "res_file = \"/epyc/projects/kbmod/runs/DEEP/results/20190402_A0b_001.results.ecsv\"\n", + "known_file = \"/epyc/projects/kbmod/data/fakes_detections_joined.fits\"\n", + "\n", + "in_results = Results.read_table(res_file)\n", + "print(f\"Loaded a results table with {len(res_file)} entries and columns:\\n{in_results.colnames}\")\n", + "\n", + "wcs = in_results.wcs\n", + "if wcs is None:\n", + " raise ValueError(\"WCS missing from results file.\")\n", + "\n", + "if \"mjd_mid\" in in_results.table.meta:\n", + " obstimes = np.array(in_results.table.meta[\"mjd_mid\"])\n", + "else:\n", + " raise ValueError(\"Metadata 'mjd_mid' missing from results file.\")\n", + "print(f\"Loaded {len(obstimes)} timestamps.\")\n", + "\n", + "known_table = Table.read(known_file)\n", + "print(f\"\\n\\nLoaded a known objects table with {len(known_table)} entries and columns:\\n{known_table.colnames}\")" + ] + }, + { + "cell_type": "markdown", + "id": "c06cc1e2-2186-4418-a746-1bbe5e7bc971", + "metadata": {}, + "source": [ + "We use the `KnownObjsMatcher` to determine which of the found results correspond to previously known objects. `KnownObjsMatcher` provides the ability to match by either the number or ratio of observations that are in close proximity to the known object. Here we use a minimum number with reasonable proximity thresholds in space and time." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "19eeeb00-b7f2-4969-8737-185f9161b34a", + "metadata": {}, + "outputs": [], + "source": [ + "min_obs = 10\n", + "fake_matcher = KnownObjsMatcher(\n", + " known_table,\n", + " obstimes,\n", + " matcher_name=\"known_matcher\",\n", + " sep_thresh=2.0, # Obs must be within 2 arcsecs.\n", + " time_thresh_s=600.0, # Obs must be within 10 minutes.\n", + " name_col=\"ORBITID\", # For the DEEP-data known objects only.\n", + ")\n", + "\n", + "# First create the matches column.\n", + "fake_matcher.match(in_results, wcs)\n", + "\n", + "# Second filter the matches.\n", + "fake_matcher.match_on_min_obs(in_results, min_obs)\n", + "\n", + "matched_col_name = fake_matcher.match_min_obs_col(min_obs)\n", + "print(f\"Matches stored in column '{matched_col_name}'\")" + ] + }, + { + "cell_type": "markdown", + "id": "2d8faadb-a496-4798-a0df-27712e4db8cd", + "metadata": {}, + "source": [ + "Iterate over the matched column computing a Boolean of whether there was any match (True if the match list is not empty). Add the resulting list as a new \"is_known\" column." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b6490988-242e-4a0f-b355-fcb1804d118d", + "metadata": {}, + "outputs": [], + "source": [ + "is_known = ~in_results.is_empty_value(matched_col_name)\n", + "in_results.table[\"is_known\"] = is_known\n", + "matched_count = np.count_nonzero(is_known)\n", + "\n", + "print(f\"Found {matched_count} of the {len(in_results)} results matched known objects.\")" + ] + }, + { + "cell_type": "markdown", + "id": "c4c79978-ed82-4261-abd4-14e745efb09a", + "metadata": {}, + "source": [ + "We could save the resulting joined table using:\n", + "```\n", + "in_results.write_table(output_filename)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56aadd08-aaed-4de5-a9eb-77ecafef1e55", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Jeremy's KBMOD", + "language": "python", + "name": "kbmod_jk" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/kbmod/filters/known_object_filters.py b/src/kbmod/filters/known_object_filters.py index 6c6cdc6ec..a3fa20f8a 100644 --- a/src/kbmod/filters/known_object_filters.py +++ b/src/kbmod/filters/known_object_filters.py @@ -22,9 +22,9 @@ class KnownObjsMatcher: In addition to modifying a KBMOD `Results` table to include columns for matched known objects, it also provides methods for filtering the results based on the matches. This includes - marking observations that matched to known objects as invalid, and filtering out results that matched to known objects by - either the minimum number of observations that matched to that known object or the proportion - of observations from the catalog for that known object that were matched to a given result. + marking observations that matched to known objects as invalid, and filtering out results that matched to + known objects by either the minimum number of observations that matched to that known object or the + proportion of observations from the catalog for that known object that were matched to a given result. """ def __init__( @@ -152,12 +152,13 @@ def to_skycoords(self): def match(self, result_data, wcs): """This function takes a list of results and matches them to known objects. - This modifies the `Results` table by adding a column with name `self.matcher_name` that provides for each result a dictionary mapping the names of known - objects (as defined by the catalog's `name_col`) to a boolean array indicating which observations - in the result matched to that known object. Note that depending on the matching parameters, a result - can match to multiple known objects from the catalog even at the same observation time. + This modifies the `Results` table by adding a column with name `self.matcher_name` that provides for each + result a dictionary mapping the names of known objects (as defined by the catalog's `name_col`) to a boolean + array indicating which observations in the result matched to that known object. Note that depending on the + matching parameters, a result can match to multiple known objects from the catalog even at the same observation time. - So for a dataset with 5 observations a result matching to 2 known objects, A and B, might have an entry in the column `self.matcher_name` like: + So for a dataset with 5 observations a result matching to 2 known objects, A and B, might have an entry in + the column `self.matcher_name` like: ```{ "A": [True, True, False, False, False], "B": [False, False, False, True, True], @@ -192,7 +193,8 @@ def match(self, result_data, wcs): # can use this to later map back to the original index of all observations in the stack. trj_idx_to_obs_idx = np.where(result_data[result_idx]["obs_valid"])[0] - # Now we can compare the SkyCoords of the known objects to the SkyCoords of the result trajectories using search_around_sky + # Now we can compare the SkyCoords of the known objects to the SkyCoords of the result trajectories + # using search_around_sky. # This will return a list of indices of known objects that are within sep_thresh of a trajectory # Note that subsequent calls by default will use the same underlying KD-Tree iin coords2.cache. trjs_idx, known_objs_idx, _, _ = search_around_sky( @@ -297,6 +299,11 @@ def match_on_min_obs( `Results` The modified `Results` object returned for chaining. """ + if self.matcher_name not in result_data.table.colnames: + raise ValueError( + f"Column {self.matcher_name} not found in results table. Please run match() first." + ) + matched_objs = [] for idx in range(len(result_data)): matched_objs.append(set([])) @@ -341,6 +348,11 @@ def match_on_obs_ratio( if obs_ratio < 0 or obs_ratio > 1: raise ValueError("obs_ratio must be within the range [0, 1].") + if self.matcher_name not in result_data.table.colnames: + raise ValueError( + f"Column {self.matcher_name} not found in results table. Please run match() first." + ) + # Create a dictionary of how many observations we have for each known object # in our catalog known_obj_cnts = dict(Counter(self.data[self.name_col])) diff --git a/src/kbmod/results.py b/src/kbmod/results.py index 9cd79c57f..15b6ebac6 100644 --- a/src/kbmod/results.py +++ b/src/kbmod/results.py @@ -480,6 +480,38 @@ def mask_based_on_invalid_obs(self, input_mat, mask_value): masked_mat[~self.table["obs_valid"]] = mask_value return masked_mat + def is_empty_value(self, colname): + """Create a Boolean vector indicating whether the entry in each row + is an 'empty' value (None or anything of length 0). Used to mark or + filter missing values. + + Parameter + --------- + colname : str + The name of the column to check. + + Returns + ------- + result : `numpy.ndarray` + An array of Boolean values indicating whether the result is + one of the empty values. + """ + if colname not in self.table.colnames: + raise KeyError(f"Querying unknown column {colname}") + + # Skip numeric types (integers, floats, etc.) + result = np.full(len(self.table), False) + if np.issubdtype(self.table[colname].dtype, np.number): + return result + + # Go through each entry and check whether it is None or something of length=0. + for idx, val in enumerate(self.table[colname]): + if val is None: + result[idx] = True + elif hasattr(val, "__len__") and len(val) == 0: + result[idx] = True + return result + def filter_rows(self, rows, label=""): """Filter the rows in the `Results` to only include those indices that are provided in a list of row indices (integers) or marked diff --git a/tests/test_known_object_filters.py b/tests/test_known_object_filters.py index a6d2e7b11..c19ad8029 100644 --- a/tests/test_known_object_filters.py +++ b/tests/test_known_object_filters.py @@ -595,15 +595,23 @@ def test_match_min_obs(self): def test_empty_filter_matches(self): # Test that we can filter matches with an empty Results table + empty_res = Results() matcher = KnownObjsMatcher( self.known_objs, self.obstimes, self.matcher_name, ) - # Adds a matching column to our empty table. - empty_res = matcher.match_on_obs_ratio(Results(), 0.5) + + # No matcher_name column in the data. + with self.assertRaises(ValueError): + _ = matcher.match_on_obs_ratio(empty_res, 0.5) + + # Do the match to add the columns. + matcher.match(empty_res, self.wcs) + empty_res = matcher.match_on_obs_ratio(empty_res, 0.5) + + # Test an invalid matching column with self.assertRaises(ValueError): - # Test an inavlid matching column matcher.filter_matches(empty_res, "empty") empty_res = matcher.filter_matches(empty_res, matcher.match_obs_ratio_col(0.5)) @@ -611,17 +619,29 @@ def test_empty_filter_matches(self): def test_empty_get_recovered_objects(self): # Test that we can get recovered objects with an empty Results table + empty_res = Results() matcher = KnownObjsMatcher( self.known_objs, self.obstimes, self.matcher_name, ) - # Adds a matching column to our empty table. - empty_res = matcher.match_on_min_obs(Results(), 5) + + # No matcher_name column in the data. + with self.assertRaises(ValueError): + _ = matcher.match_on_min_obs(empty_res, 5) + + # Do the match to add the columns. + matcher.match(empty_res, self.wcs) + empty_res = matcher.match_on_min_obs(empty_res, 5) + + # Test an invalid matching column with self.assertRaises(ValueError): - # Test an inavlid matching column matcher.get_recovered_objects(empty_res, "empty") recovered, missed = matcher.get_recovered_objects(empty_res, matcher.match_min_obs_col(5)) self.assertEqual(0, len(recovered)) self.assertEqual(0, len(missed)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_results.py b/tests/test_results.py index 84664e277..56bc73d73 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -251,6 +251,33 @@ def test_compute_likelihood_curves(self): ) self.assertTrue(np.array_equal(np.isfinite(lh_mat3), expected)) + def test_is_empty_value(self): + table = Results.from_trajectories(self.trj_list) + + # Create a two new columns: one with integers and the other with meaningless + # index pairs (three of which are empty) + nums_col = [i for i in range(len(table))] + table.table["nums"] = nums_col + + pairs_col = [(i, i + 1) for i in range(len(table))] + pairs_col[1] = None + pairs_col[3] = () + pairs_col[7] = () + table.table["pairs"] = pairs_col + + expected = [False] * len(table) + expected[1] = True + expected[3] = True + expected[7] = True + + # Check that we can tell which entries are empty. + nums_is_empty = table.is_empty_value("nums") + self.assertFalse(np.any(nums_is_empty)) + + pairs_is_empty = table.is_empty_value("pairs") + print(pairs_is_empty) + self.assertTrue(np.array_equal(pairs_is_empty, expected)) + def test_filter_by_index(self): table = Results.from_trajectories(self.trj_list) self.assertEqual(len(table), self.num_entries) From dca67575a0bf38053b614403ce48dbb016998de7 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 22 Nov 2024 12:19:26 -0500 Subject: [PATCH 14/15] Fix linting error --- notebooks/join_known_objects_example.ipynb | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/notebooks/join_known_objects_example.ipynb b/notebooks/join_known_objects_example.ipynb index bd1170beb..f73f639fb 100644 --- a/notebooks/join_known_objects_example.ipynb +++ b/notebooks/join_known_objects_example.ipynb @@ -62,7 +62,10 @@ "print(f\"Loaded {len(obstimes)} timestamps.\")\n", "\n", "known_table = Table.read(known_file)\n", - "print(f\"\\n\\nLoaded a known objects table with {len(known_table)} entries and columns:\\n{known_table.colnames}\")" + "print(\n", + " f\"\\n\\nLoaded a known objects table with {len(known_table)} entries \"\n", + " f\"and columns:\\n{known_table.colnames}\"\n", + ")" ] }, { From bd5bb4b1ccfb7c65d1251dc13d5fbc6d8c50717d Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Mon, 2 Dec 2024 12:39:41 -0800 Subject: [PATCH 15/15] Only run test_core_search_exact with GPUs (#747) --- tests/test_core_search_exact.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_core_search_exact.py b/tests/test_core_search_exact.py index f36f4dc52..fce4bb4c0 100644 --- a/tests/test_core_search_exact.py +++ b/tests/test_core_search_exact.py @@ -11,6 +11,7 @@ from kbmod.trajectory_generator import VelocityGridSearch +@unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") class test_search_exact(unittest.TestCase): def test_core_search_exact(self): # image properties