Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update unyt_arrayConverter to handle sequences of unyt types. #126

Merged
merged 14 commits into from
Jan 22, 2020
Merged
Binary file added docs/_static/mpl_fig3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 2 additions & 3 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
numpy
sympy
astropy
pint
sphinx==1.7.9
matplotlib
sphinx
57 changes: 46 additions & 11 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1184,29 +1184,64 @@ calculation involves many operations on arrays with only a few elements.

Plotting with Matplotlib
++++++++++++++++++++++++
.. note::
- This is an experimental feature. Please report issues.
- This feature works in Matplotlib versions 2.2.4 and above
- Matplotlib is not a dependency of Unyt

Matplotlib is Unyt aware. With no additional effort, Matplotlib will label the x and y
axes with the units.
Matplotlib is Unyt aware. After enabling support in :mod:`unyt` using the
:class:`unyt.matplotlib_support <unyt.mpl_interface.matplotlib_support>` context
manager, Matplotlib will label the x and y axes with the units.

>>> import matplotlib.pyplot as plt
>>> from unyt import s, K
>>> from unyt import s, K, matplotlib_support
>>> x = [0.0, 0.01, 0.02]*s
>>> y = [298.15, 308.15, 318.15]*K
>>> plt.plot(x, y)
>>> with matplotlib_support:
... plt.plot(x, y)
... plt.show()
[<matplotlib.lines.Line2D object at ...>]
>>> plt.show()

.. image:: _static/mpl_fig1.png

You can change the plotted units without affecting the original data.

>>> plt.plot(x, y, xunits="ms", yunits=("J", "thermal"))
>>> with matplotlib_support:
... plt.plot(x, y, xunits="ms", yunits=("J", "thermal"))
... plt.show()
[<matplotlib.lines.Line2D object at ...>]
>>> plt.show()

.. image:: _static/mpl_fig2.png

.. note::

- This feature works in Matplotlib versions 2.2.4 and above
- Matplotlib is not a dependency of Unyt
It is also possible to set the label style, the choices ``"()"``, ``"[]"`` and
``"/"`` are supported.

>>> import matplotlib.pyplot as plt
>>> from unyt import s, K, matplotlib_support
>>> matplotlib_support.label_style = "[]"
>>> with matplotlib_support:
... plt.plot([0, 1, 2]*s, [3, 4, 5]*K)
... plt.show()
[<matplotlib.lines.Line2D object at ...>]

.. image:: _static/mpl_fig3.png

There are three ways to use the context manager:

1. As a conventional context manager in a ``with`` statement as shown above

2. As a feature toggle in an interactive session:

>>> import matplotlib.pyplot as plt
>>> from unyt import s, K, matplotlib_support
>>> matplotlib_support.enable()
>>> plt.plot([0, 1, 2]*s, [3, 4, 5]*K)
[<matplotlib.lines.Line2D object at ...>]
>>> plt.show()
>>> matplotlib_support.disable()

3. As an enable for a complete session:

>>> import unyt
>>> unyt.matplotlib_support()
>>> import matplotlib.pyplot as plt
8 changes: 7 additions & 1 deletion unyt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@
from unyt.unit_systems import UnitSystem # NOQA: F401
from unyt.testing import assert_allclose_units # NOQA: F401
from unyt.dimensions import accepts, returns # NOQA: F401
from unyt import mpl_interface # NOQA: F401

try:
from unyt.mpl_interface import matplotlib_support # NOQA: F401
except ImportError:
pass
else:
matplotlib_support = matplotlib_support()


# function to only import quantities into this namespace
Expand Down
11 changes: 11 additions & 0 deletions unyt/_on_demand_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class matplotlib_imports(object):
_name = "matplotlib"
_pyplot = None
_units = None
_use = None

@property
def __version__(self):
Expand Down Expand Up @@ -163,5 +164,15 @@ def units(self):
self._units = units
return self._units

@property
def use(self):
if self._use is None:
try:
from matplotlib import use
except ImportError:
use = NotAModule(self._name)
self._use = use
return self._use


_matplotlib = matplotlib_imports()
90 changes: 79 additions & 11 deletions unyt/mpl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@


try:
from matplotlib.units import (
ConversionInterface,
AxisInfo,
registry,
)
from matplotlib.units import ConversionInterface, AxisInfo, registry
except ImportError:
pass
else:
Expand All @@ -27,6 +23,8 @@
class unyt_arrayConverter(ConversionInterface):
"""Matplotlib interface for unyt_array"""

_labelstyle = "()"

@staticmethod
def axisinfo(unit, axis):
"""Set the axis label based on unit
Expand Down Expand Up @@ -54,7 +52,18 @@ def axisinfo(unit, axis):
label = ""
else:
unit_str = unit_obj.latex_representation()
label = "$\\left(" + unit_str + "\\right)$"
if unyt_arrayConverter._labelstyle == "[]":
label = "$\\left[" + unit_str + "\\right]$"
elif unyt_arrayConverter._labelstyle == "/":
axsym = axis.axis_name
if "/" in unit_str:
label = (
"$q_{" + axsym + "}\\;/\\;\\left(" + unit_str + "\\right)$"
)
else:
label = "$q_{" + axsym + "}\\;/\\;" + unit_str + "$"
else:
label = "$\\left(" + unit_str + "\\right)$"
return AxisInfo(label=label)

@staticmethod
Expand All @@ -81,7 +90,7 @@ def convert(value, unit, axis):
Parameters
----------

value : unyt_array
value : unyt_array, unyt_quantity, or sequence there of
unit : Unit, string or tuple
This parameter comes from unyt_arrayConverter.default_units() or from
user code such as Axes.plot(), Axis.set_units(), etc. In user code, it
Expand All @@ -98,11 +107,70 @@ def convert(value, unit, axis):
Raises
------

ConversionError if unit does not have the same dimensions as value
UnitConversionError if unit does not have the same dimensions as value or
if we don't know how to convert value.
"""
converted_value = value
if isinstance(unit, str) or isinstance(unit, Unit):
unit = (unit,)
return value.to(*unit)
if isinstance(value, (unyt_array, unyt_quantity)):
converted_value = value.to(*unit)
else:
value_type = type(value)
converted_value = []
for obj in value:
converted_value.append(obj.to(*unit))
converted_value = value_type(converted_value)
return converted_value

class matplotlib_support:
"""Context manager for setting up integration with Unyt in Matplotlib

Parameters
----------

label_style : str
One of the following set, ``{'()', '[]', '/'}``. These choices
correspond to the following unit labels:

* ``'()'`` -> ``'(unit)'``
* ``'[]'`` -> ``'[unit]'``
* ``'/'`` -> ``'q_x / unit'``
"""

def __init__(self, label_style="()"):
self._labelstyle = label_style
unyt_arrayConverter._labelstyle = label_style

def __call__(self):
self.__enter__()

@property
def label_style(self):
"""str: One of the following set, ``{'()', '[]', '/'}``.
These choices correspond to the following unit labels:

* ``'()'`` -> ``'(unit)'``
* ``'[]'`` -> ``'[unit]'``
* ``'/'`` -> ``'q_x / unit'``
"""
return self._labelstyle

@label_style.setter
def label_style(self, label_style="()"):
self._labelstyle = label_style
unyt_arrayConverter._labelstyle = label_style

def __enter__(self):
registry[unyt_array] = unyt_arrayConverter()
registry[unyt_quantity] = unyt_arrayConverter()

def __exit__(self, exc_type, exc_val, exc_tb):
registry.pop(unyt_array)
registry.pop(unyt_quantity)

def enable(self):
self.__enter__()

registry[unyt_array] = unyt_arrayConverter()
registry[unyt_quantity] = unyt_arrayConverter()
def disable(self):
self.__exit__(None, None, None)
95 changes: 93 additions & 2 deletions unyt/tests/test_mpl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,41 @@
import numpy as np
import pytest
from unyt._on_demand_imports import _matplotlib, NotAModule
from unyt import s, K
from unyt import m, s, K, unyt_array, unyt_quantity
from unyt.exceptions import UnitConversionError

try:
from unyt import matplotlib_support
from unyt.mpl_interface import unyt_arrayConverter
except ImportError:
pass

check_matplotlib = pytest.mark.skipif(
isinstance(_matplotlib.pyplot, NotAModule), reason="matplotlib not installed"
)


@pytest.fixture
def ax(scope="module"):
def ax():
_matplotlib.use("agg")
matplotlib_support.enable()
fig, ax = _matplotlib.pyplot.subplots()
yield ax
_matplotlib.pyplot.close()
matplotlib_support.disable()


@check_matplotlib
def test_label(ax):
x = [0, 1, 2] * s
y = [3, 4, 5] * K
matplotlib_support.label_style = "()"
ax.plot(x, y)
expected_xlabel = "$\\left(\\rm{s}\\right)$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
expected_ylabel = "$\\left(\\rm{K}\\right)$"
assert ax.yaxis.get_label().get_text() == expected_ylabel
_matplotlib.pyplot.close()


@check_matplotlib
Expand Down Expand Up @@ -85,6 +95,7 @@ def test_conversionerror(ax):
def test_ndarray_label(ax):
x = [0, 1, 2] * s
y = np.arange(3, 6)
matplotlib_support.label_style = "()"
ax.plot(x, y)
expected_xlabel = "$\\left(\\rm{s}\\right)$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
Expand All @@ -96,8 +107,88 @@ def test_ndarray_label(ax):
def test_list_label(ax):
x = [0, 1, 2] * s
y = [3, 4, 5]
matplotlib_support.label_style = "()"
ax.plot(x, y)
expected_xlabel = "$\\left(\\rm{s}\\right)$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
expected_ylabel = ""
assert ax.yaxis.get_label().get_text() == expected_ylabel


@check_matplotlib
def test_errorbar(ax):
x = unyt_array([8, 9, 10], "cm")
y = unyt_array([8, 9, 10], "kg")
y_scatter = [unyt_array([0.1, 0.2, 0.3], "kg"), unyt_array([0.1, 0.2, 0.3], "kg")]
x_lims = (unyt_quantity(5, "cm"), unyt_quantity(12, "cm"))
y_lims = (unyt_quantity(5, "kg"), unyt_quantity(12, "kg"))

ax.errorbar(x, y, yerr=y_scatter)
x_lims = (unyt_quantity(5, "cm"), unyt_quantity(12, "cm"))
y_lims = (unyt_quantity(5, "kg"), unyt_quantity(12, "kg"))
ax.set_xlim(*x_lims)
ax.set_ylim(*y_lims)


@check_matplotlib
def test_hist2d(ax):
x = np.random.normal(size=50000) * s
y = 3 * x + np.random.normal(size=50000) * s
ax.hist2d(x, y, bins=(50, 50))


@check_matplotlib
def test_imshow(ax):
data = np.reshape(np.random.normal(size=10000), (100, 100))
ax.imshow(data, vmin=data.min(), vmax=data.max())


@check_matplotlib
def test_hist(ax):
data = np.random.normal(size=10000) * s
bin_edges = np.linspace(data.min(), data.max(), 50)
ax.hist(data, bins=bin_edges)


@check_matplotlib
def test_matplotlib_support():
with pytest.raises(KeyError):
_matplotlib.units.registry[unyt_array]
matplotlib_support.enable()
assert isinstance(_matplotlib.units.registry[unyt_array], unyt_arrayConverter)
matplotlib_support.disable()
assert unyt_array not in _matplotlib.units.registry.keys()
assert unyt_quantity not in _matplotlib.units.registry.keys()
# test as a callable
matplotlib_support()
assert isinstance(_matplotlib.units.registry[unyt_array], unyt_arrayConverter)


@check_matplotlib
def test_labelstyle():
x = [0, 1, 2] * s
y = [3, 4, 5] * K
matplotlib_support.label_style = "[]"
assert matplotlib_support.label_style == "[]"
matplotlib_support.enable()
assert unyt_arrayConverter._labelstyle == "[]"
fig, ax = _matplotlib.pyplot.subplots()
ax.plot(x, y)
expected_xlabel = "$\\left[\\rm{s}\\right]$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
expected_ylabel = "$\\left[\\rm{K}\\right]$"
assert ax.yaxis.get_label().get_text() == expected_ylabel
matplotlib_support.label_style = "/"
ax.clear()
ax.plot(x, y)
expected_xlabel = "$q_{x}\\;/\\;\\rm{s}$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
expected_ylabel = "$q_{y}\\;/\\;\\rm{K}$"
assert ax.yaxis.get_label().get_text() == expected_ylabel
x = [0, 1, 2] * m / s
ax.clear()
ax.plot(x, y)
expected_xlabel = "$q_{x}\\;/\\;\\left(\\rm{m} / \\rm{s}\\right)$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
_matplotlib.pyplot.close()
matplotlib_support.disable()