Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the image-level meta data in WorkUnit #739

Merged
merged 5 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions src/kbmod/reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,8 @@ def reproject_work_unit(
The WCS to reproject all the images into.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
parallelize : `bool`
If True, use multiprocessing to reproject the images in parallel.
Default is True.
Expand Down Expand Up @@ -175,9 +174,8 @@ def _reproject_work_unit(
The WCS to reproject all the images into.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
write_output : `bool`
Whether or not to write the reprojection results out as a sharded `WorkUnit`.
directory : `str`
Expand All @@ -195,6 +193,14 @@ def _reproject_work_unit(
images = work_unit.im_stack.get_images()
unique_obstimes, unique_obstime_indices = work_unit.get_unique_obstimes_and_indices()

# Create a list of the correct WCS. We do this extraction once and reuse for all images.
if frame == "original":
wcs_list = work_unit.get_constituent_meta("original_wcs")
elif frame == "ebd":
wcs_list = work_unit.get_constituent_meta("ebd_wcs")
else:
raise ValueError("Invalid projection frame provided.")

stack = ImageStack()
for obstime_index, o_i in tqdm(
enumerate(zip(unique_obstimes, unique_obstime_indices)),
Expand All @@ -214,13 +220,7 @@ def _reproject_work_unit(
variance = image.get_variance()
mask = image.get_mask()

if frame == "original":
original_wcs = work_unit.get_wcs(index)
elif frame == "ebd":
original_wcs = work_unit._per_image_ebd_wcs[index]
else:
raise ValueError("Invalid projection frame provided.")

original_wcs = wcs_list[index]
if original_wcs is None:
raise ValueError(f"No WCS provided for index {index}")

Expand Down Expand Up @@ -295,11 +295,11 @@ def _reproject_work_unit(
im_stack=stack,
config=work_unit.config,
wcs=common_wcs,
constituent_images=work_unit.constituent_images,
constituent_images=work_unit.get_constituent_meta("data_loc"),
per_image_wcs=work_unit._per_image_wcs,
per_image_ebd_wcs=work_unit._per_image_ebd_wcs,
per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"),
per_image_indices=unique_obstime_indices,
geocentric_distances=work_unit.geocentric_distances,
geocentric_distances=work_unit.get_constituent_meta("geocentric_distance"),
reprojected=True,
)

Expand Down Expand Up @@ -328,9 +328,8 @@ def _reproject_work_unit_in_parallel(
The WCS to reproject all the images into.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
max_parallel_processes : `int`
The maximum number of parallel processes to use when reprojecting.
Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in
Expand Down Expand Up @@ -452,11 +451,11 @@ def _reproject_work_unit_in_parallel(
im_stack=stack,
config=work_unit.config,
wcs=common_wcs,
constituent_images=work_unit.constituent_images,
constituent_images=work_unit.get_constituent_meta("data_loc"),
per_image_wcs=work_unit._per_image_wcs,
per_image_ebd_wcs=work_unit._per_image_ebd_wcs,
per_image_ebd_wcs=work_unit.get_constituent_meta("ebd_wcs"),
per_image_indices=unique_obstimes_indices,
geocentric_distances=work_unit.geocentric_distances,
geocentric_distances=work_unit.get_constituent_meta("geocentric_distances"),
reprojected=True,
)

Expand Down Expand Up @@ -492,9 +491,8 @@ def reproject_lazy_work_unit(
shards).
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.
max_parallel_processes : `int`
The maximum number of parallel processes to use when reprojecting.
Default is 8. For more see `concurrent.futures.ProcessPoolExecutor` in
Expand Down Expand Up @@ -572,9 +570,8 @@ def _validate_original_wcs(work_unit, indices, frame="original"):
The indices to be validated in work_unit.
frame : `str`
The WCS frame of reference to use when reprojecting.
Can either be 'original' or 'ebd' to specify whether to
use the WorkUnit._per_image_wcs or ._per_image_ebd_wcs
respectively.
Can either be 'original' or 'ebd' to specify which WCS to access
from the WorkUnit.

Returns
-------
Expand All @@ -590,7 +587,7 @@ def _validate_original_wcs(work_unit, indices, frame="original"):
if frame == "original":
original_wcs = [work_unit.get_wcs(i) for i in indices]
elif frame == "ebd":
original_wcs = [work_unit._per_image_ebd_wcs[i] for i in indices]
original_wcs = [work_unit.get_constituent_meta("ebd_wcs")[i] for i in indices]
else:
raise ValueError("Invalid projection frame provided.")

Expand Down
145 changes: 104 additions & 41 deletions src/kbmod/work_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from astropy.coordinates import SkyCoord, EarthLocation
from astropy.io import fits
from astropy.table import Table
from astropy.time import Time
from astropy.utils.exceptions import AstropyWarning
from astropy.wcs.utils import skycoord_to_pixel
Expand All @@ -22,8 +23,6 @@
calc_ecliptic_angle,
extract_wcs_from_hdu_header,
wcs_fits_equal,
wcs_from_dict,
wcs_to_dict,
)


Expand All @@ -40,6 +39,45 @@ class WorkUnit:

Attributes
----------
im_stack : `kbmod.search.ImageStack`
The image data for the KBMOD run.
config : `kbmod.configuration.SearchConfiguration`
The configuration for the KBMOD run.
n_constituents : `int`
The number of original images making up the data in this WorkUnit. This might be
different from the number of images stored in memory if the WorkUnit has been
reprojected.
org_img_meta : `astropy.table.Table`
The meta data for each constituent image. Includes columns:
* data_loc - the original location of the image
* ebd_wcs - Used to reproject the images into EBD space.
* geocentric_distance - The best fit geocentric distances used when creating
the per image EBD WCS.
* original_wcs - The original WCS of the image.
wcs : `astropy.wcs.WCS`
A global WCS for all images in the WorkUnit. Only exists
if all images have been projected to same pixel space.
per_image_wcs : `list`
A list with one WCS for each image in the WorkUnit. Used for when
the images have *not* been standardized to the same pixel space.
heliocentric_distance : `float`
The heliocentric distance that was used when creating the `per_image_ebd_wcs`.
reprojected : `bool`
Whether or not the WorkUnit image data has been reprojected.
per_image_indices : `list` of `list`
A list of lists containing the indicies of `constituent_images` at each layer
of the `ImageStack`. Used for finding corresponding original images when we
stitch images together during reprojection.
lazy : `bool`
Whether or not to load the image data for the `WorkUnit`.
file_paths : `list[str]`
The paths for the shard files, only created if the `WorkUnit` is loaded
in lazy mode.
obstimes : `list[float]`
The MJD obstimes of the images.

Parameters
----------
im_stack : `kbmod.search.ImageStack`
The image data for the KBMOD run.
config : `kbmod.configuration.SearchConfiguration`
Expand Down Expand Up @@ -98,23 +136,29 @@ def __init__(
self.file_paths = file_paths
self._obstimes = obstimes

# Handle WCS input. If both the global and per-image WCS are provided,
# ensure they are consistent.
self.wcs = wcs
# Determine the number of constituent images. If we are given a list of constituent_images,
# use that. Otherwise use the size of the image stack.
if constituent_images is None:
n_constituents = im_stack.img_count()
self.constituent_images = [None] * n_constituents
self.n_constituents = im_stack.img_count()
else:
n_constituents = len(constituent_images)
self.constituent_images = constituent_images
self.n_constituents = len(constituent_images)

# Track the meta data for each constituent image in the WorkUnit. For the original
# WCS, we track the per-image WCS if it is provided and otherwise the global WCS.
self.org_img_meta = Table()
self.add_org_img_meta_data("data_loc", constituent_images)
self.add_org_img_meta_data("original_wcs", per_image_wcs, default=wcs)

# Handle WCS input. If both the global and per-image WCS are provided,
# ensure they are consistent.
self.wcs = wcs
if per_image_wcs is None:
self._per_image_wcs = [None] * n_constituents
self._per_image_wcs = [None] * self.n_constituents
if self.wcs is None and per_image_ebd_wcs is None:
warnings.warn("No WCS provided.", Warning)
else:
if len(per_image_wcs) != n_constituents:
raise ValueError(f"Incorrect number of WCS provided. Expected {n_constituents}")
if len(per_image_wcs) != self.n_constituents:
raise ValueError(f"Incorrect number of WCS provided. Expected {self.n_constituents}")
self._per_image_wcs = per_image_wcs

# Check if all the per-image WCS are None. This can happen during a load.
Expand All @@ -127,31 +171,52 @@ def __init__(
self.wcs = self._per_image_wcs[0]
self._per_image_wcs = [None] * im_stack.img_count()

# TODO: Refactor all of this code to make it cleaner

if per_image_ebd_wcs is None:
self._per_image_ebd_wcs = [None] * n_constituents
else:
if len(per_image_ebd_wcs) != n_constituents:
raise ValueError(f"Incorrect number of EBD WCS provided. Expected {n_constituents}")
self._per_image_ebd_wcs = per_image_ebd_wcs

if geocentric_distances is None:
self.geocentric_distances = [None] * n_constituents
else:
self.geocentric_distances = geocentric_distances

self.heliocentric_distance = heliocentric_distance
# Add the meta data needed for reprojection, including: the reprojected WCS, the geocentric
# distances, and each images indices in the original constituent images.
self.reprojected = reprojected
self.heliocentric_distance = heliocentric_distance
self.add_org_img_meta_data("geocentric_distance", geocentric_distances)
self.add_org_img_meta_data("ebd_wcs", per_image_ebd_wcs)

# If we have mosaicked images, each image in the stack could link back
# to more than one constituents image. Build a mapping of image stack index
# to needed original image indices.
if per_image_indices is None:
self._per_image_indices = [[i] for i in range(len(self.constituent_images))]
self._per_image_indices = [[i] for i in range(self.n_constituents)]
else:
self._per_image_indices = per_image_indices

def __len__(self):
"""Returns the size of the WorkUnit in number of images."""
return self.im_stack.img_count()

def get_constituent_meta(self, column):
"""Get the meta data values of a given column for all the constituent images."""
return list(self.org_img_meta[column].data)

def add_org_img_meta_data(self, column, data, default=None):
"""Add a column of meta data for the constituent images. Adds a column of all
default values if data is None and the column does not already exist.

Parameters
----------
column : `str`
The name of the meta data column.
data : list-like
The data for each constituent. If None then uses None for
each column.
"""
if data is None:
if column not in self.org_img_meta.colnames:
self.org_img_meta[column] = [default] * self.n_constituents
elif len(data) == self.n_constituents:
self.org_img_meta[column] = data
else:
raise ValueError(
f"Data mismatch size for WorkUnit metadata {column}. "
f"Expected {self.n_constituents} but found {len(data)}."
)

def has_common_wcs(self):
"""Returns whether the WorkUnit has a common WCS for all images."""
return self.wcs is not None
Expand Down Expand Up @@ -686,17 +751,16 @@ def metadata_to_primary_header(self, include_wcs=True):
hdul = fits.HDUList()
pri = fits.PrimaryHDU()
pri.header["NUMIMG"] = self.get_num_images()
pri.header["NCON"] = self.n_constituents
pri.header["REPRJCTD"] = self.reprojected
pri.header["HELIO"] = self.heliocentric_distance
for i in range(len(self.constituent_images)):
pri.header[f"GEO_{i}"] = self.geocentric_distances[i]
for i in range(self.n_constituents):
pri.header[f"GEO_{i}"] = self.org_img_meta["geocentric_distance"][i]

# If the global WCS exists, append the corresponding keys.
if self.wcs is not None:
append_wcs_to_hdu_header(self.wcs, pri.header)

pri.header["NCON"] = len(self.constituent_images)

hdul.append(pri)

meta_hdu = fits.BinTableHDU()
Expand All @@ -713,17 +777,17 @@ def metadata_to_primary_header(self, include_wcs=True):
return hdul

def append_all_wcs(self, hdul):
"""Append the `_per_image_wcs` and
`_per_image_ebd_wcs` elements to a header.
"""Append all the original WCS and EBD WCS to a header.

Parameters
----------
hdul : `astropy.io.fits.HDUList`
The HDU list.
"""
n_constituents = len(self.constituent_images)
for i in range(n_constituents):
img_location = self.constituent_images[i]
all_ebd_wcs = self.get_constituent_meta("ebd_wcs")

for i in range(self.n_constituents):
img_location = self.org_img_meta["data_loc"][i]

orig_wcs = self._per_image_wcs[i]
wcs_hdu = fits.TableHDU()
Expand All @@ -732,9 +796,8 @@ def append_all_wcs(self, hdul):
wcs_hdu.header["ILOC"] = img_location
hdul.append(wcs_hdu)

im_ebd_wcs = self._per_image_ebd_wcs[i]
ebd_hdu = fits.TableHDU()
append_wcs_to_hdu_header(im_ebd_wcs, ebd_hdu.header)
append_wcs_to_hdu_header(all_ebd_wcs[i], ebd_hdu.header)
ebd_hdu.name = f"EBD_{i}"
hdul.append(ebd_hdu)

Expand Down Expand Up @@ -797,7 +860,7 @@ def image_positions_to_original_icrs(
position_ebd_coords = radec_coords

helio_dist = self.heliocentric_distance
geo_dists = [self.geocentric_distances[i] for i in image_indices]
geo_dists = [self.org_img_meta["geocentric_distance"][i] for i in image_indices]
all_times = self.get_all_obstimes()
obstimes = [all_times[i] for i in image_indices]

Expand All @@ -824,7 +887,7 @@ def image_positions_to_original_icrs(
coord = inverted_coords[i]
pos = []
for j in inds:
con_image = self.constituent_images[j]
con_image = self.org_img_meta["data_loc"][j]
con_wcs = self._per_image_wcs[j]
height, width = con_wcs.array_shape
x, y = skycoord_to_pixel(coord, con_wcs)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_reprojection.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ def test_reproject(self):
assert reprojected_wunit.wcs != None
assert reprojected_wunit.im_stack.get_width() == 60
assert reprojected_wunit.im_stack.get_height() == 50
assert reprojected_wunit.geocentric_distances == self.test_wunit.geocentric_distances

test_dists = self.test_wunit.get_constituent_meta("geocentric_distance")
reproject_dists = reprojected_wunit.get_constituent_meta("geocentric_distance")
assert test_dists == reproject_dists

images = reprojected_wunit.im_stack.get_images()

# will be 3 as opposed to the four in the original `WorkUnit`,
Expand Down
Loading
Loading