diff --git a/spectral_cube/lower_dimensional_structures.py b/spectral_cube/lower_dimensional_structures.py index e7a077919..92d2906bd 100644 --- a/spectral_cube/lower_dimensional_structures.py +++ b/spectral_cube/lower_dimensional_structures.py @@ -8,7 +8,7 @@ from astropy import convolution from astropy import units as u from astropy import wcs -#from astropy import log +from astropy import log from astropy.io.fits import Header, Card, HDUList, PrimaryHDU from .io.core import determine_format @@ -41,10 +41,21 @@ def header(self): for keyword in header: if 'NAXIS' in keyword: del header[keyword] - header.insert(2, Card(keyword='NAXIS', value=self.ndim)) + + header.insert(3, Card(keyword='NAXIS', value=self.ndim)) for ind,sh in enumerate(self.shape[::-1]): - header.insert(3+ind, Card(keyword='NAXIS{0:1d}'.format(ind+1), - value=sh)) + log.debug('Adding NAXIS{0} at position {1}'.format(ind+1, + 3+ind+1)) + header.insert(3+ind+1, Card(keyword='NAXIS{0:1d}'.format(ind+1), + value=sh)) + if self.wcs.naxis > self.ndim: + assert ind != 0 + for ii in range(self.wcs.naxis - self.ndim): + log.debug('Adding NAXIS{0} at position {1}'.format(ind+ii+2, + 3+ind+ii+2)) + header.insert(3+ind+ii+2, + Card(keyword='NAXIS{0:1d}'.format(ind+ii+2), + value=1)) if 'beam' in self.meta: header.update(self.meta['beam'].to_header_keywords()) @@ -210,9 +221,6 @@ def __new__(cls, value, unit=None, dtype=None, copy=True, wcs=None, if np.asarray(value).ndim != 2: raise ValueError("value should be a 2-d array") - if wcs is not None and wcs.wcs.naxis != 2: - raise ValueError("wcs should have two dimension") - self = u.Quantity.__new__(cls, value, unit=unit, dtype=dtype, copy=copy).view(cls) self._wcs = wcs @@ -466,9 +474,6 @@ def __new__(cls, value, unit=None, dtype=None, copy=True, wcs=None, if np.asarray(value).ndim != 1: raise ValueError("value should be a 1-d array") - if wcs is not None and wcs.wcs.naxis != 1: - raise ValueError("wcs should have two dimension") - self = u.Quantity.__new__(cls, value, unit=unit, dtype=dtype, copy=copy).view(cls) self._wcs = wcs @@ -524,16 +529,8 @@ def __repr__(self): @property def header(self): - header = self._header - # This inplace update is OK; it's not bad to overwrite WCS in this - # header - if self.wcs is not None: - header.update(self.wcs.to_header()) - header['BUNIT'] = self.unit.to_string(format='fits') - header.insert(2, Card(keyword='NAXIS', value=self.ndim)) - for ind,sh in enumerate(self.shape[::-1]): - header.insert(3+ind, Card(keyword='NAXIS{0:1d}'.format(ind+1), - value=sh)) + # use the generic LowerDimensionalObject header, but more + header = super(OneDSpectrum, self).header # Preserve the spectrum's spectral units if 'CUNIT1' in header and self._spectral_unit != u.Unit(header['CUNIT1']): diff --git a/spectral_cube/spectral_cube.py b/spectral_cube/spectral_cube.py index 7c5420a42..030dfac29 100644 --- a/spectral_cube/spectral_cube.py +++ b/spectral_cube/spectral_cube.py @@ -268,6 +268,8 @@ def apply_numpy_function(self, function, fill=np.nan, unit=None, check_endian=False, progressbar=False, + dropped_axis_slice_position='middle', + dropped_axis_cdelt='same', **kwargs): """ Apply a numpy function to the cube @@ -306,6 +308,15 @@ def apply_numpy_function(self, function, fill=np.nan, progressbar : bool Show a progressbar while iterating over the slices through the cube? + dropped_axis_slice_position : 'middle', 'start', 'end' + If an axis is being dropped, where should the WCS say the + projection is? It can be at the start, middle, or end of the + axis. + dropped_axis_cdelt : 'same', 'full_range', or value + If an axis is being dropped, what should the new CDELT be? For an + integral, for example, one might want the value to be the full + range. For a slice, it should stay the same. For something like + min or max, it might be zero. kwargs : dict Passed to the numpy function. @@ -375,7 +386,14 @@ def apply_numpy_function(self, function, fill=np.nan, return out else: - new_wcs = wcs_utils.drop_axis(self._wcs, np2wcs[axis]) + + new_wcs = wcs_utils.drop_axis_by_slicing(self._wcs, + self.shape, + axis, + dropped_axis_cdelt=dropped_axis_cdelt, + dropped_axis_slice_position=dropped_axis_slice_position, + ) + header = self._nowcs_header return Projection(out, copy=False, wcs=new_wcs, meta=meta, @@ -472,7 +490,10 @@ def mean(self, axis=None, how='cube'): projection=False) out = ttl / counts if projection: - new_wcs = wcs_utils.drop_axis(self._wcs, np2wcs[axis]) + new_wcs = wcs_utils.drop_axis_by_slicing(self.wcs, self.shape, + dropped_axis=axis, + dropped_axis_slice_position='middle', + dropped_axis_cdelt='full_range') meta = {'collapse_axis': axis} meta.update(self._meta) return Projection(out, copy=False, wcs=new_wcs, @@ -540,7 +561,10 @@ def std(self, axis=None, how='cube', ddof=0): out = (result/(counts-ddof))**0.5 if projection: - new_wcs = wcs_utils.drop_axis(self._wcs, np2wcs[axis]) + new_wcs = wcs_utils.drop_axis_by_slicing(self.wcs, self.shape, + dropped_axis=axis, + dropped_axis_slice_position='middle', + dropped_axis_cdelt='full_range') meta = {'collapse_axis': axis} meta.update(self._meta) return Projection(out, copy=False, wcs=new_wcs, @@ -712,7 +736,10 @@ def _cube_on_cube_operation(self, function, cube, equivalencies=[], **kwargs): return self._new_cube_with(data=data, unit=unit) def apply_function(self, function, axis=None, weights=None, unit=None, - projection=False, progressbar=False, **kwargs): + projection=False, progressbar=False, + dropped_axis_slice_position='middle', + dropped_axis_cdelt='same', + **kwargs): """ Apply a function to valid data along the specified axis or to the whole cube, optionally using a weight array that is the same shape (or at @@ -736,6 +763,15 @@ def apply_function(self, function, axis=None, weights=None, unit=None, progressbar : bool Show a progressbar while iterating over the slices/rays through the cube? + dropped_axis_slice_position : 'middle', 'start', 'end' + If an axis is being dropped, where should the WCS say the + projection is? It can be at the start, middle, or end of the + axis. + dropped_axis_cdelt : 'same', 'full_range', or value + If an axis is being dropped, what should the new CDELT be? For an + integral, for example, one might want the value to be the full + range. For a slice, it should stay the same. For something like + min or max, it might be zero. Returns ------- @@ -778,7 +814,10 @@ def apply_function(self, function, axis=None, weights=None, unit=None, pbu() if projection and axis in (0,1,2): - new_wcs = wcs_utils.drop_axis(self._wcs, np2wcs[axis]) + new_wcs = wcs_utils.drop_axis_by_slicing(self.wcs, self.shape, + dropped_axis=axis, + dropped_axis_slice_position=dropped_axis_slice_position, + dropped_axis_cdelt=dropped_axis_cdelt) meta = {'collapse_axis': axis} meta.update(self._meta) @@ -1023,7 +1062,10 @@ def __getitem__(self, view): ) # only one element, so drop an axis - newwcs = wcs_utils.drop_axis(self._wcs, intslices[0]) + newwcs = wcs_utils.drop_axis_by_slicing(self.wcs, self.shape, + dropped_axis=intslices[0], + dropped_axis_slice_position=view[2-intslices[0]], + dropped_axis_cdelt='same') header = self._nowcs_header if intslices[0] == 0: @@ -1373,7 +1415,10 @@ def moment(self, order=0, axis=0, how='auto'): if order == 1 and axis == 0: out += self.world[0, :, :][0] - new_wcs = wcs_utils.drop_axis(self._wcs, np2wcs[axis]) + new_wcs = wcs_utils.drop_axis_by_slicing(self.wcs, self.shape, + dropped_axis=axis, + dropped_axis_slice_position='middle', + dropped_axis_cdelt='full_range') meta = {'moment_order': order, 'moment_axis': axis, @@ -2664,7 +2709,10 @@ def __getitem__(self, view): meta=meta) # only one element, so drop an axis - newwcs = wcs_utils.drop_axis(self._wcs, intslices[0]) + newwcs = wcs_utils.drop_axis_by_slicing(self.wcs, self.shape, + dropped_axis=intslices[0], + dropped_axis_slice_position='middle', + dropped_axis_cdelt='full_range') header = self._nowcs_header # Slice objects know how to parse Beam objects stored in the diff --git a/spectral_cube/tests/test_projection.py b/spectral_cube/tests/test_projection.py index 648359f43..bae8cb73f 100644 --- a/spectral_cube/tests/test_projection.py +++ b/spectral_cube/tests/test_projection.py @@ -511,3 +511,10 @@ def test_1d_slice_round(): assert 'OneDSpectrum' in sp.round().__repr__() assert 'OneDSpectrum' in sp[1:-1].round().__repr__() + +def test_slice_wcs(): + + cube, data = cube_and_raw('255_delta.fits') + + mom0 = cube.moment0(axis=0) + assert mom0.header['NAXIS3'] == 1 diff --git a/spectral_cube/tests/test_visualization.py b/spectral_cube/tests/test_visualization.py index b13790af4..833bd69e9 100644 --- a/spectral_cube/tests/test_visualization.py +++ b/spectral_cube/tests/test_visualization.py @@ -38,7 +38,7 @@ def test_to_pvextractor(): @pytest.mark.skipif("not MATPLOTLIB_INSTALLED") -def test_projvis(): +def test_projvis_noaplpy(): cube, data = cube_and_raw('vda_Jybeam_lower.fits') diff --git a/spectral_cube/tests/test_wcs_utils.py b/spectral_cube/tests/test_wcs_utils.py index 241a560e2..43279a46a 100644 --- a/spectral_cube/tests/test_wcs_utils.py +++ b/spectral_cube/tests/test_wcs_utils.py @@ -2,6 +2,8 @@ from astropy.io import fits +import pytest + from ..wcs_utils import * from . import path @@ -128,3 +130,36 @@ def test_strip_wcs(): header2_stripped = strip_wcs_from_header(header2) assert header1_stripped == header2_stripped + +@pytest.mark.parametrize(('position', 'result'), + (('start', 0.), + ('middle', 5e-5), + ('end', 10e-5))) +def test_drop_by_slice(position, result): + + wcs = WCS(naxis=3) + wcs.wcs.crpix = [1., 1., 1.] + wcs.wcs.crval = [0., 0., 0.] + wcs.wcs.cdelt = [1e-5, 2e-5, 3e-5] + wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN', 'FREQ'] + + newwcs = drop_axis_by_slicing(wcs, shape=[10,12,14], dropped_axis=0, + dropped_axis_slice_position=position) + + # drop-by-slicing moves axis to be last + np.testing.assert_almost_equal(newwcs.wcs.crval[2], result) + assert all(newwcs.wcs.cdelt == [2e-5,3e-5,1e-5]) + +def test_drop_by_slice_middle_fullrange(): + + wcs = WCS(naxis=3) + wcs.wcs.crpix = [1., 1., 1.] + wcs.wcs.crval = [0., 0., 0.] + wcs.wcs.cdelt = [1e-5, 1e-5, 1e-5] + wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN', 'FREQ'] + + newwcs = drop_axis_by_slicing(wcs, shape=[10,12,14], dropped_axis=0, + dropped_axis_cdelt='full_range') + + np.testing.assert_almost_equal(newwcs.wcs.crval[2], 5e-5) + np.testing.assert_almost_equal(newwcs.wcs.cdelt[2], 10e-5) diff --git a/spectral_cube/wcs_utils.py b/spectral_cube/wcs_utils.py index 2b780119d..fd3f6fa2b 100644 --- a/spectral_cube/wcs_utils.py +++ b/spectral_cube/wcs_utils.py @@ -1,7 +1,7 @@ from __future__ import print_function, absolute_import, division import numpy as np -from astropy.wcs import WCS +from astropy.wcs import WCS, InconsistentAxisTypesError import warnings from astropy import units as u from astropy import log @@ -18,6 +18,53 @@ 'WAVELENG':'WAVE', } +class WCSWrapper(WCS): + """ + Wrapper of WCS to deal with some of the special cases we face within + spectral_cube + """ + @staticmethod + def from_wcs(otherwcs): + """ + Create a WCSWrapper class from another WCS object + """ + new_wcs = WCSWrapper() + + new_wcs.wcs = otherwcs.wcs + new_wcs.naxis = otherwcs.naxis + + assert new_wcs.wcs.naxis == otherwcs.wcs.naxis == otherwcs.naxis == new_wcs.naxis + + return new_wcs + + @property + def has_celestial(self): + if hasattr(self, '_has_celestial'): + return self._has_celestial + try: + return self.celestial.naxis == 2 + except InconsistentAxisTypesError: + return False + + @has_celestial.setter + def has_celestial(self, value): + if value is not False: + warnings.warn("_has_celestial is being set to {0}, " + "which may not be what you want." + .format(value)) + self._has_celestial = value + + def wcs_pix2world(self, pixels, reference): + if ((pixels.shape[1] < self.naxis and + hasattr(self, 'active_dimensions') and + len(self.active_dimensions) < self.naxis)): + pixels = np.asarray(pixels) + pixels = np.c_[pixels, np.zeros(pixels.shape[0])] + result = super(WCSWrapper, self).wcs_pix2world(pixels, reference) + return result[:, :len(self.active_dimensions) - self.naxis] + else: + return super(WCSWrapper, self).wcs_pix2world(pixels, reference) + def drop_axis(wcs, dropax): """ Drop the ax on axis dropax @@ -150,6 +197,9 @@ def reindex_wcs(wcs, inds): ps_cards.append((i, m, v)) outwcs.wcs.set_ps(ps_cards) + assert outwcs.naxis == len(inds) + assert outwcs.wcs.naxis == len(inds) + return outwcs @@ -420,3 +470,83 @@ def diagonal_wcs_to_cdelt(mywcs): del mywcs.wcs.cd mywcs.wcs.cdelt = cdelt return mywcs + +def drop_axis_by_slicing(mywcs, shape, dropped_axis, + dropped_axis_slice_position='middle', + dropped_axis_cdelt='same', + convert_misaligned_to_offset=True, + ): + """ + Parameters + ---------- + dropped_axis_slice_position : 'middle', 'start', 'end' + If an axis is being dropped, where should the WCS say the + projection is? It can be at the start, middle, or end of the + axis. + dropped_axis_cdelt : 'same', 'full_range', or value + If an axis is being dropped, what should the new CDELT be? For an + integral, for example, one might want the value to be the full + range. For a slice, it should stay the same. For something like + min or max, it might be zero. + convert_misaligned_to_offset : bool + If the axes are misaligned, it is not possible to "drop" an axis. + In this case, a generic "offset axis" will be returned. + """ + log.debug("Dropping axis by slicing with args: {0}, {1}, {2}, {3}, {4}" + .format(mywcs, shape, dropped_axis, dropped_axis_slice_position, + dropped_axis_cdelt)) + ndim = len(shape) + + if mywcs.get_axis_types()[dropped_axis]['coordinate_type'] == 'celestial': + dropping_celestial = True + else: + dropping_celestial = False + + if dropped_axis_slice_position == 'middle': + dropped_axis_slice_position = shape[dropped_axis]//2 + elif dropped_axis_slice_position == 'start': + dropped_axis_slice_position = 0 + elif dropped_axis_slice_position == 'end': + dropped_axis_slice_position = shape[dropped_axis] + + dropax_slice = slice(dropped_axis_slice_position, + dropped_axis_slice_position+1) + + view = [slice(None) if ax!=dropped_axis else dropax_slice + for ax in range(ndim)] + + crpix_new = [0 if ax!=dropped_axis else dropped_axis_slice_position + for ax in range(ndim)] + new_crval = mywcs.wcs_pix2world([crpix_new], 0)[0, dropped_axis] + + result = slice_wcs(mywcs, view, shape=shape) + + result.wcs.crval[dropped_axis] = new_crval + result.wcs.crpix[dropped_axis] = 1 + + if dropped_axis_cdelt == 'same': + dropped_axis_cdelt = mywcs.wcs.cdelt[dropped_axis] + elif dropped_axis_cdelt == 'full_range': + ref_pixels = np.array( + [(0,0) if ax!=dropped_axis else (0, shape[dropped_axis]) + for ax in range(ndim)]) + refvals = mywcs.wcs_pix2world(ref_pixels.T, 0)[:, dropped_axis] + dropped_axis_cdelt = refvals[1]-refvals[0] + + result.wcs.cdelt[dropped_axis] = dropped_axis_cdelt + + new_inds = np.array([ii for ii in range(ndim) if ii != dropped_axis] + + [dropped_axis]) + result = reindex_wcs(result, new_inds) + + if dropping_celestial: + new_result = WCSWrapper.from_wcs(result) + new_result.has_celestial = False + new_result.active_dimensions = list(range(ndim-1)) + assert new_result.naxis == new_result.wcs.naxis + result = new_result + + assert result.naxis == result.wcs.naxis + assert result.naxis == mywcs.naxis + + return result