Skip to content

Commit

Permalink
More fixes to hue, cmap_kwargs.
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Dec 16, 2018
1 parent a938d24 commit 6440365
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 22 deletions.
51 changes: 33 additions & 18 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@ def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend):
if add_legend and not hue:
raise ValueError('hue must be specified for generating a legend')

if hue and not _ensure_numeric(ds[hue].values):
if hue and add_legend and not _ensure_numeric(ds[hue].values):
if discrete_legend is None:
discrete_legend = True
elif discrete_legend is False:
raise ValueError('Cannot create a colorbar for a non numeric'
' coordinate')

if not hue and add_legend is None:
discrete_legend = None

dims = ds[x].dims
if ds[y].dims != dims:
raise ValueError('{} and {} must have the same dimensions.'
Expand All @@ -49,15 +52,17 @@ def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend):

if hue:
hue_label = label_from_attrs(ds.coords[hue])
hue_values = ds[x].coords[hue] if discrete_legend else None
else:
hue_label = None
hue_values = None

return {'add_legend': add_legend,
'discrete_legend': discrete_legend,
'hue_label': hue_label,
'xlabel': label_from_attrs(ds[x]),
'ylabel': label_from_attrs(ds[y]),
'hue_values': ds[x].coords[hue] if discrete_legend else None}
'hue_values': hue_values}


def _infer_scatter_data(ds, x, y, hue):
Expand Down Expand Up @@ -175,30 +180,40 @@ def scatter(ds, x, y, hue=None, col=None, row=None,
figsize = kwargs.pop('figsize', None)
ax = kwargs.pop('ax', None)
ax = get_axis(figsize, size, aspect, ax)

kwargs = kwargs.copy()
_meta_data = kwargs.pop('_meta_data', None)

if discrete_legend:
primitive = []
for label, grp in ds.groupby(ds[hue]):
data = _infer_scatter_data(grp, x, y, hue=None)
primitive.append(ax.scatter(data['x'], data['y'], label=label))
primitive.append(ax.scatter(data['x'], data['y'], label=label,
**kwargs))
else:
data = _infer_scatter_data(ds, x, y, hue)
cmap_kwargs = {'plot_data': ds[hue],
'vmin': vmin,
'vmax': vmax,
'cmap': colors if colors else cmap,
'center': center,
'robust': robust,
'extend': extend,
'levels': levels,
'filled': None,
'norm': norm,
}
cmap_params = _determine_cmap_params(**cmap_kwargs)
if hue is not None:
cmap_kwargs = {'plot_data': ds[hue],
'vmin': vmin,
'vmax': vmax,
'cmap': colors if colors else cmap,
'center': center,
'robust': robust,
'extend': extend,
'levels': levels,
'filled': None,
'norm': norm}
cmap_params = _determine_cmap_params(**cmap_kwargs)
cmap_kwargs_subset = dict(
(vv, cmap_kwargs[vv])
for vv in ['vmin', 'vmax', 'norm', 'cmap'])
else:
cmap_kwargs_subset = {}

primitive = ax.scatter(data['x'], data['y'], c=data['color'],
vmin=cmap_kwargs['vmin'],
vmax=cmap_kwargs['vmax'])
**cmap_kwargs_subset, **kwargs)

if '_meta_data' in kwargs: # if this was called from map_scatter,
if _meta_data: # if this was called from map_scatter,
return primitive # finish here. Else, make labels

if meta_data.get('xlabel', None):
Expand Down
10 changes: 7 additions & 3 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,13 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False,
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
mappable = scatter(subset, x=x, y=y, hue=hue,
ax=ax, **kwargs)
self._mappables.append(mappable)
maybe_mappable = scatter(subset, x=x, y=y, hue=hue,
ax=ax, **kwargs)
# TODO: better way to verify that an artist is mappable?
# https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522
if (maybe_mappable
and hasattr(maybe_mappable, 'autoscale_None')):
self._mappables.append(maybe_mappable)

self._finalize_grid(meta_data['xlabel'], meta_data['ylabel'])

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1791,7 +1791,7 @@ def test_wrong_num_of_dimensions(self):
self.darray.plot.line(row='row', hue='hue')


class TestScatterPlots(PlotTestCase):
class TestDatasetScatterPlots(PlotTestCase):
@pytest.fixture(autouse=True)
def setUp(self):
das = [DataArray(np.random.randn(3, 3, 4, 4),
Expand Down

0 comments on commit 6440365

Please sign in to comment.