Skip to content

Commit

Permalink
Reduce usage of global matplotlib state in axisgrid objects (#2388)
Browse files Browse the repository at this point in the history
* Reduce usage of global matplotlib state in axisgrid objects

* Improve test coverage and legac compat in JointGrid

* Update release notes

* Always cast func.__module__ to str for safety
  • Loading branch information
mwaskom authored Dec 19, 2020
1 parent 9b57ca1 commit ead5a52
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 25 deletions.
4 changes: 4 additions & 0 deletions doc/releases/v0.11.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
v0.11.1 (Unreleased)
--------------------

- [Enhancement| Reduced the use of matplotlib global state in the :ref:`multi-grid classes <multi-plot-grids>` (pr:`2388`).

- |Fix| Restored support for using tuples or numeric keys to reference fields in a long-form `data` object (:pr:`2386`).

- |Fix| Fixed a bug in :func:`lineplot` where NAs were propagating into the confidence interval, sometimes erasing it from the plot (:pr:`2273`).
Expand All @@ -24,6 +26,8 @@ v0.11.1 (Unreleased)

- |Fix| Fixed a bug in :func:`clustermap` where `annot=False` was ignored (:pr:`2323`).

- |Fix| Fixed a bug in :func:`clustermap` where row/col color annotations could not have a categorical dtype (:pr:`2389`).

- |Fix| Fixed a bug in :func:`boxenplot` where the `linewidth` parameter was ignored (:func:`2287`).

- |Fix| Raise a more informative error in :class:`PairGrid`/:func:`pairplot` when no variables can be found to define the rows/columns of the grid (:func:`2382`).
Expand Down
83 changes: 59 additions & 24 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,10 +631,8 @@ def map(self, func, *args, **kwargs):
# If color was a keyword argument, grab it here
kw_color = kwargs.pop("color", None)

if hasattr(func, "__module__"):
func_module = str(func.__module__)
else:
func_module = ""
# How we use the function depends on where it comes from
func_module = str(getattr(func, "__module__", ""))

# Check for categorical plots without order information
if func_module == "seaborn.categorical":
Expand All @@ -657,7 +655,8 @@ def map(self, func, *args, **kwargs):
continue

# Get the current axis
ax = self.facet_axis(row_i, col_j)
modify_state = not func_module.startswith("seaborn")
ax = self.facet_axis(row_i, col_j, modify_state)

# Decide what color to plot with
kwargs["color"] = self._facet_color(hue_k, kw_color)
Expand Down Expand Up @@ -728,7 +727,8 @@ def map_dataframe(self, func, *args, **kwargs):
continue

# Get the current axis
ax = self.facet_axis(row_i, col_j)
modify_state = not str(func.__module__).startswith("seaborn")
ax = self.facet_axis(row_i, col_j, modify_state)

# Decide what color to plot with
kwargs["color"] = self._facet_color(hue_k, kw_color)
Expand Down Expand Up @@ -771,6 +771,7 @@ def _facet_plot(self, func, ax, plot_args, plot_kwargs):
for key, val in zip(semantics, plot_args):
plot_kwargs[key] = val
plot_args = []
plot_kwargs["ax"] = ax
func(*plot_args, **plot_kwargs)

# Sort out the supporting information
Expand All @@ -783,7 +784,7 @@ def _finalize_grid(self, axlabels):
self.set_titles()
self.tight_layout()

def facet_axis(self, row_i, col_j):
def facet_axis(self, row_i, col_j, modify_state=True):
"""Make the axis identified by these indices active and return it."""

# Calculate the actual indices of the axes to plot on
Expand All @@ -793,7 +794,8 @@ def facet_axis(self, row_i, col_j):
ax = self.axes[row_i, col_j]

# Get a reference to the axes object we want, and make it active
plt.sca(ax)
if modify_state:
plt.sca(ax)
return ax

def despine(self, **kwargs):
Expand Down Expand Up @@ -1374,8 +1376,11 @@ def map_diag(self, func, **kwargs):
# Loop over diagonal variables and axes, making one plot in each
for var, ax in zip(self.diag_vars, self.diag_axes):

plt.sca(ax)
plot_kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
plot_kwargs["ax"] = ax
else:
plt.sca(ax)

vector = self.data[var]
if self._hue_var is not None:
Expand Down Expand Up @@ -1408,7 +1413,11 @@ def _map_diag_iter_hue(self, func, **kwargs):
for var, ax in zip(self.diag_vars, self.diag_axes):
hue_grouped = self.data[var].groupby(self.hue_vals)

plt.sca(ax)
plot_kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
plot_kwargs["ax"] = ax
else:
plt.sca(ax)

for k, label_k in enumerate(self._hue_order):

Expand All @@ -1427,9 +1436,9 @@ def _map_diag_iter_hue(self, func, **kwargs):
data_k = utils.remove_na(data_k)

if str(func.__module__).startswith("seaborn"):
func(x=data_k, label=label_k, color=color, **kwargs)
func(x=data_k, label=label_k, color=color, **plot_kwargs)
else:
func(data_k, label=label_k, color=color, **kwargs)
func(data_k, label=label_k, color=color, **plot_kwargs)

self._clean_axis(ax)

Expand Down Expand Up @@ -1465,8 +1474,11 @@ def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs):
self._plot_bivariate_iter_hue(x_var, y_var, ax, func, **kwargs)
return

plt.sca(ax)
kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
kwargs["ax"] = ax
else:
plt.sca(ax)

if x_var == y_var:
axes_vars = [x_var]
Expand Down Expand Up @@ -1497,7 +1509,12 @@ def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs):

def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs):
"""Draw a bivariate plot while iterating over hue subsets."""
plt.sca(ax)
kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
kwargs["ax"] = ax
else:
plt.sca(ax)

if x_var == y_var:
axes_vars = [x_var]
else:
Expand Down Expand Up @@ -1704,8 +1721,11 @@ def plot_joint(self, func, **kwargs):
Returns ``self`` for easy method chaining.
"""
plt.sca(self.ax_joint)
kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
kwargs["ax"] = self.ax_joint
else:
plt.sca(self.ax_joint)
if self.hue is not None:
kwargs["hue"] = self.hue
self._inject_kwargs(func, kwargs, self._hue_params)
Expand Down Expand Up @@ -1738,25 +1758,40 @@ def plot_marginals(self, func, **kwargs):
Returns ``self`` for easy method chaining.
"""
seaborn_func = (
str(func.__module__).startswith("seaborn")
# deprecated distplot has a legacy API, special case it
and not func.__name__ == "distplot"
)
func_params = signature(func).parameters
kwargs = kwargs.copy()
if self.hue is not None:
kwargs["hue"] = self.hue
self._inject_kwargs(func, kwargs, self._hue_params)

if "legend" in signature(func).parameters:
if "legend" in func_params:
kwargs.setdefault("legend", False)

plt.sca(self.ax_marg_x)
if str(func.__module__).startswith("seaborn"):
func(x=self.x, **kwargs)
if "orientation" in func_params:
# e.g. plt.hist
orient_kw_x = {"orientation": "vertical"}
orient_kw_y = {"orientation": "horizontal"}
elif "vertical" in func_params:
# e.g. sns.distplot (also how did this get backwards?)
orient_kw_x = {"vertical": False}
orient_kw_y = {"vertical": True}

if seaborn_func:
func(x=self.x, ax=self.ax_marg_x, **kwargs)
else:
func(self.x, vertical=False, **kwargs)
plt.sca(self.ax_marg_x)
func(self.x, **orient_kw_x, **kwargs)

plt.sca(self.ax_marg_y)
if str(func.__module__).startswith("seaborn"):
func(y=self.y, **kwargs)
if seaborn_func:
func(y=self.y, ax=self.ax_marg_y, **kwargs)
else:
func(self.y, vertical=True, **kwargs)
plt.sca(self.ax_marg_y)
func(self.y, **orient_kw_y, **kwargs)

self.ax_marg_x.yaxis.get_label().set_visible(False)
self.ax_marg_y.xaxis.get_label().set_visible(False)
Expand Down
37 changes: 36 additions & 1 deletion seaborn/tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .. import rcmod
from ..palettes import color_palette
from ..relational import scatterplot
from ..distributions import histplot, kdeplot
from ..distributions import histplot, kdeplot, distplot
from ..categorical import pointplot
from .. import axisgrid as ag
from .._testing import (
Expand Down Expand Up @@ -415,6 +415,8 @@ def test_map_dataframe(self):

def plot(x, y, data=None, **kws):
plt.plot(data[x], data[y], **kws)
# Modify __module__ so this doesn't look like a seaborn function
plot.__module__ = "test"

g.map_dataframe(plot, "x", "y", linestyle="--")

Expand Down Expand Up @@ -1003,6 +1005,20 @@ def test_diag_sharey(self):
for ax in g.diag_axes[1:]:
assert ax.get_ylim() == g.diag_axes[0].get_ylim()

def test_map_diag_matplotlib(self):

bins = 10
g = ag.PairGrid(self.df)
g.map_diag(plt.hist, bins=bins)
for ax in g.diag_axes:
assert len(ax.patches) == bins

levels = len(self.df["a"].unique())
g = ag.PairGrid(self.df, hue="a")
g.map_diag(plt.hist, bins=bins)
for ax in g.diag_axes:
assert len(ax.patches) == (bins * levels)

def test_palette(self):

rcmod.set()
Expand Down Expand Up @@ -1458,6 +1474,25 @@ def test_univariate_plot(self):
y2, _ = g.ax_marg_y.lines[0].get_xydata().T
npt.assert_array_equal(y1, y2)

def test_univariate_plot_distplot(self):

bins = 10
g = ag.JointGrid(x="x", y="x", data=self.data)
with pytest.warns(FutureWarning):
g.plot_marginals(distplot, bins=bins)
assert len(g.ax_marg_x.patches) == bins
assert len(g.ax_marg_y.patches) == bins
for x, y in zip(g.ax_marg_x.patches, g.ax_marg_y.patches):
assert x.get_height() == y.get_width()

def test_univariate_plot_matplotlib(self):

bins = 10
g = ag.JointGrid(x="x", y="x", data=self.data)
g.plot_marginals(plt.hist, bins=bins)
assert len(g.ax_marg_x.patches) == bins
assert len(g.ax_marg_y.patches) == bins

def test_plot(self):

g = ag.JointGrid(x="x", y="x", data=self.data)
Expand Down

0 comments on commit ead5a52

Please sign in to comment.