Skip to content

Commit

Permalink
Store more meta-data in the meta-datatable
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Nov 7, 2024
1 parent 679e4eb commit fd7e937
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 71 deletions.
55 changes: 26 additions & 29 deletions src/kbmod/reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ def reproject_work_unit(
The WCS to reproject all the images into.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
parallelize : `bool`
If True, use multiprocessing to reproject the images in parallel.
Default is True.
Expand Down Expand Up @@ -175,9 +174,8 @@ def _reproject_work_unit(
The WCS to reproject all the images into.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
write_output : `bool`
Whether or not to write the reprojection results out as a sharded `WorkUnit`.
directory : `str`
Expand All @@ -195,6 +193,14 @@ def _reproject_work_unit(
images = work_unit.im_stack.get_images()
unique_obstimes, unique_obstime_indices = work_unit.get_unique_obstimes_and_indices()

# Create a list of the correct WCS. We do this extraction once and reuse for all images.
if frame == "original":
wcs_list = work_unit.get_constituent_meta("original_wcs")
elif frame == "ebd":
wcs_list = work_unit.get_constituent_meta("ebd_wcs")
else:
raise ValueError("Invalid projection frame provided.")

stack = ImageStack()
for obstime_index, o_i in tqdm(
enumerate(zip(unique_obstimes, unique_obstime_indices)),
Expand All @@ -214,13 +220,7 @@ def _reproject_work_unit(
variance = image.get_variance()
mask = image.get_mask()

if frame == "original":
original_wcs = work_unit.get_wcs(index)
elif frame == "ebd":
original_wcs = work_unit._per_image_ebd_wcs[index]
else:
raise ValueError("Invalid projection frame provided.")

original_wcs = frame[index]
if original_wcs is None:
raise ValueError(f"No WCS provided for index {index}")

Expand Down Expand Up @@ -295,11 +295,11 @@ def _reproject_work_unit(
im_stack=stack,
config=work_unit.config,
wcs=common_wcs,
constituent_images=work_unit.constituent_images,
constituent_images=work_unit.get_constituent_meta("data_loc"),
per_image_wcs=work_unit._per_image_wcs,
per_image_ebd_wcs=work_unit._per_image_ebd_wcs,
per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"),
per_image_indices=unique_obstime_indices,
geocentric_distances=work_unit.geocentric_distances,
geocentric_distances=work_unit.get_constituent_meta("geocentric_distance"),
reprojected=True,
)

Expand Down Expand Up @@ -328,9 +328,8 @@ def _reproject_work_unit_in_parallel(
The WCS to reproject all the images into.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
max_parallel_processes : `int`
The maximum number of parallel processes to use when reprojecting.
Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in
Expand Down Expand Up @@ -452,11 +451,11 @@ def _reproject_work_unit_in_parallel(
im_stack=stack,
config=work_unit.config,
wcs=common_wcs,
constituent_images=work_unit.constituent_images,
constituent_images=work_unit.get_constituent_meta("data_loc"),
per_image_wcs=work_unit._per_image_wcs,
per_image_ebd_wcs=work_unit._per_image_ebd_wcs,
per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"),
per_image_indices=unique_obstimes_indices,
geocentric_distances=work_unit.geocentric_distances,
geocentric_distances=work_unit.get_constituent_meta("geocentric_distances"),
reprojected=True,
)

Expand Down Expand Up @@ -492,9 +491,8 @@ def reproject_lazy_work_unit(
shards).
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
max_parallel_processes : `int`
The maximum number of parallel processes to use when reprojecting.
Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in
Expand Down Expand Up @@ -572,9 +570,8 @@ def _validate_original_wcs(work_unit, indices, frame="original"):
The indices to be validated in work_unit.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
Returns
-------
Expand All @@ -590,7 +587,7 @@ def _validate_original_wcs(work_unit, indices, frame="original"):
if frame == "original":
original_wcs = [work_unit.get_wcs(i) for i in indices]
elif frame == "ebd":
original_wcs = [work_unit._per_image_ebd_wcs[i] for i in indices]
original_wcs = [work_unit.get_constituent_meta("ebd_wcs")[i] for i in indices]
else:
raise ValueError("Invalid projection frame provided.")

Expand Down
80 changes: 40 additions & 40 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ class WorkUnit:
The number of original images making up the data in this WorkUnit. This might be
different from the number of images stored in memory if the WorkUnit has been
reprojected.
img_meta : `astropy.table.Table`
The meta data for each constituent image.
org_img_meta : `astropy.table.Table`
The meta data for each constituent image. Includes columns:
* data_loc - the original location of the image
* ebd_wcs - Used to reproject the images into EBD space.
* geocentric_distance - The best fit geocentric distances used when creating
the per image EBD WCS.
* original_wcs - The original WCS of the image.
wcs : `astropy.wcs.WCS`
A global WCS for all images in the WorkUnit. Only exists
if all images have been projected to same pixel space.
per_image_wcs : `list`
A list with one WCS for each image in the WorkUnit. Used for when
the images have *not* been standardized to the same pixel space.
per_image_ebd_wcs : `list`
A list with one WCS for each image in the WorkUnit. Used to reproject the images
into EBD space.
heliocentric_distance : `float`
The heliocentric distance that was used when creating the `per_image_ebd_wcs`.
geocentric_distances : `list`
The best fit geocentric distances used when creating the `per_image_ebd_wcs`.
reprojected : `bool`
Whether or not the WorkUnit image data has been reprojected.
per_image_indices : `list` of `list`
Expand Down Expand Up @@ -136,10 +136,18 @@ def __init__(
self.file_paths = file_paths
self._obstimes = obstimes

# Track the meta data for each constituent image in the WorkUnit.
self.n_constituents = im_stack.img_count() if constituent_images is None else len(constituent_images)
self.img_meta = Table()
self.add_img_meta_data("data_loc", constituent_images)
# Determine the number of constituent images. If we are given a list of constituent_images,
# use that. Otherwise use the size of the image stack.
if constituent_images is None:
self.n_constituents = im_stack.img_count()
else:
self.n_constituents = len(constituent_images)

# Track the meta data for each constituent image in the WorkUnit. For the original
# WCS, we track the per-image WCS if it is provided and otherwise the global WCS.
self.org_img_meta = Table()
self.add_org_img_meta_data("data_loc", constituent_images)
self.add_org_img_meta_data("original_wcs", per_image_wcs, default=wcs)

# Handle WCS input. If both the global and per-image WCS are provided,
# ensure they are consistent.
Expand Down Expand Up @@ -167,19 +175,12 @@ def __init__(
# distances, and each images indices in the original constituent images.
self.reprojected = reprojected
self.heliocentric_distance = heliocentric_distance
self.add_org_img_meta_data("geocentric_distance", geocentric_distances)
self.add_org_img_meta_data("ebd_wcs", per_image_ebd_wcs)

if per_image_ebd_wcs is None:
self._per_image_ebd_wcs = [None] * self.n_constituents
else:
if len(per_image_ebd_wcs) != self.n_constituents:
raise ValueError(f"Incorrect number of EBD WCS provided. Expected {self.n_constituents}")
self._per_image_ebd_wcs = per_image_ebd_wcs

if geocentric_distances is None:
self.geocentric_distances = [None] * self.n_constituents
else:
self.geocentric_distances = geocentric_distances

# If we have mosaicked images, each image in the stack could link back
# to more than one constituents image. Build a mapping of image stack index
# to needed original image indices.
if per_image_indices is None:
self._per_image_indices = [[i] for i in range(self.n_constituents)]
else:
Expand All @@ -189,14 +190,13 @@ def __len__(self):
"""Returns the size of the WorkUnit in number of images."""
return self.im_stack.img_count()

@property
def constituent_images(self):
"""Alias constituent_images to the correct column of image meta data."""
return self.img_meta["data_loc"].data
def get_constituent_meta(self, column):
"""Get the meta data values of a given column for all the constituent images."""
return list(self.org_img_meta[column].data)

def add_img_meta_data(self, column, data):
def add_org_img_meta_data(self, column, data, default=None):
"""Add a column of meta data for the constituent images. Adds a column of all
None if data is None and the column does not already exist.
default values if data is None and the column does not already exist.
Parameters
----------
Expand All @@ -207,10 +207,10 @@ def add_img_meta_data(self, column, data):
each column.
"""
if data is None:
if column not in self.img_meta.colnames:
self.img_meta[column] = [None] * self.n_constituents
if column not in self.org_img_meta.colnames:
self.org_img_meta[column] = [default] * self.n_constituents
elif len(data) == self.n_constituents:
self.img_meta[column] = data
self.org_img_meta[column] = data
else:
raise ValueError(
f"Data mismatch size for WorkUnit metadata {column}. "
Expand Down Expand Up @@ -755,7 +755,7 @@ def metadata_to_primary_header(self, include_wcs=True):
pri.header["REPRJCTD"] = self.reprojected
pri.header["HELIO"] = self.heliocentric_distance
for i in range(self.n_constituents):
pri.header[f"GEO_{i}"] = self.geocentric_distances[i]
pri.header[f"GEO_{i}"] = self.org_img_meta["geocentric_distance"][i]

# If the global WCS exists, append the corresponding keys.
if self.wcs is not None:
Expand All @@ -777,16 +777,17 @@ def metadata_to_primary_header(self, include_wcs=True):
return hdul

def append_all_wcs(self, hdul):
"""Append the `_per_image_wcs` and
`_per_image_ebd_wcs` elements to a header.
"""Append all the original WCS and EBD WCS to a header.
Parameters
----------
hdul : `astropy.io.fits.HDUList`
The HDU list.
"""
all_ebd_wcs = self.get_constituent_meta("ebd_wcs")

for i in range(self.n_constituents):
img_location = self.img_meta["data_loc"][i]
img_location = self.org_img_meta["data_loc"][i]

orig_wcs = self._per_image_wcs[i]
wcs_hdu = fits.TableHDU()
Expand All @@ -795,9 +796,8 @@ def append_all_wcs(self, hdul):
wcs_hdu.header["ILOC"] = img_location
hdul.append(wcs_hdu)

im_ebd_wcs = self._per_image_ebd_wcs[i]
ebd_hdu = fits.TableHDU()
append_wcs_to_hdu_header(im_ebd_wcs, ebd_hdu.header)
append_wcs_to_hdu_header(all_ebd_wcs[i], ebd_hdu.header)
ebd_hdu.name = f"EBD_{i}"
hdul.append(ebd_hdu)

Expand Down Expand Up @@ -860,7 +860,7 @@ def image_positions_to_original_icrs(
position_ebd_coords = radec_coords

helio_dist = self.heliocentric_distance
geo_dists = [self.geocentric_distances[i] for i in image_indices]
geo_dists = [self.org_img_meta["geocentric_distance"][i] for i in image_indices]
all_times = self.get_all_obstimes()
obstimes = [all_times[i] for i in image_indices]

Expand All @@ -887,7 +887,7 @@ def image_positions_to_original_icrs(
coord = inverted_coords[i]
pos = []
for j in inds:
con_image = self.img_meta["data_loc"][j]
con_image = self.org_img_meta["data_loc"][j]
con_wcs = self._per_image_wcs[j]
height, width = con_wcs.array_shape
x, y = skycoord_to_pixel(coord, con_wcs)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def test_reproject(self):
assert reprojected_wunit.wcs != None
assert reprojected_wunit.im_stack.get_width() == 60
assert reprojected_wunit.im_stack.get_height() == 50
assert reprojected_wunit.geocentric_distances == self.test_wunit.geocentric_distances

test_dists = self.test_wunit.get_constituent_meta("geocentric_distance")
reproject_dists = reprojected_wunit.get_constituent_meta("geocentric_distance")
assert test_dists == reproject_dists

images = reprojected_wunit.im_stack.get_images()

# will be 3 as opposed to the four in the original `WorkUnit`,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_save_and_load_fits(self):
self.assertEqual(work2.config["num_obs"], self.num_images)

# Check that we correctly retrieved the provenance information via “data_loc”
for index, value in enumerate(work2.img_meta["data_loc"]):
for index, value in enumerate(work2.org_img_meta["data_loc"]):
self.assertEqual(value, self.constituent_images[index])

# We throw an error if we try to overwrite a file with overwrite=False
Expand Down

0 comments on commit fd7e937

Please sign in to comment.