-
-
Notifications
You must be signed in to change notification settings - Fork 407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Separation plot #1359
Separation plot #1359
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remember to also update arviz/doc/api.rst
backend_kwargs, | ||
show, | ||
): | ||
"""Matplotlib separation plot.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Matplotlib separation plot.""" | |
"""Bokeh separation plot.""" |
if idata is not None and not isinstance(idata, InferenceData): | ||
raise ValueError("idata must be of type InferenceData or None") | ||
|
||
if idata is None: | ||
if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): | ||
raise ValueError( | ||
"y and y_hat must be array or DataArray when idata is None " | ||
"but they are of types {}".format([type(arg) for arg in (y, y_hat)]) | ||
) | ||
else: | ||
|
||
if y_hat is None and isinstance(y, str): | ||
label_y_hat = y | ||
y_hat = y | ||
elif y_hat is None: | ||
raise ValueError("y_hat cannot be None if y is not a str") | ||
|
||
if isinstance(y, str): | ||
y = idata.observed_data[y].values | ||
elif not isinstance(y, (np.ndarray, xr.DataArray)): | ||
raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) | ||
|
||
if isinstance(y_hat, str): | ||
label_y_hat = y_hat | ||
y_hat = idata.posterior_predictive[y_hat].mean(axis=(1, 0)).values | ||
elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): | ||
raise ValueError( | ||
"y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this can be done in the general function and then the backend specific ones only get the arrays (plus maybe labels too?)
arviz/plots/separationplot.py
Outdated
expected_events=False, | ||
figsize=None, | ||
textsize=None, | ||
color=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we set color="C0"
here and then convert to hex before passing it to the backend specific function? This is more a philosophical question about whether we want this behaviour than anything else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I think we want this everywhere, not only here.
ax.scatter(y_hat[idx][expected_events - 1], 0, label="Expected events", **exp_events_kwargs) | ||
|
||
if legend and expected_events or y_hat_line: | ||
handles, labels = plt.gca().get_legend_handles_labels() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why gca instead of using ax? not sure how is this different from ax.legend()
😅
Codecov Report
@@ Coverage Diff @@
## master #1359 +/- ##
==========================================
- Coverage 91.74% 91.73% -0.02%
==========================================
Files 102 105 +3
Lines 10778 10907 +129
==========================================
+ Hits 9888 10005 +117
- Misses 890 902 +12
Continue to review full report at Codecov.
|
pydocstyle
Co-authored-by: Oriol Abril-Pla <[email protected]>
Co-authored-by: Oriol Abril-Pla <[email protected]>
Co-authored-by: Oriol Abril-Pla <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a small comment to try to minimize duplicated code, but it's note necessary to do, can be merged as is.
I love the PR and having the separation plot on ArviZ. I have been asked a couple times for posterior predictive plots for binary outcomes in Discourse, if you are up for it, it could be interesting to make a blogpost about it, how to use it, how to interpret it and share it in pymc and stan discourses
if len(y) != len(y_hat): | ||
warnings.warn( | ||
"y and y_hat must be the same lenght", | ||
UserWarning, | ||
) | ||
|
||
locs = np.linspace(0, 1, len(y_hat)) | ||
width = np.diff(locs).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe this could also go in the general function as it's not backend specific
thanks @agustinaarroyuelo! |
Description
The separation plot is a really interesting and simple way for assessing a model's fit when the outcome is binary. After fitting, model predictions are sorted and represented as vertical lines. When adding a color scheme that identifies the positive and negative class, for a good fit, we should see most of the instances of the positive class on the right hand side of the plot. That is, where the highest valued predictions are located.
It would be great if you could suggest how to choose default colors for this plot. Finally, I would like to know your overall comments on this code.Checklist