Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly support user-provided norm. #2443

Merged
merged 5 commits into from
Oct 8, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ Enhancements
- Added support for Python 3.7. (:issue:`2271`).
By `Joe Hamman <https://github.com/jhamman>`_.

- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override
the norm's ``vmin`` and ``vmax``.
(:issue:`2381`)
By `Deepak Cherian <https://github.com/dcherian>`_.

Bug fixes
~~~~~~~~~

Expand Down
12 changes: 7 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,9 @@ def _plot2d(plotfunc):
Adds colorbar to axis
add_labels : Boolean, optional
Use xarray metadata to label axes
norm : ``matplotlib.colors.Normalize`` instance, optional
If the ``norm`` has vmin or vmax specified, the corresponding kwarg
must be None.
vmin, vmax : floats, optional
Values to anchor the colormap, otherwise they are inferred from the
data and other keyword arguments. When a diverging dataset is inferred,
Expand Down Expand Up @@ -630,7 +633,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
levels=None, infer_intervals=None, colors=None,
subplot_kws=None, cbar_ax=None, cbar_kwargs=None,
xscale=None, yscale=None, xticks=None, yticks=None,
xlim=None, ylim=None, **kwargs):
xlim=None, ylim=None, norm=None, **kwargs):
# All 2d plots in xarray share this function signature.
# Method signature below should be consistent.

Expand Down Expand Up @@ -727,6 +730,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
'extend': extend,
'levels': levels,
'filled': plotfunc.__name__ != 'contour',
'norm': norm,
}

cmap_params = _determine_cmap_params(**cmap_kwargs)
Expand All @@ -741,9 +745,6 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
if 'pcolormesh' == plotfunc.__name__:
kwargs['infer_intervals'] = infer_intervals

# This allows the user to pass in a custom norm coming via kwargs
kwargs.setdefault('norm', cmap_params['norm'])

if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring):
# forbid usage of mpl strings
raise ValueError("plt.imshow's `aspect` kwarg is not available "
Expand All @@ -753,6 +754,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'],
vmin=cmap_params['vmin'],
vmax=cmap_params['vmax'],
norm=cmap_params['norm'],
**kwargs)

# Label the plot with metadata
Expand Down Expand Up @@ -804,7 +806,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None,
levels=None, infer_intervals=None, subplot_kws=None,
cbar_ax=None, cbar_kwargs=None,
xscale=None, yscale=None, xticks=None, yticks=None,
xlim=None, ylim=None, **kwargs):
xlim=None, ylim=None, norm=None, **kwargs):
"""
The method should have the same signature as the function.

Expand Down
33 changes: 30 additions & 3 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
# vlim might be computed below
vlim = None

# save state; needed later
vmin_was_none = vmin is None
vmax_was_none = vmax is None

if vmin is None:
if robust:
vmin = np.percentile(calc_data, ROBUST_PERCENTILE)
Expand Down Expand Up @@ -205,6 +209,28 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
vmin += center
vmax += center

# now check norm and harmonize with vmin, vmax
if norm is not None:
if norm.vmin is None:
norm.vmin = vmin
else:
if not vmin_was_none and vmin != norm.vmin:
raise ValueError('Cannot supply vmin and a norm'
+ ' with a different vmin.')
dcherian marked this conversation as resolved.
Show resolved Hide resolved
vmin = norm.vmin

if norm.vmax is None:
norm.vmax = vmax
else:
if not vmax_was_none and vmax != norm.vmax:
raise ValueError('Cannot supply vmax and a norm'
+ ' with a different vmax.')
vmax = norm.vmax

# if BoundaryNorm, then set levels
if isinstance(norm, mpl.colors.BoundaryNorm):
levels = norm.boundaries

# Choose default colormaps if not provided
if cmap is None:
if divergent:
Expand All @@ -213,7 +239,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
cmap = OPTIONS['cmap_sequential']

# Handle discrete levels
if levels is not None:
if levels is not None and norm is None:
if is_scalar(levels):
if user_minmax:
levels = np.linspace(vmin, vmax, levels)
Expand All @@ -228,8 +254,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
if extend is None:
extend = _determine_extend(calc_data, vmin, vmax)

if levels is not None:
cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled)
if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm):
cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled)
norm = newnorm if norm is None else norm

return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
levels=levels, norm=norm)
Expand Down
57 changes: 45 additions & 12 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import numpy as np
import pandas as pd
import xarray as xr
import pytest

import xarray as xr
import xarray.plot as xplt
from xarray import DataArray
from xarray.coding.times import _import_cftime
Expand All @@ -17,9 +17,8 @@
import_seaborn, label_from_attrs)

from . import (
TestCase, assert_array_equal, assert_equal, raises_regex,
requires_matplotlib, requires_matplotlib2, requires_seaborn,
requires_cftime)
TestCase, assert_array_equal, assert_equal, raises_regex, requires_cftime,
requires_matplotlib, requires_matplotlib2, requires_seaborn)

# import mpl and change the backend before other mpl imports
try:
Expand Down Expand Up @@ -623,6 +622,26 @@ def test_divergentcontrol(self):
assert cmap_params['vmax'] == 0.6
assert cmap_params['cmap'] == "viridis"

def test_norm_sets_vmin_vmax(self):
vmin = self.data.min()
vmax = self.data.max()

for norm, extend in zip([mpl.colors.LogNorm(),
mpl.colors.LogNorm(vmin + 1, vmax - 1),
mpl.colors.LogNorm(None, vmax - 1),
mpl.colors.LogNorm(vmin + 1, None)],
['neither', 'both', 'max', 'min']):

test_min = vmin if norm.vmin is None else norm.vmin
test_max = vmax if norm.vmax is None else norm.vmax

cmap_params = _determine_cmap_params(self.data, norm=norm)

assert cmap_params['vmin'] == test_min
assert cmap_params['vmax'] == test_max
assert cmap_params['extend'] == extend
assert cmap_params['norm'] == norm


@requires_matplotlib
class TestDiscreteColorMap(TestCase):
Expand Down Expand Up @@ -659,10 +678,10 @@ def test_build_discrete_cmap(self):

@pytest.mark.slow
def test_discrete_colormap_list_of_levels(self):
for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both',
[2, 5, 10, 11]),
('neither', [0, 5, 10, 15]), ('min',
[2, 5, 10, 15])]:
for extend, levels in [('max', [-1, 2, 4, 8, 10]),
('both', [2, 5, 10, 11]),
('neither', [0, 5, 10, 15]),
('min', [2, 5, 10, 15])]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(levels=levels)
assert_array_equal(levels, primitive.norm.boundaries)
Expand All @@ -676,10 +695,10 @@ def test_discrete_colormap_list_of_levels(self):

@pytest.mark.slow
def test_discrete_colormap_int_levels(self):
for extend, levels, vmin, vmax in [('neither', 7, None,
None), ('neither', 7, None, 20),
('both', 7, 4, 8), ('min', 10, 4,
15)]:
for extend, levels, vmin, vmax in [('neither', 7, None, None),
('neither', 7, None, 20),
('both', 7, 4, 8),
('min', 10, 4, 15)]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(
levels=levels, vmin=vmin, vmax=vmax)
Expand All @@ -705,6 +724,11 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self):
assert primitive.norm.vmax == max(levels)
assert primitive.norm.vmin == min(levels)

def test_discrete_colormap_provided_boundary_norm(self):
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
primitive = self.darray.plot.contourf(norm=norm)
np.testing.assert_allclose(primitive.levels, norm.boundaries)


class Common2dMixin:
"""
Expand Down Expand Up @@ -1078,6 +1102,15 @@ def test_cmap_and_color_both(self):
with pytest.raises(ValueError):
self.plotmethod(colors='k', cmap='RdBu')

def test_colormap_error_norm_and_vmin_vmax(self):
norm = mpl.colors.LogNorm(0.1, 1e1)

with pytest.raises(ValueError):
self.darray.plot(norm=norm, vmin=2)

with pytest.raises(ValueError):
self.darray.plot(norm=norm, vmax=2)


@pytest.mark.slow
class TestContourf(Common2dMixin, PlotTestCase):
Expand Down