Skip to content

Commit

Permalink
Merge pull request #204 from Cadair/wcsaxes-2d
Browse files Browse the repository at this point in the history
Fix explicit axes argument for 2D plots
  • Loading branch information
Cadair authored Sep 10, 2019
2 parents e6c9b83 + f856415 commit d8f4c27
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 54 deletions.
1 change: 1 addition & 0 deletions changelog/204.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix the ability to pass a custom Axes to `ndcube.NDCube.plot` for a 2D cube.
113 changes: 62 additions & 51 deletions ndcube/mixins/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import astropy.units as u
from astropy.visualization.wcsaxes import WCSAxes
import sunpy.visualization.wcsaxes_compat as wcsaxes_compat
try:
from sunpy.visualization.animator import ImageAnimator, ImageAnimatorWCS, LineAnimator
Expand Down Expand Up @@ -80,6 +81,7 @@ def plot(self, axes=None, plot_axis_indices=None, axes_coordinates=None,
axes_units, data_unit, **kwargs)
else:
if len(plot_axis_indices) == 1:

ax = self._animate_cube_1D(
plot_axis_index=plot_axis_indices[0], axes_coordinates=axes_coordinates,
axes_units=axes_units, data_unit=data_unit, **kwargs)
Expand Down Expand Up @@ -202,12 +204,12 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None
data = (self.data * self.unit).to(data_unit).value
# Combine data with mask
data = np.ma.masked_array(data, self.mask)
if axes is None:
try:
axes_coord_check = axes_coordinates == [None, None]
except Exception:
axes_coord_check = False
if axes_coord_check:
try:
axes_coord_check = axes_coordinates == [None, None]
except Exception:
axes_coord_check = False
if axes_coord_check and (isinstance(axes, WCSAxes) or axes is None):
if axes is None:
# Build slice list for WCS for initializing WCSAxes object.
if self.wcs.naxis != 2:
slice_list = []
Expand All @@ -220,52 +222,61 @@ def _plot_2D_cube(self, axes=None, plot_axis_indices=None, axes_coordinates=None
slice_list.append(1)
if index != 2:
raise ValueError("Dimensions of WCS and data don't match")
ax = wcsaxes_compat.gca_wcs(self.wcs, slices=tuple(slice_list))
axes = wcsaxes_compat.gca_wcs(self.wcs, slices=tuple(slice_list))
else:
ax = wcsaxes_compat.gca_wcs(self.wcs)
# Set axis labels
x_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[0],
self.missing_axes)
ax.set_xlabel("{} [{}]".format(
self.world_axis_physical_types[plot_axis_indices[0]],
self.wcs.wcs.cunit[x_wcs_axis]))
y_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[1],
self.missing_axes)
ax.set_ylabel("{} [{}]".format(
self.world_axis_physical_types[plot_axis_indices[1]],
self.wcs.wcs.cunit[y_wcs_axis]))
# Plot data
ax.imshow(data, **kwargs)
else:
# Else manually set axes x and y values based on user's input for axes_coordinates.
new_axes_coordinates, new_axis_units, default_labels = \
self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape)
# Initialize axes object and set values along axis.
fig, ax = plt.subplots(1, 1)
# Since we can't assume the x-axis will be uniform, create NonUniformImage
# axes and add it to the axes object.
if plot_axis_indices[0] < plot_axis_indices[1]:
data = data.transpose()
im_ax = mpl.image.NonUniformImage(
ax, extent=(new_axes_coordinates[plot_axis_indices[0]][0],
new_axes_coordinates[plot_axis_indices[0]][-1],
new_axes_coordinates[plot_axis_indices[1]][0],
new_axes_coordinates[plot_axis_indices[1]][-1]), **kwargs)
im_ax.set_data(new_axes_coordinates[plot_axis_indices[0]],
new_axes_coordinates[plot_axis_indices[1]], data)
ax.add_image(im_ax)
# Set the limits, labels, etc. of the axes.
xlim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[0]][0],
new_axes_coordinates[plot_axis_indices[0]][-1]))
ax.set_xlim(xlim)
ylim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[1]][0],
new_axes_coordinates[plot_axis_indices[1]][-1]))
ax.set_ylim(ylim)
xlabel = kwargs.pop("xlabel", default_labels[plot_axis_indices[0]])
ylabel = kwargs.pop("ylabel", default_labels[plot_axis_indices[1]])
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
return ax
axes = wcsaxes_compat.gca_wcs(self.wcs)

# Plot data
axes.imshow(data, **kwargs)

# Set axis labels
x_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[0],
self.missing_axes)

axes.coords[x_wcs_axis].set_axislabel("{} [{}]".format(
self.world_axis_physical_types[plot_axis_indices[0]],
self.wcs.wcs.cunit[x_wcs_axis]))

y_wcs_axis = utils.cube.data_axis_to_wcs_axis(plot_axis_indices[1],
self.missing_axes)

axes.coords[y_wcs_axis].set_axislabel("{} [{}]".format(
self.world_axis_physical_types[plot_axis_indices[1]],
self.wcs.wcs.cunit[y_wcs_axis]))

else:
# Else manually set axes x and y values based on user's input for axes_coordinates.
new_axes_coordinates, new_axis_units, default_labels = \
self._derive_axes_coordinates(axes_coordinates, axes_units, data.shape)
# Initialize axes object and set values along axis.
if axes is None:
axes = plt.gca()
# Since we can't assume the x-axis will be uniform, create NonUniformImage
# axes and add it to the axes object.
if plot_axis_indices[0] < plot_axis_indices[1]:
data = data.transpose()
im_ax = mpl.image.NonUniformImage(
axes, extent=(new_axes_coordinates[plot_axis_indices[0]][0],
new_axes_coordinates[plot_axis_indices[0]][-1],
new_axes_coordinates[plot_axis_indices[1]][0],
new_axes_coordinates[plot_axis_indices[1]][-1]), **kwargs)
im_ax.set_data(new_axes_coordinates[plot_axis_indices[0]],
new_axes_coordinates[plot_axis_indices[1]], data)
axes.add_image(im_ax)
# Set the limits, labels, etc. of the axes.
xlim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[0]][0],
new_axes_coordinates[plot_axis_indices[0]][-1]))
axes.set_xlim(xlim)
ylim = kwargs.pop("xlim", (new_axes_coordinates[plot_axis_indices[1]][0],
new_axes_coordinates[plot_axis_indices[1]][-1]))
axes.set_ylim(ylim)

xlabel = kwargs.pop("xlabel", default_labels[plot_axis_indices[0]])
ylabel = kwargs.pop("ylabel", default_labels[plot_axis_indices[1]])
axes.set_xlabel(xlabel)
axes.set_ylabel(ylabel)

return axes

def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None,
axes_units=None, data_unit=None, **kwargs):
Expand Down
31 changes: 28 additions & 3 deletions ndcube/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

import numpy as np
import astropy.units as u
from astropy.visualization.wcsaxes import WCSAxes
import matplotlib
import matplotlib.pyplot as plt
try:
from sunpy.visualization.animator import ImageAnimatorWCS, LineAnimator
except ImportError:
Expand All @@ -30,6 +32,14 @@
'CTYPE3': 'HPLN-TAN', 'CUNIT3': 'deg', 'CDELT3': 0.4, 'CRPIX3': 2, 'CRVAL3': 1, 'NAXIS3': 2}
wm = WCS(header=hm, naxis=3)

spatial = {
'CTYPE1': 'HPLT-TAN', 'CUNIT1': 'deg', 'CDELT1': 0.5, 'CRPIX1': 2, 'CRVAL1': 0.5,
'NAXIS1': 3,
'CTYPE2': 'HPLN-TAN', 'CUNIT2': 'deg', 'CDELT2': 0.4, 'CRPIX2': 2, 'CRVAL2': 1,
'NAXIS2': 2
}
spatial = WCS(header=spatial, naxis=2)

data = np.array([[[1, 2, 3, 4], [2, 4, 5, 3], [0, -1, 2, 3]],
[[2, 4, 5, 1], [10, 5, 2, 2], [10, 3, 3, 0]]])

Expand All @@ -51,6 +61,10 @@
('array coord', 2, np.arange(100, 100 + data.shape[2]))
])

cube_spatial = NDCube(
data[0],
spatial)

cube_unit = NDCube(
data,
wt,
Expand Down Expand Up @@ -207,7 +221,13 @@ def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error):
@pytest.mark.parametrize("test_input, test_kwargs, expected_values", [
(cube[0], {},
(np.ma.masked_array(cube[0].data, cube[0].mask), "time [min]", "em.wl [m]",
(-0.5, 3.5, 2.5, -0.5))),
(-0.5, 3.5, -0.5, 2.5))),
(cube_spatial, {'axes': WCSAxes(plt.figure(), (0, 0, 1, 1), wcs=cube_spatial.wcs)},
(cube_spatial.data,
"custom:pos.helioprojective.lat [deg]",
"custom:pos.helioprojective.lon [deg]",
(-0.5, 3.5, -0.5, 2.5))),
(cube[0], {"axes_coordinates": ["bye", None], "axes_units": [None, u.cm]},
(np.ma.masked_array(cube[0].data, cube[0].mask), "bye [m]", "em.wl [cm]",
Expand All @@ -229,6 +249,7 @@ def test_cube_plot_1D_errors(test_input, test_kwargs, expected_error):
"em.wl [m]", "bye [m]", (2e-11, 6e-11, 0.0, 3.0)))
])
def test_cube_plot_2D(test_input, test_kwargs, expected_values):
fig = plt.figure()
# Unpack expected properties.
expected_data, expected_xlabel, expected_ylabel, expected_extent = \
expected_values
Expand All @@ -237,8 +258,12 @@ def test_cube_plot_2D(test_input, test_kwargs, expected_values):
# Check plot properties are correct.
assert isinstance(output, matplotlib.axes.Axes)
np.testing.assert_array_equal(output.images[0].get_array(), expected_data)
assert output.axes.xaxis.get_label_text() == expected_xlabel
assert output.axes.yaxis.get_label_text() == expected_ylabel
if isinstance(output, WCSAxes):
assert output.coords[0].get_axislabel() == expected_xlabel
assert output.coords[1].get_axislabel() == expected_ylabel
else:
assert output.axes.yaxis.get_label_text() == expected_ylabel
assert output.axes.xaxis.get_label_text() == expected_xlabel
assert np.allclose(output.images[0].get_extent(), expected_extent)


Expand Down

0 comments on commit d8f4c27

Please sign in to comment.