Skip to content

Commit

Permalink
Merge pull request #351 from keflavich/memmaped_coadd
Browse files Browse the repository at this point in the history
Generalize reproject_and_coadd for N-dimensional data, and add option to specify blank pixel value and progress bar
  • Loading branch information
astrofrog authored Jun 5, 2024
2 parents 9ba352c + f5e3016 commit 8ab8a78
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 125 deletions.
21 changes: 20 additions & 1 deletion reproject/array_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

__all__ = ["map_coordinates"]
__all__ = ["map_coordinates", "sample_array_edges"]


def map_coordinates(image, coords, **kwargs):
Expand Down Expand Up @@ -35,3 +35,22 @@ def map_coordinates(image, coords, **kwargs):
values[reset] = kwargs.get("cval", 0.0)

return values


def sample_array_edges(shape, *, n_samples):
# Given an N-dimensional array shape, sample each edge of the array using
# the requested number of samples (which will include vertices). To do this
# we iterate through the dimensions and for each one we sample the points
# in that dimension and iterate over the combination of other vertices.
# Returns an array with dimensions (N, n_samples)
all_positions = []
ndim = len(shape)
shape = np.array(shape)
for idim in range(ndim):
for vertex in range(2**ndim):
positions = -0.5 + shape * ((vertex & (2 ** np.arange(ndim))) > 0).astype(int)
positions = np.broadcast_to(positions, (n_samples, ndim)).copy()
positions[:, idim] = np.linspace(-0.5, shape[idim] - 0.5, n_samples)
all_positions.append(positions)
positions = np.unique(np.vstack(all_positions), axis=0).T
return positions
166 changes: 100 additions & 66 deletions reproject/mosaicking/coadd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
from astropy.wcs import WCS
from astropy.wcs.wcsapi import SlicedLowLevelWCS

from ..array_utils import sample_array_edges
from ..utils import parse_input_data, parse_input_weights, parse_output_projection
from .background import determine_offset_matrix, solve_corrections_sgd
from .subset_array import ReprojectedArraySubset

__all__ = ["reproject_and_coadd"]


def _noop(iterable):
return iterable


def reproject_and_coadd(
input_data,
output_projection,
Expand All @@ -24,14 +29,15 @@ def reproject_and_coadd(
background_reference=None,
output_array=None,
output_footprint=None,
block_sizes=None,
progress_bar=None,
blank_pixel_value=0,
**kwargs,
):
"""
Given a set of input images, reproject and co-add these to a single
Given a set of input data, reproject and co-add these to a single
final image.
This currently only works with 2-d images with celestial WCS.
Parameters
----------
input_data : iterable
Expand Down Expand Up @@ -77,7 +83,7 @@ def reproject_and_coadd(
`~astropy.io.fits.HDUList` instance, specifies the HDU to use.
reproject_function : callable
The function to use for the reprojection.
combine_function : { 'mean', 'sum', 'median', 'first', 'last', 'min', 'max' }
combine_function : { 'mean', 'sum', 'first', 'last', 'min', 'max' }
The type of function to use for combining the values into the final
image. For 'first' and 'last', respectively, the reprojected images are
simply overlaid on top of each other. With respect to the order of the
Expand All @@ -92,11 +98,22 @@ def reproject_and_coadd(
output_array : array or None
The final output array. Specify this if you already have an
appropriately-shaped array to store the data in. Must match shape
specified with ``shape_out`` or derived from the output projection.
specified with `shape_out` or derived from the output
projection.
output_footprint : array or None
The final output footprint array. Specify this if you already have an
appropriately-shaped array to store the data in. Must match shape
specified with ``shape_out`` or derived from the output projection.
specified with `shape_out` or derived from the output projection.
block_sizes : list of tuples or None
The block size to use for each dataset. Could also be a single tuple
if you want the sample block size for all data sets.
progress_bar : callable, optional
If specified, use this as a progress_bar to track loop iterations over
data sets.
blank_pixel_value : float, optional
Value to use for areas of the resulting mosaic that do not have input
data.
**kwargs
Keyword arguments to be passed to the reprojection function.
Expand All @@ -116,34 +133,49 @@ def reproject_and_coadd(

# Validate inputs

if combine_function not in ("mean", "sum", "median", "first", "last", "min", "max"):
raise ValueError("combine_function should be one of mean/sum/median/first/last/min/max")
if combine_function not in ("mean", "sum", "first", "last", "min", "max"):
raise ValueError("combine_function should be one of mean/sum/first/last/min/max")

if reproject_function is None:
raise ValueError(
"reprojection function should be specified with the reproject_function argument"
)

if progress_bar is None:
progress_bar = _noop

# Parse the output projection to avoid having to do it for each

wcs_out, shape_out = parse_output_projection(output_projection, shape_out=shape_out)

if output_array is not None and output_array.shape != shape_out:
if output_array is None:
output_array = np.zeros(shape_out)
elif output_array.shape != shape_out:
raise ValueError(
"If you specify an output array, it must have a shape matching "
f"the output shape {shape_out}"
)
if output_footprint is not None and output_footprint.shape != shape_out:

if output_footprint is None:
output_footprint = np.zeros(shape_out)
elif output_footprint.shape != shape_out:
raise ValueError(
"If you specify an output footprint array, it must have a shape matching "
f"the output shape {shape_out}"
)

# Define 'on-the-fly' mode: in the case where we don't need to match
# the backgrounds and we are combining with 'mean' or 'sum', we don't
# have to keep track of the intermediate arrays and can just modify
# the output array on-the-fly
on_the_fly = not match_background and combine_function in ("mean", "sum")

# Start off by reprojecting individual images to the final projection

arrays = []
if not on_the_fly:
arrays = []

for idata in range(len(input_data)):
for idata in progress_bar(range(len(input_data))):
# We need to pre-parse the data here since we need to figure out how to
# optimize/minimize the size of each output tile (see below).
array_in, wcs_in = parse_input_data(input_data[idata], hdu_in=hdu_in)
Expand All @@ -166,42 +198,48 @@ def reproject_and_coadd(
# significant distortion (when the edges of the input image become
# convex in the output projection), and transforming every edge pixel,
# which provides a lot of redundant information.
ny, nx = array_in.shape
n_per_edge = 11
xs = np.linspace(-0.5, nx - 0.5, n_per_edge)
ys = np.linspace(-0.5, ny - 0.5, n_per_edge)
xs = np.concatenate((xs, np.full(n_per_edge, xs[-1]), xs, np.full(n_per_edge, xs[0])))
ys = np.concatenate((np.full(n_per_edge, ys[0]), ys, np.full(n_per_edge, ys[-1]), ys))
xc_out, yc_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(xs, ys))

edges = sample_array_edges(array_in.shape, n_samples=11)[::-1]
edges_out = wcs_out.world_to_pixel(wcs_in.pixel_to_world(*edges))[::-1]

# Determine the cutout parameters

# In some cases, images might not have valid coordinates in the corners,
# such as all-sky images or full solar disk views. In this case we skip
# this step and just use the full output WCS for reprojection.

if np.any(np.isnan(xc_out)) or np.any(np.isnan(yc_out)):
imin = 0
imax = shape_out[1]
jmin = 0
jmax = shape_out[0]
else:
imin = max(0, int(np.floor(xc_out.min() + 0.5)))
imax = min(shape_out[1], int(np.ceil(xc_out.max() + 0.5)))
jmin = max(0, int(np.floor(yc_out.min() + 0.5)))
jmax = min(shape_out[0], int(np.ceil(yc_out.max() + 0.5)))
ndim_out = len(shape_out)

if imax < imin or jmax < jmin:
skip_data = False
if np.any(np.isnan(edges_out)):
bounds = list(zip([0] * ndim_out, shape_out))
else:
bounds = []
for idim in range(ndim_out):
imin = max(0, int(np.floor(edges_out[idim].min() + 0.5)))
imax = min(shape_out[idim], int(np.ceil(edges_out[idim].max() + 0.5)))
bounds.append((imin, imax))
if imax < imin:
skip_data = True
break

if skip_data:
continue

slice_out = tuple([slice(imin, imax) for (imin, imax) in bounds])

if isinstance(wcs_out, WCS):
wcs_out_indiv = wcs_out[jmin:jmax, imin:imax]
wcs_out_indiv = wcs_out[slice_out]
else:
wcs_out_indiv = SlicedLowLevelWCS(
wcs_out.low_level_wcs, (slice(jmin, jmax), slice(imin, imax))
)
wcs_out_indiv = SlicedLowLevelWCS(wcs_out.low_level_wcs, slice_out)

shape_out_indiv = (jmax - jmin, imax - imin)
shape_out_indiv = [imax - imin for (imin, imax) in bounds]

if block_sizes is not None:
if len(block_sizes) == len(input_data) and len(block_sizes[idata]) == len(shape_out):
kwargs["block_size"] = block_sizes[idata]
else:
kwargs["block_size"] = block_sizes

# TODO: optimize handling of weights by making reprojection functions
# able to handle weights, and make the footprint become the combined
Expand Down Expand Up @@ -235,12 +273,20 @@ def reproject_and_coadd(
weights[reset] = 0.0
footprint *= weights

array = ReprojectedArraySubset(array, footprint, imin, imax, jmin, jmax)
array = ReprojectedArraySubset(array, footprint, bounds)

# TODO: make sure we gracefully handle the case where the
# output image is empty (due e.g. to no overlap).

arrays.append(array)
if on_the_fly:
# By default, values outside of the footprint are set to NaN
# but we set these to 0 here to avoid getting NaNs in the
# means/sums.
array.array[array.footprint == 0] = 0
output_array[array.view_in_original_array] += array.array * array.footprint
output_footprint[array.view_in_original_array] += array.footprint
else:
arrays.append(array)

# If requested, try and match the backgrounds.
if match_background and len(arrays) > 1:
Expand All @@ -251,37 +297,32 @@ def reproject_and_coadd(
for array, correction in zip(arrays, corrections, strict=True):
array.array -= correction

# At this point, the images are now ready to be co-added.

if output_array is None:
output_array = np.zeros(shape_out)
if output_footprint is None:
output_footprint = np.zeros(shape_out)

if combine_function == "min":
output_array[...] = np.inf
elif combine_function == "max":
output_array[...] = -np.inf

if combine_function in ("mean", "sum"):
for array in arrays:
# By default, values outside of the footprint are set to NaN
# but we set these to 0 here to avoid getting NaNs in the
# means/sums.
array.array[array.footprint == 0] = 0
if match_background:
# if we're not matching the background, this part has already been done
for array in arrays:
# By default, values outside of the footprint are set to NaN
# but we set these to 0 here to avoid getting NaNs in the
# means/sums.
array.array[array.footprint == 0] = 0

output_array[array.view_in_original_array] += array.array * array.footprint
output_footprint[array.view_in_original_array] += array.footprint
output_array[array.view_in_original_array] += array.array * array.footprint
output_footprint[array.view_in_original_array] += array.footprint

if combine_function == "mean":
with np.errstate(invalid="ignore"):
output_array /= output_footprint
output_array[output_footprint == 0] = 0
output_array[output_footprint == 0] = blank_pixel_value

elif combine_function in ("first", "last", "min", "max"):
if combine_function == "min":
output_array[...] = np.inf
elif combine_function == "max":
output_array[...] = -np.inf

for array in arrays:
if combine_function == "first":
mask = (output_footprint[array.view_in_original_array] == 0) & (array.footprint > 0)
mask = output_footprint[array.view_in_original_array] == 0
elif combine_function == "last":
mask = array.footprint > 0
elif combine_function == "min":
Expand All @@ -300,13 +341,6 @@ def reproject_and_coadd(
mask, array.array, output_array[array.view_in_original_array]
)

elif combine_function == "median":
# Here we need to operate in chunks since we could otherwise run
# into memory issues

raise NotImplementedError("combine_function='median' is not yet implemented")

if combine_function in ("min", "max"):
output_array[output_footprint == 0] = 0.0
output_array[output_footprint == 0] = blank_pixel_value

return output_array, output_footprint
Loading

0 comments on commit 8ab8a78

Please sign in to comment.