Skip to content

Commit

Permalink
fix reproject with scalars
Browse files Browse the repository at this point in the history
  • Loading branch information
Nadia committed Jun 13, 2024
1 parent 8df2c80 commit 315da2d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
32 changes: 22 additions & 10 deletions src/stcal/alignment/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _calculate_new_wcs(ref_model, shape, wcs_list, fiducial, crpix=None, transfo
----------
ref_model :
The reference model to be used when extracting metadata.
bp
shape : list
The shape of the new WCS's pixel grid. If `None`, then the output bounding box
will be used to determine it.
Expand Down Expand Up @@ -785,18 +785,30 @@ def _reproject(x: float | np.ndarray, y: float | np.ndarray) -> tuple:
tuple
Tuple of np.ndarrays including reprojected x and y coordinates.
"""
sky = forward_transform(x, y)
flat_sky = []
for axis in sky:
flat_sky.append(axis.flatten())
# sky = forward_transform(x, y)
# flat_sky = []
# for axis in sky:
# flat_sky.append(axis.flatten())
# # Filter out RuntimeWarnings due to computed NaNs in the WCS
# with warnings.catch_warnings():
# warnings.simplefilter("ignore", RuntimeWarning)
# det = backward_transform(*tuple(flat_sky))
# det_reshaped = []
# for axis in det:
# det_reshaped.append(axis.reshape(x.shape))

# return tuple(det_reshaped)
shape = np.array(x).shape
sky = forward_transform(x, y)
flat_sky = [axis.flatten() for axis in sky]

# Filter out RuntimeWarnings due to computed NaNs in the WCS
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
det = backward_transform(*tuple(flat_sky))
det_reshaped = []
for axis in det:
det_reshaped.append(axis.reshape(x.shape))
detector = backward_transform(*tuple(flat_sky))

return tuple(det_reshaped)
if shape == ():
return tuple([axis.item() for axis in detector])
return tuple([axis.reshape(shape) for axis in detector])

return _reproject
2 changes: 1 addition & 1 deletion tests/test_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def get_fake_wcs():
[
(1000, 2000, np.array(2000), np.array(4000)), # string input test
([1000], [2000], np.array(2000), np.array(4000)), # array input test
pytest.param(1, 2, 3, 4, marks=pytest.mark.xfail), # expected failure test
(1, 2, 2, 4),
],
)
def test_reproject(x_inp, y_inp, x_expected, y_expected):
Expand Down

0 comments on commit 315da2d

Please sign in to comment.