diff --git a/ndcube/mixins/sequence_plotting.py b/ndcube/mixins/sequence_plotting.py index 73bb9c61c..4615469d5 100644 --- a/ndcube/mixins/sequence_plotting.py +++ b/ndcube/mixins/sequence_plotting.py @@ -34,9 +34,18 @@ def plot(self, cubesequence, axes=None, plot_as_cube=False, plot_as_line=False, dimension. Default=False - plot_axis_indices: `int` or iterable of one or two `int` - Default is images_axes=[-1, -2]. If sequence only has one dimension, - images_axes is forced to be 0. + plot_axis_indices: `int` or iterable of one or two `int`. + If two axis indices are given, the sequence is visualized as an image or + 2D animation, assuming the sequence has at least 2 dimensions. + (N.B. If plot_as_cube is True, the number of sequence dimensions is effectively + reduced by 1.) The dimension indicated by the 0th index is displayed on the + x-axis while the dimension indicated by the 1st index is displayed on the y-axis. + If only one axis index is given (either as an int or a list of one int), + then a 1D line animation is produced with the indicated dimension on the x-axis + and other dimensions represented by animations sliders. + Default=[-1, -2]. If sequence only has one dimension (or effectively one if + plot_as_cube is True), plot_axis_indices is ignored and a staice 1D line plot + is produced. axes_coordinates: `list` of physical coordinates for image axes and sliders or `None` If None coordinates derived from the WCS objects will be used for all axes. @@ -53,6 +62,14 @@ def plot(self, cubesequence, axes=None, plot_as_cube=False, plot_as_line=False, visualization is a 2D image or animation, i.e. if image_axis has length 2. """ + # If plot_axis_indices, axes_coordinates, axes_units are not None and not lists, + # convert to lists for consistent indexing behaviour. + if (not isinstance(plot_axis_indices, list)) and (plot_axis_indices is not None): + plot_axis_indices = [plot_axis_indices] + if (not isinstance(axes_coordinates, list)) and (axes_coordinates is not None): + axes_coordinates = [axes_coordinates] + if (not isinstance(axes_units, list)) and (axes_units is not None): + axes_units = [axes_units] # Ensure length of kwargs is consistent with dimensionality of sequence # and setting of plot_as_cube. naxis = len(cubesequence.dimensions) @@ -62,7 +79,7 @@ def plot(self, cubesequence, axes=None, plot_as_cube=False, plot_as_line=False, axes_units) # Make 1D line plot. ax = self._plot_1D_sequence( - cubesequence, axes=axes, x_axis_coordinates=x_axis_coordinates, + cubesequence, x_axis_coordinates=x_axis_coordinates, unit_x_axis=unit_x_axis, data_unit=data_unit, **kwargs) elif naxis == 2: if plot_as_cube: @@ -71,33 +88,41 @@ def plot(self, cubesequence, axes=None, plot_as_cube=False, plot_as_line=False, x_axis_coordinates, unit_x_axis = _derive_1D_coordinates_and_units( axes_coordinates, axes_units) ax = self._plot_2D_sequence_as_1Dline( - cubesequence, axes=axes, x_axis_coordinates=x_axis_coordinates, + cubesequence, x_axis_coordinates=x_axis_coordinates, unit_x_axis=unit_x_axis, data_unit=data_unit, **kwargs) else: - ax = self._plot_2D_sequence( - cubesequence, axes=None, plot_axis_indices=plot_axis_indices, - axes_coordinates=axes_coordinates, axes_units=axes_units, data_unit=data_unit, - **kwargs) + if plot_axis_indices is None: + plot_axis_indices = [-1, -2] + if len(plot_axis_indices) == 2: + ax = self._plot_2D_sequence( + cubesequence, plot_axis_indices=plot_axis_indices, + axes_coordinates=axes_coordinates, axes_units=axes_units, + data_unit=data_unit, **kwargs) + elif len(plot_axis_indices) == 1: + raise NotImplementedError() else: if plot_axis_indices is None: plot_axis_indices = [-1, -2] - if not plot_as_cube: - if axes_units is None: - axes_units = [None] * naxis - ax = ImageAnimatorNDCubeSequence( - cubesequence, image_axes=plot_axis_indices, - axis_ranges=axes_coordinates, unit_x_axis=axes_units[plot_axis_indices[0]], - unit_y_axis=axes_units[plot_axis_indices[1]], **kwargs) + if len(plot_axis_indices) == 2: + if not plot_as_cube: + if axes_units is None: + axes_units = [None] * naxis + ax = ImageAnimatorNDCubeSequence( + cubesequence, image_axes=plot_axis_indices, + axis_ranges=axes_coordinates, unit_x_axis=axes_units[plot_axis_indices[0]], + unit_y_axis=axes_units[plot_axis_indices[1]], **kwargs) + else: + if axes_units is None: + axes_units = [None] * (naxis-1) + ax = ImageAnimatorCommonAxisNDCubeSequence( + cubesequence, image_axes=plot_axis_indices, + axis_ranges=axes_coordinates, unit_x_axis=axes_units[plot_axis_indices[0]], + unit_y_axis=axes_units[plot_axis_indices[1]], **kwargs) else: - if axes_units is None: - axes_units = [None] * (naxis-1) - ax = ImageAnimatorCommonAxisNDCubeSequence( - cubesequence, axes=axes, image_axes=plot_axis_indices, - axis_ranges=axes_coordinates, unit_x_axis=axes_units[plot_axis_indices[0]], - unit_y_axis=axes_units[plot_axis_indices[1]], **kwargs) + raise NotImplementedError() return ax - def _plot_1D_sequence(self, cubesequence, axes=None, x_axis_coordinates=None, + def _plot_1D_sequence(self, cubesequence, x_axis_coordinates=None, unit_x_axis=None, data_unit=None, **kwargs): """ Visualizes an NDCubeSequence of scalar NDCubes as a line plot. @@ -106,9 +131,6 @@ def _plot_1D_sequence(self, cubesequence, axes=None, x_axis_coordinates=None, Parameters ---------- - axes: `astropy.visualization.wcsaxes.core.WCSAxes` or ??? or None. - The axes to plot onto. If None the current axes will be used. - x_axis_coordinates: `numpy.ndarray` or `astropy.unit.Quantity` or `str` or `None` Denotes the physical coordinates of the x-axis. If None, coordinates are derived from the WCS objects. @@ -168,7 +190,7 @@ def _plot_1D_sequence(self, cubesequence, axes=None, x_axis_coordinates=None, fig, ax = _make_1D_sequence_plot(xdata, ydata, yerror, unit_y_axis, default_xlabel, kwargs) return ax - def _plot_2D_sequence_as_1Dline(self, cubesequence, axes=None, x_axis_coordinates=None, + def _plot_2D_sequence_as_1Dline(self, cubesequence, x_axis_coordinates=None, axes_units=None, **kwargs): """ Visualizes an NDCubeSequence of 1D NDCubes with a common axis as a line plot. @@ -234,7 +256,7 @@ def _plot_2D_sequence_as_1Dline(self, cubesequence, axes=None, x_axis_coordinate fig, ax = _make_1D_sequence_plot(xdata, ydata, yerror, unit_y_axis, default_xlabel, kwargs) return ax - def _plot_2D_sequence(self, cubesequence, axes=None, plot_axis_indices=None, axes_coordinates=None, + def _plot_2D_sequence(self, cubesequence, plot_axis_indices=None, axes_coordinates=None, axes_units=None, data_unit=None, **kwargs): """ Visualizes an NDCubeSequence of 1D NDCubes as a 2D image. @@ -342,7 +364,7 @@ def _plot_2D_sequence(self, cubesequence, axes=None, plot_axis_indices=None, axe return ax - def _plot_3D_sequence_as_2Dimage(self, cubesequence, axes=None, plot_axis_indices=None, + def _plot_3D_sequence_as_2Dimage(self, cubesequence, plot_axis_indices=None, axes_coordinates=None, axes_units=None, data_unit=None, **kwargs): """ @@ -764,28 +786,27 @@ def _get_all_cube_units(sequence_data): def _check_kwargs_dimensions(naxis, plot_as_cube, plot_axis_indices, axes_coordinates, axes_units): - kws = [plot_axis_indices, axes_coordinates, axes_units] - kwarg_types = [(int), (u.Quantity, np.ndarray, str), (u.Unit, str)] - kwarg_names = ["plot_axis_indices", "axes_coordinates", "axes_unit"] if plot_as_cube is True: naxis = naxis - 1 - if naxis == 1: - kw_lens = [naxis] * len(kws) - for i in range(len(kws)): - if (not isinstance(kws[i], list)) and (kws[i] is not None): - kws[i] = [kws[i]] - else: - kw_lens = [2] + [naxis] * (len(kws) - 1) - for i in range(len(kws)): - if kws[i] is not None: - _check_single_kwarg_dimensions(kws[i], kw_lens[i], kwarg_types[i], kwarg_names[i]) - - -def _check_single_kwarg_dimensions(kw, kw_len, kw_types, str_kw): - if len(kw) != kw_len: - raise ValueError("length of {0} must be {1}.".format(str_kw, kw_len)) - if not isinstance(kw[0], kw_types): - raise TypeError("{0} must be one of ({1}) or list of one of ({1}).".format(str_kw, kw_types)) + if (plot_axis_indices is not None) and (naxis > 1): + if len(plot_axis_indices) not in [1, 2]: + raise ValueError("plot_axis_indices can have at most length 2.") + if axes_coordinates is not None: + if len(axes_coordinates) != naxis: + raise ValueError("length of axes_coordinates must be {0}.".format(naxis)) + ax_coord_types = (u.Quantity, np.ndarray, str) + for axis_coordinate in axes_coordinates: + if not instance(axis_coordinate, ax_coord_types): + raise TypeError("axes_coordinates must be one of {0} or list of {0}.".format( + ax_coord_types)) + if axes_units is not None: + if len(axes_units) != naxis: + raise ValueError("length of axes_units must be {0}.".format(naxis)) + ax_unit_types = (u.Unit, str) + for axis_coordinate in axes_units: + if not instance(axis_coordinate, ax_coord_types): + raise TypeError("axes_units must be one of {0} or list of {0}.".format( + ax_coord_types)) def _derive_1D_coordinates_and_units(axes_coordinates, axes_units):