Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separation plot #1359

Merged
merged 14 commits into from
Sep 2, 2020
Merged

Separation plot #1359

merged 14 commits into from
Sep 2, 2020

Conversation

agustinaarroyuelo
Copy link
Contributor

@agustinaarroyuelo agustinaarroyuelo commented Aug 21, 2020

Description

The separation plot is a really interesting and simple way for assessing a model's fit when the outcome is binary. After fitting, model predictions are sorted and represented as vertical lines. When adding a color scheme that identifies the positive and negative class, for a good fit, we should see most of the instances of the positive class on the right hand side of the plot. That is, where the highest valued predictions are located.

idata = az.load_arviz_data('classification10d')
ax = az.plot_separation(
    idata=idata, 
    y='outcome',
    y_hat_line=True,
    expected_events=True,
    figsize=(10, 1),
)

separationplot

It would be great if you could suggest how to choose default colors for this plot. Finally, I would like to know your overall comments on this code.

Checklist

  • Follows official PR format
  • Includes a sample plot to visually illustrate the changes (only for plot-related functions)
  • New features are properly documented (with an example if appropriate)?
  • Includes new or updated tests to cover the new feature
  • Code style correct (follows pylint and black guidelines)
  • Changes are listed in changelog

arviz/plots/backends/bokeh/separationplot.py Outdated Show resolved Hide resolved
arviz/plots/backends/matplotlib/separationplot.py Outdated Show resolved Hide resolved
arviz/plots/separationplot.py Show resolved Hide resolved
Copy link
Contributor

@aloctavodia aloctavodia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remember to also update arviz/doc/api.rst

arviz/plots/separationplot.py Show resolved Hide resolved
arviz/plots/backends/bokeh/separationplot.py Outdated Show resolved Hide resolved
@agustinaarroyuelo agustinaarroyuelo changed the title [WIP] Separation plot Separation plot Aug 26, 2020
backend_kwargs,
show,
):
"""Matplotlib separation plot."""
Copy link
Member

@OriolAbril OriolAbril Aug 27, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Matplotlib separation plot."""
"""Bokeh separation plot."""

arviz/plots/backends/bokeh/separationplot.py Outdated Show resolved Hide resolved
Comment on lines 59 to 87
if idata is not None and not isinstance(idata, InferenceData):
raise ValueError("idata must be of type InferenceData or None")

if idata is None:
if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)):
raise ValueError(
"y and y_hat must be array or DataArray when idata is None "
"but they are of types {}".format([type(arg) for arg in (y, y_hat)])
)
else:

if y_hat is None and isinstance(y, str):
label_y_hat = y
y_hat = y
elif y_hat is None:
raise ValueError("y_hat cannot be None if y is not a str")

if isinstance(y, str):
y = idata.observed_data[y].values
elif not isinstance(y, (np.ndarray, xr.DataArray)):
raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y)))

if isinstance(y_hat, str):
label_y_hat = y_hat
y_hat = idata.posterior_predictive[y_hat].mean(axis=(1, 0)).values
elif not isinstance(y_hat, (np.ndarray, xr.DataArray)):
raise ValueError(
"y_hat must be of types array, DataArray or str, not {}".format(type(y_hat))
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be done in the general function and then the backend specific ones only get the arrays (plus maybe labels too?)

expected_events=False,
figsize=None,
textsize=None,
color=None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we set color="C0" here and then convert to hex before passing it to the backend specific function? This is more a philosophical question about whether we want this behaviour than anything else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we want this everywhere, not only here.

ax.scatter(y_hat[idx][expected_events - 1], 0, label="Expected events", **exp_events_kwargs)

if legend and expected_events or y_hat_line:
handles, labels = plt.gca().get_legend_handles_labels()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why gca instead of using ax? not sure how is this different from ax.legend() 😅

arviz/plots/separationplot.py Outdated Show resolved Hide resolved
doc/api.rst Outdated Show resolved Hide resolved
@codecov
Copy link

codecov bot commented Aug 27, 2020

Codecov Report

Merging #1359 into master will decrease coverage by 0.01%.
The diff coverage is 90.69%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1359      +/-   ##
==========================================
- Coverage   91.74%   91.73%   -0.02%     
==========================================
  Files         102      105       +3     
  Lines       10778    10907     +129     
==========================================
+ Hits         9888    10005     +117     
- Misses        890      902      +12     
Impacted Files Coverage Δ
arviz/plots/separationplot.py 72.72% <72.72%> (ø)
arviz/plots/backends/matplotlib/separationplot.py 96.07% <96.07%> (ø)
arviz/plots/backends/bokeh/separationplot.py 97.72% <97.72%> (ø)
arviz/plots/__init__.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update b7221d7...cb91fa1. Read the comment docs.

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a small comment to try to minimize duplicated code, but it's note necessary to do, can be merged as is.

I love the PR and having the separation plot on ArviZ. I have been asked a couple times for posterior predictive plots for binary outcomes in Discourse, if you are up for it, it could be interesting to make a blogpost about it, how to use it, how to interpret it and share it in pymc and stan discourses

Comment on lines 59 to 66
if len(y) != len(y_hat):
warnings.warn(
"y and y_hat must be the same lenght",
UserWarning,
)

locs = np.linspace(0, 1, len(y_hat))
width = np.diff(locs).mean()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this could also go in the general function as it's not backend specific

@aloctavodia aloctavodia merged commit 5a6bbee into arviz-devs:master Sep 2, 2020
@aloctavodia
Copy link
Contributor

thanks @agustinaarroyuelo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants