diff --git a/CHANGELOG.md b/CHANGELOG.md index c1c1168b79..3da6aa0b74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * Added `html_repr` of InferenceData objects for jupyter notebooks. (#1217) * Added support for PyJAGS via the function `from_pyjags`. (#1219 and #1245) * `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228, #1240 and #1249) +* `plot_hdi` can now take already computed HDI values (#1241) ### Maintenance and fixes * Include data from `MultiObservedRV` to `observed_data` when using @@ -16,12 +17,14 @@ * Added `log_likelihood` argument to `from_pyro` and a warning if log likelihood cannot be obtained (#1227) * Skip tests on matplotlib animations if ffmpeg is not installed (#1227) * Fix hpd bug where arguments were being ignored (#1236) +* Remove false positive warning in `plot_hdi` and fixed matplotlib axes generation (#1241) * Change the default `zorder` of scatter points from `0` to `0.6` in `plot_pair` (#1246) * Update `get_bins` for numpy 1.19 compatibility (#1256) ### Deprecation * Using `from_pymc3` without a model context available now raises a `FutureWarning` and will be deprecated in a future version (#1227) +* `hdi` with 2d data raises a FutureWarning (#1241) ### Documentation * A section has been added to the documentation at InferenceDataCookbook.ipynb illustrating the use of ArviZ in conjunction with PyJAGS. (#1219 and #1245) diff --git a/arviz/plots/backends/bokeh/hdiplot.py b/arviz/plots/backends/bokeh/hdiplot.py index cf511716eb..0ce80c22d3 100644 --- a/arviz/plots/backends/bokeh/hdiplot.py +++ b/arviz/plots/backends/bokeh/hdiplot.py @@ -1,8 +1,5 @@ """Bokeh hdiplot.""" -from itertools import cycle - import bokeh.plotting as bkp -from matplotlib.pyplot import rcParams as mpl_rcParams import numpy as np from . import backend_kwarg_defaults @@ -21,26 +18,10 @@ def plot_hdi(ax, x_data, y_data, plot_kwargs, fill_kwargs, backend_kwargs, show) if ax is None: ax = bkp.figure(**backend_kwargs) - color = plot_kwargs.pop("color") - if len(color) == 2 and color[0] == "C": - color = [ - prop - for _, prop in zip( - range(int(color[1:])), cycle(mpl_rcParams["axes.prop_cycle"].by_key()["color"]) - ) - ][-1] - plot_kwargs.setdefault("line_color", color) + plot_kwargs.setdefault("line_color", plot_kwargs.pop("color")) plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0)) - color = fill_kwargs.pop("color") - if len(color) == 2 and color[0] == "C": - color = [ - prop - for _, prop in zip( - range(int(color[1:])), cycle(mpl_rcParams["axes.prop_cycle"].by_key()["color"]) - ) - ][-1] - fill_kwargs.setdefault("fill_color", color) + fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color")) fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0)) ax.patch( diff --git a/arviz/plots/backends/matplotlib/hdiplot.py b/arviz/plots/backends/matplotlib/hdiplot.py index 585acb40f3..d93ebec7bc 100644 --- a/arviz/plots/backends/matplotlib/hdiplot.py +++ b/arviz/plots/backends/matplotlib/hdiplot.py @@ -1,21 +1,21 @@ """Matplotlib hdiplot.""" -import warnings import matplotlib.pyplot as plt -from . import backend_show +from . import backend_kwarg_defaults, backend_show def plot_hdi(ax, x_data, y_data, plot_kwargs, fill_kwargs, backend_kwargs, show): """Matplotlib HDI plot.""" - if backend_kwargs is not None: - warnings.warn( - ( - "Argument backend_kwargs has not effect in matplotlib.plot_hdi" - "Supplied value won't be used" - ) - ) + if backend_kwargs is None: + backend_kwargs = {} + + backend_kwargs = { + **backend_kwarg_defaults(), + **backend_kwargs, + } if ax is None: - ax = plt.gca() + _, ax = plt.subplots(1, 1, **backend_kwargs) + ax.plot(x_data, y_data, **plot_kwargs) ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], **fill_kwargs) diff --git a/arviz/plots/hdiplot.py b/arviz/plots/hdiplot.py index 9585401e9e..62c9e948f2 100644 --- a/arviz/plots/hdiplot.py +++ b/arviz/plots/hdiplot.py @@ -4,23 +4,26 @@ import numpy as np from scipy.interpolate import griddata from scipy.signal import savgol_filter +from xarray import Dataset from ..stats import hdi -from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser +from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser, vectorized_to_hex from ..rcparams import rcParams from ..utils import credible_interval_warning def plot_hdi( x, - y, + y=None, hdi_prob=None, + hdi_data=None, color="C1", circular=False, smooth=True, smooth_kwargs=None, fill_kwargs=None, plot_kwargs=None, + hdi_kwargs=None, ax=None, backend=None, backend_kwargs=None, @@ -34,75 +37,124 @@ def plot_hdi( ---------- x : array-like Values to plot. - y : array-like - Values from which to compute the HDI. Assumed shape (chain, draw, \*shape). + y : array-like, optional + Values from which to compute the HDI. Assumed shape ``(chain, draw, \*shape)``. + Only optional if hdi_data is present. + hdi_data : array_like, optional + Precomputed HDI values to use. Assumed shape is ``(*x.shape, 2)``. hdi_prob : float, optional - Probability for the highest density interval. Defaults to 0.94. - color : str + Probability for the highest density interval. Defaults to ``stats.hdi_prob`` rcParam. + color : str, optional Color used for the limits of the HDI and fill. Should be a valid matplotlib color. circular : bool, optional Whether to compute the HDI taking into account `x` is a circular variable (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables). - smooth : boolean + smooth : boolean, optional If True the result will be smoothed by first computing a linear interpolation of the data over a regular grid and then applying the Savitzky-Golay filter to the interpolated data. Defaults to True. smooth_kwargs : dict, optional - Additional keywords modifying the Savitzky-Golay filter. See Scipy's documentation for - details. - fill_kwargs : dict - Keywords passed to `fill_between` (use fill_kwargs={'alpha': 0} to disable fill). - plot_kwargs : dict - Keywords passed to HDI limits. - ax: axes, optional + Additional keywords modifying the Savitzky-Golay filter. See + :func:`scipy:scipy.signal.savgol_filter` for details. + fill_kwargs : dict, optional + Keywords passed to :meth:`mpl:matplotlib.axes.Axes.fill_between` + (use fill_kwargs={'alpha': 0} to disable fill) or to + :meth:`bokeh:bokeh.plotting.figure.Figure.patch`. + plot_kwargs : dict, optional + HDI limits keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.plot` or + :meth:`bokeh:bokeh.plotting.figure.Figure.patch`. + hdi_kwargs : dict, optional + Keyword arguments passed to :func:`~arviz.hdi`. Ignored if ``hdi_data`` is present. + ax : axes, optional Matplotlib axes or bokeh figures. - backend: str, optional - Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib". - backend_kwargs: bool, optional - These are kwargs specific to the backend being used. For additional documentation - check the plotting method of the backend. + backend : {"matplotlib","bokeh"}, optional + Select plotting backend. + backend_kwargs : bool, optional + These are kwargs specific to the backend being used. Passed to ::`` show : bool, optional Call backend show function. - credible_interval: float, optional + credible_interval : float, optional Deprecated: Please see hdi_prob Returns ------- axes : matplotlib axes or bokeh figures + + See Also + -------- + hdi : Calculate highest density interval (HDI) of array for given probability. + + Examples + -------- + Plot HDI interval of simulated regression data using `y` argument: + + .. plot:: + :context: close-figs + + >>> import numpy as np + >>> import arviz as az + >>> x_data = np.random.normal(0, 1, 100) + >>> y_data = np.random.normal(2 + x_data * 0.5, 0.5, (2, 50, 100)) + >>> az.plot_hdi(x_data, y_data) + + Precalculate HDI interval per chain and plot separately: + + .. plot:: + :context: close-figs + + >>> hdi_data = az.hdi(y_data, input_core_dims=[["draw"]]) + >>> ax = az.plot_hdi(x_data, hdi_data=hdi_data[0], color="r", fill_kwargs={"alpha": .2}) + >>> az.plot_hdi(x_data, hdi_data=hdi_data[1], color="k", ax=ax, fill_kwargs={"alpha": .2}) + """ if credible_interval: hdi_prob = credible_interval_warning(credible_interval, hdi_prob) + if hdi_kwargs is None: + hdi_kwargs = {} plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot") - plot_kwargs.setdefault("color", color) + plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color)) plot_kwargs.setdefault("alpha", 0) - fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin") - fill_kwargs.setdefault("color", color) + fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between") + fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color)) fill_kwargs.setdefault("alpha", 0.5) x = np.asarray(x) - y = np.asarray(y) - x_shape = x.shape - y_shape = y.shape - if y_shape[-len(x_shape) :] != x_shape: - msg = "Dimension mismatch for x: {} and y: {}." - msg += " y-dimensions should be (chain, draw, *x.shape) or" - msg += " (draw, *x.shape)" - raise TypeError(msg.format(x_shape, y_shape)) - - if len(y_shape[: -len(x_shape)]) > 1: - new_shape = tuple([-1] + list(x_shape)) - y = y.reshape(new_shape) - - if hdi_prob is None: - hdi_prob = rcParams["stats.hdi_prob"] - else: - if not 1 >= hdi_prob > 0: - raise ValueError("The value of hdi_prob should be in the interval (0, 1]") - hdi_ = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False) + if y is None and hdi_data is None: + raise ValueError("One of {y, hdi_data} is required") + if hdi_data is not None and y is not None: + warnings.warn("Both y and hdi_data arguments present, ignoring y") + elif hdi_data is not None: + hdi_prob = ( + hdi_data.hdi.attrs.get("hdi_prob", np.nan) if hasattr(hdi_data, "hdi") else np.nan + ) + if isinstance(hdi_data, Dataset): + data_vars = list(hdi_data.data_vars) + if len(data_vars) != 1: + raise ValueError( + "Found several variables in hdi_data. Only single variable Datasets are " + "supported." + ) + hdi_data = hdi_data[data_vars[0]] + else: + y = np.asarray(y) + if hdi_prob is None: + hdi_prob = rcParams["stats.hdi_prob"] + else: + if not 1 >= hdi_prob > 0: + raise ValueError("The value of hdi_prob should be in the interval (0, 1]") + hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs) + + hdi_shape = hdi_data.shape + if hdi_shape[:-1] != x_shape: + msg = ( + "Dimension mismatch for x: {} and hdi: {}. Check the dimensions of y and" + "hdi_kwargs to make sure they are compatible" + ) + raise TypeError(msg.format(x_shape, hdi_shape)) if smooth: if smooth_kwargs is None: @@ -111,12 +163,12 @@ def plot_hdi( smooth_kwargs.setdefault("polyorder", 2) x_data = np.linspace(x.min(), x.max(), 200) x_data[0] = (x_data[0] + x_data[1]) / 2 - hdi_interp = griddata(x, hdi_, x_data) + hdi_interp = griddata(x, hdi_data, x_data) y_data = savgol_filter(hdi_interp, axis=0, **smooth_kwargs) else: idx = np.argsort(x) x_data = x[idx] - y_data = hdi_[idx] + y_data = hdi_data[idx] hdiplot_kwargs = dict( ax=ax, @@ -132,12 +184,11 @@ def plot_hdi( backend = rcParams["plot.backend"] backend = backend.lower() - # TODO: Add backend kwargs plot = get_plotting_function("plot_hdi", "hdiplot", backend) ax = plot(**hdiplot_kwargs) return ax def plot_hpd(*args, **kwargs): # noqa: D103 - warnings.warn("plot_hdi has been deprecated, please use plot_hdi", DeprecationWarning) + warnings.warn("plot_hpd has been deprecated, please use plot_hdi", DeprecationWarning) return plot_hdi(*args, **kwargs) diff --git a/arviz/plots/plot_utils.py b/arviz/plots/plot_utils.py index 14bc8057f3..a9c229e621 100644 --- a/arviz/plots/plot_utils.py +++ b/arviz/plots/plot_utils.py @@ -663,6 +663,7 @@ def matplotlib_kwarg_dealiaser(args, kind, backend="matplotlib"): "plot": mpl.lines.Line2D, "hist": mpl.patches.Patch, "hexbin": mpl.collections.PolyCollection, + "fill_between": mpl.collections.PolyCollection, "hlines": mpl.collections.LineCollection, "text": mpl.text.Text, "contour": mpl.contour.ContourSet, diff --git a/arviz/stats/stats.py b/arviz/stats/stats.py index e2503fe0fc..35b9922d5c 100644 --- a/arviz/stats/stats.py +++ b/arviz/stats/stats.py @@ -365,7 +365,7 @@ def hdi( **kwargs, ): """ - Calculate highest density interval (HDI) of array for given percentage. + Calculate highest density interval (HDI) of array for given probability. The HDI is the minimum width Bayesian credible interval (BCI). @@ -376,15 +376,15 @@ def hdi( Any object that can be converted to an az.InferenceData object. Refer to documentation of az.convert_to_dataset for details. hdi_prob: float, optional - HDI prob for which interval will be computed. Defaults to 0.94. + HDI prob for which interval will be computed. Defaults to ``stats.hdi_prob`` rcParam. circular: bool, optional Whether to compute the hdi taking into account `x` is a circular variable (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables). Only works if multimodal is False. - multimodal: bool + multimodal: bool, optional If true it may compute more than one hdi interval if the distribution is multimodal and the modes are well separated. - skipna: bool + skipna: bool, optional If true ignores nan values when computing the hdi interval. Defaults to false. group: str, optional Specifies which InferenceData group should be used to calculate hdi. @@ -403,14 +403,18 @@ def hdi( max_modes: int, optional Specifies the maximum number of modes for multimodal case. kwargs: dict, optional - Additional keywords passed to `wrap_xarray_ufunc`. - See the docstring of :obj:`wrap_xarray_ufunc method `. + Additional keywords passed to :func:`~arviz.wrap_xarray_ufunc`. Returns ------- np.ndarray or xarray.Dataset, depending upon input lower(s) and upper(s) values of the interval(s). + See Also + -------- + plot_hdi : Plot HDI intervals for regression data. + xarray.Dataset.quantile : Calculate quantiles of array for given probabilities. + Examples -------- Calculate the HDI of a Normal random variable: @@ -476,7 +480,12 @@ def hdi( return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data if isarray and ary.ndim == 2: - kwargs.setdefault("input_core_dims", [["chain"]]) + warnings.warn( + "hdi currently interprets 2d data as (draw, shape) but this will change in " + "a future release to (chain, draw) for coherence with other functions", + FutureWarning, + ) + ary = np.expand_dims(ary, 0) ary = convert_to_dataset(ary, group=group) if coords is not None: @@ -484,7 +493,10 @@ def hdi( var_names = _var_names(var_names, ary, filter_vars) ary = ary[var_names] if var_names else ary - hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs) + hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob)) + hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords( + {"hdi": hdi_coord} + ) hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data return hdi_data.x.values if isarray else hdi_data @@ -1161,13 +1173,9 @@ def summary( sd = posterior.std(dim=("chain", "draw"), ddof=1, skipna=skipna) - hdi_lower, hdi_higher = xr.apply_ufunc( - _make_ufunc(hdi, n_output=2), - posterior, - kwargs=dict(hdi_prob=hdi_prob, multimodal=False, skipna=skipna), - input_core_dims=(("chain", "draw"),), - output_core_dims=tuple([] for _ in range(2)), - ) + hdi_post = hdi(posterior, hdi_prob=hdi_prob, multimodal=False, skipna=skipna) + hdi_lower = hdi_post.sel(hdi="lower", drop=True) + hdi_higher = hdi_post.sel(hdi="higher", drop=True) if include_circ: nan_policy = "omit" if skipna else "propagate" @@ -1199,13 +1207,9 @@ def summary( input_core_dims=(("chain", "draw"),), ) - circ_hdi_lower, circ_hdi_higher = xr.apply_ufunc( - _make_ufunc(hdi, n_output=2), - posterior, - kwargs=dict(hdi_prob=hdi_prob, circular=True, skipna=skipna), - input_core_dims=(("chain", "draw"),), - output_core_dims=tuple([] for _ in range(2)), - ) + circ_hdi = hdi(posterior, hdi_prob=hdi_prob, circular=True, skipna=skipna) + circ_hdi_lower = circ_hdi.sel(hdi="lower", drop=True) + circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True) if kind in ["all", "diagnostics"]: mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc( diff --git a/arviz/tests/base_tests/test_plots_bokeh.py b/arviz/tests/base_tests/test_plots_bokeh.py index 9e345322f6..11d1417757 100644 --- a/arviz/tests/base_tests/test_plots_bokeh.py +++ b/arviz/tests/base_tests/test_plots_bokeh.py @@ -40,7 +40,7 @@ plot_trace, plot_violin, ) -from ...stats import compare, loo, waic # pylint: disable=wrong-import-position +from ...stats import compare, loo, waic, hdi # pylint: disable=wrong-import-position # Skip tests if bokeh not installed bkp = importorskip("bokeh.plotting") # pylint: disable=invalid-name @@ -483,16 +483,20 @@ def test_plot_forest_bad(models, model_fits): "kwargs", [ {"color": "C5", "circular": True}, - {"fill_kwargs": {"alpha": 0}}, + {"hdi_data": True, "fill_kwargs": {"alpha": 0}}, {"plot_kwargs": {"alpha": 0}}, {"smooth_kwargs": {"window_length": 33, "polyorder": 5, "mode": "mirror"}}, - {"smooth": False}, + {"hdi_data": True, "smooth": False, "color": "xkcd:jade"}, ], ) def test_plot_hdi(models, data, kwargs): - axis = plot_hdi( - data["y"], models.model_1.posterior["theta"], backend="bokeh", show=False, **kwargs - ) + hdi_data = kwargs.pop("hdi_data", None) + y_data = models.model_1.posterior["theta"] + if hdi_data: + hdi_data = hdi(y_data) + axis = plot_hdi(data["y"], hdi_data=hdi_data, backend="bokeh", show=False, **kwargs) + else: + axis = plot_hdi(data["y"], y_data, backend="bokeh", show=False, **kwargs) assert axis diff --git a/arviz/tests/base_tests/test_plots_matplotlib.py b/arviz/tests/base_tests/test_plots_matplotlib.py index fc7f576c84..3e91daef50 100644 --- a/arviz/tests/base_tests/test_plots_matplotlib.py +++ b/arviz/tests/base_tests/test_plots_matplotlib.py @@ -10,7 +10,7 @@ import pytest from ...data import from_dict, load_arviz_data -from ...stats import compare, loo, waic +from ...stats import compare, loo, waic, hdi from ..helpers import ( # pylint: disable=unused-import eight_schools_params, models, @@ -847,14 +847,43 @@ def test_plot_compare_no_ic(models): "kwargs", [ {"color": "0.5", "circular": True}, - {"fill_kwargs": {"alpha": 0}}, + {"hdi_data": True, "fill_kwargs": {"alpha": 0}}, {"plot_kwargs": {"alpha": 0}}, {"smooth_kwargs": {"window_length": 33, "polyorder": 5, "mode": "mirror"}}, - {"smooth": False}, + {"hdi_data": True, "smooth": False}, ], ) def test_plot_hdi(models, data, kwargs): - plot_hdi(data["y"], models.model_1.posterior["theta"], **kwargs) + hdi_data = kwargs.pop("hdi_data", None) + if hdi_data: + hdi_data = hdi(models.model_1.posterior["theta"]) + ax = plot_hdi(data["y"], hdi_data=hdi_data, **kwargs) + else: + ax = plot_hdi(data["y"], models.model_1.posterior["theta"], **kwargs) + assert ax + + +def test_plot_hdi_warning(): + """Check using both y and hdi_data sends a warning.""" + x_data = np.random.normal(0, 1, 100) + y_data = np.random.normal(2 + x_data * 0.5, 0.5, (1, 200, 100)) + hdi_data = hdi(y_data) + with pytest.warns(UserWarning, match="Both y and hdi_data"): + ax = plot_hdi(x_data, y=y_data, hdi_data=hdi_data) + assert ax + + +def test_plot_hdi_missing_arg_error(): + """Check that both y and hdi_data missing raises an error.""" + with pytest.raises(ValueError, match="One of {y, hdi_data"): + plot_hdi(np.arange(20)) + + +def test_plot_hdi_dataset_error(models): + """Check hdi_data as multiple variable Dataset raises an error.""" + hdi_data = hdi(models.model_1) + with pytest.raises(ValueError, match="Only single variable Dataset"): + plot_hdi(np.arange(8), hdi_data=hdi_data) @pytest.mark.parametrize("limits", [(-10.0, 10.0), (-5, 5), (None, None)]) diff --git a/doc/conf.py b/doc/conf.py index 690c33e374..aa24f3915f 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -269,4 +269,5 @@ def setup(app): "pymc3": ("https://docs.pymc.io/", None), "mpl": ("https://matplotlib.org/", None), "bokeh": ("https://docs.bokeh.org/en/latest/", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), }