Skip to content

Commit

Permalink
Merge pull request #118 from BaptistePellorceAstro/explode_along_axis
Browse files Browse the repository at this point in the history
Adding a new explode_along_axis function inside NDCube
  • Loading branch information
DanRyanIrish authored May 18, 2018
2 parents 2837f04 + 1d43c63 commit c72df57
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
35 changes: 35 additions & 0 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from astropy.utils.misc import InheritDocstrings

from ndcube import utils
from ndcube.ndcube_sequence import NDCubeSequence
from ndcube.utils.wcs import wcs_ivoa_mapping
from ndcube.mixins import NDCubeSlicingMixin, NDCubePlotMixin

Expand Down Expand Up @@ -526,6 +527,40 @@ def __repr__(self):
""".format(wcs=self.wcs.__repr__(), lengthNDCube=self.dimensions,
axis_type=self.world_axis_physical_types))

def explode_along_axis(self, axis):
"""
Separates slices of NDCubes along a given cube axis into a NDCubeSequence
of (N-1)DCubes.
Parameters
----------
axis : `int`
The axis along which the data is to be changed.
Returns
-------
result : `ndcube_sequence.NDCubeSequence`
"""
# If axis is -ve then calculate the axis from the length of the dimensions of one cube
if axis < 0:
axis = len(self.dimensions) + axis
# To store the resultant cube
result_cubes = []
# All slices are initially initialised as slice(None, None, None)
cube_slices = [slice(None, None, None)] * self.data.ndim
# Slicing the cube inside result_cube
for i in range(self.data.shape[axis]):
# Setting the slice value to the index so that the slices are done correctly.
cube_slices[axis] = i
# Set to None the metadata of sliced cubes.
item = tuple(cube_slices)
sliced_cube = self[item]
sliced_cube.meta = None
# Appending the sliced cubes in the result_cube list
result_cubes.append(sliced_cube)
# Creating a new NDCubeSequence with the result_cubes and common axis as axis
return NDCubeSequence(result_cubes, common_axis=axis, meta=self.meta)

class NDCube(NDCubeBase, NDCubePlotMixin, astropy.nddata.NDArithmeticMixin):
pass
Expand Down
2 changes: 0 additions & 2 deletions ndcube/ndcube_sequence.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import numpy as np

import astropy.units as u
import sunpy.map
from sunpy.map import MapCube

from ndcube import utils
from ndcube.mixins.sequence_plotting import NDCubeSequencePlotMixin
Expand Down
22 changes: 20 additions & 2 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
'''
Tests for NDCube
'''
from collections import namedtuple
from collections import OrderedDict
import datetime

import pytest
import sunpy.map
import numpy as np
import astropy.units as u

from ndcube import NDCube, NDCubeOrdered
from ndcube.utils.wcs import WCS, _wcs_slicer
from ndcube.tests import helpers
from ndcube.ndcube_sequence import NDCubeSequence

# sample data for tests
# TODO: use a fixture reading from a test file. file TBD.
Expand Down Expand Up @@ -910,3 +910,21 @@ def test_axis_world_coords_without_input(test_input, expected):
for i in range(len(all_coords)):
np.testing.assert_allclose(all_coords[i].value, expected[i].value)
assert all_coords[i].unit == expected[i].unit


@pytest.mark.parametrize("test_input,expected", [
((cubem, 0, 0), ((2*u.pix, 3*u.pix, 4*u.pix), NDCubeSequence, dict, NDCube, OrderedDict)),
((cubem, 1, 0), ((3*u.pix, 2*u.pix, 4*u.pix), NDCubeSequence, dict, NDCube, OrderedDict)),
((cubem, -2, 0), ((3*u.pix, 2*u.pix, 4*u.pix), NDCubeSequence, dict, NDCube, OrderedDict))
])
def test_explode_along_axis(test_input, expected):
inp_cube, inp_axis, inp_slice = test_input
exp_dimensions, exp_type_seq, exp_meta_seq, exp_type_cube, exp_meta_cube = expected
output = inp_cube.explode_along_axis(inp_axis)
assert tuple(output.dimensions) == tuple(exp_dimensions)
assert any(output[inp_slice].dimensions == \
u.Quantity((exp_dimensions[1], exp_dimensions[2]), unit='pix'))
assert isinstance(output, exp_type_seq)
assert isinstance(output[inp_slice], exp_type_cube)
assert isinstance(output.meta, exp_meta_seq)
assert isinstance(output[inp_slice].meta, exp_meta_cube)

0 comments on commit c72df57

Please sign in to comment.