Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update unyt_arrayConverter to handle sequences of unyt types. #126

Merged
merged 14 commits into from
Jan 22, 2020
Merged
8 changes: 4 additions & 4 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1186,16 +1186,16 @@ Plotting with Matplotlib
++++++++++++++++++++++++
.. note::
- This is an experimental feature. Please report issues.
- The context manager ``MplUnitsCM`` will temporarily enable this feature
- The context manager ``matplotlib_support`` will temporarily enable this feature

Matplotlib is Unyt aware. With no additional effort, Matplotlib will label the x and y
axes with the units.

>>> import matplotlib.pyplot as plt
>>> from unyt import s, K, MplUnitsCM
>>> from unyt import s, K, matplotlib_support
>>> x = [0.0, 0.01, 0.02]*s
>>> y = [298.15, 308.15, 318.15]*K
>>> with MplUnitsCM():
>>> with matplotlib_support:
... plt.plot(x, y)
... plt.show()
[<matplotlib.lines.Line2D object at ...>]
Expand All @@ -1204,7 +1204,7 @@ axes with the units.

You can change the plotted units without affecting the original data.

>>> with MplUnitsCM():
>>> with matplotlib_support:
... plt.plot(x, y, xunits="ms", yunits=("J", "thermal"))
... plt.show()
[<matplotlib.lines.Line2D object at ...>]
Expand Down
4 changes: 3 additions & 1 deletion unyt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@
from unyt.unit_systems import UnitSystem # NOQA: F401
from unyt.testing import assert_allclose_units # NOQA: F401
from unyt.dimensions import accepts, returns # NOQA: F401
from unyt.mpl_interface import MplUnitsCM # NOQA: F401
from unyt.mpl_interface import matplotlib_support # NOQA: F401
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved

matplotlib_support = matplotlib_support()
ngoldbaum marked this conversation as resolved.
Show resolved Hide resolved


# function to only import quantities into this namespace
Expand Down
59 changes: 55 additions & 4 deletions unyt/mpl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
class unyt_arrayConverter(ConversionInterface):
"""Matplotlib interface for unyt_array"""

_labelstyle = "()"

@staticmethod
def axisinfo(unit, axis):
"""Set the axis label based on unit
Expand Down Expand Up @@ -55,7 +57,18 @@ def axisinfo(unit, axis):
label = ""
else:
unit_str = unit_obj.latex_representation()
label = "$\\left(" + unit_str + "\\right)$"
if unyt_arrayConverter._labelstyle == "[]":
label = "$\\left[" + unit_str + "\\right]$"
elif unyt_arrayConverter._labelstyle == "/":
axsym = axis.axis_name
if "/" in unit_str:
label = (
"$q_{" + axsym + "}\\;/\\;\\left(" + unit_str + "\\right)$"
)
else:
label = "$q_{" + axsym + "}\\;/\\;" + unit_str + "$"
else:
label = "$\\left(" + unit_str + "\\right)$"
return AxisInfo(label=label)

@staticmethod
Expand Down Expand Up @@ -113,11 +126,49 @@ def convert(value, unit, axis):
converted_value.append(obj.to(*unit))
converted_value = value_type(converted_value)
else:
raise ConversionError(f"unable to convert {value}")
raise ConversionError("unable to convert {%s}".format(value))
return converted_value

class MplUnitsCM:
"""Context manager for experimenting with Unyt in Matplotlib"""
class matplotlib_support:
"""Context manager for experimenting with Unyt in Matplotlib

Parameters
----------

label_style : string, one from the set {'()', '[]', '/'}
The axis label style.
'()' -> '(unit)'
'[]' -> '[unit]'
'/' -> 'q / unit'
SI standard where label is a mathematical expression.
'q' is a generic quantity symbol and for a value on the axis, x,
the equation is q = x * unit.
"""

def __init__(self, label_style="()"):
self._labelstyle = label_style
unyt_arrayConverter._labelstyle = label_style

def __call__(self):
self.__enter__()

@property
def label_style(self):
"""label_style : string, one from the set {'()', '[]', '/'}
The axis label style.
'()' -> '(unit)'
'[]' -> '[unit]'
'/' -> 'q / unit'
SI standard where label is a mathematical expression.
'q' is a generic quantity symbol and for a value on the axis, x,
the equation is q = x * unit.
"""
return self._labelstyle

@label_style.setter
def label_style(self, label_style="()"):
self._labelstyle = label_style
unyt_arrayConverter._labelstyle = label_style

def __enter__(self):
registry[unyt_array] = unyt_arrayConverter()
Expand Down
42 changes: 33 additions & 9 deletions unyt/tests/test_mpl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import numpy as np
import pytest
from unyt._on_demand_imports import _matplotlib, NotAModule
from unyt import s, K, unyt_array, unyt_quantity
from unyt import s, K, unyt_array, unyt_quantity, matplotlib_support
l-johnston marked this conversation as resolved.
Show resolved Hide resolved
from unyt.exceptions import UnitConversionError
from unyt.mpl_interface import MplUnitsCM, unyt_arrayConverter
from unyt.mpl_interface import unyt_arrayConverter


check_matplotlib = pytest.mark.skipif(
Expand All @@ -14,12 +14,11 @@

@pytest.fixture
def ax(scope="module"):
mpl_units = MplUnitsCM()
mpl_units.enable()
matplotlib_support.enable()
fig, ax = _matplotlib.pyplot.subplots()
yield ax
_matplotlib.pyplot.close()
mpl_units.disable()
matplotlib_support.disable()


@check_matplotlib
Expand Down Expand Up @@ -146,12 +145,37 @@ def test_hist(ax):


@check_matplotlib
def test_MplUnitsCM():
mplunits = MplUnitsCM()
def test_matplotlib_support():
with pytest.raises(KeyError):
_matplotlib.units.registry[unyt_array]
mplunits.enable()
matplotlib_support.enable()
assert isinstance(_matplotlib.units.registry[unyt_array], unyt_arrayConverter)
mplunits.disable()
matplotlib_support.disable()
assert unyt_array not in _matplotlib.units.registry.keys()
assert unyt_quantity not in _matplotlib.units.registry.keys()
# test as a callable
matplotlib_support()
assert isinstance(_matplotlib.units.registry[unyt_array], unyt_arrayConverter)


@check_matplotlib
def test_labelstyle():
x = [0, 1, 2] * s
y = [3, 4, 5] * K
matplotlib_support.label_style = "[]"
matplotlib_support.enable()
assert unyt_arrayConverter._labelstyle == "[]"
fig, ax = _matplotlib.pyplot.subplots()
ax.plot(x, y)
expected_xlabel = "$\\left[\\rm{s}\\right]$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
expected_ylabel = "$\\left[\\rm{K}\\right]$"
assert ax.yaxis.get_label().get_text() == expected_ylabel
matplotlib_support.label_style = "/"
ax.clear()
ax.plot(x, y)
expected_xlabel = "$q_{x}\\;/\\;\\rm{s}$"
assert ax.xaxis.get_label().get_text() == expected_xlabel
expected_ylabel = "$q_{y}\\;/\\;\\rm{K}$"
assert ax.yaxis.get_label().get_text() == expected_ylabel
_matplotlib.pyplot.close()