diff --git a/changelog/204.bugfix.rst b/changelog/204.bugfix.rst new file mode 100644 index 000000000..e5d0cb746 --- /dev/null +++ b/changelog/204.bugfix.rst @@ -0,0 +1 @@ +Fix the ability to pass a custom Axes to `ndcube.NDCube.plot` for a 2D cube. diff --git a/ndcube/mixins/plotting.py b/ndcube/mixins/plotting.py index 677e4f4f2..2dea566ab 100644 --- a/ndcube/mixins/plotting.py +++ b/ndcube/mixins/plotting.py @@ -5,6 +5,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt import astropy.units as u +from astropy.visualization.wcsaxes import WCSAxes import sunpy.visualization.wcsaxes_compat as wcsaxes_compat try: from sunpy.visualization.animator import ImageAnimator, ImageAnimatorWCS, LineAnimator @@ -80,6 +81,7 @@ def plot(self, axes=None, plot_axis_indices=None, axes_coordinates=None, axes_units, data_unit, **kwargs) else: if len(plot_axis_indices) == 1: + ax = self._animate_cube_1D( plot_axis_index=plot_axis_indices[0], axes_coordinates=axes_coordinates, axes_units=axes_units, data_unit=data_unit, **kwargs) @@ -202,12 +204,12 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None data = (self.data * self.unit).to(data_unit).value # Combine data with mask data = np.ma.masked_array(data, self.mask) - if axes is None: - try: - axes_coord_check = axes_coordinates == [None, None] - except Exception: - axes_coord_check = False - if axes_coord_check: + try: + axes_coord_check = axes_coordinates == [None, None] + except Exception: + axes_coord_check = False + if axes_coord_check and (isinstance(axes, WCSAxes) or axes is None): + if axes is None: # Build slice list for WCS for initializing WCSAxes object. if self.wcs.naxis != 2: slice_list = [] @@ -220,52 +222,61 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None slice_list.append(1) if index != 2: raise ValueError("Dimensions of WCS and data don't match") - ax = wcsaxes_compat.gca_wcs(self.wcs, slices=tuple(slice_list)) + axes = wcsaxes_compat.gca_wcs(self.wcs, slices=tuple(slice_list)) else: - ax = wcsaxes_compat.gca_wcs(self.wcs) - # Set axis labels - x_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[0], - self.missing_axes) - ax.set_xlabel("{} [{}]".format( - self.world_axis_physical_types[plot_axis_indices[0]], - self.wcs.wcs.cunit[x_wcs_axis])) - y_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[1], - self.missing_axes) - ax.set_ylabel("{} [{}]".format( - self.world_axis_physical_types[plot_axis_indices[1]], - self.wcs.wcs.cunit[y_wcs_axis])) - # Plot data - ax.imshow(data, **kwargs) - else: - # Else manually set axes x and y values based on user's input for axes_coordinates. - new_axes_coordinates, new_axis_units, default_labels = \ - self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape) - # Initialize axes object and set values along axis. - fig, ax = plt.subplots(1, 1) - # Since we can't assume the x-axis will be uniform, create NonUniformImage - # axes and add it to the axes object. - if plot_axis_indices[0] < plot_axis_indices[1]: - data = data.transpose() - im_ax = mpl.image.NonUniformImage( - ax, extent=(new_axes_coordinates[plot_axis_indices[0]][0], - new_axes_coordinates[plot_axis_indices[0]][-1], - new_axes_coordinates[plot_axis_indices[1]][0], - new_axes_coordinates[plot_axis_indices[1]][-1]), **kwargs) - im_ax.set_data(new_axes_coordinates[plot_axis_indices[0]], - new_axes_coordinates[plot_axis_indices[1]], data) - ax.add_image(im_ax) - # Set the limits, labels, etc. of the axes. - xlim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[0]][0], - new_axes_coordinates[plot_axis_indices[0]][-1])) - ax.set_xlim(xlim) - ylim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[1]][0], - new_axes_coordinates[plot_axis_indices[1]][-1])) - ax.set_ylim(ylim) - xlabel = kwargs.pop("xlabel", default_labels[plot_axis_indices[0]]) - ylabel = kwargs.pop("ylabel", default_labels[plot_axis_indices[1]]) - ax.set_xlabel(xlabel) - ax.set_ylabel(ylabel) - return ax + axes = wcsaxes_compat.gca_wcs(self.wcs) + + # Plot data + axes.imshow(data, **kwargs) + + # Set axis labels + x_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[0], + self.missing_axes) + + axes.coords[x_wcs_axis].set_axislabel("{} [{}]".format( + self.world_axis_physical_types[plot_axis_indices[0]], + self.wcs.wcs.cunit[x_wcs_axis])) + + y_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[1], + self.missing_axes) + + axes.coords[y_wcs_axis].set_axislabel("{} [{}]".format( + self.world_axis_physical_types[plot_axis_indices[1]], + self.wcs.wcs.cunit[y_wcs_axis])) + + else: + # Else manually set axes x and y values based on user's input for axes_coordinates. + new_axes_coordinates, new_axis_units, default_labels = \ + self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape) + # Initialize axes object and set values along axis. + if axes is None: + axes = plt.gca() + # Since we can't assume the x-axis will be uniform, create NonUniformImage + # axes and add it to the axes object. + if plot_axis_indices[0] < plot_axis_indices[1]: + data = data.transpose() + im_ax = mpl.image.NonUniformImage( + axes, extent=(new_axes_coordinates[plot_axis_indices[0]][0], + new_axes_coordinates[plot_axis_indices[0]][-1], + new_axes_coordinates[plot_axis_indices[1]][0], + new_axes_coordinates[plot_axis_indices[1]][-1]), **kwargs) + im_ax.set_data(new_axes_coordinates[plot_axis_indices[0]], + new_axes_coordinates[plot_axis_indices[1]], data) + axes.add_image(im_ax) + # Set the limits, labels, etc. of the axes. + xlim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[0]][0], + new_axes_coordinates[plot_axis_indices[0]][-1])) + axes.set_xlim(xlim) + ylim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[1]][0], + new_axes_coordinates[plot_axis_indices[1]][-1])) + axes.set_ylim(ylim) + + xlabel = kwargs.pop("xlabel", default_labels[plot_axis_indices[0]]) + ylabel = kwargs.pop("ylabel", default_labels[plot_axis_indices[1]]) + axes.set_xlabel(xlabel) + axes.set_ylabel(ylabel) + + return axes def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None, axes_units=None, data_unit=None, **kwargs): diff --git a/ndcube/tests/test_plotting.py b/ndcube/tests/test_plotting.py index 8f5e14272..b92be27af 100644 --- a/ndcube/tests/test_plotting.py +++ b/ndcube/tests/test_plotting.py @@ -4,7 +4,9 @@ import numpy as np import astropy.units as u +from astropy.visualization.wcsaxes import WCSAxes import matplotlib +import matplotlib.pyplot as plt try: from sunpy.visualization.animator import ImageAnimatorWCS, LineAnimator except ImportError: @@ -30,6 +32,14 @@ 'CTYPE3': 'HPLN-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.4, 'CRPIX3': 2, 'CRVAL3': 1, 'NAXIS3': 2} wm = WCS(header=hm, naxis=3) +spatial = { + 'CTYPE1': 'HPLT-TAN', 'CUNIT1': 'deg', 'CDELT1': 0.5, 'CRPIX1': 2, 'CRVAL1': 0.5, + 'NAXIS1': 3, + 'CTYPE2': 'HPLN-TAN', 'CUNIT2': 'deg', 'CDELT2': 0.4, 'CRPIX2': 2, 'CRVAL2': 1, + 'NAXIS2': 2 +} +spatial = WCS(header=spatial, naxis=2) + data = np.array([[[1, 2, 3, 4], [2, 4, 5, 3], [0, -1, 2, 3]], [[2, 4, 5, 1], [10, 5, 2, 2], [10, 3, 3, 0]]]) @@ -51,6 +61,10 @@ ('array coord', 2, np.arange(100, 100 + data.shape[2])) ]) +cube_spatial = NDCube( + data[0], + spatial) + cube_unit = NDCube( data, wt, @@ -207,7 +221,13 @@ def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error): @pytest.mark.parametrize("test_input, test_kwargs, expected_values", [ (cube[0], {}, (np.ma.masked_array(cube[0].data, cube[0].mask), "time [min]", "em.wl [m]", - (-0.5, 3.5, 2.5, -0.5))), + (-0.5, 3.5, -0.5, 2.5))), + + (cube_spatial, {'axes': WCSAxes(plt.figure(), (0, 0, 1, 1), wcs=cube_spatial.wcs)}, + (cube_spatial.data, + "custom:pos.helioprojective.lat [deg]", + "custom:pos.helioprojective.lon [deg]", + (-0.5, 3.5, -0.5, 2.5))), (cube[0], {"axes_coordinates": ["bye", None], "axes_units": [None, u.cm]}, (np.ma.masked_array(cube[0].data, cube[0].mask), "bye [m]", "em.wl [cm]", @@ -229,6 +249,7 @@ def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error): "em.wl [m]", "bye [m]", (2e-11, 6e-11, 0.0, 3.0))) ]) def test_cube_plot_2D(test_input, test_kwargs, expected_values): + fig = plt.figure() # Unpack expected properties. expected_data, expected_xlabel, expected_ylabel, expected_extent = \ expected_values @@ -237,8 +258,12 @@ def test_cube_plot_2D(test_input, test_kwargs, expected_values): # Check plot properties are correct. assert isinstance(output, matplotlib.axes.Axes) np.testing.assert_array_equal(output.images[0].get_array(), expected_data) - assert output.axes.xaxis.get_label_text() == expected_xlabel - assert output.axes.yaxis.get_label_text() == expected_ylabel + if isinstance(output, WCSAxes): + assert output.coords[0].get_axislabel() == expected_xlabel + assert output.coords[1].get_axislabel() == expected_ylabel + else: + assert output.axes.yaxis.get_label_text() == expected_ylabel + assert output.axes.xaxis.get_label_text() == expected_xlabel assert np.allclose(output.images[0].get_extent(), expected_extent)