-
-
Notifications
You must be signed in to change notification settings - Fork 407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separation plot #1359
Merged
aloctavodia
merged 14 commits into
arviz-devs:master
from
agustinaarroyuelo:separationplot
Sep 2, 2020
Merged
Separation plot #1359
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
b8b2576
separation plot
agustinaarroyuelo 8470952
add squeeze argument
agustinaarroyuelo c47cb43
remove cmap argument
agustinaarroyuelo 63fb226
add example and tests
agustinaarroyuelo 9825389
add gallery examples
agustinaarroyuelo c4cc963
update changelog
agustinaarroyuelo 64b1158
run black
agustinaarroyuelo 0d6ac2d
fix legend
agustinaarroyuelo 0d3194f
use labeled dimensions
agustinaarroyuelo 6e652af
update docstring
agustinaarroyuelo ca54234
update doc/api.rst
agustinaarroyuelo e5a3c33
change expected events plot
agustinaarroyuelo 5b4f80f
fix label
agustinaarroyuelo cb91fa1
move locs and width to general function
agustinaarroyuelo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
"""Bokeh separation plot.""" | ||
import numpy as np | ||
|
||
from ...plot_utils import _scale_fig_size, vectorized_to_hex | ||
from .. import show_layout | ||
from . import backend_kwarg_defaults, create_axes_grid | ||
|
||
|
||
def plot_separation( | ||
y, | ||
y_hat, | ||
y_hat_line, | ||
label_y_hat, | ||
expected_events, | ||
figsize, | ||
textsize, | ||
color, | ||
legend, # pylint: disable=unused-argument | ||
locs, | ||
width, | ||
ax, | ||
plot_kwargs, | ||
y_hat_line_kwargs, | ||
exp_events_kwargs, | ||
backend_kwargs, | ||
show, | ||
): | ||
"""Matplotlib separation plot.""" | ||
if backend_kwargs is None: | ||
backend_kwargs = {} | ||
|
||
if plot_kwargs is None: | ||
plot_kwargs = {} | ||
|
||
# plot_kwargs.setdefault("color", "#2a2eec") | ||
# if color: | ||
plot_kwargs["color"] = vectorized_to_hex(color) | ||
|
||
backend_kwargs = { | ||
**backend_kwarg_defaults(), | ||
**backend_kwargs, | ||
} | ||
|
||
if y_hat_line_kwargs is None: | ||
y_hat_line_kwargs = {} | ||
|
||
y_hat_line_kwargs.setdefault("color", "black") | ||
y_hat_line_kwargs.setdefault("line_width", 2) | ||
|
||
if exp_events_kwargs is None: | ||
exp_events_kwargs = {} | ||
|
||
exp_events_kwargs.setdefault("color", "black") | ||
exp_events_kwargs.setdefault("size", 15) | ||
|
||
figsize, *_ = _scale_fig_size(figsize, textsize) | ||
|
||
idx = np.argsort(y_hat) | ||
|
||
backend_kwargs["x_range"] = (0, 1) | ||
backend_kwargs["y_range"] = (0, 1) | ||
|
||
if ax is None: | ||
ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs) | ||
|
||
for i, loc in enumerate(locs): | ||
positive = not y[idx][i] == 0 | ||
alpha = 1 if positive else 0.3 | ||
ax.vbar( | ||
loc, | ||
top=1, | ||
width=width, | ||
fill_alpha=alpha, | ||
line_alpha=alpha, | ||
**plot_kwargs, | ||
) | ||
|
||
if y_hat_line: | ||
ax.line( | ||
np.linspace(0, 1, len(y_hat)), | ||
y_hat[idx], | ||
legend_label=label_y_hat, | ||
**y_hat_line_kwargs, | ||
) | ||
|
||
if expected_events: | ||
expected_events = int(np.round(np.sum(y_hat))) | ||
ax.triangle( | ||
y_hat[idx][len(y_hat) - expected_events - 1], | ||
0, | ||
legend_label="Expected events", | ||
**exp_events_kwargs, | ||
) | ||
|
||
ax.axis.visible = False | ||
ax.xgrid.grid_line_color = None | ||
ax.ygrid.grid_line_color = None | ||
|
||
show_layout(ax, show) | ||
|
||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
"""Matplotlib separation plot.""" | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
|
||
from ...plot_utils import _scale_fig_size | ||
from . import backend_kwarg_defaults, backend_show, create_axes_grid | ||
|
||
|
||
def plot_separation( | ||
y, | ||
y_hat, | ||
y_hat_line, | ||
label_y_hat, | ||
expected_events, | ||
figsize, | ||
textsize, | ||
color, | ||
legend, | ||
locs, | ||
width, | ||
ax, | ||
plot_kwargs, | ||
y_hat_line_kwargs, | ||
exp_events_kwargs, | ||
backend_kwargs, | ||
show, | ||
): | ||
"""Matplotlib separation plot.""" | ||
if backend_kwargs is None: | ||
backend_kwargs = {} | ||
|
||
if plot_kwargs is None: | ||
plot_kwargs = {} | ||
|
||
# plot_kwargs.setdefault("color", "C0") | ||
# if color: | ||
plot_kwargs["color"] = color | ||
|
||
if y_hat_line_kwargs is None: | ||
y_hat_line_kwargs = {} | ||
|
||
y_hat_line_kwargs.setdefault("color", "k") | ||
|
||
if exp_events_kwargs is None: | ||
exp_events_kwargs = {} | ||
|
||
exp_events_kwargs.setdefault("color", "k") | ||
exp_events_kwargs.setdefault("marker", "^") | ||
exp_events_kwargs.setdefault("s", 100) | ||
exp_events_kwargs.setdefault("zorder", 2) | ||
|
||
backend_kwargs = { | ||
**backend_kwarg_defaults(), | ||
**backend_kwargs, | ||
} | ||
|
||
(figsize, *_) = _scale_fig_size(figsize, textsize, 1, 1) | ||
backend_kwargs.setdefault("figsize", figsize) | ||
backend_kwargs["squeeze"] = True | ||
|
||
if ax is None: | ||
_, ax = create_axes_grid(1, backend_kwargs=backend_kwargs) | ||
|
||
idx = np.argsort(y_hat) | ||
|
||
for i, loc in enumerate(locs): | ||
positive = not y[idx][i] == 0 | ||
alpha = 1 if positive else 0.3 | ||
ax.bar(loc, 1, width=width, alpha=alpha, **plot_kwargs) | ||
|
||
if y_hat_line: | ||
ax.plot(np.linspace(0, 1, len(y_hat)), y_hat[idx], label=label_y_hat, **y_hat_line_kwargs) | ||
|
||
if expected_events: | ||
expected_events = int(np.round(np.sum(y_hat))) | ||
ax.scatter( | ||
y_hat[idx][len(y_hat) - expected_events - 1], | ||
0, | ||
label="Expected events", | ||
**exp_events_kwargs | ||
) | ||
|
||
if legend and expected_events or y_hat_line: | ||
handles, labels = ax.get_legend_handles_labels() | ||
labels_dict = dict(zip(labels, handles)) | ||
ax.legend(labels_dict.values(), labels_dict.keys()) | ||
|
||
ax.set_xticks([]) | ||
ax.set_yticks([]) | ||
ax.set_xlim(0, 1) | ||
ax.set_ylim(0, 1) | ||
|
||
if backend_show(show): | ||
plt.show() | ||
|
||
return ax |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
"""Separation plot for discrete outcome models.""" | ||
import warnings | ||
import numpy as np | ||
import xarray as xr | ||
|
||
from ..data import InferenceData | ||
from ..rcparams import rcParams | ||
from .plot_utils import get_plotting_function | ||
|
||
|
||
def plot_separation( | ||
idata=None, | ||
y=None, | ||
y_hat=None, | ||
y_hat_line=False, | ||
expected_events=False, | ||
figsize=None, | ||
textsize=None, | ||
color="C0", | ||
legend=True, | ||
ax=None, | ||
plot_kwargs=None, | ||
y_hat_line_kwargs=None, | ||
exp_events_kwargs=None, | ||
backend=None, | ||
backend_kwargs=None, | ||
show=None, | ||
): | ||
"""Separation plot for binary outcome models. | ||
|
||
Model predictions are sorted and plotted using a color code according to | ||
the observed data. | ||
|
||
Parameters | ||
---------- | ||
idata : InferenceData | ||
InferenceData object. | ||
y : array, DataArray or str | ||
Observed data. If str, idata must be present and contain the observed data group | ||
y_hat : array, DataArray or str | ||
Posterior predictive samples for ``y``. It must have the same shape as y. If str or | ||
None, idata must contain the posterior predictive group. | ||
y_hat_line : bool, optional | ||
Plot the sorted `y_hat` predictions. | ||
expected_events : bool, optional | ||
Plot the total number of expected events. | ||
figsize : figure size tuple, optional | ||
If None, size is (8 + numvars, 8 + numvars) | ||
textsize: int, optional | ||
Text size for labels. If None it will be autoscaled based on figsize. | ||
color : str, optional | ||
Color to assign to the postive class. The negative class will be plotted using the | ||
same color and an `alpha=0.3` transparency. | ||
legend : bool, optional | ||
Show the legend of the figure. | ||
ax: axes, optional | ||
Matplotlib axes or bokeh figures. | ||
plot_kwargs : dict, optional | ||
Additional keywords passed to :meth:`mpl:matplotlib.axes.Axes.bar` or | ||
:meth:`bokeh:bokeh.plotting.Figure.vbar` for separation plot. | ||
y_hat_line_kwargs : dict, optional | ||
Additional keywords passed to ax.plot for `y_hat` line. | ||
exp_events_kwargs : dict, optional | ||
Additional keywords passed to ax.scatter for expected_events marker. | ||
backend: str, optional | ||
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". | ||
backend_kwargs: bool, optional | ||
These are kwargs specific to the backend being used. For additional documentation | ||
check the plotting method of the backend. | ||
show : bool, optional | ||
Call backend show function. | ||
|
||
aloctavodia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Returns | ||
------- | ||
axes : matplotlib axes or bokeh figures | ||
|
||
References | ||
---------- | ||
* Greenhill, B. et al (2011) see https://doi.org/10.1111/j.1540-5907.2011.00525.x | ||
|
||
Examples | ||
-------- | ||
Separation plot for a logistic regression model. | ||
|
||
.. plot:: | ||
:context: close-figs | ||
|
||
>>> import arviz as az | ||
>>> idata = az.load_arviz_data('classification10d') | ||
>>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1)) | ||
|
||
""" | ||
label_y_hat = "y_hat" | ||
if idata is not None and not isinstance(idata, InferenceData): | ||
raise ValueError("idata must be of type InferenceData or None") | ||
|
||
if idata is None: | ||
if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): | ||
raise ValueError( | ||
"y and y_hat must be array or DataArray when idata is None " | ||
"but they are of types {}".format([type(arg) for arg in (y, y_hat)]) | ||
) | ||
else: | ||
|
||
if y_hat is None and isinstance(y, str): | ||
label_y_hat = y | ||
y_hat = y | ||
elif y_hat is None: | ||
raise ValueError("y_hat cannot be None if y is not a str") | ||
|
||
if isinstance(y, str): | ||
y = idata.observed_data[y].values | ||
elif not isinstance(y, (np.ndarray, xr.DataArray)): | ||
raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) | ||
|
||
if isinstance(y_hat, str): | ||
label_y_hat = y_hat | ||
y_hat = idata.posterior_predictive[y_hat].mean(dim=("chain", "draw")).values | ||
elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): | ||
raise ValueError( | ||
"y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) | ||
) | ||
|
||
if len(y) != len(y_hat): | ||
warnings.warn( | ||
"y and y_hat must be the same lenght", | ||
UserWarning, | ||
) | ||
|
||
locs = np.linspace(0, 1, len(y_hat)) | ||
width = np.diff(locs).mean() | ||
|
||
separation_kwargs = dict( | ||
y=y, | ||
y_hat=y_hat, | ||
y_hat_line=y_hat_line, | ||
label_y_hat=label_y_hat, | ||
expected_events=expected_events, | ||
figsize=figsize, | ||
textsize=textsize, | ||
color=color, | ||
legend=legend, | ||
locs=locs, | ||
width=width, | ||
ax=ax, | ||
plot_kwargs=plot_kwargs, | ||
y_hat_line_kwargs=y_hat_line_kwargs, | ||
exp_events_kwargs=exp_events_kwargs, | ||
backend_kwargs=backend_kwargs, | ||
show=show, | ||
) | ||
|
||
if backend is None: | ||
backend = rcParams["plot.backend"] | ||
backend = backend.lower() | ||
|
||
plot = get_plotting_function("plot_separation", "separationplot", backend) | ||
axes = plot(**separation_kwargs) | ||
|
||
return axes |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.