Skip to content

Commit

Permalink
Properly support user-provided norm. (#2443)
Browse files Browse the repository at this point in the history
* Properly support user-provided norm.

Fixes #2381

* remove top level mpl import.

* More accurate error message.

* whats-new fixes.
  • Loading branch information
dcherian authored Oct 8, 2018
1 parent cf1e6c7 commit 5f09deb
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 21 deletions.
13 changes: 8 additions & 5 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ Breaking changes

Documentation
~~~~~~~~~~~~~

Enhancements
~~~~~~~~~~~~

- Added support for Python 3.7. (:issue:`2271`).
By `Joe Hamman <https://github.com/jhamman>`_.

- Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a
CFTimeIndex by a specified frequency. (:issue:`2244`). By `Spencer Clark
<https://github.com/spencerkclark>`_.
CFTimeIndex by a specified frequency. (:issue:`2244`).
By `Spencer Clark <https://github.com/spencerkclark>`_.
- Added support for using ``cftime.datetime`` coordinates with
:py:meth:`~xarray.DataArray.differentiate`,
:py:meth:`~xarray.Dataset.differentiate`,
Expand All @@ -60,11 +60,14 @@ Bug fixes
~~~~~~~~~

- Addition and subtraction operators used with a CFTimeIndex now preserve the
index's type. (:issue:`2244`). By `Spencer Clark <https://github.com/spencerkclark>`_.
index's type. (:issue:`2244`).
By `Spencer Clark <https://github.com/spencerkclark>`_.
- ``xarray.DataArray.roll`` correctly handles multidimensional arrays.
(:issue:`2445`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.

- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override
the norm's ``vmin`` and ``vmax``. (:issue:`2381`)
By `Deepak Cherian <https://github.com/dcherian>`_.
- ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument.
(:issue:`2240`)
By `Keisuke Fujii <https://github.com/fujiisoup>`_.
Expand Down
12 changes: 7 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,9 @@ def _plot2d(plotfunc):
Adds colorbar to axis
add_labels : Boolean, optional
Use xarray metadata to label axes
norm : ``matplotlib.colors.Normalize`` instance, optional
If the ``norm`` has vmin or vmax specified, the corresponding kwarg
must be None.
vmin, vmax : floats, optional
Values to anchor the colormap, otherwise they are inferred from the
data and other keyword arguments. When a diverging dataset is inferred,
Expand Down Expand Up @@ -630,7 +633,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
levels=None, infer_intervals=None, colors=None,
subplot_kws=None, cbar_ax=None, cbar_kwargs=None,
xscale=None, yscale=None, xticks=None, yticks=None,
xlim=None, ylim=None, **kwargs):
xlim=None, ylim=None, norm=None, **kwargs):
# All 2d plots in xarray share this function signature.
# Method signature below should be consistent.

Expand Down Expand Up @@ -727,6 +730,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
'extend': extend,
'levels': levels,
'filled': plotfunc.__name__ != 'contour',
'norm': norm,
}

cmap_params = _determine_cmap_params(**cmap_kwargs)
Expand All @@ -746,9 +750,6 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
if 'pcolormesh' == plotfunc.__name__:
kwargs['infer_intervals'] = infer_intervals

# This allows the user to pass in a custom norm coming via kwargs
kwargs.setdefault('norm', cmap_params['norm'])

if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring):
# forbid usage of mpl strings
raise ValueError("plt.imshow's `aspect` kwarg is not available "
Expand All @@ -758,6 +759,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None,
primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'],
vmin=cmap_params['vmin'],
vmax=cmap_params['vmax'],
norm=cmap_params['norm'],
**kwargs)

# Label the plot with metadata
Expand Down Expand Up @@ -809,7 +811,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None,
levels=None, infer_intervals=None, subplot_kws=None,
cbar_ax=None, cbar_kwargs=None,
xscale=None, yscale=None, xticks=None, yticks=None,
xlim=None, ylim=None, **kwargs):
xlim=None, ylim=None, norm=None, **kwargs):
"""
The method should have the same signature as the function.
Expand Down
33 changes: 30 additions & 3 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
# vlim might be computed below
vlim = None

# save state; needed later
vmin_was_none = vmin is None
vmax_was_none = vmax is None

if vmin is None:
if robust:
vmin = np.percentile(calc_data, ROBUST_PERCENTILE)
Expand Down Expand Up @@ -204,6 +208,28 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
vmin += center
vmax += center

# now check norm and harmonize with vmin, vmax
if norm is not None:
if norm.vmin is None:
norm.vmin = vmin
else:
if not vmin_was_none and vmin != norm.vmin:
raise ValueError('Cannot supply vmin and a norm'
+ ' with a different vmin.')
vmin = norm.vmin

if norm.vmax is None:
norm.vmax = vmax
else:
if not vmax_was_none and vmax != norm.vmax:
raise ValueError('Cannot supply vmax and a norm'
+ ' with a different vmax.')
vmax = norm.vmax

# if BoundaryNorm, then set levels
if isinstance(norm, mpl.colors.BoundaryNorm):
levels = norm.boundaries

# Choose default colormaps if not provided
if cmap is None:
if divergent:
Expand All @@ -212,7 +238,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
cmap = OPTIONS['cmap_sequential']

# Handle discrete levels
if levels is not None:
if levels is not None and norm is None:
if is_scalar(levels):
if user_minmax:
levels = np.linspace(vmin, vmax, levels)
Expand All @@ -227,8 +253,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None,
if extend is None:
extend = _determine_extend(calc_data, vmin, vmax)

if levels is not None:
cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled)
if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm):
cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled)
norm = newnorm if norm is None else norm

return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend,
levels=levels, norm=norm)
Expand Down
50 changes: 42 additions & 8 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,26 @@ def test_divergentcontrol(self):
assert cmap_params['vmax'] == 0.6
assert cmap_params['cmap'] == "viridis"

def test_norm_sets_vmin_vmax(self):
vmin = self.data.min()
vmax = self.data.max()

for norm, extend in zip([mpl.colors.LogNorm(),
mpl.colors.LogNorm(vmin + 1, vmax - 1),
mpl.colors.LogNorm(None, vmax - 1),
mpl.colors.LogNorm(vmin + 1, None)],
['neither', 'both', 'max', 'min']):

test_min = vmin if norm.vmin is None else norm.vmin
test_max = vmax if norm.vmax is None else norm.vmax

cmap_params = _determine_cmap_params(self.data, norm=norm)

assert cmap_params['vmin'] == test_min
assert cmap_params['vmax'] == test_max
assert cmap_params['extend'] == extend
assert cmap_params['norm'] == norm


@requires_matplotlib
class TestDiscreteColorMap(object):
Expand Down Expand Up @@ -665,10 +685,10 @@ def test_build_discrete_cmap(self):

@pytest.mark.slow
def test_discrete_colormap_list_of_levels(self):
for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both',
[2, 5, 10, 11]),
('neither', [0, 5, 10, 15]), ('min',
[2, 5, 10, 15])]:
for extend, levels in [('max', [-1, 2, 4, 8, 10]),
('both', [2, 5, 10, 11]),
('neither', [0, 5, 10, 15]),
('min', [2, 5, 10, 15])]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(levels=levels)
assert_array_equal(levels, primitive.norm.boundaries)
Expand All @@ -682,10 +702,10 @@ def test_discrete_colormap_list_of_levels(self):

@pytest.mark.slow
def test_discrete_colormap_int_levels(self):
for extend, levels, vmin, vmax in [('neither', 7, None,
None), ('neither', 7, None, 20),
('both', 7, 4, 8), ('min', 10, 4,
15)]:
for extend, levels, vmin, vmax in [('neither', 7, None, None),
('neither', 7, None, 20),
('both', 7, 4, 8),
('min', 10, 4, 15)]:
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']:
primitive = getattr(self.darray.plot, kind)(
levels=levels, vmin=vmin, vmax=vmax)
Expand All @@ -711,6 +731,11 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self):
assert primitive.norm.vmax == max(levels)
assert primitive.norm.vmin == min(levels)

def test_discrete_colormap_provided_boundary_norm(self):
norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4)
primitive = self.darray.plot.contourf(norm=norm)
np.testing.assert_allclose(primitive.levels, norm.boundaries)


class Common2dMixin(object):
"""
Expand Down Expand Up @@ -1085,6 +1110,15 @@ def test_cmap_and_color_both(self):
with pytest.raises(ValueError):
self.plotmethod(colors='k', cmap='RdBu')

def test_colormap_error_norm_and_vmin_vmax(self):
norm = mpl.colors.LogNorm(0.1, 1e1)

with pytest.raises(ValueError):
self.darray.plot(norm=norm, vmin=2)

with pytest.raises(ValueError):
self.darray.plot(norm=norm, vmax=2)


@pytest.mark.slow
class TestContourf(Common2dMixin, PlotTestCase):
Expand Down

0 comments on commit 5f09deb

Please sign in to comment.