Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #490: make rescale_imagehdu more robust against dimension mismatches #503

Merged
merged 2 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions scopesim/optics/image_plane_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,19 +487,22 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity,
primary_wcs = WCS(imagehdu.header, key=wcs_suffix[0])

# make sure that units are correct and zoom factor is positive
# The length of the zoom factor will be determined by imagehdu.data,
# which might differ from the dimension of primary_wcs. Here, pick
# the spatial dimensions only.
pixel_scale = pixel_scale << u.Unit(primary_wcs.wcs.cunit[0])
zoom = np.abs(primary_wcs.wcs.cdelt / pixel_scale.value)
zoom = np.abs(primary_wcs.wcs.cdelt[:2] / pixel_scale.value)

if len(imagehdu.data.shape) == 3:
zoom = np.append(zoom, [1.]) # wavelength dimension unscaled if present

logger.debug("zoom factor: %s", zoom)

if primary_wcs.naxis == 3:
# zoom = np.append(zoom, [1])
zoom[2] = 1.
if primary_wcs.naxis != imagehdu.data.ndim:
# FIXME: this happens often - shouldn't WCSs be trimmed down before? (OC)
logger.warning("imagehdu.data.ndim is %d, but primary_wcs.naxis with "
"key %s is %d, both should be equal.",
imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis)
zoom = zoom[:2]

logger.debug("zoom %s", zoom)
"key %s is %d, both should be equal.",
imagehdu.data.ndim, wcs_suffix, primary_wcs.naxis)

if all(zoom == 1.):
# Nothing to do
Expand All @@ -525,28 +528,30 @@ def rescale_imagehdu(imagehdu: fits.ImageHDU, pixel_scale: float | u.Quantity,
logger.warning("imagehdu.data.ndim is %d, but wcs.naxis with key "
"%s is %d, both should be equal.",
imagehdu.data.ndim, ww.wcs.alt, ww.naxis)
# TODO: could this be ww = ww.sub(2) instead? or .celestial?
# ww = WCS(imagehdu.header, key=key, naxis=imagehdu.data.ndim)

if any(ctype != "LINEAR" for ctype in ww.wcs.ctype):
logger.warning("Non-linear WCS rescaled using linear procedure.")

new_crpix = (zoom + 1) / 2 + (ww.wcs.crpix - 1) * zoom
#ew_crpix = np.round(new_crpix * 2) / 2 # round to nearest half-pixel
logger.debug("new crpix %s", new_crpix)
ww.wcs.crpix = new_crpix

# Keep CDELT3 if cube...
new_cdelt = ww.wcs.cdelt[:]
new_cdelt /= zoom
ww.wcs.cdelt = new_cdelt

# TODO: is forcing deg here really the best way?
# FIXME: NO THIS WILL MESS UP IF new_cdelt IS IN ARCSEC!!!!!
# new_cunit = [str(cunit) for cunit in ww.wcs.cunit]
# new_cunit[0] = "mm" if key == "D" else "deg"
# new_cunit[1] = "mm" if key == "D" else "deg"
# ww.wcs.cunit = new_cunit
# Assuming linearity, a given world coordinate is determined by
# VAL = CRVAL + (PIX - CRPIX ) * CDELT (old system)
# = CRVAL' + (PIX' - CRPIX') * CDELT' (new system)
# CDELT is simply transformed by the zoom factor:
# CDELT' = CDELT / ZOOM
# The transformation keeps CRVAL' = CRVAL, hence
# CRPIX' = PIX' - (PIX - CRPIX) * ZOOM
# The relation between PIX' and PIX is linear
# PIX' = CONST + ZOOM * PIX
# The fix point is PIX = PIX' = 1/2, which is the lower/left edge of the field,
# thus PIX' = (1 - ZOOM)/2 + ZOOM * PIX
# This leads to
# CRPIX' = 1/2 + (CRPIX - 1/2) * ZOOM
#
# The transformation only applies to spatial coordinates, which we assume to be
# the first two in the WCS.
ww.wcs.cdelt[:2] /= zoom[:2]
ww.wcs.crpix[:2] = 0.5 + (ww.wcs.crpix[:2] - 0.5) * zoom[:2]
#ww.wcs.crpix[:2] = (zoom[:2] + 1) / 2 + (ww.wcs.crpix[:2] - 1) * zoom[:2]
logger.debug("new crpix %s", ww.wcs.crpix)

imagehdu.header.update(ww.to_header())

Expand Down
27 changes: 27 additions & 0 deletions scopesim/tests/mocks/py_objects/imagehdu_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,30 @@ def _image_hdu_three_wcs():
hdu.header.update(wcs_g.to_header())

return hdu

def _image_hdu_3d_data():
nx, ny = 100, 100
nz = 3

# a 3D WCS
the_wcs0 = wcs.WCS(naxis=3, key="")
the_wcs0.wcs.ctype = ["LINEAR", "LINEAR", "WAVE"]
the_wcs0.wcs.cunit = ["arcsec", "arcsec", "um"]
the_wcs0.wcs.cdelt = [1, 1, 0.1]
the_wcs0.wcs.crval = [0, 0, 2.2]
the_wcs0.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2, 1]

# a 2D WCS for spatial dimensions
the_wcsd = wcs.WCS(naxis=2, key="D")
the_wcsd.wcs.ctype = ["LINEAR", "LINEAR"]
the_wcsd.wcs.cunit = ["mm", "mm"]
the_wcsd.wcs.cdelt = [1, 1]
the_wcsd.wcs.crval = [0, 0]
the_wcsd.wcs.crpix = [(nx + 1) / 2, (ny + 1) / 2]

image = np.ones((nz, ny, nx))
hdr = the_wcs0.to_header()
hdr.extend(the_wcsd.to_header())
hdu = fits.ImageHDU(data=image, header=hdr)

return hdu
72 changes: 49 additions & 23 deletions scopesim/tests/tests_optics/test_ImagePlane.py
Original file line number Diff line number Diff line change
@@ -1,51 +1,64 @@
"""Tests for ImagePlane and some ImagePlaneUtils"""

# pylint: disable=missing-class-docstring
# pylint: disable=missing-function-docstring

from copy import deepcopy

import pytest
from pytest import approx
from copy import deepcopy

import numpy as np
from astropy.io import fits
from astropy import units as u
from astropy.table import Table
from astropy import wcs

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

import scopesim.optics.image_plane as opt_imp
import scopesim.optics.image_plane_utils as imp_utils

from scopesim.tests.mocks.py_objects.imagehdu_objects import \
_image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

_image_hdu_square, _image_hdu_rect, _image_hdu_three_wcs,\
_image_hdu_3d_data

PLOTS = False


@pytest.fixture(scope="function")
def image_hdu_rect():
@pytest.fixture(scope="function", name="image_hdu_rect")
def fixture_image_hdu_rect():
return _image_hdu_rect()


@pytest.fixture(scope="function")
def image_hdu_rect_mm():
@pytest.fixture(scope="function", name="image_hdu_rect_mm")
def fixture_image_hdu_rect_mm():
return _image_hdu_rect("D")


@pytest.fixture(scope="function")
def image_hdu_square():
@pytest.fixture(scope="function", name="image_hdu_square")
def fixture_image_hdu_square():
return _image_hdu_square()


@pytest.fixture(scope="function")
def image_hdu_square_mm():
@pytest.fixture(scope="function", name="image_hdu_square_mm")
def fixture_image_hdu_square_mm():
return _image_hdu_square("D")

@pytest.fixture(scope="function")
def image_hdu_three_wcs():

@pytest.fixture(scope="function", name="image_hdu_three_wcs")
def fixture_image_hdu_three_wcs():
return _image_hdu_three_wcs()

@pytest.fixture(scope="function")
def input_table():

@pytest.fixture(scope="function", name="image_hdu_3d_data")
def fixture_image_hdu_3d_data():
return _image_hdu_3d_data()


@pytest.fixture(scope="function", name="input_table")
def fixture_input_table():
x = [-10, -10, 0, 10, 10] * u.arcsec
y = [-10, 10, 0, -10, 10] * u.arcsec
f = [1, 3, 1, 1, 5]
Expand All @@ -54,8 +67,8 @@ def input_table():
return tbl


@pytest.fixture(scope="function")
def input_table_mm():
@pytest.fixture(scope="function", name="input_table_mm")
def fixture_input_table_mm():
x = [-10, -10, 0, 10, 10] * u.mm
y = [-10, 10, 0, -10, 10] * u.mm
f = [1, 3, 1, 1, 5]
Expand Down Expand Up @@ -312,7 +325,7 @@ def test_points_are_added_to_small_canvas(self, input_table):
assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"])

if PLOTS:
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand All @@ -328,7 +341,7 @@ def test_mm_points_are_added_to_small_canvas(self, input_table_mm):
assert np.sum(canvas_hdu.data) == np.sum(tbl1["flux"])

if PLOTS:
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand Down Expand Up @@ -387,7 +400,7 @@ def test_mm_points_are_added_to_massive_canvas(self, input_table_mm):
if PLOTS:
x, y = imp_utils.val2pix(hdr, 0, 0, "D")
plt.plot(x, y, "ro")
"top left is green, top right is yellow"
# "top left is green, top right is yellow"
plt.imshow(canvas_hdu.data, origin="lower")
plt.show()

Expand Down Expand Up @@ -701,6 +714,19 @@ def test_rescale_works_on_nondefault_wcs(self, image_hdu_three_wcs):
assert new_hdu.header['CDELT1D'] == 20


def test_rescale_works_on_3d_imageplane(self, image_hdu_3d_data):
pixel_scale = 0.274
wcses = wcs.find_all_wcs(image_hdu_3d_data.header)
fact = pixel_scale / wcses[0].wcs.cdelt[0]

new_hdu = imp_utils.rescale_imagehdu(image_hdu_3d_data, pixel_scale)
new_wcses = wcs.find_all_wcs(new_hdu.header)

assert new_wcses[0].wcs.cdelt[0] == pixel_scale
assert new_wcses[0].wcs.cdelt[2] == wcses[0].wcs.cdelt[2]
assert new_wcses[1].wcs.cdelt[1] / fact == approx(wcses[1].wcs.cdelt[1])


###############################################################################
# ..todo: When you have time, reintegrate these tests, There are some good ones

Expand Down
Loading