Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separation plot #1359

Merged
merged 14 commits into from
Sep 2, 2020
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Added `to_dict` method for InferenceData object ([1223](https://github.com/arviz-devs/arviz/pull/1223))
* Added `circ_var_names` argument to `plot_trace` allowing for circular traceplot (Matplotlib) ([1336](https://github.com/arviz-devs/arviz/pull/1336))
* Ridgeplot is hdi aware. By default displays truncated densities at the specified `hdi_prop` level ([1348](https://github.com/arviz-devs/arviz/pull/1348))
* Added `plot_separation` ([1359](https://github.com/arviz-devs/arviz/pull/1359))


### Maintenance and fixes
Expand Down
2 changes: 2 additions & 0 deletions arviz/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .rankplot import plot_rank
from .traceplot import plot_trace
from .violinplot import plot_violin
from .separationplot import plot_separation

__all__ = [
"plot_autocorr",
Expand All @@ -48,4 +49,5 @@
"plot_rank",
"plot_trace",
"plot_violin",
"plot_separation",
]
101 changes: 101 additions & 0 deletions arviz/plots/backends/bokeh/separationplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Bokeh separation plot."""
import numpy as np

from ...plot_utils import _scale_fig_size, vectorized_to_hex
from .. import show_layout
from . import backend_kwarg_defaults, create_axes_grid


def plot_separation(
y,
y_hat,
y_hat_line,
label_y_hat,
expected_events,
figsize,
textsize,
color,
legend, # pylint: disable=unused-argument
locs,
width,
ax,
plot_kwargs,
y_hat_line_kwargs,
exp_events_kwargs,
backend_kwargs,
show,
):
"""Matplotlib separation plot."""
Copy link
Member

@OriolAbril OriolAbril Aug 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Matplotlib separation plot."""
"""Bokeh separation plot."""

if backend_kwargs is None:
backend_kwargs = {}

if plot_kwargs is None:
plot_kwargs = {}

# plot_kwargs.setdefault("color", "#2a2eec")
# if color:
plot_kwargs["color"] = vectorized_to_hex(color)

backend_kwargs = {
**backend_kwarg_defaults(),
**backend_kwargs,
}

if y_hat_line_kwargs is None:
y_hat_line_kwargs = {}

y_hat_line_kwargs.setdefault("color", "black")
y_hat_line_kwargs.setdefault("line_width", 2)

if exp_events_kwargs is None:
exp_events_kwargs = {}

exp_events_kwargs.setdefault("color", "black")
exp_events_kwargs.setdefault("size", 15)

figsize, *_ = _scale_fig_size(figsize, textsize)

idx = np.argsort(y_hat)

backend_kwargs["x_range"] = (0, 1)
backend_kwargs["y_range"] = (0, 1)

if ax is None:
ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs)

for i, loc in enumerate(locs):
positive = not y[idx][i] == 0
alpha = 1 if positive else 0.3
ax.vbar(
loc,
top=1,
width=width,
fill_alpha=alpha,
line_alpha=alpha,
**plot_kwargs,
)

if y_hat_line:
ax.line(
np.linspace(0, 1, len(y_hat)),
y_hat[idx],
legend_label=label_y_hat,
**y_hat_line_kwargs,
)

if expected_events:
expected_events = int(np.round(np.sum(y_hat)))
ax.triangle(
y_hat[idx][len(y_hat) - expected_events - 1],
0,
legend_label="Expected events",
**exp_events_kwargs,
)

ax.axis.visible = False
ax.xgrid.grid_line_color = None
ax.ygrid.grid_line_color = None

show_layout(ax, show)

return ax
96 changes: 96 additions & 0 deletions arviz/plots/backends/matplotlib/separationplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Matplotlib separation plot."""
import matplotlib.pyplot as plt
import numpy as np

from ...plot_utils import _scale_fig_size
from . import backend_kwarg_defaults, backend_show, create_axes_grid


def plot_separation(
y,
y_hat,
y_hat_line,
label_y_hat,
expected_events,
figsize,
textsize,
color,
legend,
locs,
width,
ax,
plot_kwargs,
y_hat_line_kwargs,
exp_events_kwargs,
backend_kwargs,
show,
):
"""Matplotlib separation plot."""
if backend_kwargs is None:
backend_kwargs = {}

if plot_kwargs is None:
plot_kwargs = {}

# plot_kwargs.setdefault("color", "C0")
# if color:
plot_kwargs["color"] = color

if y_hat_line_kwargs is None:
y_hat_line_kwargs = {}

y_hat_line_kwargs.setdefault("color", "k")

if exp_events_kwargs is None:
exp_events_kwargs = {}

exp_events_kwargs.setdefault("color", "k")
exp_events_kwargs.setdefault("marker", "^")
exp_events_kwargs.setdefault("s", 100)
exp_events_kwargs.setdefault("zorder", 2)

backend_kwargs = {
**backend_kwarg_defaults(),
**backend_kwargs,
}

(figsize, *_) = _scale_fig_size(figsize, textsize, 1, 1)
backend_kwargs.setdefault("figsize", figsize)
backend_kwargs["squeeze"] = True

if ax is None:
_, ax = create_axes_grid(1, backend_kwargs=backend_kwargs)

idx = np.argsort(y_hat)

for i, loc in enumerate(locs):
positive = not y[idx][i] == 0
alpha = 1 if positive else 0.3
ax.bar(loc, 1, width=width, alpha=alpha, **plot_kwargs)

if y_hat_line:
ax.plot(np.linspace(0, 1, len(y_hat)), y_hat[idx], label=label_y_hat, **y_hat_line_kwargs)

if expected_events:
expected_events = int(np.round(np.sum(y_hat)))
ax.scatter(
y_hat[idx][len(y_hat) - expected_events - 1],
0,
label="Expected events",
**exp_events_kwargs
)

if legend and expected_events or y_hat_line:
handles, labels = ax.get_legend_handles_labels()
labels_dict = dict(zip(labels, handles))
ax.legend(labels_dict.values(), labels_dict.keys())

ax.set_xticks([])
ax.set_yticks([])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

if backend_show(show):
plt.show()

return ax
160 changes: 160 additions & 0 deletions arviz/plots/separationplot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Separation plot for discrete outcome models."""
import warnings
import numpy as np
import xarray as xr

from ..data import InferenceData
from ..rcparams import rcParams
from .plot_utils import get_plotting_function


def plot_separation(
idata=None,
y=None,
y_hat=None,
y_hat_line=False,
expected_events=False,
figsize=None,
textsize=None,
color="C0",
legend=True,
ax=None,
plot_kwargs=None,
y_hat_line_kwargs=None,
exp_events_kwargs=None,
backend=None,
backend_kwargs=None,
show=None,
):
"""Separation plot for binary outcome models.

Model predictions are sorted and plotted using a color code according to
the observed data.

Parameters
----------
idata : InferenceData
InferenceData object.
y : array, DataArray or str
Observed data. If str, idata must be present and contain the observed data group
y_hat : array, DataArray or str
Posterior predictive samples for ``y``. It must have the same shape as y. If str or
None, idata must contain the posterior predictive group.
y_hat_line : bool, optional
Plot the sorted `y_hat` predictions.
expected_events : bool, optional
Plot the total number of expected events.
figsize : figure size tuple, optional
If None, size is (8 + numvars, 8 + numvars)
textsize: int, optional
Text size for labels. If None it will be autoscaled based on figsize.
color : str, optional
Color to assign to the postive class. The negative class will be plotted using the
same color and an `alpha=0.3` transparency.
legend : bool, optional
Show the legend of the figure.
ax: axes, optional
Matplotlib axes or bokeh figures.
plot_kwargs : dict, optional
Additional keywords passed to :meth:`mpl:matplotlib.axes.Axes.bar` or
:meth:`bokeh:bokeh.plotting.Figure.vbar` for separation plot.
y_hat_line_kwargs : dict, optional
Additional keywords passed to ax.plot for `y_hat` line.
exp_events_kwargs : dict, optional
Additional keywords passed to ax.scatter for expected_events marker.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_kwargs: bool, optional
These are kwargs specific to the backend being used. For additional documentation
check the plotting method of the backend.
show : bool, optional
Call backend show function.

aloctavodia marked this conversation as resolved.
Show resolved Hide resolved
Returns
-------
axes : matplotlib axes or bokeh figures

References
----------
* Greenhill, B. et al (2011) see https://doi.org/10.1111/j.1540-5907.2011.00525.x

Examples
--------
Separation plot for a logistic regression model.

.. plot::
:context: close-figs

>>> import arviz as az
>>> idata = az.load_arviz_data('classification10d')
>>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1))

"""
label_y_hat = "y_hat"
if idata is not None and not isinstance(idata, InferenceData):
raise ValueError("idata must be of type InferenceData or None")

if idata is None:
if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)):
raise ValueError(
"y and y_hat must be array or DataArray when idata is None "
"but they are of types {}".format([type(arg) for arg in (y, y_hat)])
)
else:

if y_hat is None and isinstance(y, str):
label_y_hat = y
y_hat = y
elif y_hat is None:
raise ValueError("y_hat cannot be None if y is not a str")

if isinstance(y, str):
y = idata.observed_data[y].values
elif not isinstance(y, (np.ndarray, xr.DataArray)):
raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y)))

if isinstance(y_hat, str):
label_y_hat = y_hat
y_hat = idata.posterior_predictive[y_hat].mean(dim=("chain", "draw")).values
elif not isinstance(y_hat, (np.ndarray, xr.DataArray)):
raise ValueError(
"y_hat must be of types array, DataArray or str, not {}".format(type(y_hat))
)

if len(y) != len(y_hat):
warnings.warn(
"y and y_hat must be the same lenght",
UserWarning,
)

locs = np.linspace(0, 1, len(y_hat))
width = np.diff(locs).mean()

separation_kwargs = dict(
y=y,
y_hat=y_hat,
y_hat_line=y_hat_line,
label_y_hat=label_y_hat,
expected_events=expected_events,
figsize=figsize,
textsize=textsize,
color=color,
legend=legend,
locs=locs,
width=width,
ax=ax,
plot_kwargs=plot_kwargs,
y_hat_line_kwargs=y_hat_line_kwargs,
exp_events_kwargs=exp_events_kwargs,
backend_kwargs=backend_kwargs,
show=show,
)

if backend is None:
backend = rcParams["plot.backend"]
backend = backend.lower()

plot = get_plotting_function("plot_separation", "separationplot", backend)
axes = plot(**separation_kwargs)

return axes
Loading