Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests for NDCube.plot and API Update #103

Merged
merged 8 commits into from
Mar 29, 2018
205 changes: 134 additions & 71 deletions ndcube/mixins/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from warnings import warn

import numpy as np
import matplotlib.pyplot as plt

Expand All @@ -13,8 +15,8 @@ class NDCubePlotMixin:
Add plotting functionality to a NDCube class.
"""

def plot(self, axes=None, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=None,
axis_ranges=None, unit=None, origin=0, **kwargs):
def plot(self, axes=None, plot_axis_indices=[-1, -2], axes_coordinates=None,
axes_units=None, data_unit=None, origin=0, **kwargs):
"""
Plots an interactive visualization of this cube with a slider
controlling the wavelength axis for data having dimensions greater than 2.
Expand All @@ -25,24 +27,20 @@ def plot(self, axes=None, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=Non

Parameters
----------
image_axes: `list`
plot_axis_indices: `list`
The two axes that make the image.
Like [-1,-2] this implies cube instance -1 dimension
will be x-axis and -2 dimension will be y-axis.

axes: `astropy.visualization.wcsaxes.core.WCSAxes` or None:
The axes to plot onto. If None the current axes will be used.

unit_x_axis: `astropy.units.Unit`
The unit of x axis for 2D plots.

unit_y_axis: `astropy.units.Unit`
The unit of y axis for 2D plots.
axes_unit: `list` of `astropy.units.Unit`

unit: `astropy.unit.Unit`
data_unit: `astropy.unit.Unit`
The data is changed to the unit given or the cube.unit if not given, for 1D plots.

axis_ranges: list of physical coordinates for array or None
axes_coordinates: list of physical coordinates for array or None
If None array indices will be used for all axes.
If a list it should contain one element for each axis of the numpy array.
For the image axes a [min, max] pair should be specified which will be
Expand All @@ -52,56 +50,43 @@ def plot(self, axes=None, image_axes=[-1, -2], unit_x_axis=None, unit_y_axis=Non
If None is specified for an axis then the array indices will be used
for that axis.
"""
# If old API is used, convert to new API.
plot_axis_indices, axes_coordiantes, axes_units, data_unit, kwargs = _support_101_plot_API(
plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs)
axis_data = ['x' for i in range(2)]
axis_data[image_axes[1]] = 'y'
if self.data.ndim >= 3:
plot = self._plot_3D_cube(image_axes=image_axes, unit_x_axis=unit_x_axis,
unit_y_axis=unit_y_axis, axis_ranges=axis_ranges, **kwargs)
axis_data[plot_axis_indices[1]] = 'y'
if self.data.ndim is 1:
plot = self._plot_1D_cube(data_unit=data_unit, origin=origin)
elif self.data.ndim is 2:
plot = self._plot_2D_cube(axes=axes, image_axes=axis_data[::-1], **kwargs)
elif self.data.ndim is 1:
plot = self._plot_1D_cube(unit=unit, origin=origin)
plot = self._plot_2D_cube(axes=axes, plot_axis_indices=axis_data[::-1], **kwargs)
else:
plot = self._plot_3D_cube(plot_axis_indices=plot_axis_indices,
axes_coordinates=axes_coordinates, axes_units=axes_units,
**kwargs)
return plot

def _plot_3D_cube(self, image_axes=None, unit_x_axis=None, unit_y_axis=None,
axis_ranges=None, **kwargs):
def _plot_1D_cube(self, data_unit=None, origin=0):
"""
Plots an interactive visualization of this cube using sliders to move through axes
plot using in the image.
Parameters other than data and wcs are passed to ImageAnimatorWCS, which in turn
passes them to imshow.
Plots a graph.
Keyword arguments are passed on to matplotlib.

Parameters
----------
image_axes: `list`
The two axes that make the image.
Like [-1,-2] this implies cube instance -1 dimension
will be x-axis and -2 dimension will be y-axis.

unit_x_axis: `astropy.units.Unit`
The unit of x axis.

unit_y_axis: `astropy.units.Unit`
The unit of y axis.

axis_ranges: `list` of physical coordinates for array or None
If None array indices will be used for all axes.
If a list it should contain one element for each axis of the numpy array.
For the image axes a [min, max] pair should be specified which will be
passed to :func:`matplotlib.pyplot.imshow` as extent.
For the slider axes a [min, max] pair can be specified or an array the
same length as the axis which will provide all values for that slider.
If None is specified for an axis then the array indices will be used
for that axis.
data_unit: `astropy.unit.Unit`
The data is changed to the unit given or the cube.unit if not given.
"""
if not image_axes:
image_axes = [-1, -2]
i = ImageAnimatorWCS(self.data, wcs=self.wcs, image_axes=image_axes,
unit_x_axis=unit_x_axis, unit_y_axis=unit_y_axis,
axis_ranges=axis_ranges, **kwargs)
return i
index_not_one = []
for i, _bool in enumerate(self.missing_axis):
if not _bool:
index_not_one.append(i)
if data_unit is None:
data_unit = self.wcs.wcs.cunit[index_not_one[0]]
plot = plt.plot(self.pixel_to_world(*[u.Quantity(np.arange(self.data.shape[0]),
unit=u.pix)])[0].to(data_unit),
self.data)
return plot

def _plot_2D_cube(self, axes=None, image_axes=None, **kwargs):
def _plot_2D_cube(self, axes=None, plot_axis_indices=None, **kwargs):
"""
Plots a 2D image onto the current
axes. Keyword arguments are passed on to matplotlib.
Expand All @@ -111,21 +96,21 @@ def _plot_2D_cube(self, axes=None, image_axes=None, **kwargs):
axes: `astropy.visualization.wcsaxes.core.WCSAxes` or `None`:
The axes to plot onto. If None the current axes will be used.

image_axes: `list`.
The first axis in WCS object will become the first axis of image_axes and
second axis in WCS object will become the second axis of image_axes.
plot_axis_indices: `list`.
The first axis in WCS object will become the first axis of plot_axis_indices and
second axis in WCS object will become the second axis of plot_axis_indices.
Default: ['x', 'y']
"""
if not image_axes:
image_axes = ['x', 'y']
if not plot_axis_indices:
plot_axis_indices = ['x', 'y']
if axes is None:
if self.wcs.naxis is not 2:
missing_axis = self.missing_axis
slice_list = []
index = 0
for i, bool_ in enumerate(missing_axis):
if not bool_:
slice_list.append(image_axes[index])
slice_list.append(plot_axis_indices[index])
index += 1
else:
slice_list.append(1)
Expand All @@ -135,23 +120,101 @@ def _plot_2D_cube(self, axes=None, image_axes=None, **kwargs):
plot = axes.imshow(self.data, **kwargs)
return plot

def _plot_1D_cube(self, unit=None, origin=0):
def _plot_3D_cube(self, plot_axis_indices=None, axes_units=None,
axes_coordinates=None, **kwargs):
"""
Plots a graph.
Keyword arguments are passed on to matplotlib.
Plots an interactive visualization of this cube using sliders to move through axes
plot using in the image.
Parameters other than data and wcs are passed to ImageAnimatorWCS, which in turn
passes them to imshow.

Parameters
----------
unit: `astropy.unit.Unit`
The data is changed to the unit given or the cube.unit if not given.
plot_axis_indices: `list`
The two axes that make the image.
Like [-1,-2] this implies cube instance -1 dimension
will be x-axis and -2 dimension will be y-axis.

axes_unit: `list` of `astropy.units.Unit`

axes_coordinates: `list` of physical coordinates for array or None
If None array indices will be used for all axes.
If a list it should contain one element for each axis of the numpy array.
For the image axes a [min, max] pair should be specified which will be
passed to :func:`matplotlib.pyplot.imshow` as extent.
For the slider axes a [min, max] pair can be specified or an array the
same length as the axis which will provide all values for that slider.
If None is specified for an axis then the array indices will be used
for that axis.
"""
index_not_one = []
for i, _bool in enumerate(self.missing_axis):
if not _bool:
index_not_one.append(i)
if unit is None:
unit = self.wcs.wcs.cunit[index_not_one[0]]
plot = plt.plot(self.pixel_to_world(*[u.Quantity(np.arange(self.data.shape[0]),
unit=u.pix)])[0].to(unit),
self.data)
return plot
if plot_axis_indices is None:
plot_axis_indices = [-1, -2]
if axes_units is None:
axes_units = [None, None]
i = ImageAnimatorWCS(self.data, wcs=self.wcs, image_axes=plot_axis_indices,
unit_x_axis=axes_units[0], unit_y_axis=axes_units[1],
axis_ranges=axes_coordinates, **kwargs)
return i


def _support_101_plot_API(plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs):
"""Check if user has used old API and convert it to new API."""
# Get old API variable values.
image_axes = kwargs.pop("image_axes", None)
axis_ranges = kwargs.pop("axis_ranges", None)
unit_x_axis = kwargs.pop("unit_x_axis", None)
unit_y_axis = kwargs.pop("unit_y_axis", None)
unit = kwargs.pop("unit", None)
# Check if conflicting new and old API values have been set.
# If not, set new API using old API and raise deprecation warning.
if image_axes is not None:
variable_names = ("image_axes", "plot_axis_indices")
_raise_101_API_deprecation_warning(*variable_names)
if plot_axis_indices is None:
plot_axis_indices = image_axes
else:
_raise_API_error(*variable_names)
if axis_ranges is not None:
variable_names = ("axis_ranges", "axes_coordinates")
_raise_101_API_deprecation_warning(*variable_names)
if axes_coordinates is None:
axes_coordinates = axis_ranges
else:
_raise_API_error(*variable_names)
if (unit_x_axis is not None or unit_y_axis is not None) and axes_units is not None:
_raise_API_error("unit_x_axis and/or unit_y_axis", "axes_units")
if axes_units is None:
variable_names = ("unit_x_axis and unit_y_axis", "axes_units")
if unit_x_axis is not None:
_raise_101_API_deprecation_warning(*variable_names)
if len(plot_axis_indices) == 1:
axes_units = unit_x_axis
elif len(plot_axis_indices) == 2:
if unit_y_axis is None:
axes_units = [unit_x_axis, None]
else:
axes_units = [unit_x_axis, unit_y_axis]
else:
raise ValueError("Length of image_axes must be less than 3.")
else:
if unit_y_axis is not None:
_raise_101_API_deprecation_warning(*variable_names)
axes_units = [None, unit_y_axis]
if unit is not None:
variable_names = ("unit", "data_unit")
_raise_101_API_deprecation_warning(*variable_names)
if data_unit is None:
data_unit = unit
else:
_raise_API_error(*variable_names)
# Return values of new API
return plot_axis_indices, axes_coordinates, axes_units, data_unit, kwargs


def _raise_API_error(old_name, new_name):
raise ValueError(
"Conflicting inputs: {0} (old API) cannot be set if {1} (new API) is set".format(
old_name, new_name))

def _raise_101_API_deprecation_warning(old_name, new_name):
warn("{0} is deprecated and will not be supported in version 2.0. It will be replaced by {1}. See docstring.".format(old_name, new_name), DeprecationWarning)
Loading