Skip to content

Commit

Permalink
extend functionality and improve docs
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Jun 17, 2020
1 parent 2bddf74 commit f335325
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* Added support for PyJAGS via the function `from_pyjags` in the module arviz.data.io_pyjags. (#1219)
* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228
and #1240)
* `plot_hdi` can now take already computed HDI values (#1241)

### Maintenance and fixes
* Include data from `MultiObservedRV` to `observed_data` when using
Expand Down
78 changes: 55 additions & 23 deletions arviz/plots/hdiplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,74 @@ def plot_hdi(
x : array-like
Values to plot.
y : array-like, optional
Values from which to compute the HDI. Assumed shape (chain, draw, \*shape).
Values from which to compute the HDI. Assumed shape ``(chain, draw, \*shape)``.
Only optional if hdi_data is present.
hdi_data : array_like, optional
HDI values to use.
Precomputed HDI values to use. Assumed shape is ``(*x.shape, 2)``.
hdi_prob : float, optional
Probability for the highest density interval. Defaults to 0.94.
color : str
Probability for the highest density interval. Defaults to ``stats.hdi_prob`` rcParam.
color : str, optional
Color used for the limits of the HDI and fill. Should be a valid matplotlib color.
circular : bool, optional
Whether to compute the HDI taking into account `x` is a circular variable
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
smooth : boolean
smooth : boolean, optional
If True the result will be smoothed by first computing a linear interpolation of the data
over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
Defaults to True.
smooth_kwargs : dict, optional
Additional keywords modifying the Savitzky-Golay filter. See Scipy's documentation for
details.
fill_kwargs : dict
Keywords passed to `fill_between` (use fill_kwargs={'alpha': 0} to disable fill).
plot_kwargs : dict
Keywords passed to HDI limits.
ax: axes, optional
Additional keywords modifying the Savitzky-Golay filter. See
:func:`scipy:scipy.signal.savgol_filter` for details.
fill_kwargs : dict, optional
Keywords passed to :meth:`mpl:matplotlib.axes.Axes.fill_between`
(use fill_kwargs={'alpha': 0} to disable fill) or to
:meth:`bokeh:bokeh.plotting.figure.Figure.patch`.
plot_kwargs : dict, optional
HDI limits keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.plot` or
:meth:`~bokeh:bokeh.plotting.figure.Figure.patch`.
hdi_kwargs : dict, optional
Keyword arguments passed to :func:`~arviz.hdi`. Ingnored if ``hdi_data`` is present.
ax : axes, optional
Matplotlib axes or bokeh figures.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_kwargs: bool, optional
These are kwargs specific to the backend being used. For additional documentation
check the plotting method of the backend.
backend : {"matplotlib","bokeh"}, optional
Select plotting backend.
backend_kwargs : bool, optional
These are kwargs specific to the backend being used. Passed to ::``
show : bool, optional
Call backend show function.
credible_interval: float, optional
credible_interval : float, optional
Deprecated: Please see hdi_prob
Returns
-------
axes : matplotlib axes or bokeh figures
See Also
--------
hdi : Calculate highest density interval (HDI) of array for given probability.
Examples
--------
Plot HDI interval of simulated regression data using `y` argument:
.. plot::
:context: close-figs
>>> import numpy as np
>>> import arviz as az
>>> x_data = np.random.normal(0, 1, 100)
>>> y_data = np.random.normal(2 + x_data * 0.5, 0.5, (2, 50, 100))
>>> az.plot_hdi(x_data, y_data)
Precalculate HDI interval per chain and plot separately:
.. plot::
:context: close-figs
>>> hdi_data = az.hdi(y_data, input_core_dims=[["draw"]])
>>> ax = az.plot_hdi(x_data, hdi_data=hdi_data[0], color="r", fill_kwargs={"alpha": .2})
>>> az.plot_hdi(x_data, hdi_data=hdi_data[1], color="k", ax=ax, fill_kwargs={"alpha": .2})
"""
if credible_interval:
hdi_prob = credible_interval_warning(credible_interval, hdi_prob)
Expand All @@ -91,21 +122,22 @@ def plot_hdi(
x = np.asarray(x)
x_shape = x.shape


if y is None and hdi_data is None:
raise ValueError("One of {y, hdi_data} is required")
elif hdi_data is not None and y is not None:
if hdi_data is not None and y is not None:
warnings.warn("Both y and hdi_data arguments present, ignoring y")
elif y is not None:
elif hdi_data is not None:
hdi_prob = (
hdi_data.hdi.attrs.get("hdi_prob", np.nan) if hasattr(hdi_data, "hdi") else np.nan
)
else:
y = np.asarray(y)
if hdi_prob is None:
hdi_prob = rcParams["stats.hdi_prob"]
else:
if not 1 >= hdi_prob > 0:
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)
else:
hdi_prob = hdi_data.hdi.attrs.get("hdi_prob", np.nan)

hdi_shape = hdi_data.shape
if hdi_shape[:-1] != x_shape:
Expand Down
27 changes: 19 additions & 8 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def hdi(
**kwargs,
):
"""
Calculate highest density interval (HDI) of array for given percentage.
Calculate highest density interval (HDI) of array for given probability.
The HDI is the minimum width Bayesian credible interval (BCI).
Expand All @@ -376,15 +376,15 @@ def hdi(
Any object that can be converted to an az.InferenceData object.
Refer to documentation of az.convert_to_dataset for details.
hdi_prob: float, optional
HDI prob for which interval will be computed. Defaults to 0.94.
HDI prob for which interval will be computed. Defaults to ``stats.hdi_prob`` rcParam.
circular: bool, optional
Whether to compute the hdi taking into account `x` is a circular variable
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
Only works if multimodal is False.
multimodal: bool
multimodal: bool, optional
If true it may compute more than one hdi interval if the distribution is multimodal and the
modes are well separated.
skipna: bool
skipna: bool, optional
If true ignores nan values when computing the hdi interval. Defaults to false.
group: str, optional
Specifies which InferenceData group should be used to calculate hdi.
Expand All @@ -403,14 +403,18 @@ def hdi(
max_modes: int, optional
Specifies the maximum number of modes for multimodal case.
kwargs: dict, optional
Additional keywords passed to `wrap_xarray_ufunc`.
See the docstring of :obj:`wrap_xarray_ufunc method </.stats_utils.wrap_xarray_ufunc>`.
Additional keywords passed to :func:`~arviz.wrap_xarray_ufunc`.
Returns
-------
np.ndarray or xarray.Dataset, depending upon input
lower(s) and upper(s) values of the interval(s).
See Also
--------
plot_hdi : Plot HDI intervals for regression data.
xarray.Dataset.quantile : Calculate quantiles of array for given probabilities.
Examples
--------
Calculate the HDI of a Normal random variable:
Expand Down Expand Up @@ -476,7 +480,12 @@ def hdi(
return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data

if isarray and ary.ndim == 2:
kwargs.setdefault("input_core_dims", [["chain"]])
warnings.warn(
"hdi currently interprets 2d data as (draw, shape) but this will change in "
"a future release to (chain, draw) for coherence with other functions",
FutureWarning,
)
ary = np.expand_dims(ary, 0)

ary = convert_to_dataset(ary, group=group)
if coords is not None:
Expand All @@ -485,7 +494,9 @@ def hdi(
ary = ary[var_names] if var_names else ary

hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob))
hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords({"hdi": hdi_coord})
hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords(
{"hdi": hdi_coord}
)
hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data
return hdi_data.x.values if isarray else hdi_data

Expand Down
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,4 +269,5 @@ def setup(app):
"pymc3": ("https://docs.pymc.io/", None),
"mpl": ("https://matplotlib.org/", None),
"bokeh": ("https://docs.bokeh.org/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
}

0 comments on commit f335325

Please sign in to comment.