diff --git a/doc/whats-new.rst b/doc/whats-new.rst index e9c223ff801..85e9f2313d6 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -40,15 +40,15 @@ Breaking changes Documentation ~~~~~~~~~~~~~ + Enhancements ~~~~~~~~~~~~ - Added support for Python 3.7. (:issue:`2271`). By `Joe Hamman `_. - - Added :py:meth:`~xarray.CFTimeIndex.shift` for shifting the values of a - CFTimeIndex by a specified frequency. (:issue:`2244`). By `Spencer Clark - `_. + CFTimeIndex by a specified frequency. (:issue:`2244`). + By `Spencer Clark `_. - Added support for using ``cftime.datetime`` coordinates with :py:meth:`~xarray.DataArray.differentiate`, :py:meth:`~xarray.Dataset.differentiate`, @@ -60,11 +60,14 @@ Bug fixes ~~~~~~~~~ - Addition and subtraction operators used with a CFTimeIndex now preserve the - index's type. (:issue:`2244`). By `Spencer Clark `_. + index's type. (:issue:`2244`). + By `Spencer Clark `_. - ``xarray.DataArray.roll`` correctly handles multidimensional arrays. (:issue:`2445`) By `Keisuke Fujii `_. - +- ``xarray.plot()`` now properly accepts a ``norm`` argument and does not override + the norm's ``vmin`` and ``vmax``. (:issue:`2381`) + By `Deepak Cherian `_. - ``xarray.DataArray.std()`` now correctly accepts ``ddof`` keyword argument. (:issue:`2240`) By `Keisuke Fujii `_. diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 3f9f1090c70..b44ae7b3856 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -562,6 +562,9 @@ def _plot2d(plotfunc): Adds colorbar to axis add_labels : Boolean, optional Use xarray metadata to label axes + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. vmin, vmax : floats, optional Values to anchor the colormap, otherwise they are inferred from the data and other keyword arguments. When a diverging dataset is inferred, @@ -630,7 +633,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, levels=None, infer_intervals=None, colors=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, **kwargs): + xlim=None, ylim=None, norm=None, **kwargs): # All 2d plots in xarray share this function signature. # Method signature below should be consistent. @@ -727,6 +730,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, 'extend': extend, 'levels': levels, 'filled': plotfunc.__name__ != 'contour', + 'norm': norm, } cmap_params = _determine_cmap_params(**cmap_kwargs) @@ -746,9 +750,6 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, if 'pcolormesh' == plotfunc.__name__: kwargs['infer_intervals'] = infer_intervals - # This allows the user to pass in a custom norm coming via kwargs - kwargs.setdefault('norm', cmap_params['norm']) - if 'imshow' == plotfunc.__name__ and isinstance(aspect, basestring): # forbid usage of mpl strings raise ValueError("plt.imshow's `aspect` kwarg is not available " @@ -758,6 +759,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], vmin=cmap_params['vmin'], vmax=cmap_params['vmax'], + norm=cmap_params['norm'], **kwargs) # Label the plot with metadata @@ -809,7 +811,7 @@ def plotmethod(_PlotMethods_obj, x=None, y=None, figsize=None, size=None, levels=None, infer_intervals=None, subplot_kws=None, cbar_ax=None, cbar_kwargs=None, xscale=None, yscale=None, xticks=None, yticks=None, - xlim=None, ylim=None, **kwargs): + xlim=None, ylim=None, norm=None, **kwargs): """ The method should have the same signature as the function. diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index a284c186937..be38a6d7a4c 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -172,6 +172,10 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, # vlim might be computed below vlim = None + # save state; needed later + vmin_was_none = vmin is None + vmax_was_none = vmax is None + if vmin is None: if robust: vmin = np.percentile(calc_data, ROBUST_PERCENTILE) @@ -204,6 +208,28 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, vmin += center vmax += center + # now check norm and harmonize with vmin, vmax + if norm is not None: + if norm.vmin is None: + norm.vmin = vmin + else: + if not vmin_was_none and vmin != norm.vmin: + raise ValueError('Cannot supply vmin and a norm' + + ' with a different vmin.') + vmin = norm.vmin + + if norm.vmax is None: + norm.vmax = vmax + else: + if not vmax_was_none and vmax != norm.vmax: + raise ValueError('Cannot supply vmax and a norm' + + ' with a different vmax.') + vmax = norm.vmax + + # if BoundaryNorm, then set levels + if isinstance(norm, mpl.colors.BoundaryNorm): + levels = norm.boundaries + # Choose default colormaps if not provided if cmap is None: if divergent: @@ -212,7 +238,7 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, cmap = OPTIONS['cmap_sequential'] # Handle discrete levels - if levels is not None: + if levels is not None and norm is None: if is_scalar(levels): if user_minmax: levels = np.linspace(vmin, vmax, levels) @@ -227,8 +253,9 @@ def _determine_cmap_params(plot_data, vmin=None, vmax=None, cmap=None, if extend is None: extend = _determine_extend(calc_data, vmin, vmax) - if levels is not None: - cmap, norm = _build_discrete_cmap(cmap, levels, extend, filled) + if levels is not None or isinstance(norm, mpl.colors.BoundaryNorm): + cmap, newnorm = _build_discrete_cmap(cmap, levels, extend, filled) + norm = newnorm if norm is None else norm return dict(vmin=vmin, vmax=vmax, cmap=cmap, extend=extend, levels=levels, norm=norm) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 01303202c93..53f6077ee4f 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -628,6 +628,26 @@ def test_divergentcontrol(self): assert cmap_params['vmax'] == 0.6 assert cmap_params['cmap'] == "viridis" + def test_norm_sets_vmin_vmax(self): + vmin = self.data.min() + vmax = self.data.max() + + for norm, extend in zip([mpl.colors.LogNorm(), + mpl.colors.LogNorm(vmin + 1, vmax - 1), + mpl.colors.LogNorm(None, vmax - 1), + mpl.colors.LogNorm(vmin + 1, None)], + ['neither', 'both', 'max', 'min']): + + test_min = vmin if norm.vmin is None else norm.vmin + test_max = vmax if norm.vmax is None else norm.vmax + + cmap_params = _determine_cmap_params(self.data, norm=norm) + + assert cmap_params['vmin'] == test_min + assert cmap_params['vmax'] == test_max + assert cmap_params['extend'] == extend + assert cmap_params['norm'] == norm + @requires_matplotlib class TestDiscreteColorMap(object): @@ -665,10 +685,10 @@ def test_build_discrete_cmap(self): @pytest.mark.slow def test_discrete_colormap_list_of_levels(self): - for extend, levels in [('max', [-1, 2, 4, 8, 10]), ('both', - [2, 5, 10, 11]), - ('neither', [0, 5, 10, 15]), ('min', - [2, 5, 10, 15])]: + for extend, levels in [('max', [-1, 2, 4, 8, 10]), + ('both', [2, 5, 10, 11]), + ('neither', [0, 5, 10, 15]), + ('min', [2, 5, 10, 15])]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)(levels=levels) assert_array_equal(levels, primitive.norm.boundaries) @@ -682,10 +702,10 @@ def test_discrete_colormap_list_of_levels(self): @pytest.mark.slow def test_discrete_colormap_int_levels(self): - for extend, levels, vmin, vmax in [('neither', 7, None, - None), ('neither', 7, None, 20), - ('both', 7, 4, 8), ('min', 10, 4, - 15)]: + for extend, levels, vmin, vmax in [('neither', 7, None, None), + ('neither', 7, None, 20), + ('both', 7, 4, 8), + ('min', 10, 4, 15)]: for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: primitive = getattr(self.darray.plot, kind)( levels=levels, vmin=vmin, vmax=vmax) @@ -711,6 +731,11 @@ def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): assert primitive.norm.vmax == max(levels) assert primitive.norm.vmin == min(levels) + def test_discrete_colormap_provided_boundary_norm(self): + norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) + primitive = self.darray.plot.contourf(norm=norm) + np.testing.assert_allclose(primitive.levels, norm.boundaries) + class Common2dMixin(object): """ @@ -1085,6 +1110,15 @@ def test_cmap_and_color_both(self): with pytest.raises(ValueError): self.plotmethod(colors='k', cmap='RdBu') + def test_colormap_error_norm_and_vmin_vmax(self): + norm = mpl.colors.LogNorm(0.1, 1e1) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmin=2) + + with pytest.raises(ValueError): + self.darray.plot(norm=norm, vmax=2) + @pytest.mark.slow class TestContourf(Common2dMixin, PlotTestCase):