Skip to content

Commit

Permalink
Fill in plot() function of NDCubeSequencePlotMixin and some minor kwa…
Browse files Browse the repository at this point in the history
…rg API changes.
  • Loading branch information
DanRyanIrish committed Mar 22, 2018
1 parent 4a53405 commit 6835bcd
Showing 1 changed file with 119 additions and 41 deletions.
160 changes: 119 additions & 41 deletions ndcube/mixins/sequence_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
"All sequence sub-cubes' unit attribute are not compatible with unit_y_axis set by user."

class NDCubeSequencePlotMixin:
def plot(self, axes=None, plot_as_cube=False, plot_axes=None, axes_coordinates=None,
axes_units=None, data_unit=None, **kwargs):
def plot(self, cubesequence, axes=None, plot_as_cube=False, plot_as_line=False,
plot_axis_indices=None, axes_coordinates=None, axes_units=None, data_unit=None,
**kwargs):
"""
Visualizes data in the NDCubeSequence.
Expand All @@ -33,7 +34,7 @@ def plot(self, axes=None, plot_as_cube=False, plot_axes=None, axes_coordinates=N
dimension.
Default=False
plot_axes: `int` or iterable of one or two `int`
plot_axis_indices: `int` or iterable of one or two `int`
Default is images_axes=[-1, -2]. If sequence only has one dimension,
images_axes is forced to be 0.
Expand All @@ -52,10 +53,52 @@ def plot(self, axes=None, plot_as_cube=False, plot_axes=None, axes_coordinates=N
visualization is a 2D image or animation, i.e. if image_axis has length 2.
"""
raise NotImplementedError()
# Ensure length of kwargs is consistent with dimensionality of sequence
# and setting of plot_as_cube.
naxis = len(cubesequence.dimensions)
_check_kwargs_dimensions(naxis, plot_as_cube, plot_axis_indices, axes_coordinates, axes_units)
if naxis == 1:
x_axis_coordinates, unit_x_axis = _derive_1D_coordinates_and_units(axes_coordinates,
axes_units)
# Make 1D line plot.
ax = self._plot_1D_sequence(
cubesequence, axes=axes, x_axis_coordinates=x_axis_coordinates,
unit_x_axis=unit_x_axis, data_unit=data_unit, **kwargs)
elif naxis == 2:
if plot_as_cube:
if cubesequence._common_axis is None:
raise TypeError("Common axis must be set to plot sequence as cube.")
x_axis_coordinates, unit_x_axis = _derive_1D_coordinates_and_units(
axes_coordinates, axes_units)
ax = self._plot_2D_sequence_as_1Dline(
cubesequence, axes=axes, x_axis_coordinates=x_axis_coordinates,
unit_x_axis=unit_x_axis, data_unit=data_unit, **kwargs)
else:
ax = self._plot_2D_sequence(
cubesequence, axes=None, plot_axis_indices=plot_axis_indices,
axes_coordinates=axes_coordinates, axes_units=axes_units, data_unit=data_unit,
**kwargs)
else:
if plot_axis_indices is None:
plot_axis_indices = [-1, -2]
if not plot_as_cube:
if axes_units is None:
axes_units = [None] * naxis
ax = ImageAnimatorNDCubeSequence(
cubesequence, image_axes=plot_axis_indices,
axis_ranges=axes_coordinates, unit_x_axis=axes_units[plot_axis_indices[0]],
unit_y_axis=axes_units[plot_axis_indices[1]], **kwargs)
else:
if axes_units is None:
axes_units = [None] * (naxis-1)
ax = ImageAnimatorCommonAxisNDCubeSequence(
cubesequence, axes=axes, image_axes=plot_axis_indices,
axis_ranges=axes_coordinates, unit_x_axis=axes_units[plot_axis_indices[0]],
unit_y_axis=axes_units[plot_axis_indices[1]], **kwargs)
return ax

def _plot_1D_sequence(self, cubesequence, axes=None, x_axis_coordinates=None, axes_units=None,
**kwargs):
def _plot_1D_sequence(self, cubesequence, axes=None, x_axis_coordinates=None,
unit_x_axis=None, data_unit=None, **kwargs):
"""
Visualizes an NDCubeSequence of scalar NDCubes as a line plot.
Expand All @@ -66,7 +109,7 @@ def _plot_1D_sequence(self, cubesequence, axes=None, x_axis_coordinates=None, ax
axes: `astropy.visualization.wcsaxes.core.WCSAxes` or ??? or None.
The axes to plot onto. If None the current axes will be used.
x_axis_values: `numpy.ndarray` or `astropy.unit.Quantity` or `str` or `None`
x_axis_coordinates: `numpy.ndarray` or `astropy.unit.Quantity` or `str` or `None`
Denotes the physical coordinates of the x-axis.
If None, coordinates are derived from the WCS objects.
If an `astropy.units.Quantity` or a `numpy.ndarray` gives the coordinates for
Expand All @@ -79,17 +122,12 @@ def _plot_1D_sequence(self, cubesequence, axes=None, x_axis_coordinates=None, ax
the coordinate denoted by x_axis_range. Not used if x_axis_range is a
`numpy.ndarray` or the designated extra coordinate is a `numpy.ndarray`
unit_y_axis: `astropy.units.unit` or valid unit `str`
data_unit: `astropy.units.unit` or valid unit `str`
The units into which the y-axis should be displayed. The unit attribute of all
the sub-cubes must be compatible to set this kwarg.
"""
if axes_units is None:
unit_x_axis = None
unit_y_axis = None
else:
unit_x_axis = axes_units[0]
unit_y_axis = axes_units[1]
unit_y_axis = data_unit
# Check that the unit attribute is a set in all cubes and derive unit_y_axis if not set.
sequence_units, unit_y_axis = _determine_sequence_units(cubesequence.data, unit_y_axis)
# If all cubes have unit set, create a data quantity from cubes' data.
Expand Down Expand Up @@ -196,7 +234,7 @@ def _plot_2D_sequence_as_1Dline(self, cubesequence, axes=None, x_axis_coordinate
fig, ax = _make_1D_sequence_plot(xdata, ydata, yerror, unit_y_axis, default_xlabel, kwargs)
return ax

def _plot_2D_sequence(self, cubesequence, axes=None, plot_axes=None, axes_coordinates=None,
def _plot_2D_sequence(self, cubesequence, axes=None, plot_axis_indices=None, axes_coordinates=None,
axes_units=None, data_unit=None, **kwargs):
"""
Visualizes an NDCubeSequence of 1D NDCubes as a 2D image.
Expand All @@ -213,10 +251,10 @@ def _plot_2D_sequence(self, cubesequence, axes=None, plot_axes=None, axes_coordi
axes_coordinates = [None, None]
if axes_units is None:
axes_units = [None, None]
if plot_axes is None:
plot_axes = [-1, -2]
# Convert plot_axes to array for function operations.
plot_axes = np.asarray(plot_axes)
if plot_axis_indices is None:
plot_axis_indices = [-1, -2]
# Convert plot_axis_indices to array for function operations.
plot_axis_indices = np.asarray(plot_axis_indices)
# Check that the unit attribute is set of all cubes and derive unit_y_axis if not set.
sequence_units, data_unit = _determine_sequence_units(cubesequence.data, data_unit)
# If all cubes have unit set, create a data quantity from cube's data.
Expand All @@ -226,7 +264,7 @@ def _plot_2D_sequence(self, cubesequence, axes=None, plot_axes=None, axes_coordi
else:
data = np.stack([cube.data for i, cube in enumerate(cubesequence.data)])
# Transpose data if user-defined images_axes require it.
if plot_axes[0] < plot_axes[1]:
if plot_axis_indices[0] < plot_axis_indices[1]:
data = data.transpose()
# Determine index of above axes variables corresponding to cube axis.
cube_axis_index = 1
Expand Down Expand Up @@ -291,20 +329,20 @@ def _plot_2D_sequence(self, cubesequence, axes=None, plot_axes=None, axes_coordi
# Since we can't assume the x-axis will be uniform, create NonUniformImage
# axes and add it to the axes object.
im_ax = mpl.image.NonUniformImage(
ax, extent=(axes_coordinates[plot_axes[0]][0], axes_coordinates[plot_axes[0]][-1],
axes_coordinates[plot_axes[1]][0], axes_coordinates[plot_axes[1]][-1]),
ax, extent=(axes_coordinates[plot_axis_indices[0]][0], axes_coordinates[plot_axis_indices[0]][-1],
axes_coordinates[plot_axis_indices[1]][0], axes_coordinates[plot_axis_indices[1]][-1]),
**kwargs)
im_ax.set_data(axes_coordinates[plot_axes[0]], axes_coordinates[plot_axes[1]], data)
im_ax.set_data(axes_coordinates[plot_axis_indices[0]], axes_coordinates[plot_axis_indices[1]], data)
ax.add_image(im_ax)
# Set the limits, labels, etc. of the axes.
ax.set_xlim((axes_coordinates[plot_axes[0]][0], axes_coordinates[plot_axes[0]][-1]))
ax.set_ylim((axes_coordinates[plot_axes[1]][0], axes_coordinates[plot_axes[1]][-1]))
ax.set_xlabel(axes_labels[plot_axes[0]])
ax.set_ylabel(axes_labels[plot_axes[1]])
ax.set_xlim((axes_coordinates[plot_axis_indices[0]][0], axes_coordinates[plot_axis_indices[0]][-1]))
ax.set_ylim((axes_coordinates[plot_axis_indices[1]][0], axes_coordinates[plot_axis_indices[1]][-1]))
ax.set_xlabel(axes_labels[plot_axis_indices[0]])
ax.set_ylabel(axes_labels[plot_axis_indices[1]])

return ax

def _plot_3D_sequence_as_2Dimage(self, cubesequence, axes=None, plot_axes=None,
def _plot_3D_sequence_as_2Dimage(self, cubesequence, axes=None, plot_axis_indices=None,
axes_coordinates=None, axes_units=None, data_unit=None,
**kwargs):
"""
Expand All @@ -318,10 +356,10 @@ def _plot_3D_sequence_as_2Dimage(self, cubesequence, axes=None, plot_axes=None,
axes_coordinates = [None, None]
if axes_units is None:
axes_units = [None, None]
if plot_axes is None:
plot_axes = [-1, -2]
# Convert plot_axes to array for function operations.
plot_axes = np.asarray(plot_axes)
if plot_axis_indices is None:
plot_axis_indices = [-1, -2]
# Convert plot_axis_indices to array for function operations.
plot_axis_indices = np.asarray(plot_axis_indices)
# Check that the unit attribute is set of all cubes and derive unit_y_axis if not set.
sequence_units, data_unit = _determine_sequence_units(cubesequence.data, data_unit)
# If all cubes have unit set, create a data quantity from cube's data.
Expand All @@ -332,7 +370,7 @@ def _plot_3D_sequence_as_2Dimage(self, cubesequence, axes=None, plot_axes=None,
else:
data = np.concatenate([cube.data for cube in cubesequence.data],
axis=cubesequence._common_axis)
if plot_axes[0] < plot_axes[1]:
if plot_axis_indices[0] < plot_axis_indices[1]:
data = data.transpose()
# Determine index of common axis and other cube axis.
common_axis_index = cubesequence._common_axis
Expand Down Expand Up @@ -406,16 +444,16 @@ def _plot_3D_sequence_as_2Dimage(self, cubesequence, axes=None, plot_axes=None,
# Since we can't assume the x-axis will be uniform, create NonUniformImage
# axes and add it to the axes object.
im_ax = mpl.image.NonUniformImage(
ax, extent=(axes_coordinates[plot_axes[0]][0], axes_coordinates[plot_axes[0]][-1],
axes_coordinates[plot_axes[1]][0], axes_coordinates[plot_axes[1]][-1]),
ax, extent=(axes_coordinates[plot_axis_indices[0]][0], axes_coordinates[plot_axis_indices[0]][-1],
axes_coordinates[plot_axis_indices[1]][0], axes_coordinates[plot_axis_indices[1]][-1]),
**kwargs)
im_ax.set_data(axes_coordinates[plot_axes[0]], axes_coordinates[plot_axes[1]], data)
im_ax.set_data(axes_coordinates[plot_axis_indices[0]], axes_coordinates[plot_axis_indices[1]], data)
ax.add_image(im_ax)
# Set the limits, labels, etc. of the axes.
ax.set_xlim((axes_coordinates[plot_axes[0]][0], axes_coordinates[plot_axes[0]][-1]))
ax.set_ylim((axes_coordinates[plot_axes[1]][0], axes_coordinates[plot_axes[1]][-1]))
ax.set_xlabel(axes_labels[plot_axes[0]])
ax.set_ylabel(axes_labels[plot_axes[1]])
ax.set_xlim((axes_coordinates[plot_axis_indices[0]][0], axes_coordinates[plot_axis_indices[0]][-1]))
ax.set_ylim((axes_coordinates[plot_axis_indices[1]][0], axes_coordinates[plot_axis_indices[1]][-1]))
ax.set_xlabel(axes_labels[plot_axis_indices[0]])
ax.set_ylabel(axes_labels[plot_axis_indices[1]])

return ax

Expand Down Expand Up @@ -691,7 +729,6 @@ def _make_1D_sequence_plot(xdata, ydata, yerror, unit_y_axis, default_xlabel, kw
ylim = kwargs.pop("ylim", None)
# Plot data
fig, ax = plt.subplots(1, 1)
print(xdata.shape, ydata.shape)
ax.errorbar(xdata, ydata, yerror, **kwargs)
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
Expand Down Expand Up @@ -724,3 +761,44 @@ def _get_all_cube_units(sequence_data):
else:
sequence_units.append(cube.unit)
return sequence_units


def _check_kwargs_dimensions(naxis, plot_as_cube, plot_axis_indices, axes_coordinates, axes_units):
kws = [plot_axis_indices, axes_coordinates, axes_units]
kwarg_types = [(int), (u.Quantity, np.ndarray, str), (u.Unit, str)]
kwarg_names = ["plot_axis_indices", "axes_coordinates", "axes_unit"]
if plot_as_cube is True:
naxis = naxis - 1
if naxis == 1:
kw_lens = [naxis] * len(kws)
for i in range(len(kws)):
if (not isinstance(kws[i], list)) and (kws[i] is not None):
kws[i] = [kws[i]]
else:
kw_lens = [2] + [naxis] * (len(kws) - 1)
for i in range(len(kws)):
if kws[i] is not None:
_check_single_kwarg_dimensions(kws[i], kw_lens[i], kwarg_types[i], kwarg_names[i])


def _check_single_kwarg_dimensions(kw, kw_len, kw_types, str_kw):
if len(kw) != kw_len:
raise ValueError("length of {0} must be {1}.".format(str_kw, kw_len))
if not isinstance(kw[0], kw_types):
raise TypeError("{0} must be one of ({1}) or list of one of ({1}).".format(str_kw, kw_types))


def _derive_1D_coordinates_and_units(axes_coordinates, axes_units):
if axes_coordinates is None:
x_axis_coordinates = axes_coordinates
else:
if not isinstance(axes_coordinates, list):
axes_coordinates = [axes_coordinates]
x_axis_coordinates = axes_coordinates[0]
if axes_units is None:
unit_x_axis = axes_units
else:
if not isinstance(axes_units, list):
axes_units = [axes_units]
unit_x_axis = axes_units[0]
return x_axis_coordinates, unit_x_axis

0 comments on commit 6835bcd

Please sign in to comment.