Skip to content

Commit

Permalink
Merge pull request #742 from dirac-institute/save_metadata
Browse files Browse the repository at this point in the history
Save metadata
  • Loading branch information
jeremykubica authored Nov 18, 2024
2 parents f02f9ac + 603d442 commit f910a3c
Show file tree
Hide file tree
Showing 10 changed files with 403 additions and 305 deletions.
16 changes: 13 additions & 3 deletions src/kbmod/image_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,11 +884,21 @@ def toWorkUnit(self, search_config=None, **kwargs):
from .work_unit import WorkUnit

logger.info("Building WorkUnit from ImageCollection")

# Extract data from each standardizer and each LayeredImage within
# that standardizer.
layeredImages = []
for std in self.get_standardizers(**kwargs):
for img in std["std"].toLayeredImage():
layeredImages.append(img)
imgstack = ImageStack(layeredImages)

# Extract all of the relevant metadata from the ImageCollection.
metadata = Table(self.toBinTableHDU().data)
if None not in self.wcs:
return WorkUnit(imgstack, search_config, per_image_wcs=list(self.wcs))
return WorkUnit(imgstack, search_config)
metadata["per_image_wcs"] = list(self.wcs)

# Create the basic WorkUnit from the ImageStack.
imgstack = ImageStack(layeredImages)
work = WorkUnit(imgstack, search_config, org_image_meta=metadata)

return work
31 changes: 18 additions & 13 deletions src/kbmod/reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def reproject_work_unit(
A `kbmod.WorkUnit` reprojected with a common `astropy.wcs.WCS`, or `None` in the case
where `write_output` is set to True.
"""
if work_unit.reprojected:
raise ValueError("Unable to reproject a reprojected WorkUnit.")

show_progress = is_interactive() if show_progress is None else show_progress
if (work_unit.lazy or write_output) and (directory is None or filename is None):
raise ValueError("can't write output to sharded fits without directory and filename provided.")
Expand Down Expand Up @@ -195,7 +198,7 @@ def _reproject_work_unit(

# Create a list of the correct WCS. We do this extraction once and reuse for all images.
if frame == "original":
wcs_list = work_unit.get_constituent_meta("original_wcs")
wcs_list = work_unit.get_constituent_meta("per_image_wcs")
elif frame == "ebd":
wcs_list = work_unit.get_constituent_meta("ebd_wcs")
else:
Expand Down Expand Up @@ -283,24 +286,25 @@ def _reproject_work_unit(
stack.append_image(new_layered_image, force_move=True)

if write_output:
# Create a copy of the WorkUnit to write the global metadata.
# We preserve the metgadata for the consituent images.
new_work_unit = copy(work_unit)
new_work_unit._per_image_indices = unique_obstime_indices
new_work_unit.wcs = common_wcs
new_work_unit.reprojected = True

hdul = new_work_unit.metadata_to_primary_header()
hdul = new_work_unit.metadata_to_hdul()
hdul.writeto(os.path.join(directory, filename))
else:
# Create a new WorkUnit with the new ImageStack and global WCS.
# We preserve the metgadata for the consituent images.
new_wunit = WorkUnit(
im_stack=stack,
config=work_unit.config,
wcs=common_wcs,
constituent_images=work_unit.get_constituent_meta("data_loc"),
per_image_wcs=work_unit._per_image_wcs,
per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"),
per_image_indices=unique_obstime_indices,
geocentric_distances=work_unit.get_constituent_meta("geocentric_distance"),
reprojected=True,
org_image_meta=work_unit.org_img_meta,
)

return new_wunit
Expand Down Expand Up @@ -425,7 +429,7 @@ def _reproject_work_unit_in_parallel(
new_work_unit.wcs = common_wcs
new_work_unit.reprojected = True

hdul = new_work_unit.metadata_to_primary_header()
hdul = new_work_unit.metadata_to_hdul()
hdul.writeto(os.path.join(directory, filename))
else:
stack = ImageStack([])
Expand All @@ -446,17 +450,15 @@ def _reproject_work_unit_in_parallel(
# sort by the time_stamp
stack.sort_by_time()

# Add the imageStack to a new WorkUnit and return it.
# Add the imageStack to a new WorkUnit and return it. We preserve the metgadata
# for the consituent images.
new_wunit = WorkUnit(
im_stack=stack,
config=work_unit.config,
wcs=common_wcs,
constituent_images=work_unit.get_constituent_meta("data_loc"),
per_image_wcs=work_unit._per_image_wcs,
per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"),
per_image_indices=unique_obstimes_indices,
geocentric_distances=work_unit.get_constituent_meta("geocentric_distances"),
reprojected=True,
org_image_meta=work_unit.org_img_meta,
)

return new_wunit
Expand Down Expand Up @@ -549,12 +551,13 @@ def reproject_lazy_work_unit(
if not result.result():
raise RuntimeError("one or more jobs failed.")

# We use new metadata for the new images and the same metadata for the original images.
new_work_unit = copy(work_unit)
new_work_unit._per_image_indices = unique_obstimes_indices
new_work_unit.wcs = common_wcs
new_work_unit.reprojected = True

hdul = new_work_unit.metadata_to_primary_header()
hdul = new_work_unit.metadata_to_hdul()
hdul.writeto(os.path.join(directory, filename))


Expand Down Expand Up @@ -591,6 +594,8 @@ def _validate_original_wcs(work_unit, indices, frame="original"):
else:
raise ValueError("Invalid projection frame provided.")

if len(original_wcs) == 0:
raise ValueError(f"No WCS found for frame {frame}")
if np.any(original_wcs) is None:
# find indices where the wcs is None
bad_indices = np.where(original_wcs == None)
Expand Down
6 changes: 4 additions & 2 deletions src/kbmod/standardizers/fits_standardizers/kbmodv05.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,15 @@ def __init__(self, location=None, hdulist=None, config=None, **kwargs):
]

def translateHeader(self):
"""Returns the following metadata, read from the primary header, as a
dictionary:
"""Returns at least the following metadata, read from the primary header,
as a dictionary:
======== ========== ===================================================
Key Header Key Description
======== ========== ===================================================
mjd DATE-AVG Decimal MJD timestamp of the middle of the exposure
filter FILTER Filter band
visit EXPID Exposure ID
======== ========== ===================================================
"""
# this is the 1 mandatory piece of metadata we need to extract
Expand All @@ -159,6 +160,7 @@ def translateHeader(self):

# these are all optional things
standardizedHeader["filter"] = self.primary["FILTER"]
standardizedHeader["visit"] = self.primary["EXPID"]

# If no observatory information is given, default to the Deccam data
# (Cerro Tololo Inter-American Observatory).
Expand Down
6 changes: 4 additions & 2 deletions src/kbmod/standardizers/fits_standardizers/kbmodv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def translateHeader(self):
Key Header Key Description
======== ========== ===================================================
mjd_mid DATE-AVG Decimal MJD timestamp of the middle of the exposure
filter FILTER Filter band
visit_id IDNUM Visit ID
FILTER FILTER Filter band
visit EXPID Exposure ID
IDNUM IDNUM Visit ID
observat OBSERVAT Observatory name
obs_lat OBS-LAT Observatory Latitude
obs_lon OBS-LONG Observatory Longitude
Expand All @@ -164,6 +165,7 @@ def translateHeader(self):
# these are all optional things
standardizedHeader["FILTER"] = self.primary["FILTER"]
standardizedHeader["IDNUM"] = self.primary["IDNUM"]
standardizedHeader["visit"] = self.primary["EXPID"]
standardizedHeader["OBSID"] = self.primary["OBSID"]
standardizedHeader["DTNSANAM"] = self.primary["DTNSANAM"]
standardizedHeader["AIRMASS"] = self.primary["AIRMASS"]
Expand Down
14 changes: 10 additions & 4 deletions src/kbmod/wcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,14 +450,17 @@ def serialize_wcs(wcs):
Parameters
----------
wcs : `astropy.wcs.WCS`
wcs : `astropy.wcs.WCS` or None
The WCS to convert.
Returns
-------
wcs_str : `str`
The serialized WCS.
The serialized WCS. Returns an empty string if wcs is None.
"""
if wcs is None:
return ""

# Since AstroPy's WCS does not output NAXIS, we need to manually add those.
header = wcs.to_header(relax=True)
header["NAXIS1"], header["NAXIS2"] = wcs.pixel_shape
Expand All @@ -474,9 +477,12 @@ def deserialize_wcs(wcs_str):
Returns
-------
wcs : `astropy.wcs.WCS`
The resulting WCS.
wcs : `astropy.wcs.WCS` or None
The resulting WCS or None if no data is provided.
"""
if wcs_str == "" or wcs_str.lower() == "none":
return None

wcs_dict = json.loads(wcs_str)
wcs = astropy.wcs.WCS(wcs_dict)
wcs.pixel_shape = (wcs_dict["NAXIS1"], wcs_dict["NAXIS2"])
Expand Down
Loading

0 comments on commit f910a3c

Please sign in to comment.