From 6440365af15b0fa7d937df37eb90739868794ee6 Mon Sep 17 00:00:00 2001 From: dcherian Date: Sat, 15 Dec 2018 16:05:04 -0800 Subject: [PATCH] More fixes to hue, cmap_kwargs. --- xarray/plot/dataset_plot.py | 51 ++++++++++++++++++++++++------------- xarray/plot/facetgrid.py | 10 +++++--- xarray/tests/test_plot.py | 2 +- 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 7f5c345c595..80400d3fdb2 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -30,13 +30,16 @@ def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend): if add_legend and not hue: raise ValueError('hue must be specified for generating a legend') - if hue and not _ensure_numeric(ds[hue].values): + if hue and add_legend and not _ensure_numeric(ds[hue].values): if discrete_legend is None: discrete_legend = True elif discrete_legend is False: raise ValueError('Cannot create a colorbar for a non numeric' ' coordinate') + if not hue and add_legend is None: + discrete_legend = None + dims = ds[x].dims if ds[y].dims != dims: raise ValueError('{} and {} must have the same dimensions.' @@ -49,15 +52,17 @@ def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend): if hue: hue_label = label_from_attrs(ds.coords[hue]) + hue_values = ds[x].coords[hue] if discrete_legend else None else: hue_label = None + hue_values = None return {'add_legend': add_legend, 'discrete_legend': discrete_legend, 'hue_label': hue_label, 'xlabel': label_from_attrs(ds[x]), 'ylabel': label_from_attrs(ds[y]), - 'hue_values': ds[x].coords[hue] if discrete_legend else None} + 'hue_values': hue_values} def _infer_scatter_data(ds, x, y, hue): @@ -175,30 +180,40 @@ def scatter(ds, x, y, hue=None, col=None, row=None, figsize = kwargs.pop('figsize', None) ax = kwargs.pop('ax', None) ax = get_axis(figsize, size, aspect, ax) + + kwargs = kwargs.copy() + _meta_data = kwargs.pop('_meta_data', None) + if discrete_legend: primitive = [] for label, grp in ds.groupby(ds[hue]): data = _infer_scatter_data(grp, x, y, hue=None) - primitive.append(ax.scatter(data['x'], data['y'], label=label)) + primitive.append(ax.scatter(data['x'], data['y'], label=label, + **kwargs)) else: data = _infer_scatter_data(ds, x, y, hue) - cmap_kwargs = {'plot_data': ds[hue], - 'vmin': vmin, - 'vmax': vmax, - 'cmap': colors if colors else cmap, - 'center': center, - 'robust': robust, - 'extend': extend, - 'levels': levels, - 'filled': None, - 'norm': norm, - } - cmap_params = _determine_cmap_params(**cmap_kwargs) + if hue is not None: + cmap_kwargs = {'plot_data': ds[hue], + 'vmin': vmin, + 'vmax': vmax, + 'cmap': colors if colors else cmap, + 'center': center, + 'robust': robust, + 'extend': extend, + 'levels': levels, + 'filled': None, + 'norm': norm} + cmap_params = _determine_cmap_params(**cmap_kwargs) + cmap_kwargs_subset = dict( + (vv, cmap_kwargs[vv]) + for vv in ['vmin', 'vmax', 'norm', 'cmap']) + else: + cmap_kwargs_subset = {} + primitive = ax.scatter(data['x'], data['y'], c=data['color'], - vmin=cmap_kwargs['vmin'], - vmax=cmap_kwargs['vmax']) + **cmap_kwargs_subset, **kwargs) - if '_meta_data' in kwargs: # if this was called from map_scatter, + if _meta_data: # if this was called from map_scatter, return primitive # finish here. Else, make labels if meta_data.get('xlabel', None): diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 1c0ef9d73f2..717a98eec04 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -327,9 +327,13 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False, # None is the sentinel value if d is not None: subset = self.data.loc[d] - mappable = scatter(subset, x=x, y=y, hue=hue, - ax=ax, **kwargs) - self._mappables.append(mappable) + maybe_mappable = scatter(subset, x=x, y=y, hue=hue, + ax=ax, **kwargs) + # TODO: better way to verify that an artist is mappable? + # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 + if (maybe_mappable + and hasattr(maybe_mappable, 'autoscale_None')): + self._mappables.append(maybe_mappable) self._finalize_grid(meta_data['xlabel'], meta_data['ylabel']) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 9eeff5027a0..0dfcafb4a75 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -1791,7 +1791,7 @@ def test_wrong_num_of_dimensions(self): self.darray.plot.line(row='row', hue='hue') -class TestScatterPlots(PlotTestCase): +class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): das = [DataArray(np.random.randn(3, 3, 4, 4),