diff --git a/CHANGELOG.md b/CHANGELOG.md index ec23cde5f6..8a18c0dbfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * Added `to_dict` method for InferenceData object ([1223](https://github.com/arviz-devs/arviz/pull/1223)) * Added `circ_var_names` argument to `plot_trace` allowing for circular traceplot (Matplotlib) ([1336](https://github.com/arviz-devs/arviz/pull/1336)) * Ridgeplot is hdi aware. By default displays truncated densities at the specified `hdi_prop` level ([1348](https://github.com/arviz-devs/arviz/pull/1348)) +* Added `plot_separation` ([1359](https://github.com/arviz-devs/arviz/pull/1359)) ### Maintenance and fixes diff --git a/arviz/plots/__init__.py b/arviz/plots/__init__.py index 0f2a2d9334..3912c4e98a 100644 --- a/arviz/plots/__init__.py +++ b/arviz/plots/__init__.py @@ -22,6 +22,7 @@ from .rankplot import plot_rank from .traceplot import plot_trace from .violinplot import plot_violin +from .separationplot import plot_separation __all__ = [ "plot_autocorr", @@ -48,4 +49,5 @@ "plot_rank", "plot_trace", "plot_violin", + "plot_separation", ] diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py new file mode 100644 index 0000000000..78c153b3ec --- /dev/null +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -0,0 +1,101 @@ +"""Bokeh separation plot.""" +import numpy as np + +from ...plot_utils import _scale_fig_size, vectorized_to_hex +from .. import show_layout +from . import backend_kwarg_defaults, create_axes_grid + + +def plot_separation( + y, + y_hat, + y_hat_line, + label_y_hat, + expected_events, + figsize, + textsize, + color, + legend, # pylint: disable=unused-argument + locs, + width, + ax, + plot_kwargs, + y_hat_line_kwargs, + exp_events_kwargs, + backend_kwargs, + show, +): + """Matplotlib separation plot.""" + if backend_kwargs is None: + backend_kwargs = {} + + if plot_kwargs is None: + plot_kwargs = {} + + # plot_kwargs.setdefault("color", "#2a2eec") + # if color: + plot_kwargs["color"] = vectorized_to_hex(color) + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } + + if y_hat_line_kwargs is None: + y_hat_line_kwargs = {} + + y_hat_line_kwargs.setdefault("color", "black") + y_hat_line_kwargs.setdefault("line_width", 2) + + if exp_events_kwargs is None: + exp_events_kwargs = {} + + exp_events_kwargs.setdefault("color", "black") + exp_events_kwargs.setdefault("size", 15) + + figsize, *_ = _scale_fig_size(figsize, textsize) + + idx = np.argsort(y_hat) + + backend_kwargs["x_range"] = (0, 1) + backend_kwargs["y_range"] = (0, 1) + + if ax is None: + ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs) + + for i, loc in enumerate(locs): + positive = not y[idx][i] == 0 + alpha = 1 if positive else 0.3 + ax.vbar( + loc, + top=1, + width=width, + fill_alpha=alpha, + line_alpha=alpha, + **plot_kwargs, + ) + + if y_hat_line: + ax.line( + np.linspace(0, 1, len(y_hat)), + y_hat[idx], + legend_label=label_y_hat, + **y_hat_line_kwargs, + ) + + if expected_events: + expected_events = int(np.round(np.sum(y_hat))) + ax.triangle( + y_hat[idx][len(y_hat) - expected_events - 1], + 0, + legend_label="Expected events", + **exp_events_kwargs, + ) + + ax.axis.visible = False + ax.xgrid.grid_line_color = None + ax.ygrid.grid_line_color = None + + show_layout(ax, show) + + return ax diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py new file mode 100644 index 0000000000..3b8bc5bb57 --- /dev/null +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -0,0 +1,96 @@ +"""Matplotlib separation plot.""" +import matplotlib.pyplot as plt +import numpy as np + +from ...plot_utils import _scale_fig_size +from . import backend_kwarg_defaults, backend_show, create_axes_grid + + +def plot_separation( + y, + y_hat, + y_hat_line, + label_y_hat, + expected_events, + figsize, + textsize, + color, + legend, + locs, + width, + ax, + plot_kwargs, + y_hat_line_kwargs, + exp_events_kwargs, + backend_kwargs, + show, +): + """Matplotlib separation plot.""" + if backend_kwargs is None: + backend_kwargs = {} + + if plot_kwargs is None: + plot_kwargs = {} + + # plot_kwargs.setdefault("color", "C0") + # if color: + plot_kwargs["color"] = color + + if y_hat_line_kwargs is None: + y_hat_line_kwargs = {} + + y_hat_line_kwargs.setdefault("color", "k") + + if exp_events_kwargs is None: + exp_events_kwargs = {} + + exp_events_kwargs.setdefault("color", "k") + exp_events_kwargs.setdefault("marker", "^") + exp_events_kwargs.setdefault("s", 100) + exp_events_kwargs.setdefault("zorder", 2) + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } + + (figsize, *_) = _scale_fig_size(figsize, textsize, 1, 1) + backend_kwargs.setdefault("figsize", figsize) + backend_kwargs["squeeze"] = True + + if ax is None: + _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs) + + idx = np.argsort(y_hat) + + for i, loc in enumerate(locs): + positive = not y[idx][i] == 0 + alpha = 1 if positive else 0.3 + ax.bar(loc, 1, width=width, alpha=alpha, **plot_kwargs) + + if y_hat_line: + ax.plot(np.linspace(0, 1, len(y_hat)), y_hat[idx], label=label_y_hat, **y_hat_line_kwargs) + + if expected_events: + expected_events = int(np.round(np.sum(y_hat))) + ax.scatter( + y_hat[idx][len(y_hat) - expected_events - 1], + 0, + label="Expected events", + **exp_events_kwargs + ) + + if legend and expected_events or y_hat_line: + handles, labels = ax.get_legend_handles_labels() + labels_dict = dict(zip(labels, handles)) + ax.legend(labels_dict.values(), labels_dict.keys()) + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + if backend_show(show): + plt.show() + + return ax diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py new file mode 100644 index 0000000000..c91871e18d --- /dev/null +++ b/arviz/plots/separationplot.py @@ -0,0 +1,160 @@ +"""Separation plot for discrete outcome models.""" +import warnings +import numpy as np +import xarray as xr + +from ..data import InferenceData +from ..rcparams import rcParams +from .plot_utils import get_plotting_function + + +def plot_separation( + idata=None, + y=None, + y_hat=None, + y_hat_line=False, + expected_events=False, + figsize=None, + textsize=None, + color="C0", + legend=True, + ax=None, + plot_kwargs=None, + y_hat_line_kwargs=None, + exp_events_kwargs=None, + backend=None, + backend_kwargs=None, + show=None, +): + """Separation plot for binary outcome models. + + Model predictions are sorted and plotted using a color code according to + the observed data. + + Parameters + ---------- + idata : InferenceData + InferenceData object. + y : array, DataArray or str + Observed data. If str, idata must be present and contain the observed data group + y_hat : array, DataArray or str + Posterior predictive samples for ``y``. It must have the same shape as y. If str or + None, idata must contain the posterior predictive group. + y_hat_line : bool, optional + Plot the sorted `y_hat` predictions. + expected_events : bool, optional + Plot the total number of expected events. + figsize : figure size tuple, optional + If None, size is (8 + numvars, 8 + numvars) + textsize: int, optional + Text size for labels. If None it will be autoscaled based on figsize. + color : str, optional + Color to assign to the postive class. The negative class will be plotted using the + same color and an `alpha=0.3` transparency. + legend : bool, optional + Show the legend of the figure. + ax: axes, optional + Matplotlib axes or bokeh figures. + plot_kwargs : dict, optional + Additional keywords passed to :meth:`mpl:matplotlib.axes.Axes.bar` or + :meth:`bokeh:bokeh.plotting.Figure.vbar` for separation plot. + y_hat_line_kwargs : dict, optional + Additional keywords passed to ax.plot for `y_hat` line. + exp_events_kwargs : dict, optional + Additional keywords passed to ax.scatter for expected_events marker. + backend: str, optional + Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". + backend_kwargs: bool, optional + These are kwargs specific to the backend being used. For additional documentation + check the plotting method of the backend. + show : bool, optional + Call backend show function. + + Returns + ------- + axes : matplotlib axes or bokeh figures + + References + ---------- + * Greenhill, B. et al (2011) see https://doi.org/10.1111/j.1540-5907.2011.00525.x + + Examples + -------- + Separation plot for a logistic regression model. + + .. plot:: + :context: close-figs + + >>> import arviz as az + >>> idata = az.load_arviz_data('classification10d') + >>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1)) + + """ + label_y_hat = "y_hat" + if idata is not None and not isinstance(idata, InferenceData): + raise ValueError("idata must be of type InferenceData or None") + + if idata is None: + if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): + raise ValueError( + "y and y_hat must be array or DataArray when idata is None " + "but they are of types {}".format([type(arg) for arg in (y, y_hat)]) + ) + else: + + if y_hat is None and isinstance(y, str): + label_y_hat = y + y_hat = y + elif y_hat is None: + raise ValueError("y_hat cannot be None if y is not a str") + + if isinstance(y, str): + y = idata.observed_data[y].values + elif not isinstance(y, (np.ndarray, xr.DataArray)): + raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) + + if isinstance(y_hat, str): + label_y_hat = y_hat + y_hat = idata.posterior_predictive[y_hat].mean(dim=("chain", "draw")).values + elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): + raise ValueError( + "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) + ) + + if len(y) != len(y_hat): + warnings.warn( + "y and y_hat must be the same lenght", + UserWarning, + ) + + locs = np.linspace(0, 1, len(y_hat)) + width = np.diff(locs).mean() + + separation_kwargs = dict( + y=y, + y_hat=y_hat, + y_hat_line=y_hat_line, + label_y_hat=label_y_hat, + expected_events=expected_events, + figsize=figsize, + textsize=textsize, + color=color, + legend=legend, + locs=locs, + width=width, + ax=ax, + plot_kwargs=plot_kwargs, + y_hat_line_kwargs=y_hat_line_kwargs, + exp_events_kwargs=exp_events_kwargs, + backend_kwargs=backend_kwargs, + show=show, + ) + + if backend is None: + backend = rcParams["plot.backend"] + backend = backend.lower() + + plot = get_plotting_function("plot_separation", "separationplot", backend) + axes = plot(**separation_kwargs) + + return axes diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 1318019090..1cada4a1a8 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -29,6 +29,7 @@ plot_posterior, plot_ppc, plot_rank, + plot_separation, plot_trace, plot_violin, ) @@ -132,6 +133,22 @@ def test_plot_density_bad_kwargs(models): plot_density(obj, hdi_prob=2, backend="bokeh", show=False) +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"y_hat_line": True}, + {"expected_events": True}, + {"y_hat_line_kwargs": {"linestyle": "dotted"}}, + {"exp_events_kwargs": {"marker": "o"}}, + ], +) +def test_plot_separation(kwargs): + idata = load_arviz_data("classification10d") + ax = plot_separation(idata=idata, y="outcome", backend="bokeh", show=False, **kwargs) + assert ax + + @pytest.mark.parametrize( "kwargs", [ diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index e8f13b66a4..6e9c1f4fcb 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -33,6 +33,7 @@ plot_posterior, plot_ppc, plot_rank, + plot_separation, plot_trace, plot_violin, ) @@ -159,6 +160,22 @@ def test_plot_density_bad_kwargs(models): plot_density(obj, hdi_prob=2) +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"y_hat_line": True}, + {"expected_events": True}, + {"y_hat_line_kwargs": {"linestyle": "dotted"}}, + {"exp_events_kwargs": {"marker": "o"}}, + ], +) +def test_plot_separation(kwargs): + idata = load_arviz_data("classification10d") + ax = plot_separation(idata=idata, y="outcome", **kwargs) + assert ax + + @pytest.mark.parametrize( "kwargs", [ diff --git a/doc/api.rst b/doc/api.rst index 58c02bd523..d4d093bed2 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -34,6 +34,7 @@ Plots plot_posterior plot_ppc plot_rank + plot_separation plot_trace plot_violin diff --git a/examples/bokeh/bokeh_plot_separation.py b/examples/bokeh/bokeh_plot_separation.py new file mode 100644 index 0000000000..82abe87536 --- /dev/null +++ b/examples/bokeh/bokeh_plot_separation.py @@ -0,0 +1,11 @@ +""" +Separationplot +========== + +_thumb: .2, .8 +""" +import arviz as az + +idata = az.load_arviz_data("classification10d") + +ax = az.plot_separation(idata=idata, y="outcome", y_hat="outcome", figsize=(8, 1), backend="bokeh") diff --git a/examples/matplotlib/mpl_plot_separation.py b/examples/matplotlib/mpl_plot_separation.py new file mode 100644 index 0000000000..8d0e8370a8 --- /dev/null +++ b/examples/matplotlib/mpl_plot_separation.py @@ -0,0 +1,17 @@ +""" +Separationplot +========== + +_thumb: .2, .8 +""" +import matplotlib.pyplot as plt + +import arviz as az + +az.style.use("arviz-darkgrid") + +idata = az.load_arviz_data("classification10d") + +az.plot_separation(idata=idata, y="outcome", y_hat="outcome", figsize=(8, 1)) + +plt.show()