Skip to content

Commit

Permalink
Add case to _derive_axes_coordinates() for when axis coordinates are …
Browse files Browse the repository at this point in the history
…datetimes and removed kwargs not accepted currently when calling ImageAnimator.
  • Loading branch information
DanRyanIrish committed Mar 31, 2018
1 parent 1ba5b1a commit a48106e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
12 changes: 6 additions & 6 deletions ndcube/mixins/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,13 +329,7 @@ def _plot_3D_cube(self, plot_axis_indices=None, axes_coordinates=None,
new_axes_coordinates, new_axes_units, default_labels = \
self._derive_axes_coordinates(axes_coordinates, axes_units)
# If axis labels not set by user add to kwargs.
if "xlabel" not in kwargs:
kwargs["xlabel"] = default_labels[plot_axis_indices[0]]
if "ylabel" not in kwargs:
kwargs["ylabel"] = default_labels[plot_axis_indices[1]]
ax = ImageAnimator(data, image_axes=plot_axis_indices,
unit_x_axis=new_axes_units[plot_axis_indices[0]],
unit_y_axis=new_axes_units[plot_axis_indices[1]],
axis_ranges=new_axes_coordinates, **kwargs)
return ax

Expand Down Expand Up @@ -367,6 +361,12 @@ def _derive_axes_coordinates(self, axes_coordinates, axes_units):
else:
new_axis_unit = axes_units[i]
new_axis_coordinate = new_axis_coordinate.to(new_axis_unit).value
elif isinstance(new_axis_coordinate[0], datetime.datetime):
axis_label_text = "{0}/sec since {1}".format(
axis_label_text, new_axis_coordinate[0])
new_axis_coordinate = np.array([(t-new_axis_coordinate[0]).total_seconds()
for t in new_axis_coordinate])
new_axis_unit = u.s
else:
if axes_units[i] is None:
new_axis_unit = None
Expand Down
6 changes: 4 additions & 2 deletions ndcube/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
('bye', 2, u.Quantity(range(data.shape[2]), unit=u.m)),
('another time', 2, np.array(
[datetime.datetime(2000, 1, 1)+datetime.timedelta(minutes=i)
for i in range(data.shape[2])]))
for i in range(data.shape[2])])),
('array coord', 2, np.arange(100, 100+data.shape[2]))
])

cube_unit = NDCube(
Expand Down Expand Up @@ -222,7 +223,7 @@ def test_cube_plot_2D(test_input, test_kwargs, expected_values):


@pytest.mark.parametrize("test_input, test_kwargs, expected_error", [
(cube[0], {"axes_coordinates": ["another time", None], "axes_units": [u.cm, None]}, TypeError),
(cube[0], {"axes_coordinates": ["array coord", None], "axes_units": [u.cm, None]}, TypeError),
(cube[0], {"axes_coordinates": [np.arange(10, 10+cube[0].data.shape[1]), None],
"axes_units": [u.cm, None]}, TypeError),
(cube[0], {"data_unit": u.cm}, TypeError)
Expand All @@ -231,6 +232,7 @@ def test_cube_plot_2D_errors(test_input, test_kwargs, expected_error):
with pytest.raises(expected_error):
output = test_input.plot(**test_kwargs)


@pytest.mark.parametrize("test_input, test_kwargs, expected_values", [
(cubem, {},
(cubem.data, [np.array([0., 2.]), [0, 3], [0, 4]], "", ""))
Expand Down

0 comments on commit a48106e

Please sign in to comment.