From b8b2576157aec255955c4960f160152cdb532217 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 21 Aug 2020 10:51:04 -0300 Subject: [PATCH 01/14] separation plot --- arviz/plots/__init__.py | 2 + arviz/plots/backends/bokeh/separationplot.py | 95 ++++++++++++++++++ .../backends/matplotlib/separationplot.py | 99 +++++++++++++++++++ arviz/plots/separationplot.py | 99 +++++++++++++++++++ 4 files changed, 295 insertions(+) create mode 100644 arviz/plots/backends/bokeh/separationplot.py create mode 100644 arviz/plots/backends/matplotlib/separationplot.py create mode 100644 arviz/plots/separationplot.py diff --git a/arviz/plots/__init__.py b/arviz/plots/__init__.py index 0f2a2d9334..3912c4e98a 100644 --- a/arviz/plots/__init__.py +++ b/arviz/plots/__init__.py @@ -22,6 +22,7 @@ from .rankplot import plot_rank from .traceplot import plot_trace from .violinplot import plot_violin +from .separationplot import plot_separation __all__ = [ "plot_autocorr", @@ -48,4 +49,5 @@ "plot_rank", "plot_trace", "plot_violin", + "plot_separation", ] diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py new file mode 100644 index 0000000000..04e09bb26d --- /dev/null +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -0,0 +1,95 @@ +"""Matplotlib separation plot""" +import matplotlib.pyplot as plt +import numpy as np + +from ...plot_utils import _scale_fig_size +from . import backend_kwarg_defaults, create_axes_grid +from .. import show_layout + + +def plot_separation( + idata, + y, + y_hat, + y_hat_line, + expected_events, + figsize, + textsize, + color, + cmap, + legend, # pylint: disable=unused-argument + ax, + plot_kwargs, + backend_kwargs, + show, +): + """Matplotlib separation plot.""" + if backend_kwargs is None: + backend_kwargs = {} + + if plot_kwargs is None: + plot_kwargs = {} + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } + + if cmap: + cmap = plt.get_cmap(cmap).colors + negative_color, positive_color = cmap[-1], cmap[0] + else: + if color: + negative_color, positive_color = color[0], color[1] + else: + negative_color, positive_color = "peru", "maroon" + + figsize, *_ = _scale_fig_size(figsize, textsize) + if isinstance(y_hat, str): + y_hat_var = idata.posterior_predictive[y_hat].values.mean(1).mean(0) + label_line = y_hat + + idx = np.argsort(y_hat_var) + + if isinstance(y, str): + y = idata.observed_data[y].values[idx].ravel() + + widths = np.linspace(0, 1, len(y_hat_var)) + delta = np.diff(widths).mean() + + backend_kwargs["x_range"] = (delta, 1.5) + backend_kwargs["y_range"] = (0, 1) + + if ax is None: + ax = create_axes_grid(1, figsize=figsize, backend_kwargs=backend_kwargs,) + ax = ax.ravel()[0] + + for i, width in enumerate(widths): + bar_color, tag = (negative_color, False) if y[i] == 0 else (positive_color, True) + label = "Positive class" if tag else "Negative class" + + ax.vbar(width, top=1, width=width, color=bar_color, legend_label=label, **plot_kwargs) + + if y_hat_line: + ax.line( + np.linspace(delta, 1.5, len(y_hat_var)), + y_hat_var[idx], + color="black", + legend_label=label_line, + ) + + if expected_events: + expected_events = int(np.round(np.sum(y_hat_var))) + ax.triangle( + y_hat_var[idx][expected_events], + 0, + color="black", + legend_label="Expected events", + **plot_kwargs + ) + + ax.axis.visible = False + + show_layout(ax, show) + + return ax diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py new file mode 100644 index 0000000000..19ca2e8952 --- /dev/null +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -0,0 +1,99 @@ +"""Matplotlib separation plot""" +import matplotlib.pyplot as plt +import numpy as np + +from ...plot_utils import _scale_fig_size +from . import backend_kwarg_defaults, backend_show, create_axes_grid + + +def plot_separation( + idata, + y, + y_hat, + y_hat_line, + expected_events, + figsize, + textsize, + color, + cmap, + legend, + ax, + plot_kwargs, + backend_kwargs, + show, +): + """Matplotlib separation plot.""" + if backend_kwargs is None: + backend_kwargs = {} + + if plot_kwargs is None: + plot_kwargs = {} + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } + + if cmap: + cmap = plt.get_cmap(cmap).colors + negative_color, positive_color = cmap[-1], cmap[0] + else: + if color: + negative_color, positive_color = color[0], color[1] + else: + negative_color, positive_color = "C1", "C3" + + (figsize, *_) = _scale_fig_size(figsize, textsize, 1, 1) + backend_kwargs.setdefault("figsize", figsize) + backend_kwargs["squeeze"] = True + + if ax is None: + _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs) + + if isinstance(y_hat, str): + y_hat_var = idata.posterior_predictive[y_hat].values.mean(1).mean(0) + label_line = y_hat + + idx = np.argsort(y_hat_var) + + if isinstance(y, str): + y = idata.observed_data[y].values[idx].ravel() + + widths = np.linspace(0, 1, len(y_hat_var)) + + for i, width in enumerate(widths): + bar_color, tag = (negative_color, False) if y[i] == 0 else (positive_color, True) + label = "Positive class" if tag else "Negative class" + ax.bar(width, 1, width=width, color=bar_color, align="edge", label=label, **plot_kwargs) + + delta = np.diff(widths).mean() + + if y_hat_line: + ax.plot( + np.linspace(delta, 1.5, len(y_hat_var)), + y_hat_var[idx], + color="k", + label=label_line, + **plot_kwargs + ) + + if expected_events: + expected_events = int(np.round(np.sum(y_hat_var))) + ax.scatter( + y_hat_var[idx][expected_events], 0, marker="^", color="k", label="Expected events", + ) + + if legend: + handles, labels = plt.gca().get_legend_handles_labels() + labels_dict = dict(zip(labels, handles)) + ax.legend(labels_dict.values(), labels_dict.keys()) + + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_xlim(delta, 1.5) + ax.set_ylim(0, 1) + + if backend_show(show): + plt.show() + + return ax diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py new file mode 100644 index 0000000000..f7041cb6dd --- /dev/null +++ b/arviz/plots/separationplot.py @@ -0,0 +1,99 @@ +"""Plot separation plot for discrete outcome models.""" +from ..rcparams import rcParams +from .plot_utils import get_plotting_function + + +def plot_separation( + idata=None, + y=None, + y_hat=None, + y_hat_line=False, + expected_events=False, + figsize=None, + textsize=None, + color=None, + cmap=None, + legend=True, + ax=None, + plot_kwargs=None, + backend=None, + backend_kwargs=None, + show=None, +): + + """Plot separation plot for discrete outcome models. + + Parameters + ---------- + idata : InferenceData + InferenceData object. + y : array, DataArray or str + Observed data. If str, idata must be present and contain the observed data group + y_hat : array, DataArray or str + Posterior predictive samples for ``y``. It must have the same shape as y plus an + extra dimension at the end of size n_samples (chains and draws stacked). If str or + None, idata must contain the posterior predictive group. If None, y_hat is taken + equal to y, thus, y must be str too. + y_hat_line : bool, optional + Plot the sorted `y_hat` predictions. + expected_events : bool, optional + Plot the total number of expected events. + figsize : figure size tuple, optional + If None, size is (8 + numvars, 8 + numvars) + textsize: int, optional + Text size for labels. If None it will be autoscaled based on figsize. + color : list or array_like, optional + The first color will be used to plot the negative class while the second color will + be assigned to the positive class. + cmap : str, optional + Colors for the separation plot will be taken from both ends of the color map + respectively. + legend : bool, optional + Show the legend of the figure. + ax: axes, optional + Matplotlib axes or bokeh figures. + plot_kwargs : dict, optional + Additional keywords passed to ax.plot for `y_hat` line. + backend: str, optional + Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". + backend_kwargs: bool, optional + These are kwargs specific to the backend being used. For additional documentation + check the plotting method of the backend. + show : bool, optional + Call backend show function. + + Returns + ------- + axes : matplotlib axes or bokeh figures + + References + ---------- + * Greenhill, B. et al (2011) see https://doi.org/10.1111/j.1540-5907.2011.00525.x + + """ + + separation_kwargs = dict( + idata=idata, + y=y, + y_hat=y_hat, + y_hat_line=y_hat_line, + expected_events=expected_events, + figsize=figsize, + textsize=textsize, + color=color, + cmap=cmap, + legend=legend, + ax=ax, + plot_kwargs=plot_kwargs, + backend_kwargs=backend_kwargs, + show=show, + ) + + if backend is None: + backend = rcParams["plot.backend"] + backend = backend.lower() + + plot = get_plotting_function("plot_separation", "separationplot", backend) + axes = plot(**separation_kwargs) + + return axes From 847095289d5ba1b45b2ad3914a772820b0f46b9a Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 21 Aug 2020 11:28:08 -0300 Subject: [PATCH 02/14] add squeeze argument --- arviz/plots/backends/bokeh/separationplot.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index 04e09bb26d..e73c74e7b2 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -61,8 +61,7 @@ def plot_separation( backend_kwargs["y_range"] = (0, 1) if ax is None: - ax = create_axes_grid(1, figsize=figsize, backend_kwargs=backend_kwargs,) - ax = ax.ravel()[0] + ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs,) for i, width in enumerate(widths): bar_color, tag = (negative_color, False) if y[i] == 0 else (positive_color, True) From c47cb436a94a1b9e0684358096ff0a5ab6b27aa4 Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 21 Aug 2020 13:56:28 -0300 Subject: [PATCH 03/14] remove cmap argument --- arviz/plots/backends/bokeh/separationplot.py | 32 ++++++++++--------- .../backends/matplotlib/separationplot.py | 32 ++++++++++--------- arviz/plots/separationplot.py | 11 ++----- 3 files changed, 37 insertions(+), 38 deletions(-) diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index e73c74e7b2..a72ee4f626 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -1,5 +1,4 @@ """Matplotlib separation plot""" -import matplotlib.pyplot as plt import numpy as np from ...plot_utils import _scale_fig_size @@ -16,7 +15,6 @@ def plot_separation( figsize, textsize, color, - cmap, legend, # pylint: disable=unused-argument ax, plot_kwargs, @@ -35,14 +33,8 @@ def plot_separation( **backend_kwargs, } - if cmap: - cmap = plt.get_cmap(cmap).colors - negative_color, positive_color = cmap[-1], cmap[0] - else: - if color: - negative_color, positive_color = color[0], color[1] - else: - negative_color, positive_color = "peru", "maroon" + if not color: + color = "blue" figsize, *_ = _scale_fig_size(figsize, textsize) if isinstance(y_hat, str): @@ -57,21 +49,29 @@ def plot_separation( widths = np.linspace(0, 1, len(y_hat_var)) delta = np.diff(widths).mean() - backend_kwargs["x_range"] = (delta, 1.5) + backend_kwargs["x_range"] = (0, 1) backend_kwargs["y_range"] = (0, 1) if ax is None: ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs,) for i, width in enumerate(widths): - bar_color, tag = (negative_color, False) if y[i] == 0 else (positive_color, True) + tag = False if y[i] == 0 else True label = "Positive class" if tag else "Negative class" - - ax.vbar(width, top=1, width=width, color=bar_color, legend_label=label, **plot_kwargs) + alpha = 0.3 if not tag else 1 + ax.vbar( + width, + top=1, + width=delta, + color=color, + fill_alpha=alpha, + legend_label=label, + **plot_kwargs + ) if y_hat_line: ax.line( - np.linspace(delta, 1.5, len(y_hat_var)), + np.linspace(0, 1, len(y_hat_var)), y_hat_var[idx], color="black", legend_label=label_line, @@ -88,6 +88,8 @@ def plot_separation( ) ax.axis.visible = False + ax.xgrid.grid_line_color = None + ax.ygrid.grid_line_color = None show_layout(ax, show) diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index 19ca2e8952..53fa7697eb 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -15,7 +15,6 @@ def plot_separation( figsize, textsize, color, - cmap, legend, ax, plot_kwargs, @@ -34,14 +33,8 @@ def plot_separation( **backend_kwargs, } - if cmap: - cmap = plt.get_cmap(cmap).colors - negative_color, positive_color = cmap[-1], cmap[0] - else: - if color: - negative_color, positive_color = color[0], color[1] - else: - negative_color, positive_color = "C1", "C3" + if not color: + color = "C0" (figsize, *_) = _scale_fig_size(figsize, textsize, 1, 1) backend_kwargs.setdefault("figsize", figsize) @@ -60,17 +53,26 @@ def plot_separation( y = idata.observed_data[y].values[idx].ravel() widths = np.linspace(0, 1, len(y_hat_var)) + delta = np.diff(widths).mean() for i, width in enumerate(widths): - bar_color, tag = (negative_color, False) if y[i] == 0 else (positive_color, True) + tag = False if y[i] == 0 else True label = "Positive class" if tag else "Negative class" - ax.bar(width, 1, width=width, color=bar_color, align="edge", label=label, **plot_kwargs) - - delta = np.diff(widths).mean() + alpha = 0.3 if not tag else 1 + ax.bar( + width, + 1, + width=delta, + color=color, + align="edge", + label=label, + alpha=alpha, + **plot_kwargs + ) if y_hat_line: ax.plot( - np.linspace(delta, 1.5, len(y_hat_var)), + np.linspace(0, 1, len(y_hat_var)), y_hat_var[idx], color="k", label=label_line, @@ -90,7 +92,7 @@ def plot_separation( ax.set_xticks([]) ax.set_yticks([]) - ax.set_xlim(delta, 1.5) + ax.set_xlim(0, 1) ax.set_ylim(0, 1) if backend_show(show): diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index f7041cb6dd..aa410ae0dd 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -12,7 +12,6 @@ def plot_separation( figsize=None, textsize=None, color=None, - cmap=None, legend=True, ax=None, plot_kwargs=None, @@ -42,12 +41,9 @@ def plot_separation( If None, size is (8 + numvars, 8 + numvars) textsize: int, optional Text size for labels. If None it will be autoscaled based on figsize. - color : list or array_like, optional - The first color will be used to plot the negative class while the second color will - be assigned to the positive class. - cmap : str, optional - Colors for the separation plot will be taken from both ends of the color map - respectively. + color : str, optional + Color to assign to the postive class. The negative class will be plotted using the + same color and an `alpha=0.3` transparency. legend : bool, optional Show the legend of the figure. ax: axes, optional @@ -81,7 +77,6 @@ def plot_separation( figsize=figsize, textsize=textsize, color=color, - cmap=cmap, legend=legend, ax=ax, plot_kwargs=plot_kwargs, From 63fb226a193f0ed05a6d1a79645657bf8b56ac87 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 25 Aug 2020 13:48:45 -0300 Subject: [PATCH 04/14] add example and tests pydocstyle --- arviz/plots/backends/bokeh/separationplot.py | 105 +++++++++++------ .../backends/matplotlib/separationplot.py | 108 +++++++++++------- arviz/plots/separationplot.py | 26 ++++- arviz/tests/base_tests/test_plots_bokeh.py | 17 +++ .../tests/base_tests/test_plots_matplotlib.py | 17 +++ 5 files changed, 196 insertions(+), 77 deletions(-) diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index a72ee4f626..435db2eb3b 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -1,9 +1,12 @@ -"""Matplotlib separation plot""" +"""Bokeh separation plot.""" +import warnings import numpy as np +import xarray as xr +from ....data import InferenceData from ...plot_utils import _scale_fig_size -from . import backend_kwarg_defaults, create_axes_grid from .. import show_layout +from . import backend_kwarg_defaults, create_axes_grid def plot_separation( @@ -18,6 +21,8 @@ def plot_separation( legend, # pylint: disable=unused-argument ax, plot_kwargs, + y_hat_line_kwargs, + exp_events_kwargs, backend_kwargs, show, ): @@ -28,63 +33,99 @@ def plot_separation( if plot_kwargs is None: plot_kwargs = {} + plot_kwargs.setdefault("color", "royalblue") + if color: + plot_kwargs["color"] = color + backend_kwargs = { **backend_kwarg_defaults(), **backend_kwargs, } - if not color: - color = "blue" + if y_hat_line_kwargs is None: + y_hat_line_kwargs = {} - figsize, *_ = _scale_fig_size(figsize, textsize) - if isinstance(y_hat, str): - y_hat_var = idata.posterior_predictive[y_hat].values.mean(1).mean(0) - label_line = y_hat + y_hat_line_kwargs.setdefault("color", "black") - idx = np.argsort(y_hat_var) + if exp_events_kwargs is None: + exp_events_kwargs = {} - if isinstance(y, str): - y = idata.observed_data[y].values[idx].ravel() + exp_events_kwargs.setdefault("color", "black") + + figsize, *_ = _scale_fig_size(figsize, textsize) + + if idata is not None and not isinstance(idata, InferenceData): + raise ValueError("idata must be of type InferenceData or None") + + if idata is None: + if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): + raise ValueError( + "y and y_hat must be array or DataArray when idata is None " + "but they are of types {}".format([type(arg) for arg in (y, y_hat)]) + ) + else: + + if y_hat is None and isinstance(y, str): + label_y_hat = y + y_hat = y + elif y_hat is None: + raise ValueError("y_hat cannot be None if y is not a str") + + if isinstance(y, str): + y = idata.observed_data[y].values + elif not isinstance(y, (np.ndarray, xr.DataArray)): + raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) + + if isinstance(y_hat, str): + label_y_hat = y_hat + y_hat = idata.posterior_predictive[y_hat].mean(axis=(1, 0)).values + elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): + raise ValueError( + "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) + ) + + idx = np.argsort(y_hat) + + if len(y) != len(y_hat): + warnings.warn( + "y and y_hat must be the same lenght", UserWarning, + ) - widths = np.linspace(0, 1, len(y_hat_var)) - delta = np.diff(widths).mean() + locs = np.linspace(0, 1, len(y_hat)) + width = np.diff(locs).mean() backend_kwargs["x_range"] = (0, 1) backend_kwargs["y_range"] = (0, 1) if ax is None: - ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs,) + ax = create_axes_grid(1, figsize=figsize, squeeze=True, backend_kwargs=backend_kwargs) - for i, width in enumerate(widths): - tag = False if y[i] == 0 else True - label = "Positive class" if tag else "Negative class" - alpha = 0.3 if not tag else 1 + for i, loc in enumerate(locs): + positive = not y[idx][i] == 0 + label = "Positive class" if positive else "Negative class" + alpha = 1 if positive else 0.3 ax.vbar( - width, + loc, top=1, - width=delta, - color=color, + width=width, fill_alpha=alpha, legend_label=label, - **plot_kwargs + line_alpha=alpha, + **plot_kwargs, ) if y_hat_line: ax.line( - np.linspace(0, 1, len(y_hat_var)), - y_hat_var[idx], - color="black", - legend_label=label_line, + np.linspace(0, 1, len(y_hat)), + y_hat[idx], + legend_label=label_y_hat, + **y_hat_line_kwargs, ) if expected_events: - expected_events = int(np.round(np.sum(y_hat_var))) + expected_events = int(np.round(np.sum(y_hat))) ax.triangle( - y_hat_var[idx][expected_events], - 0, - color="black", - legend_label="Expected events", - **plot_kwargs + y_hat[idx][expected_events - 1], 0, legend_label="Expected events", **exp_events_kwargs, ) ax.axis.visible = False diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index 53fa7697eb..c80188c4d7 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -1,7 +1,10 @@ -"""Matplotlib separation plot""" +"""Matplotlib separation plot.""" +import warnings import matplotlib.pyplot as plt import numpy as np +import xarray as xr +from ....data import InferenceData from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid @@ -18,6 +21,8 @@ def plot_separation( legend, ax, plot_kwargs, + y_hat_line_kwargs, + exp_events_kwargs, backend_kwargs, show, ): @@ -28,14 +33,26 @@ def plot_separation( if plot_kwargs is None: plot_kwargs = {} + plot_kwargs.setdefault("color", "C0") + if color: + plot_kwargs["color"] = color + + if y_hat_line_kwargs is None: + y_hat_line_kwargs = {} + + y_hat_line_kwargs.setdefault("color", "k") + + if exp_events_kwargs is None: + exp_events_kwargs = {} + + exp_events_kwargs.setdefault("color", "k") + exp_events_kwargs.setdefault("marker", "^") + backend_kwargs = { **backend_kwarg_defaults(), **backend_kwargs, } - if not color: - color = "C0" - (figsize, *_) = _scale_fig_size(figsize, textsize, 1, 1) backend_kwargs.setdefault("figsize", figsize) backend_kwargs["squeeze"] = True @@ -43,47 +60,58 @@ def plot_separation( if ax is None: _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs) - if isinstance(y_hat, str): - y_hat_var = idata.posterior_predictive[y_hat].values.mean(1).mean(0) - label_line = y_hat - - idx = np.argsort(y_hat_var) - - if isinstance(y, str): - y = idata.observed_data[y].values[idx].ravel() - - widths = np.linspace(0, 1, len(y_hat_var)) - delta = np.diff(widths).mean() - - for i, width in enumerate(widths): - tag = False if y[i] == 0 else True - label = "Positive class" if tag else "Negative class" - alpha = 0.3 if not tag else 1 - ax.bar( - width, - 1, - width=delta, - color=color, - align="edge", - label=label, - alpha=alpha, - **plot_kwargs + if idata is not None and not isinstance(idata, InferenceData): + raise ValueError("idata must be of type InferenceData or None") + + if idata is None: + if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): + raise ValueError( + "y and y_hat must be array or DataArray when idata is None " + "but they are of types {}".format([type(arg) for arg in (y, y_hat)]) + ) + else: + + if y_hat is None and isinstance(y, str): + label_y_hat = y + y_hat = y + elif y_hat is None: + raise ValueError("y_hat cannot be None if y is not a str") + + if isinstance(y, str): + y = idata.observed_data[y].values + elif not isinstance(y, (np.ndarray, xr.DataArray)): + raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) + + if isinstance(y_hat, str): + label_y_hat = y_hat + y_hat = idata.posterior_predictive[y_hat].mean(axis=(1, 0)).values + elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): + raise ValueError( + "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) + ) + + idx = np.argsort(y_hat) + + if len(y) != len(y_hat): + warnings.warn( + "y and y_hat must be the same lenght", UserWarning, ) + locs = np.linspace(0, 1, len(y_hat)) + width = np.diff(locs).mean() + + for i, loc in enumerate(locs): + positive = not y[idx][i] == 0 + label = "Positive class" if positive else "Negative class" + alpha = 1 if positive else 0.3 + ax.bar(loc, 1, width=width, label=label, alpha=alpha, **plot_kwargs) + if y_hat_line: - ax.plot( - np.linspace(0, 1, len(y_hat_var)), - y_hat_var[idx], - color="k", - label=label_line, - **plot_kwargs - ) + ax.plot(np.linspace(0, 1, len(y_hat)), y_hat[idx], label=label_y_hat, **y_hat_line_kwargs) if expected_events: - expected_events = int(np.round(np.sum(y_hat_var))) - ax.scatter( - y_hat_var[idx][expected_events], 0, marker="^", color="k", label="Expected events", - ) + expected_events = int(np.round(np.sum(y_hat))) + ax.scatter(y_hat[idx][expected_events - 1], 0, label="Expected events", **exp_events_kwargs) if legend: handles, labels = plt.gca().get_legend_handles_labels() diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index aa410ae0dd..babe9b9745 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -15,11 +15,12 @@ def plot_separation( legend=True, ax=None, plot_kwargs=None, + y_hat_line_kwargs=None, + exp_events_kwargs=None, backend=None, backend_kwargs=None, show=None, ): - """Plot separation plot for discrete outcome models. Parameters @@ -29,10 +30,8 @@ def plot_separation( y : array, DataArray or str Observed data. If str, idata must be present and contain the observed data group y_hat : array, DataArray or str - Posterior predictive samples for ``y``. It must have the same shape as y plus an - extra dimension at the end of size n_samples (chains and draws stacked). If str or - None, idata must contain the posterior predictive group. If None, y_hat is taken - equal to y, thus, y must be str too. + Posterior predictive samples for ``y``. It must have the same shape as y. If str or + None, idata must contain the posterior predictive group. y_hat_line : bool, optional Plot the sorted `y_hat` predictions. expected_events : bool, optional @@ -49,7 +48,11 @@ def plot_separation( ax: axes, optional Matplotlib axes or bokeh figures. plot_kwargs : dict, optional + Additional keywords passed to ax.bar for separation plot. + y_hat_line_kwargs : dict, optional Additional keywords passed to ax.plot for `y_hat` line. + exp_events_kwargs : dict, optional + Additional keywords passed to ax.scatter for expected_events marker. backend: str, optional Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". backend_kwargs: bool, optional @@ -66,6 +69,17 @@ def plot_separation( ---------- * Greenhill, B. et al (2011) see https://doi.org/10.1111/j.1540-5907.2011.00525.x + Examples + -------- + Separation plot for a logistic regression model. + + .. plot:: + :context: close-figs + + >>> import arviz as az + >>> idata = az.load_arviz_data('classification10d') + >>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1)) + """ separation_kwargs = dict( @@ -80,6 +94,8 @@ def plot_separation( legend=legend, ax=ax, plot_kwargs=plot_kwargs, + y_hat_line_kwargs=y_hat_line_kwargs, + exp_events_kwargs=exp_events_kwargs, backend_kwargs=backend_kwargs, show=show, ) diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 1318019090..1cada4a1a8 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -29,6 +29,7 @@ plot_posterior, plot_ppc, plot_rank, + plot_separation, plot_trace, plot_violin, ) @@ -132,6 +133,22 @@ def test_plot_density_bad_kwargs(models): plot_density(obj, hdi_prob=2, backend="bokeh", show=False) +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"y_hat_line": True}, + {"expected_events": True}, + {"y_hat_line_kwargs": {"linestyle": "dotted"}}, + {"exp_events_kwargs": {"marker": "o"}}, + ], +) +def test_plot_separation(kwargs): + idata = load_arviz_data("classification10d") + ax = plot_separation(idata=idata, y="outcome", backend="bokeh", show=False, **kwargs) + assert ax + + @pytest.mark.parametrize( "kwargs", [ diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index e8f13b66a4..6e9c1f4fcb 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -33,6 +33,7 @@ plot_posterior, plot_ppc, plot_rank, + plot_separation, plot_trace, plot_violin, ) @@ -159,6 +160,22 @@ def test_plot_density_bad_kwargs(models): plot_density(obj, hdi_prob=2) +@pytest.mark.parametrize( + "kwargs", + [ + {}, + {"y_hat_line": True}, + {"expected_events": True}, + {"y_hat_line_kwargs": {"linestyle": "dotted"}}, + {"exp_events_kwargs": {"marker": "o"}}, + ], +) +def test_plot_separation(kwargs): + idata = load_arviz_data("classification10d") + ax = plot_separation(idata=idata, y="outcome", **kwargs) + assert ax + + @pytest.mark.parametrize( "kwargs", [ From 9825389f56a984e9b7267daa34e3a69c81e49c12 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 26 Aug 2020 12:36:21 -0300 Subject: [PATCH 05/14] add gallery examples --- arviz/plots/backends/bokeh/separationplot.py | 13 ++++--------- .../plots/backends/matplotlib/separationplot.py | 5 +++-- arviz/plots/separationplot.py | 8 +++++--- doc/api.rst | 1 + examples/bokeh/bokeh_plot_separation.py | 11 +++++++++++ examples/matplotlib/mpl_plot_separation.py | 17 +++++++++++++++++ 6 files changed, 41 insertions(+), 14 deletions(-) create mode 100644 examples/bokeh/bokeh_plot_separation.py create mode 100644 examples/matplotlib/mpl_plot_separation.py diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index 435db2eb3b..c818e6c3cd 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -33,7 +33,7 @@ def plot_separation( if plot_kwargs is None: plot_kwargs = {} - plot_kwargs.setdefault("color", "royalblue") + plot_kwargs.setdefault("color", "#2a2eec") if color: plot_kwargs["color"] = color @@ -46,11 +46,13 @@ def plot_separation( y_hat_line_kwargs = {} y_hat_line_kwargs.setdefault("color", "black") + y_hat_line_kwargs.setdefault("line_width", 2) if exp_events_kwargs is None: exp_events_kwargs = {} exp_events_kwargs.setdefault("color", "black") + exp_events_kwargs.setdefault("size", 15) figsize, *_ = _scale_fig_size(figsize, textsize) @@ -102,16 +104,9 @@ def plot_separation( for i, loc in enumerate(locs): positive = not y[idx][i] == 0 - label = "Positive class" if positive else "Negative class" alpha = 1 if positive else 0.3 ax.vbar( - loc, - top=1, - width=width, - fill_alpha=alpha, - legend_label=label, - line_alpha=alpha, - **plot_kwargs, + loc, top=1, width=width, fill_alpha=alpha, line_alpha=alpha, **plot_kwargs, ) if y_hat_line: diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index c80188c4d7..205a2288ce 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -47,6 +47,8 @@ def plot_separation( exp_events_kwargs.setdefault("color", "k") exp_events_kwargs.setdefault("marker", "^") + exp_events_kwargs.setdefault("s", 100) + exp_events_kwargs.setdefault("zorder", 2) backend_kwargs = { **backend_kwarg_defaults(), @@ -102,9 +104,8 @@ def plot_separation( for i, loc in enumerate(locs): positive = not y[idx][i] == 0 - label = "Positive class" if positive else "Negative class" alpha = 1 if positive else 0.3 - ax.bar(loc, 1, width=width, label=label, alpha=alpha, **plot_kwargs) + ax.bar(loc, 1, width=width, alpha=alpha, **plot_kwargs) if y_hat_line: ax.plot(np.linspace(0, 1, len(y_hat)), y_hat[idx], label=label_y_hat, **y_hat_line_kwargs) diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index babe9b9745..13528fa61c 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -1,4 +1,4 @@ -"""Plot separation plot for discrete outcome models.""" +"""Separation plot for discrete outcome models.""" from ..rcparams import rcParams from .plot_utils import get_plotting_function @@ -21,7 +21,10 @@ def plot_separation( backend_kwargs=None, show=None, ): - """Plot separation plot for discrete outcome models. + """Separation plot for binary outcome models. + + Model predictions are sorted and plotted using a color code according to + the observed data. Parameters ---------- @@ -81,7 +84,6 @@ def plot_separation( >>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1)) """ - separation_kwargs = dict( idata=idata, y=y, diff --git a/doc/api.rst b/doc/api.rst index 58c02bd523..ee98866068 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -35,6 +35,7 @@ Plots plot_ppc plot_rank plot_trace + plot_separation plot_violin .. _stats_api: diff --git a/examples/bokeh/bokeh_plot_separation.py b/examples/bokeh/bokeh_plot_separation.py new file mode 100644 index 0000000000..82abe87536 --- /dev/null +++ b/examples/bokeh/bokeh_plot_separation.py @@ -0,0 +1,11 @@ +""" +Separationplot +========== + +_thumb: .2, .8 +""" +import arviz as az + +idata = az.load_arviz_data("classification10d") + +ax = az.plot_separation(idata=idata, y="outcome", y_hat="outcome", figsize=(8, 1), backend="bokeh") diff --git a/examples/matplotlib/mpl_plot_separation.py b/examples/matplotlib/mpl_plot_separation.py new file mode 100644 index 0000000000..8d0e8370a8 --- /dev/null +++ b/examples/matplotlib/mpl_plot_separation.py @@ -0,0 +1,17 @@ +""" +Separationplot +========== + +_thumb: .2, .8 +""" +import matplotlib.pyplot as plt + +import arviz as az + +az.style.use("arviz-darkgrid") + +idata = az.load_arviz_data("classification10d") + +az.plot_separation(idata=idata, y="outcome", y_hat="outcome", figsize=(8, 1)) + +plt.show() From c4cc963a40ac55ad446b4aecdd2e3c32388bf3da Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 26 Aug 2020 13:11:56 -0300 Subject: [PATCH 06/14] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ec23cde5f6..733a8c1089 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * Added `to_dict` method for InferenceData object ([1223](https://github.com/arviz-devs/arviz/pull/1223)) * Added `circ_var_names` argument to `plot_trace` allowing for circular traceplot (Matplotlib) ([1336](https://github.com/arviz-devs/arviz/pull/1336)) * Ridgeplot is hdi aware. By default displays truncated densities at the specified `hdi_prop` level ([1348](https://github.com/arviz-devs/arviz/pull/1348)) +* Added `plot_separation` ([1369](https://github.com/arviz-devs/arviz/pull/1359)) ### Maintenance and fixes From 64b1158e27d7441638799926ebdcf1faa7a928f8 Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 26 Aug 2020 20:08:24 -0300 Subject: [PATCH 07/14] run black --- CHANGELOG.md | 2 +- arviz/plots/backends/bokeh/separationplot.py | 15 ++++++++++++--- arviz/plots/backends/matplotlib/separationplot.py | 3 ++- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 733a8c1089..8a18c0dbfd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,7 +7,7 @@ * Added `to_dict` method for InferenceData object ([1223](https://github.com/arviz-devs/arviz/pull/1223)) * Added `circ_var_names` argument to `plot_trace` allowing for circular traceplot (Matplotlib) ([1336](https://github.com/arviz-devs/arviz/pull/1336)) * Ridgeplot is hdi aware. By default displays truncated densities at the specified `hdi_prop` level ([1348](https://github.com/arviz-devs/arviz/pull/1348)) -* Added `plot_separation` ([1369](https://github.com/arviz-devs/arviz/pull/1359)) +* Added `plot_separation` ([1359](https://github.com/arviz-devs/arviz/pull/1359)) ### Maintenance and fixes diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index c818e6c3cd..44c186c5f8 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -90,7 +90,8 @@ def plot_separation( if len(y) != len(y_hat): warnings.warn( - "y and y_hat must be the same lenght", UserWarning, + "y and y_hat must be the same lenght", + UserWarning, ) locs = np.linspace(0, 1, len(y_hat)) @@ -106,7 +107,12 @@ def plot_separation( positive = not y[idx][i] == 0 alpha = 1 if positive else 0.3 ax.vbar( - loc, top=1, width=width, fill_alpha=alpha, line_alpha=alpha, **plot_kwargs, + loc, + top=1, + width=width, + fill_alpha=alpha, + line_alpha=alpha, + **plot_kwargs, ) if y_hat_line: @@ -120,7 +126,10 @@ def plot_separation( if expected_events: expected_events = int(np.round(np.sum(y_hat))) ax.triangle( - y_hat[idx][expected_events - 1], 0, legend_label="Expected events", **exp_events_kwargs, + y_hat[idx][expected_events - 1], + 0, + legend_label="Expected events", + **exp_events_kwargs, ) ax.axis.visible = False diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index 205a2288ce..055ad89221 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -96,7 +96,8 @@ def plot_separation( if len(y) != len(y_hat): warnings.warn( - "y and y_hat must be the same lenght", UserWarning, + "y and y_hat must be the same lenght", + UserWarning, ) locs = np.linspace(0, 1, len(y_hat)) From 0d6ac2d827f9ba53d367d52dcea77566e8462afb Mon Sep 17 00:00:00 2001 From: unknown Date: Wed, 26 Aug 2020 20:15:46 -0300 Subject: [PATCH 08/14] fix legend --- arviz/plots/backends/matplotlib/separationplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index 055ad89221..435eca95dc 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -115,7 +115,7 @@ def plot_separation( expected_events = int(np.round(np.sum(y_hat))) ax.scatter(y_hat[idx][expected_events - 1], 0, label="Expected events", **exp_events_kwargs) - if legend: + if legend and expected_events or y_hat_line: handles, labels = plt.gca().get_legend_handles_labels() labels_dict = dict(zip(labels, handles)) ax.legend(labels_dict.values(), labels_dict.keys()) From 0d3194f27fcc8a5b6fea17832f60114eab3620c6 Mon Sep 17 00:00:00 2001 From: Agustina Arroyuelo Date: Thu, 27 Aug 2020 09:43:16 -0300 Subject: [PATCH 09/14] use labeled dimensions Co-authored-by: Oriol Abril-Pla --- arviz/plots/backends/bokeh/separationplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index 44c186c5f8..dda7e3c18e 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -80,7 +80,7 @@ def plot_separation( if isinstance(y_hat, str): label_y_hat = y_hat - y_hat = idata.posterior_predictive[y_hat].mean(axis=(1, 0)).values + y_hat = idata.posterior_predictive[y_hat].mean(dim=("chain", "draw")).values elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): raise ValueError( "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) From 6e652af489eacab14df69406cb73864a3f83f795 Mon Sep 17 00:00:00 2001 From: Agustina Arroyuelo Date: Thu, 27 Aug 2020 09:43:54 -0300 Subject: [PATCH 10/14] update docstring Co-authored-by: Oriol Abril-Pla --- arviz/plots/separationplot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index 13528fa61c..239787b634 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -51,7 +51,7 @@ def plot_separation( ax: axes, optional Matplotlib axes or bokeh figures. plot_kwargs : dict, optional - Additional keywords passed to ax.bar for separation plot. + Additional keywords passed to :meth:`mpl:matplotlib.axes.Axes.bar` or :meth:`bokeh:bokeh.plotting.Figure.vbar` for separation plot. y_hat_line_kwargs : dict, optional Additional keywords passed to ax.plot for `y_hat` line. exp_events_kwargs : dict, optional From ca5423416f7a06d917278abddf7df7ae1eb3d3d2 Mon Sep 17 00:00:00 2001 From: Agustina Arroyuelo Date: Thu, 27 Aug 2020 09:44:12 -0300 Subject: [PATCH 11/14] update doc/api.rst Co-authored-by: Oriol Abril-Pla --- doc/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/api.rst b/doc/api.rst index ee98866068..d4d093bed2 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -34,8 +34,8 @@ Plots plot_posterior plot_ppc plot_rank - plot_trace plot_separation + plot_trace plot_violin .. _stats_api: From e5a3c330ab39a0c58f63f9a63b3519c1d118f6d3 Mon Sep 17 00:00:00 2001 From: unknown Date: Thu, 27 Aug 2020 11:10:43 -0300 Subject: [PATCH 12/14] change expected events plot --- arviz/plots/backends/bokeh/separationplot.py | 44 +++-------------- .../backends/matplotlib/separationplot.py | 49 +++++-------------- arviz/plots/separationplot.py | 41 ++++++++++++++-- 3 files changed, 55 insertions(+), 79 deletions(-) diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index dda7e3c18e..988cc387df 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -1,19 +1,17 @@ """Bokeh separation plot.""" import warnings import numpy as np -import xarray as xr -from ....data import InferenceData -from ...plot_utils import _scale_fig_size +from ...plot_utils import _scale_fig_size, vectorized_to_hex from .. import show_layout from . import backend_kwarg_defaults, create_axes_grid def plot_separation( - idata, y, y_hat, y_hat_line, + label_y_hat, expected_events, figsize, textsize, @@ -33,9 +31,9 @@ def plot_separation( if plot_kwargs is None: plot_kwargs = {} - plot_kwargs.setdefault("color", "#2a2eec") - if color: - plot_kwargs["color"] = color + # plot_kwargs.setdefault("color", "#2a2eec") + # if color: + plot_kwargs["color"] = vectorized_to_hex(color) backend_kwargs = { **backend_kwarg_defaults(), @@ -56,36 +54,6 @@ def plot_separation( figsize, *_ = _scale_fig_size(figsize, textsize) - if idata is not None and not isinstance(idata, InferenceData): - raise ValueError("idata must be of type InferenceData or None") - - if idata is None: - if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): - raise ValueError( - "y and y_hat must be array or DataArray when idata is None " - "but they are of types {}".format([type(arg) for arg in (y, y_hat)]) - ) - else: - - if y_hat is None and isinstance(y, str): - label_y_hat = y - y_hat = y - elif y_hat is None: - raise ValueError("y_hat cannot be None if y is not a str") - - if isinstance(y, str): - y = idata.observed_data[y].values - elif not isinstance(y, (np.ndarray, xr.DataArray)): - raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) - - if isinstance(y_hat, str): - label_y_hat = y_hat - y_hat = idata.posterior_predictive[y_hat].mean(dim=("chain", "draw")).values - elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): - raise ValueError( - "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) - ) - idx = np.argsort(y_hat) if len(y) != len(y_hat): @@ -126,7 +94,7 @@ def plot_separation( if expected_events: expected_events = int(np.round(np.sum(y_hat))) ax.triangle( - y_hat[idx][expected_events - 1], + y_hat[idx][len(y_hat) - expected_events - 1], 0, legend_label="Expected events", **exp_events_kwargs, diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index 435eca95dc..5ca9ebe985 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -2,18 +2,16 @@ import warnings import matplotlib.pyplot as plt import numpy as np -import xarray as xr -from ....data import InferenceData from ...plot_utils import _scale_fig_size from . import backend_kwarg_defaults, backend_show, create_axes_grid def plot_separation( - idata, y, y_hat, y_hat_line, + label_y_hat, expected_events, figsize, textsize, @@ -33,9 +31,9 @@ def plot_separation( if plot_kwargs is None: plot_kwargs = {} - plot_kwargs.setdefault("color", "C0") - if color: - plot_kwargs["color"] = color + # plot_kwargs.setdefault("color", "C0") + # if color: + plot_kwargs["color"] = color if y_hat_line_kwargs is None: y_hat_line_kwargs = {} @@ -62,36 +60,6 @@ def plot_separation( if ax is None: _, ax = create_axes_grid(1, backend_kwargs=backend_kwargs) - if idata is not None and not isinstance(idata, InferenceData): - raise ValueError("idata must be of type InferenceData or None") - - if idata is None: - if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): - raise ValueError( - "y and y_hat must be array or DataArray when idata is None " - "but they are of types {}".format([type(arg) for arg in (y, y_hat)]) - ) - else: - - if y_hat is None and isinstance(y, str): - label_y_hat = y - y_hat = y - elif y_hat is None: - raise ValueError("y_hat cannot be None if y is not a str") - - if isinstance(y, str): - y = idata.observed_data[y].values - elif not isinstance(y, (np.ndarray, xr.DataArray)): - raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) - - if isinstance(y_hat, str): - label_y_hat = y_hat - y_hat = idata.posterior_predictive[y_hat].mean(axis=(1, 0)).values - elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): - raise ValueError( - "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) - ) - idx = np.argsort(y_hat) if len(y) != len(y_hat): @@ -113,10 +81,15 @@ def plot_separation( if expected_events: expected_events = int(np.round(np.sum(y_hat))) - ax.scatter(y_hat[idx][expected_events - 1], 0, label="Expected events", **exp_events_kwargs) + ax.scatter( + y_hat[idx][len(y_hat) - expected_events - 1], + 0, + label="Expected events", + **exp_events_kwargs + ) if legend and expected_events or y_hat_line: - handles, labels = plt.gca().get_legend_handles_labels() + handles, labels = ax.get_legend_handles_labels() labels_dict = dict(zip(labels, handles)) ax.legend(labels_dict.values(), labels_dict.keys()) diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index 239787b634..e75902e4c8 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -1,4 +1,8 @@ """Separation plot for discrete outcome models.""" +import numpy as np +import xarray as xr + +from ..data import InferenceData from ..rcparams import rcParams from .plot_utils import get_plotting_function @@ -11,7 +15,7 @@ def plot_separation( expected_events=False, figsize=None, textsize=None, - color=None, + color="C0", legend=True, ax=None, plot_kwargs=None, @@ -51,7 +55,8 @@ def plot_separation( ax: axes, optional Matplotlib axes or bokeh figures. plot_kwargs : dict, optional - Additional keywords passed to :meth:`mpl:matplotlib.axes.Axes.bar` or :meth:`bokeh:bokeh.plotting.Figure.vbar` for separation plot. + Additional keywords passed to :meth:`mpl:matplotlib.axes.Axes.bar` or + :meth:`bokeh:bokeh.plotting.Figure.vbar` for separation plot. y_hat_line_kwargs : dict, optional Additional keywords passed to ax.plot for `y_hat` line. exp_events_kwargs : dict, optional @@ -84,11 +89,41 @@ def plot_separation( >>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1)) """ + if idata is not None and not isinstance(idata, InferenceData): + raise ValueError("idata must be of type InferenceData or None") + + if idata is None: + if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat)): + raise ValueError( + "y and y_hat must be array or DataArray when idata is None " + "but they are of types {}".format([type(arg) for arg in (y, y_hat)]) + ) + else: + + if y_hat is None and isinstance(y, str): + label_y_hat = y + y_hat = y + elif y_hat is None: + raise ValueError("y_hat cannot be None if y is not a str") + + if isinstance(y, str): + y = idata.observed_data[y].values + elif not isinstance(y, (np.ndarray, xr.DataArray)): + raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y))) + + if isinstance(y_hat, str): + label_y_hat = y_hat + y_hat = idata.posterior_predictive[y_hat].mean(dim=("chain", "draw")).values + elif not isinstance(y_hat, (np.ndarray, xr.DataArray)): + raise ValueError( + "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) + ) + separation_kwargs = dict( - idata=idata, y=y, y_hat=y_hat, y_hat_line=y_hat_line, + label_y_hat=label_y_hat, expected_events=expected_events, figsize=figsize, textsize=textsize, From 5b4f80ff3ff79cf7504fa02cdd5f5dd6768b0b7d Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 28 Aug 2020 09:14:19 -0300 Subject: [PATCH 13/14] fix label --- arviz/plots/separationplot.py | 1 + 1 file changed, 1 insertion(+) diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index e75902e4c8..0430def4f4 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -89,6 +89,7 @@ def plot_separation( >>> az.plot_separation(idata=idata, y='outcome', y_hat='outcome', figsize=(8, 1)) """ + label_y_hat = "y_hat" if idata is not None and not isinstance(idata, InferenceData): raise ValueError("idata must be of type InferenceData or None") From cb91fa1c336597a2598f58e8f9bc8c9c617f8bda Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 1 Sep 2020 09:17:16 -0300 Subject: [PATCH 14/14] move locs and width to general function --- arviz/plots/backends/bokeh/separationplot.py | 12 ++---------- arviz/plots/backends/matplotlib/separationplot.py | 12 ++---------- arviz/plots/separationplot.py | 12 ++++++++++++ 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/arviz/plots/backends/bokeh/separationplot.py b/arviz/plots/backends/bokeh/separationplot.py index 988cc387df..78c153b3ec 100644 --- a/arviz/plots/backends/bokeh/separationplot.py +++ b/arviz/plots/backends/bokeh/separationplot.py @@ -1,5 +1,4 @@ """Bokeh separation plot.""" -import warnings import numpy as np from ...plot_utils import _scale_fig_size, vectorized_to_hex @@ -17,6 +16,8 @@ def plot_separation( textsize, color, legend, # pylint: disable=unused-argument + locs, + width, ax, plot_kwargs, y_hat_line_kwargs, @@ -56,15 +57,6 @@ def plot_separation( idx = np.argsort(y_hat) - if len(y) != len(y_hat): - warnings.warn( - "y and y_hat must be the same lenght", - UserWarning, - ) - - locs = np.linspace(0, 1, len(y_hat)) - width = np.diff(locs).mean() - backend_kwargs["x_range"] = (0, 1) backend_kwargs["y_range"] = (0, 1) diff --git a/arviz/plots/backends/matplotlib/separationplot.py b/arviz/plots/backends/matplotlib/separationplot.py index 5ca9ebe985..3b8bc5bb57 100644 --- a/arviz/plots/backends/matplotlib/separationplot.py +++ b/arviz/plots/backends/matplotlib/separationplot.py @@ -1,5 +1,4 @@ """Matplotlib separation plot.""" -import warnings import matplotlib.pyplot as plt import numpy as np @@ -17,6 +16,8 @@ def plot_separation( textsize, color, legend, + locs, + width, ax, plot_kwargs, y_hat_line_kwargs, @@ -62,15 +63,6 @@ def plot_separation( idx = np.argsort(y_hat) - if len(y) != len(y_hat): - warnings.warn( - "y and y_hat must be the same lenght", - UserWarning, - ) - - locs = np.linspace(0, 1, len(y_hat)) - width = np.diff(locs).mean() - for i, loc in enumerate(locs): positive = not y[idx][i] == 0 alpha = 1 if positive else 0.3 diff --git a/arviz/plots/separationplot.py b/arviz/plots/separationplot.py index 0430def4f4..c91871e18d 100644 --- a/arviz/plots/separationplot.py +++ b/arviz/plots/separationplot.py @@ -1,4 +1,5 @@ """Separation plot for discrete outcome models.""" +import warnings import numpy as np import xarray as xr @@ -120,6 +121,15 @@ def plot_separation( "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat)) ) + if len(y) != len(y_hat): + warnings.warn( + "y and y_hat must be the same lenght", + UserWarning, + ) + + locs = np.linspace(0, 1, len(y_hat)) + width = np.diff(locs).mean() + separation_kwargs = dict( y=y, y_hat=y_hat, @@ -130,6 +140,8 @@ def plot_separation( textsize=textsize, color=color, legend=legend, + locs=locs, + width=width, ax=ax, plot_kwargs=plot_kwargs, y_hat_line_kwargs=y_hat_line_kwargs,