diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 725452fba..65687cfd1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -30,6 +30,8 @@ API Changes Bug Fixes --------- +- Allowed `~ndcube.NDCubeBase.axis_world_coords` to accept negative + axis indices as arguments. [#106] - Fixed bug in ``NDCube.crop_by_coords`` in case where real world coordinate system was rotated relative to pixel grid. [#113]. diff --git a/ndcube/mixins/plotting.py b/ndcube/mixins/plotting.py index 2470eed1f..44c739a4d 100644 --- a/ndcube/mixins/plotting.py +++ b/ndcube/mixins/plotting.py @@ -1,22 +1,31 @@ from warnings import warn +import copy +import datetime import numpy as np +import matplotlib as mpl import matplotlib.pyplot as plt - import astropy.units as u -from sunpy.visualization.imageanimator import ImageAnimatorWCS +from sunpy.visualization.imageanimator import ImageAnimator, ImageAnimatorWCS import sunpy.visualization.wcsaxes_compat as wcsaxes_compat +from ndcube import utils +from ndcube.mixins import sequence_plotting + __all__ = ['NDCubePlotMixin'] +INVALID_UNIT_SET_MESSAGE = "Can only set unit for axis if corresponding coordinates in " + \ + "axes_coordinates are set to None, an astropy Quantity or the name of an extra coord that " + \ + "is an astropy Quantity." + class NDCubePlotMixin: """ Add plotting functionality to a NDCube class. """ - def plot(self, axes=None, plot_axis_indices=[-1, -2], axes_coordinates=None, - axes_units=None, data_unit=None, origin=0, **kwargs): + def plot(self, axes=None, plot_axis_indices=None, axes_coordinates=None, + axes_units=None, data_unit=None, **kwargs): """ Plots an interactive visualization of this cube with a slider controlling the wavelength axis for data having dimensions greater than 2. @@ -29,7 +38,7 @@ def plot(self, axes=None, plot_axis_indices=[-1, -2], axes_coordinates=None, ---------- plot_axis_indices: `list` The two axes that make the image. - Like [-1,-2] this implies cube instance -1 dimension + Default=[-1,-2]. This implies cube instance -1 dimension will be x-axis and -2 dimension will be y-axis. axes: `astropy.visualization.wcsaxes.core.WCSAxes` or None: @@ -49,23 +58,33 @@ def plot(self, axes=None, plot_axis_indices=[-1, -2], axes_coordinates=None, same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + """ # If old API is used, convert to new API. plot_axis_indices, axes_coordiantes, axes_units, data_unit, kwargs = _support_101_plot_API( plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs) - axis_data = ['x' for i in range(2)] - axis_data[plot_axis_indices[1]] = 'y' - if self.data.ndim is 1: - plot = self._plot_1D_cube(data_unit=data_unit, origin=origin) - elif self.data.ndim is 2: - plot = self._plot_2D_cube(axes=axes, plot_axis_indices=axis_data[::-1], **kwargs) + # Check kwargs are in consistent formats and set default values if not done so by user. + naxis = len(self.dimensions) + plot_axis_indices, axes_coordinates, axes_units = sequence_plotting._prep_axes_kwargs( + naxis, plot_axis_indices, axes_coordinates, axes_units) + if naxis is 1: + ax = self._plot_1D_cube(axes, axes_coordinates, + axes_units, data_unit, **kwargs) else: - plot = self._plot_3D_cube(plot_axis_indices=plot_axis_indices, - axes_coordinates=axes_coordinates, axes_units=axes_units, - **kwargs) - return plot + if len(plot_axis_indices) == 1: + raise NotImplementedError() + else: + if naxis == 2: + ax = self._plot_2D_cube(axes, plot_axis_indices, axes_coordinates, + axes_units, data_unit, **kwargs) + else: + ax = self._plot_3D_cube( + plot_axis_indices=plot_axis_indices, axes_coordinates=axes_coordinates, + axes_units=axes_units, **kwargs) + return ax - def _plot_1D_cube(self, data_unit=None, origin=0): + def _plot_1D_cube(self, axes=None, axes_coordinates=None, axes_units=None, data_unit=None, + **kwargs): """ Plots a graph. Keyword arguments are passed on to matplotlib. @@ -74,19 +93,73 @@ def _plot_1D_cube(self, data_unit=None, origin=0): ---------- data_unit: `astropy.unit.Unit` The data is changed to the unit given or the cube.unit if not given. + """ - index_not_one = [] - for i, _bool in enumerate(self.missing_axis): - if not _bool: - index_not_one.append(i) - if data_unit is None: - data_unit = self.wcs.wcs.cunit[index_not_one[0]] - plot = plt.plot(self.pixel_to_world(*[u.Quantity(np.arange(self.data.shape[0]), - unit=u.pix)])[0].to(data_unit), - self.data) - return plot + # Derive x-axis coordinates and unit from inputs. + x_axis_coordinates, unit_x_axis = sequence_plotting._derive_1D_coordinates_and_units( + axes_coordinates, axes_units) + if x_axis_coordinates is None: + # Default is to derive x coords and defaul xlabel from WCS object. + xname = self.world_axis_physical_types[0] + xdata = self.axis_world_coords() + elif isinstance(x_axis_coordinates, str): + # User has entered a str as x coords, get that extra coord. + xname = x_axis_coordinates + xdata = self.extra_coords[x_axis_coordinates]["value"] + else: + # Else user must have set the x-values manually. + xname = "" + xdata = x_axis_coordinates + # If a unit has been set for the x-axis, try to convert x coords to that unit. + if isinstance(xdata, u.Quantity): + if unit_x_axis is None: + unit_x_axis = xdata.unit + xdata = xdata.value + else: + xdata = xdata.to(unit_x_axis).value + else: + if unit_x_axis is not None: + raise TypeError(INVALID_UNIT_SET_MESSAGE) + # Define default x axis label. + default_xlabel = "{0} [{1}]".format(xname, unit_x_axis) + # Combine data and uncertainty with mask. + xdata = np.ma.masked_array(xdata, self.mask) + # Derive y-axis coordinates, uncertainty and unit from the NDCube's data. + if self.unit is None: + if data_unit is not None: + raise TypeError("Can only set y-axis unit if self.unit is set to a " + "compatible unit.") + else: + ydata = self.data + if self.uncertainty is None: + yerror = None + else: + yerror = self.uncertainty.array + else: + if data_unit is None: + data_unit = self.unit + ydata = self.data + if self.uncertainty is None: + yerror = None + else: + yerror = self.uncertainty.array + else: + ydata = (self.data * self.unit).to(data_unit).value + if self.uncertainty is None: + yerror = None + else: + yerror = (self.uncertainty.array * self.unit).to(data_unit).value + # Combine data and uncertainty with mask. + ydata = np.ma.masked_array(ydata, self.mask) + if yerror is not None: + yerror = np.ma.masked_array(yerror, self.mask) + # Create plot + fig, ax = sequence_plotting._make_1D_sequence_plot(xdata, ydata, yerror, + data_unit, default_xlabel, kwargs) + return ax - def _plot_2D_cube(self, axes=None, plot_axis_indices=None, **kwargs): + def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None, + axes_units=None, data_unit=None, **kwargs): """ Plots a 2D image onto the current axes. Keyword arguments are passed on to matplotlib. @@ -100,28 +173,93 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, **kwargs): The first axis in WCS object will become the first axis of plot_axis_indices and second axis in WCS object will become the second axis of plot_axis_indices. Default: ['x', 'y'] + """ - if not plot_axis_indices: - plot_axis_indices = ['x', 'y'] + # Set default values of kwargs if not set. + if axes_coordinates is None: + axes_coordinates = [None, None] + if axes_units is None: + axes_units = [None, None] + # Set which cube dimensions are on the x an y axes. + axis_data = ['x', 'x'] + axis_data[plot_axis_indices[1]] = 'y' + axis_data = axis_data[::-1] + # Determine data to be plotted + if data_unit is None: + data = self.data + else: + # If user set data_unit, convert dat to desired unit if self.unit set. + if self.unit is None: + raise TypeError("Can only set data_unit if NDCube.unit is set.") + else: + data = (self.data * self.unit).to(data_unit).value + # Combine data with mask + data = np.ma.masked_array(data, self.mask) if axes is None: - if self.wcs.naxis is not 2: - missing_axis = self.missing_axis - slice_list = [] - index = 0 - for i, bool_ in enumerate(missing_axis): - if not bool_: - slice_list.append(plot_axis_indices[index]) - index += 1 - else: - slice_list.append(1) - if index is not 2: - raise ValueError("Dimensions of WCS and data don't match") - axes = wcsaxes_compat.gca_wcs(self.wcs, slices=slice_list) - plot = axes.imshow(self.data, **kwargs) - return plot - - def _plot_3D_cube(self, plot_axis_indices=None, axes_units=None, - axes_coordinates=None, **kwargs): + try: + axes_coord_check == [None, None] + except: + axes_coord_check = False + if axes_coord_check: + # Build slice list for WCS for initializing WCSAxes object. + if self.wcs.naxis is not 2: + slice_list = [] + index = 0 + for i, bool_ in enumerate(self.missing_axis): + if not bool_: + slice_list.append(axis_data[index]) + index += 1 + else: + slice_list.append(1) + if index is not 2: + raise ValueError("Dimensions of WCS and data don't match") + ax = wcsaxes_compat.gca_wcs(self.wcs, slices=slice_list) + # Set axis labels + x_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[0], + self.missing_axis) + ax.set_xlabel("{0} [{1}]".format( + self.world_axis_physical_types[plot_axis_indices[0]], + self.wcs.wcs.cunit[x_wcs_axis])) + y_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[1], + self.missing_axis) + ax.set_ylabel("{0} [{1}]".format( + self.world_axis_physical_types[plot_axis_indices[1]], + self.wcs.wcs.cunit[y_wcs_axis])) + # Plot data + ax.imshow(data, **kwargs) + else: + # Else manually set axes x and y values based on user's input for axes_coordinates. + new_axes_coordinates, new_axis_units, default_labels = \ + self._derive_axes_coordinates(axes_coordinates, axes_units) + # Initialize axes object and set values along axis. + fig, ax = plt.subplots(1, 1) + # Since we can't assume the x-axis will be uniform, create NonUniformImage + # axes and add it to the axes object. + if plot_axis_indices[0] < plot_axis_indices[1]: + data = data.transpose() + im_ax = mpl.image.NonUniformImage( + ax, extent=(new_axes_coordinates[plot_axis_indices[0]][0], + new_axes_coordinates[plot_axis_indices[0]][-1], + new_axes_coordinates[plot_axis_indices[1]][0], + new_axes_coordinates[plot_axis_indices[1]][-1]), **kwargs) + im_ax.set_data(new_axes_coordinates[plot_axis_indices[0]], + new_axes_coordinates[plot_axis_indices[1]], data) + ax.add_image(im_ax) + # Set the limits, labels, etc. of the axes. + xlim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[0]][0], + new_axes_coordinates[plot_axis_indices[0]][-1])) + ax.set_xlim(xlim) + ylim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[1]][0], + new_axes_coordinates[plot_axis_indices[1]][-1])) + ax.set_ylim(ylim) + xlabel = kwargs.pop("xlabel", default_labels[plot_axis_indices[0]]) + ylabel = kwargs.pop("ylabel", default_labels[plot_axis_indices[1]]) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + return ax + + def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None, + axes_units=None, data_unit=None, **kwargs): """ Plots an interactive visualization of this cube using sliders to move through axes plot using in the image. @@ -146,15 +284,110 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_units=None, same length as the axis which will provide all values for that slider. If None is specified for an axis then the array indices will be used for that axis. + """ - if plot_axis_indices is None: - plot_axis_indices = [-1, -2] + # For convenience in inserting dummy variables later, ensure + # plot_axis_indices are all positive. + plot_axis_indices = [i if i >= 0 else self.data.ndim + i for i in plot_axis_indices] + # If axes kwargs not set by user, set them as list of Nones for + # each axis for consistent behaviour. + if axes_coordinates is None: + axes_coordinates = [None] * self.data.ndim if axes_units is None: - axes_units = [None, None] - i = ImageAnimatorWCS(self.data, wcs=self.wcs, image_axes=plot_axis_indices, - unit_x_axis=axes_units[0], unit_y_axis=axes_units[1], - axis_ranges=axes_coordinates, **kwargs) - return i + axes_units = [None] * self.data.ndim + # If data_unit set, convert data to that unit + if data_unit is None: + data = self.data + else: + data = (self.data * self.unit).to(data_unit).value + # Combine data values with mask. + data = np.ma.masked_array(data, self.mask) + # If axes_coordinates not provided generate an ImageAnimatorWCS plot + # using NDCube's wcs object. + if (axes_coordinates[plot_axis_indices[0]] is None and + axes_coordinates[plot_axis_indices[1]] is None): + # If there are missing axes in WCS object, add corresponding dummy axes to data. + if data.ndim < self.wcs.naxis: + new_shape = list(data.shape) + for i in np.arange(self.wcs.naxis)[self.missing_axis[::-1]]: + new_shape.insert(i, 1) + # Also insert dummy coordinates and units. + axes_coordinates.insert(i, None) + axes_units.insert(i, None) + # Iterate plot_axis_indices if neccessary + for j, pai in enumerate(plot_axis_indices): + if pai >= i: + plot_axis_indices[j] = plot_axis_indices[j] + 1 + # Reshape data + data = data.reshape(new_shape) + # Generate plot + ax = ImageAnimatorWCS(data, wcs=self.wcs, image_axes=plot_axis_indices, + unit_x_axis=axes_units[plot_axis_indices[0]], + unit_y_axis=axes_units[plot_axis_indices[1]], + axis_ranges=axes_coordinates, **kwargs) + # If one of the plot axes is set manually, produce a basic ImageAnimator object. + else: + new_axes_coordinates, new_axes_units, default_labels = \ + self._derive_axes_coordinates(axes_coordinates, axes_units) + # If axis labels not set by user add to kwargs. + ax = ImageAnimator(data, image_axes=plot_axis_indices, + axis_ranges=new_axes_coordinates, **kwargs) + return ax + + def _derive_axes_coordinates(self, axes_coordinates, axes_units): + new_axes_coordinates = [] + new_axes_units = [] + default_labels = [] + default_label_text = "" + for i, axis_coordinate in enumerate(axes_coordinates): + # If axis coordinate is None, derive axis values from WCS. + if axis_coordinate is None: + # N.B. This assumes axes are independent. Fix this before merging!!! + new_axis_coordinate = self.axis_world_coords(i) + axis_label_text = self.world_axis_physical_types[i] + elif isinstance(axis_coordinate, str): + # If axis coordinate is a string, derive axis values from + # corresponding extra coord. + new_axis_coordinate = self.extra_coords[axis_coordinate]["value"] + axis_label_text = axis_coordinate + else: + # Else user must have manually set the axis coordinates. + new_axis_coordinate = axis_coordinate + axis_label_text = default_label_text + # If axis coordinate is a Quantity, convert to unit supplied by user. + if isinstance(new_axis_coordinate, u.Quantity): + if axes_units[i] is None: + new_axis_unit = new_axis_coordinate.unit + new_axis_coordinate = new_axis_coordinate.value + else: + new_axis_unit = axes_units[i] + new_axis_coordinate = new_axis_coordinate.to(new_axis_unit).value + elif isinstance(new_axis_coordinate[0], datetime.datetime): + axis_label_text = "{0}/sec since {1}".format( + axis_label_text, new_axis_coordinate[0]) + new_axis_coordinate = np.array([(t-new_axis_coordinate[0]).total_seconds() + for t in new_axis_coordinate]) + new_axis_unit = u.s + else: + if axes_units[i] is None: + new_axis_unit = None + else: + raise TypeError(INVALID_UNIT_SET_MESSAGE) + # Derive default axis label + if type(new_axis_coordinate[0]) is datetime.datetime: + if axis_label_text == default_label_text: + default_label = "{0}".format(new_axis_coordinate[0].strftime("%Y/%m/%d %H:%M")) + else: + default_label = "{0} [{1}]".format( + axis_label_text, new_axis_coordinate[0].strftime("%Y/%m/%d %H:%M")) + else: + default_label = "{0} [{1}]".format(axis_label_text, new_axis_unit) + # Append new coordinates, units and labels to output list. + new_axes_coordinates.append(new_axis_coordinate) + new_axes_units.append(new_axis_unit) + default_labels.append(default_label) + return new_axes_coordinates, new_axes_units, default_labels + def _support_101_plot_API(plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs): diff --git a/ndcube/ndcube.py b/ndcube/ndcube.py index 1014898ee..fbec7ef76 100644 --- a/ndcube/ndcube.py +++ b/ndcube/ndcube.py @@ -336,7 +336,10 @@ def axis_world_coords(self, *axes): int_axes = np.empty(len(axes), dtype=int) for i, axis in enumerate(axes): if isinstance(axis, int): - int_axes[i] = axis + if axis < 0: + int_axes[i] = n_dimensions + axis + else: + int_axes[i] = axis elif isinstance(axis, str): int_axes[i] = utils.cube.get_axis_number_from_axis_name( axis, world_axis_types) diff --git a/ndcube/tests/test_plotting.py b/ndcube/tests/test_plotting.py index 1d44888d4..33ac5f0aa 100644 --- a/ndcube/tests/test_plotting.py +++ b/ndcube/tests/test_plotting.py @@ -38,24 +38,106 @@ mask=mask_cube, uncertainty=uncertainty, missing_axis=[False, False, False, True], - extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.pix)), - ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.pix)), - ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.pix))]) + extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.s)), + ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.W)), + ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.m)), + ('another time', 2, np.array( + [datetime.datetime(2000, 1, 1)+datetime.timedelta(minutes=i) + for i in range(data.shape[2])])), + ('array coord', 2, np.arange(100, 100+data.shape[2])) + ]) + +cube_unit = NDCube( + data, + wt, + mask=mask_cube, + unit=u.J, + uncertainty=uncertainty, + missing_axis=[False, False, False, True], + extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.s)), + ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.W)), + ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.m)), + ('another time', 2, np.array( + [datetime.datetime(2000, 1, 1)+datetime.timedelta(minutes=i) + for i in range(data.shape[2])])) + ]) + +cube_no_uncertainty = NDCube( + data, + wt, + mask=mask_cube, + missing_axis=[False, False, False, True], + extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.s)), + ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.W)), + ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.m)), + ('another time', 2, np.array( + [datetime.datetime(2000, 1, 1)+datetime.timedelta(minutes=i) + for i in range(data.shape[2])])) + ]) + +cube_unit_no_uncertainty = NDCube( + data, + wt, + mask=mask_cube, + unit=u.J, + missing_axis=[False, False, False, True], + extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.s)), + ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.W)), + ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.m)), + ('another time', 2, np.array( + [datetime.datetime(2000, 1, 1)+datetime.timedelta(minutes=i) + for i in range(data.shape[2])])) + ]) cubem = NDCube( data, wm, mask=mask_cube, uncertainty=uncertainty, - extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.pix)), - ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.pix)), - ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.pix))]) + extra_coords=[('time', 0, u.Quantity(range(data.shape[0]), unit=u.s)), + ('hello', 1, u.Quantity(range(data.shape[1]), unit=u.W)), + ('bye', 2, u.Quantity(range(data.shape[2]), unit=u.m)), + ('another time', 2, np.array( + [datetime.datetime(2000, 1, 1)+datetime.timedelta(minutes=i) + for i in range(data.shape[2])])) + ]) @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ (cube[0, 0], {}, - (u.Quantity([0.4, 0.8, 1.2, 1.6], unit="min"), np.array([1, 2, 3, 4]), - "", "", (0.4, 1.6), (1, 4))) + (np.ma.masked_array([0.4, 0.8, 1.2, 1.6], cube[0, 0].mask), + np.ma.masked_array(cube[0, 0].data, cube[0, 0].mask), + "time [min]", "Data [None]", (0.4, 1.6), (1, 4))), + + (cube_unit[0, 0], {"axes_coordinates": "bye", "axes_units": "km", "data_unit": u.erg}, + (np.ma.masked_array(cube_unit[0, 0].extra_coords["bye"]["value"].to(u.km).value, + cube_unit[0, 0].mask), + np.ma.masked_array(u.Quantity(cube_unit[0, 0].data, + unit=cube_unit[0, 0].unit).to(u.erg).value, + cube_unit[0, 0].mask), + "bye [km]", "Data [erg]", (0, 0.003), (10000000, 40000000))), + + (cube_unit[0, 0], {"axes_coordinates": np.arange(10, 10+cube_unit[0, 0].data.shape[0])}, + (np.ma.masked_array(np.arange(10, 10+cube_unit[0, 0].data.shape[0]), cube_unit[0, 0].mask), + np.ma.masked_array(cube_unit[0, 0].data, cube_unit[0, 0].mask), + " [None]", "Data [J]", (10, 10+cube_unit[0, 0].data.shape[0]-1), (1, 4))), + + (cube_no_uncertainty[0, 0], {}, + (np.ma.masked_array([0.4, 0.8, 1.2, 1.6], cube_no_uncertainty[0, 0].mask), + np.ma.masked_array(cube_no_uncertainty[0, 0].data, cube_no_uncertainty[0, 0].mask), + "time [min]", "Data [None]", (0.4, 1.6), (1, 4))), + + (cube_unit_no_uncertainty[0, 0], {}, + (np.ma.masked_array([0.4, 0.8, 1.2, 1.6], cube_unit_no_uncertainty[0, 0].mask), + np.ma.masked_array(cube_no_uncertainty[0, 0].data, cube_unit_no_uncertainty[0, 0].mask), + "time [min]", "Data [J]", (0.4, 1.6), (1, 4))), + + (cube_unit_no_uncertainty[0, 0], {"data_unit": u.erg}, + (np.ma.masked_array([0.4, 0.8, 1.2, 1.6], cube_unit_no_uncertainty[0, 0].mask), + np.ma.masked_array(u.Quantity(cube_unit[0, 0].data, + unit=cube_unit[0, 0].unit).to(u.erg).value, + cube_unit[0, 0].mask), + "time [min]", "Data [erg]", (0.4, 1.6), (10000000, 40000000))) ]) def test_cube_plot_1D(test_input, test_kwargs, expected_values): # Unpack expected properties. @@ -64,23 +146,26 @@ def test_cube_plot_1D(test_input, test_kwargs, expected_values): # Run plot method. output = test_input.plot(**test_kwargs) # Check plot properties are correct. - assert type(output) is list - assert len(output) == 1 - output = output[0] - assert type(output) is matplotlib.lines.Line2D + # Type + assert isinstance(output, matplotlib.axes.Axes) + # Check x axis data output_xdata = (output.axes.lines[0].get_xdata()) - if type(expected_xdata) == u.Quantity: - assert output_xdata.unit == expected_xdata.unit - assert np.allclose(output_xdata.value, expected_xdata.value) + assert np.allclose(output_xdata.data, expected_xdata.data) + if isinstance(output_xdata.mask, np.ndarray): + np.testing.assert_array_equal(output_xdata.mask, expected_xdata.mask) else: - np.testing.assert_array_equal(output.axes.lines[0].get_xdata(), expected_xdata) - if type(expected_ydata) == u.Quantity: - assert output_ydata.unit == expected_ydata.unit - assert np.allclose(output_ydata.value, expected_ydata.value) + assert output_xdata.mask == expected_xdata.mask + # Check y axis data + output_ydata = (output.axes.lines[0].get_ydata()) + assert np.allclose(output_ydata.data, expected_ydata.data) + if isinstance(output_ydata.mask, np.ndarray): + np.testing.assert_array_equal(output_ydata.mask, expected_ydata.mask) else: - np.testing.assert_array_equal(output.axes.lines[0].get_ydata(), expected_ydata) + assert output_ydata.mask == expected_ydata.mask + # Check axis labels assert output.axes.get_xlabel() == expected_xlabel assert output.axes.get_ylabel() == expected_ylabel + # Check axis limits output_xlim = output.axes.get_xlim() assert output_xlim[0] <= expected_xlim[0] assert output_xlim[1] >= expected_xlim[1] @@ -89,10 +174,39 @@ def test_cube_plot_1D(test_input, test_kwargs, expected_values): assert output_ylim[1] >= expected_ylim[1] +@pytest.mark.parametrize("test_input, test_kwargs, expected_error", [ + (cube[0, 0], {"axes_coordinates": np.arange(10, 10+cube_unit[0, 0].data.shape[0]), + "axes_units": u.C}, TypeError), + (cube[0, 0], {"data_unit": u.C}, TypeError) + ]) +def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error): + with pytest.raises(expected_error): + output = test_input.plot(**test_kwargs) + + @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ (cube[0], {}, - (cube[0].data, "", "", - (-0.5, 3.5, 2.5, -0.5))) + (np.ma.masked_array(cube[0].data, cube[0].mask), "time [min]", "em.wl [m]", + (0.4, 1.6, 2e-11, 6e-11))), + + (cube[0], {"axes_coordinates": ["bye", None], "axes_units": [None, u.cm]}, + (np.ma.masked_array(cube[0].data, cube[0].mask), "bye [m]", "em.wl [cm]", + (0.0, 3.0, 2e-9, 6e-9))), + + (cube[0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]), + u.Quantity(np.arange(10, 10+cube[0].data.shape[0]), unit=u.m)], + "axes_units": [None, u.cm]}, + (np.ma.masked_array(cube[0].data, cube[0].mask), " [None]", " [cm]", (10, 13, 1000, 1200))), + + (cube[0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]), + u.Quantity(np.arange(10, 10+cube[0].data.shape[0]), unit=u.m)]}, + (np.ma.masked_array(cube[0].data, cube[0].mask), " [None]", " [m]", (10, 13, 10, 12))), + + (cube_unit[0], {"plot_axis_indices": [0, 1], "axes_coordinates": [None, "bye"], + "data_unit": u.erg}, + (np.ma.masked_array((cube_unit[0].data * cube_unit[0].unit).to(u.erg).value, + cube_unit[0].mask).transpose(), + "em.wl [m]", "bye [m]", (2e-11, 6e-11, 0.0, 3.0))) ]) def test_cube_plot_2D(test_input, test_kwargs, expected_values): # Unpack expected properties. @@ -101,18 +215,29 @@ def test_cube_plot_2D(test_input, test_kwargs, expected_values): # Run plot method. output = test_input.plot(**test_kwargs) # Check plot properties are correct. - assert type(output) is matplotlib.image.AxesImage - np.testing.assert_array_equal(output.get_array(), expected_data) + assert isinstance(output, matplotlib.axes.Axes) + np.testing.assert_array_equal(output.images[0].get_array(), expected_data) assert output.axes.xaxis.get_label_text() == expected_xlabel assert output.axes.yaxis.get_label_text() == expected_ylabel - assert np.allclose(output.get_extent(), expected_extent) + assert np.allclose(output.images[0].get_extent(), expected_extent) + + +@pytest.mark.parametrize("test_input, test_kwargs, expected_error", [ + (cube[0], {"axes_coordinates": ["array coord", None], "axes_units": [u.cm, None]}, TypeError), + (cube[0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]), None], + "axes_units": [u.cm, None]}, TypeError), + (cube[0], {"data_unit": u.cm}, TypeError) + ]) +def test_cube_plot_2D_errors(test_input, test_kwargs, expected_error): + with pytest.raises(expected_error): + output = test_input.plot(**test_kwargs) @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ (cubem, {}, (cubem.data, [np.array([0., 2.]), [0, 3], [0, 4]], "", "")) ]) -def test_cube_animate_ND(test_input, test_kwargs, expected_values): +def test_cube_plot_ND_as_2DAnimation(test_input, test_kwargs, expected_values): # Unpack expected properties. expected_data, expected_axis_ranges, expected_xlabel, expected_ylabel = expected_values # Run plot method. @@ -124,10 +249,6 @@ def test_cube_animate_ND(test_input, test_kwargs, expected_values): assert output.axes.yaxis.get_label_text() == expected_ylabel -def test_cube_plot_ND_as_2DAnimation(): - pass - - @pytest.mark.parametrize("input_values, expected_values", [ ((None, None, None, None, {"image_axes": [-1, -2], "axis_ranges": [np.arange(3), np.arange(3)],