diff --git a/ndcube/mixins/plotting.py b/ndcube/mixins/plotting.py index e75d0879b..2470eed1f 100644 --- a/ndcube/mixins/plotting.py +++ b/ndcube/mixins/plotting.py @@ -1,3 +1,5 @@ +from warnings import warn + import numpy as np import matplotlib.pyplot as plt @@ -13,8 +15,8 @@ class NDCubePlotMixin: Add plotting functionality to a NDCube class. """ - def plot(self, axes=None, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=None, - axis_ranges=None, unit=None, origin=0, **kwargs): + def plot(self, axes=None, plot_axis_indices=[-1, -2], axes_coordinates=None, + axes_units=None, data_unit=None, origin=0, **kwargs): """ Plots an interactive visualization of this cube with a slider controlling the wavelength axis for data having dimensions greater than 2. @@ -25,7 +27,7 @@ def plot(self, axes=None, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=Non Parameters ---------- - image_axes: `list` + plot_axis_indices: `list` The two axes that make the image. Like [-1,-2] this implies cube instance -1 dimension will be x-axis and -2 dimension will be y-axis. @@ -33,16 +35,12 @@ def plot(self, axes=None, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=Non axes: `astropy.visualization.wcsaxes.core.WCSAxes` or None: The axes to plot onto. If None the current axes will be used. - unit_x_axis: `astropy.units.Unit` - The unit of x axis for 2D plots. - - unit_y_axis: `astropy.units.Unit` - The unit of y axis for 2D plots. + axes_unit: `list` of `astropy.units.Unit` - unit: `astropy.unit.Unit` + data_unit: `astropy.unit.Unit` The data is changed to the unit given or the cube.unit if not given, for 1D plots. - axis_ranges: list of physical coordinates for array or None + axes_coordinates: list of physical coordinates for array or None If None array indices will be used for all axes. If a list it should contain one element for each axis of the numpy array. For the image axes a [min, max] pair should be specified which will be @@ -52,56 +50,43 @@ def plot(self, axes=None, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=Non If None is specified for an axis then the array indices will be used for that axis. """ + # If old API is used, convert to new API. + plot_axis_indices, axes_coordiantes, axes_units, data_unit, kwargs = _support_101_plot_API( + plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs) axis_data = ['x' for i in range(2)] - axis_data[image_axes[1]] = 'y' - if self.data.ndim >= 3: - plot = self._plot_3D_cube(image_axes=image_axes, unit_x_axis=unit_x_axis, - unit_y_axis=unit_y_axis, axis_ranges=axis_ranges, **kwargs) + axis_data[plot_axis_indices[1]] = 'y' + if self.data.ndim is 1: + plot = self._plot_1D_cube(data_unit=data_unit, origin=origin) elif self.data.ndim is 2: - plot = self._plot_2D_cube(axes=axes, image_axes=axis_data[::-1], **kwargs) - elif self.data.ndim is 1: - plot = self._plot_1D_cube(unit=unit, origin=origin) + plot = self._plot_2D_cube(axes=axes, plot_axis_indices=axis_data[::-1], **kwargs) + else: + plot = self._plot_3D_cube(plot_axis_indices=plot_axis_indices, + axes_coordinates=axes_coordinates, axes_units=axes_units, + **kwargs) return plot - def _plot_3D_cube(self, image_axes=None, unit_x_axis=None, unit_y_axis=None, - axis_ranges=None, **kwargs): + def _plot_1D_cube(self, data_unit=None, origin=0): """ - Plots an interactive visualization of this cube using sliders to move through axes - plot using in the image. - Parameters other than data and wcs are passed to ImageAnimatorWCS, which in turn - passes them to imshow. + Plots a graph. + Keyword arguments are passed on to matplotlib. Parameters ---------- - image_axes: `list` - The two axes that make the image. - Like [-1,-2] this implies cube instance -1 dimension - will be x-axis and -2 dimension will be y-axis. - - unit_x_axis: `astropy.units.Unit` - The unit of x axis. - - unit_y_axis: `astropy.units.Unit` - The unit of y axis. - - axis_ranges: `list` of physical coordinates for array or None - If None array indices will be used for all axes. - If a list it should contain one element for each axis of the numpy array. - For the image axes a [min, max] pair should be specified which will be - passed to :func:`matplotlib.pyplot.imshow` as extent. - For the slider axes a [min, max] pair can be specified or an array the - same length as the axis which will provide all values for that slider. - If None is specified for an axis then the array indices will be used - for that axis. + data_unit: `astropy.unit.Unit` + The data is changed to the unit given or the cube.unit if not given. """ - if not image_axes: - image_axes = [-1, -2] - i = ImageAnimatorWCS(self.data, wcs=self.wcs, image_axes=image_axes, - unit_x_axis=unit_x_axis, unit_y_axis=unit_y_axis, - axis_ranges=axis_ranges, **kwargs) - return i + index_not_one = [] + for i, _bool in enumerate(self.missing_axis): + if not _bool: + index_not_one.append(i) + if data_unit is None: + data_unit = self.wcs.wcs.cunit[index_not_one[0]] + plot = plt.plot(self.pixel_to_world(*[u.Quantity(np.arange(self.data.shape[0]), + unit=u.pix)])[0].to(data_unit), + self.data) + return plot - def _plot_2D_cube(self, axes=None, image_axes=None, **kwargs): + def _plot_2D_cube(self, axes=None, plot_axis_indices=None, **kwargs): """ Plots a 2D image onto the current axes. Keyword arguments are passed on to matplotlib. @@ -111,13 +96,13 @@ def _plot_2D_cube(self, axes=None, image_axes=None, **kwargs): axes: `astropy.visualization.wcsaxes.core.WCSAxes` or `None`: The axes to plot onto. If None the current axes will be used. - image_axes: `list`. - The first axis in WCS object will become the first axis of image_axes and - second axis in WCS object will become the second axis of image_axes. + plot_axis_indices: `list`. + The first axis in WCS object will become the first axis of plot_axis_indices and + second axis in WCS object will become the second axis of plot_axis_indices. Default: ['x', 'y'] """ - if not image_axes: - image_axes = ['x', 'y'] + if not plot_axis_indices: + plot_axis_indices = ['x', 'y'] if axes is None: if self.wcs.naxis is not 2: missing_axis = self.missing_axis @@ -125,7 +110,7 @@ def _plot_2D_cube(self, axes=None, image_axes=None, **kwargs): index = 0 for i, bool_ in enumerate(missing_axis): if not bool_: - slice_list.append(image_axes[index]) + slice_list.append(plot_axis_indices[index]) index += 1 else: slice_list.append(1) @@ -135,23 +120,101 @@ def _plot_2D_cube(self, axes=None, image_axes=None, **kwargs): plot = axes.imshow(self.data, **kwargs) return plot - def _plot_1D_cube(self, unit=None, origin=0): + def _plot_3D_cube(self, plot_axis_indices=None, axes_units=None, + axes_coordinates=None, **kwargs): """ - Plots a graph. - Keyword arguments are passed on to matplotlib. + Plots an interactive visualization of this cube using sliders to move through axes + plot using in the image. + Parameters other than data and wcs are passed to ImageAnimatorWCS, which in turn + passes them to imshow. Parameters ---------- - unit: `astropy.unit.Unit` - The data is changed to the unit given or the cube.unit if not given. + plot_axis_indices: `list` + The two axes that make the image. + Like [-1,-2] this implies cube instance -1 dimension + will be x-axis and -2 dimension will be y-axis. + + axes_unit: `list` of `astropy.units.Unit` + + axes_coordinates: `list` of physical coordinates for array or None + If None array indices will be used for all axes. + If a list it should contain one element for each axis of the numpy array. + For the image axes a [min, max] pair should be specified which will be + passed to :func:`matplotlib.pyplot.imshow` as extent. + For the slider axes a [min, max] pair can be specified or an array the + same length as the axis which will provide all values for that slider. + If None is specified for an axis then the array indices will be used + for that axis. """ - index_not_one = [] - for i, _bool in enumerate(self.missing_axis): - if not _bool: - index_not_one.append(i) - if unit is None: - unit = self.wcs.wcs.cunit[index_not_one[0]] - plot = plt.plot(self.pixel_to_world(*[u.Quantity(np.arange(self.data.shape[0]), - unit=u.pix)])[0].to(unit), - self.data) - return plot + if plot_axis_indices is None: + plot_axis_indices = [-1, -2] + if axes_units is None: + axes_units = [None, None] + i = ImageAnimatorWCS(self.data, wcs=self.wcs, image_axes=plot_axis_indices, + unit_x_axis=axes_units[0], unit_y_axis=axes_units[1], + axis_ranges=axes_coordinates, **kwargs) + return i + + +def _support_101_plot_API(plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs): + """Check if user has used old API and convert it to new API.""" + # Get old API variable values. + image_axes = kwargs.pop("image_axes", None) + axis_ranges = kwargs.pop("axis_ranges", None) + unit_x_axis = kwargs.pop("unit_x_axis", None) + unit_y_axis = kwargs.pop("unit_y_axis", None) + unit = kwargs.pop("unit", None) + # Check if conflicting new and old API values have been set. + # If not, set new API using old API and raise deprecation warning. + if image_axes is not None: + variable_names = ("image_axes", "plot_axis_indices") + _raise_101_API_deprecation_warning(*variable_names) + if plot_axis_indices is None: + plot_axis_indices = image_axes + else: + _raise_API_error(*variable_names) + if axis_ranges is not None: + variable_names = ("axis_ranges", "axes_coordinates") + _raise_101_API_deprecation_warning(*variable_names) + if axes_coordinates is None: + axes_coordinates = axis_ranges + else: + _raise_API_error(*variable_names) + if (unit_x_axis is not None or unit_y_axis is not None) and axes_units is not None: + _raise_API_error("unit_x_axis and/or unit_y_axis", "axes_units") + if axes_units is None: + variable_names = ("unit_x_axis and unit_y_axis", "axes_units") + if unit_x_axis is not None: + _raise_101_API_deprecation_warning(*variable_names) + if len(plot_axis_indices) == 1: + axes_units = unit_x_axis + elif len(plot_axis_indices) == 2: + if unit_y_axis is None: + axes_units = [unit_x_axis, None] + else: + axes_units = [unit_x_axis, unit_y_axis] + else: + raise ValueError("Length of image_axes must be less than 3.") + else: + if unit_y_axis is not None: + _raise_101_API_deprecation_warning(*variable_names) + axes_units = [None, unit_y_axis] + if unit is not None: + variable_names = ("unit", "data_unit") + _raise_101_API_deprecation_warning(*variable_names) + if data_unit is None: + data_unit = unit + else: + _raise_API_error(*variable_names) + # Return values of new API + return plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs + + +def _raise_API_error(old_name, new_name): + raise ValueError( + "Conflicting inputs: {0} (old API) cannot be set if {1} (new API) is set".format( + old_name, new_name)) + +def _raise_101_API_deprecation_warning(old_name, new_name): + warn("{0} is deprecated and will not be supported in version 2.0. It will be replaced by {1}. See docstring.".format(old_name, new_name), DeprecationWarning) diff --git a/ndcube/tests/test_plotting.py b/ndcube/tests/test_plotting.py new file mode 100644 index 000000000..1d44888d4 --- /dev/null +++ b/ndcube/tests/test_plotting.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +import pytest +import datetime + +import numpy as np +import astropy.units as u +import matplotlib +import sunpy.visualization.imageanimator + +from ndcube import NDCube +from ndcube.utils.wcs import WCS +from ndcube.mixins import plotting + + +# sample data for tests +# TODO: use a fixture reading from a test file. file TBD. +ht = {'CTYPE3': 'HPLT-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.5, 'CRPIX3': 0, 'CRVAL3': 0, 'NAXIS3': 2, + 'CTYPE2': 'WAVE ', 'CUNIT2': 'Angstrom', 'CDELT2': 0.2, 'CRPIX2': 0, 'CRVAL2': 0, + 'NAXIS2': 3, + 'CTYPE1': 'TIME ', 'CUNIT1': 'min', 'CDELT1': 0.4, 'CRPIX1': 0, 'CRVAL1': 0, 'NAXIS1': 4} +wt = WCS(header=ht, naxis=3) + +hm = {'CTYPE1': 'WAVE ', 'CUNIT1': 'Angstrom', 'CDELT1': 0.2, 'CRPIX1': 0, 'CRVAL1': 10, + 'NAXIS1': 4, + 'CTYPE2': 'HPLT-TAN', 'CUNIT2': 'deg', 'CDELT2': 0.5, 'CRPIX2': 2, 'CRVAL2': 0.5, + 'NAXIS2': 3, + 'CTYPE3': 'HPLN-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.4, 'CRPIX3': 2, 'CRVAL3': 1, 'NAXIS3': 2} +wm = WCS(header=hm, naxis=3) + +data = np.array([[[1, 2, 3, 4], [2, 4, 5, 3], [0, -1, 2, 3]], + [[2, 4, 5, 1], [10, 5, 2, 2], [10, 3, 3, 0]]]) +uncertainty = np.sqrt(data) +mask_cube = data < 0 + +cube = NDCube( + data, + wt, + mask=mask_cube, + uncertainty=uncertainty, + missing_axis=[False, False, False, True], + extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.pix)), + ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.pix)), + ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.pix))]) + +cubem = NDCube( + data, + wm, + mask=mask_cube, + uncertainty=uncertainty, + extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.pix)), + ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.pix)), + ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.pix))]) + + +@pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ + (cube[0, 0], {}, + (u.Quantity([0.4, 0.8, 1.2, 1.6], unit="min"), np.array([1, 2, 3, 4]), + "", "", (0.4, 1.6), (1, 4))) + ]) +def test_cube_plot_1D(test_input, test_kwargs, expected_values): + # Unpack expected properties. + expected_xdata, expected_ydata, expected_xlabel, expected_ylabel, \ + expected_xlim, expected_ylim = expected_values + # Run plot method. + output = test_input.plot(**test_kwargs) + # Check plot properties are correct. + assert type(output) is list + assert len(output) == 1 + output = output[0] + assert type(output) is matplotlib.lines.Line2D + output_xdata = (output.axes.lines[0].get_xdata()) + if type(expected_xdata) == u.Quantity: + assert output_xdata.unit == expected_xdata.unit + assert np.allclose(output_xdata.value, expected_xdata.value) + else: + np.testing.assert_array_equal(output.axes.lines[0].get_xdata(), expected_xdata) + if type(expected_ydata) == u.Quantity: + assert output_ydata.unit == expected_ydata.unit + assert np.allclose(output_ydata.value, expected_ydata.value) + else: + np.testing.assert_array_equal(output.axes.lines[0].get_ydata(), expected_ydata) + assert output.axes.get_xlabel() == expected_xlabel + assert output.axes.get_ylabel() == expected_ylabel + output_xlim = output.axes.get_xlim() + assert output_xlim[0] <= expected_xlim[0] + assert output_xlim[1] >= expected_xlim[1] + output_ylim = output.axes.get_ylim() + assert output_ylim[0] <= expected_ylim[0] + assert output_ylim[1] >= expected_ylim[1] + + +@pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ + (cube[0], {}, + (cube[0].data, "", "", + (-0.5, 3.5, 2.5, -0.5))) + ]) +def test_cube_plot_2D(test_input, test_kwargs, expected_values): + # Unpack expected properties. + expected_data, expected_xlabel, expected_ylabel, expected_extent = \ + expected_values + # Run plot method. + output = test_input.plot(**test_kwargs) + # Check plot properties are correct. + assert type(output) is matplotlib.image.AxesImage + np.testing.assert_array_equal(output.get_array(), expected_data) + assert output.axes.xaxis.get_label_text() == expected_xlabel + assert output.axes.yaxis.get_label_text() == expected_ylabel + assert np.allclose(output.get_extent(), expected_extent) + + +@pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ + (cubem, {}, + (cubem.data, [np.array([0., 2.]), [0, 3], [0, 4]], "", "")) + ]) +def test_cube_animate_ND(test_input, test_kwargs, expected_values): + # Unpack expected properties. + expected_data, expected_axis_ranges, expected_xlabel, expected_ylabel = expected_values + # Run plot method. + output = test_input.plot(**test_kwargs) + # Check plot properties are correct. + assert type(output) is sunpy.visualization.imageanimator.ImageAnimatorWCS + np.testing.assert_array_equal(output.data, expected_data) + assert output.axes.xaxis.get_label_text() == expected_xlabel + assert output.axes.yaxis.get_label_text() == expected_ylabel + + +def test_cube_plot_ND_as_2DAnimation(): + pass + + +@pytest.mark.parametrize("input_values, expected_values", [ + ((None, None, None, None, {"image_axes": [-1, -2], + "axis_ranges": [np.arange(3), np.arange(3)], + "unit_x_axis": "km", + "unit_y_axis": u.s, + "unit": u.W}), + ([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {})), + (([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {}), + ([-1, -2], [np.arange(3), np.arange(3)], ["km", u.s], u.W, {})), + (([-1], None, None, None, {"unit_x_axis": "km"}), + ([-1], None, "km", None, {})), + (([-1, -2], None, None, None, {"unit_x_axis": "km"}), + (([-1, -2], None, ["km", None], None, {}))), + (([-1, -2], None, None, None, {"unit_y_axis": "km"}), + (([-1, -2], None, [None, "km"], None, {}))) + ]) +def test_support_101_plot_API(input_values, expected_values): + # Define expected values. + expected_plot_axis_indices, expected_axes_coordinates, expected_axes_units, \ + expected_data_unit, expected_kwargs = expected_values + # Run function + output_plot_axis_indices, output_axes_coordinates, output_axes_units, \ + output_data_unit, output_kwargs = plotting._support_101_plot_API(*input_values) + # Check values are correct + assert output_plot_axis_indices == expected_plot_axis_indices + if expected_axes_coordinates is None: + assert output_axes_coordinates == expected_axes_coordinates + elif type(expected_axes_coordinates) is list: + for i, ac in enumerate(output_axes_coordinates): + np.testing.assert_array_equal(ac, expected_axes_coordinates[i]) + assert output_axes_units == expected_axes_units + assert output_data_unit == expected_data_unit + assert output_kwargs == expected_kwargs + + +@pytest.mark.parametrize("input_values", [ + ([0, 1], None, None, None, {"image_axes": [-1, -2]}), + (None, [np.arange(1, 4), np.arange(1, 4)], None, None, + {"axis_ranges": [np.arange(3), np.arange(3)]}), + (None, None, [u.s, "km"], None, {"unit_x_axis": u.W}), + (None, None, [u.s, "km"], None, {"unit_y_axis": u.W}), + (None, None, None, u.s, {"unit": u.W}), + ([0, 1, 2], None, None, None, {"unit_x_axis": [u.s, u.km, u.W]}), + ]) +def test_support_101_plot_API_errors(input_values): + with pytest.raises(ValueError): + output = plotting._support_101_plot_API(*input_values)