Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] fix hdi false positive warning #1241

Merged
merged 11 commits into from
Jun 22, 2020
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Added `html_repr` of InferenceData objects for jupyter notebooks. (#1217)
* Added support for PyJAGS via the function `from_pyjags`. (#1219 and #1245)
* `from_pymc3` can now retrieve `coords` and `dims` from model context (#1228, #1240 and #1249)
* `plot_hdi` can now take already computed HDI values (#1241)

### Maintenance and fixes
* Include data from `MultiObservedRV` to `observed_data` when using
Expand All @@ -16,12 +17,14 @@
* Added `log_likelihood` argument to `from_pyro` and a warning if log likelihood cannot be obtained (#1227)
* Skip tests on matplotlib animations if ffmpeg is not installed (#1227)
* Fix hpd bug where arguments were being ignored (#1236)
* Remove false positive warning in `plot_hdi` and fixed matplotlib axes generation (#1241)
* Change the default `zorder` of scatter points from `0` to `0.6` in `plot_pair` (#1246)
* Update `get_bins` for numpy 1.19 compatibility (#1256)

### Deprecation
* Using `from_pymc3` without a model context available now raises a
`FutureWarning` and will be deprecated in a future version (#1227)
* `hdi` with 2d data raises a FutureWarning (#1241)

### Documentation
* A section has been added to the documentation at InferenceDataCookbook.ipynb illustrating the use of ArviZ in conjunction with PyJAGS. (#1219 and #1245)
Expand Down
23 changes: 2 additions & 21 deletions arviz/plots/backends/bokeh/hdiplot.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Bokeh hdiplot."""
from itertools import cycle

import bokeh.plotting as bkp
from matplotlib.pyplot import rcParams as mpl_rcParams
import numpy as np

from . import backend_kwarg_defaults
Expand All @@ -21,26 +18,10 @@ def plot_hdi(ax, x_data, y_data, plot_kwargs, fill_kwargs, backend_kwargs, show)
if ax is None:
ax = bkp.figure(**backend_kwargs)

color = plot_kwargs.pop("color")
if len(color) == 2 and color[0] == "C":
color = [
prop
for _, prop in zip(
range(int(color[1:])), cycle(mpl_rcParams["axes.prop_cycle"].by_key()["color"])
)
][-1]
plot_kwargs.setdefault("line_color", color)
plot_kwargs.setdefault("line_color", plot_kwargs.pop("color"))
plot_kwargs.setdefault("line_alpha", plot_kwargs.pop("alpha", 0))

color = fill_kwargs.pop("color")
if len(color) == 2 and color[0] == "C":
color = [
prop
for _, prop in zip(
range(int(color[1:])), cycle(mpl_rcParams["axes.prop_cycle"].by_key()["color"])
)
][-1]
fill_kwargs.setdefault("fill_color", color)
fill_kwargs.setdefault("fill_color", fill_kwargs.pop("color"))
fill_kwargs.setdefault("fill_alpha", fill_kwargs.pop("alpha", 0))

ax.patch(
Expand Down
20 changes: 10 additions & 10 deletions arviz/plots/backends/matplotlib/hdiplot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
"""Matplotlib hdiplot."""
import warnings
import matplotlib.pyplot as plt

from . import backend_show
from . import backend_kwarg_defaults, backend_show


def plot_hdi(ax, x_data, y_data, plot_kwargs, fill_kwargs, backend_kwargs, show):
"""Matplotlib HDI plot."""
if backend_kwargs is not None:
warnings.warn(
(
"Argument backend_kwargs has not effect in matplotlib.plot_hdi"
"Supplied value won't be used"
)
)
if backend_kwargs is None:
backend_kwargs = {}

backend_kwargs = {
**backend_kwarg_defaults(),
**backend_kwargs,
}
if ax is None:
ax = plt.gca()
_, ax = plt.subplots(1, 1, **backend_kwargs)

ax.plot(x_data, y_data, **plot_kwargs)
ax.fill_between(x_data, y_data[:, 0], y_data[:, 1], **fill_kwargs)

Expand Down
143 changes: 97 additions & 46 deletions arviz/plots/hdiplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,26 @@
import numpy as np
from scipy.interpolate import griddata
from scipy.signal import savgol_filter
from xarray import Dataset

from ..stats import hdi
from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser
from .plot_utils import get_plotting_function, matplotlib_kwarg_dealiaser, vectorized_to_hex
from ..rcparams import rcParams
from ..utils import credible_interval_warning


def plot_hdi(
x,
y,
y=None,
hdi_prob=None,
hdi_data=None,
color="C1",
circular=False,
smooth=True,
smooth_kwargs=None,
fill_kwargs=None,
plot_kwargs=None,
hdi_kwargs=None,
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
ax=None,
backend=None,
backend_kwargs=None,
Expand All @@ -34,75 +37,124 @@ def plot_hdi(
----------
x : array-like
Values to plot.
y : array-like
Values from which to compute the HDI. Assumed shape (chain, draw, \*shape).
y : array-like, optional
Values from which to compute the HDI. Assumed shape ``(chain, draw, \*shape)``.
Only optional if hdi_data is present.
hdi_data : array_like, optional
Precomputed HDI values to use. Assumed shape is ``(*x.shape, 2)``.
hdi_prob : float, optional
Probability for the highest density interval. Defaults to 0.94.
color : str
Probability for the highest density interval. Defaults to ``stats.hdi_prob`` rcParam.
color : str, optional
Color used for the limits of the HDI and fill. Should be a valid matplotlib color.
circular : bool, optional
Whether to compute the HDI taking into account `x` is a circular variable
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
smooth : boolean
smooth : boolean, optional
If True the result will be smoothed by first computing a linear interpolation of the data
over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
Defaults to True.
smooth_kwargs : dict, optional
Additional keywords modifying the Savitzky-Golay filter. See Scipy's documentation for
details.
fill_kwargs : dict
Keywords passed to `fill_between` (use fill_kwargs={'alpha': 0} to disable fill).
plot_kwargs : dict
Keywords passed to HDI limits.
ax: axes, optional
Additional keywords modifying the Savitzky-Golay filter. See
:func:`scipy:scipy.signal.savgol_filter` for details.
fill_kwargs : dict, optional
Keywords passed to :meth:`mpl:matplotlib.axes.Axes.fill_between`
(use fill_kwargs={'alpha': 0} to disable fill) or to
:meth:`bokeh:bokeh.plotting.figure.Figure.patch`.
plot_kwargs : dict, optional
HDI limits keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.plot` or
:meth:`bokeh:bokeh.plotting.figure.Figure.patch`.
hdi_kwargs : dict, optional
Keyword arguments passed to :func:`~arviz.hdi`. Ignored if ``hdi_data`` is present.
ax : axes, optional
Matplotlib axes or bokeh figures.
backend: str, optional
Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
backend_kwargs: bool, optional
These are kwargs specific to the backend being used. For additional documentation
check the plotting method of the backend.
backend : {"matplotlib","bokeh"}, optional
Select plotting backend.
backend_kwargs : bool, optional
These are kwargs specific to the backend being used. Passed to ::``
show : bool, optional
Call backend show function.
credible_interval: float, optional
credible_interval : float, optional
Deprecated: Please see hdi_prob

Returns
-------
axes : matplotlib axes or bokeh figures

See Also
--------
hdi : Calculate highest density interval (HDI) of array for given probability.

Examples
--------
Plot HDI interval of simulated regression data using `y` argument:

.. plot::
:context: close-figs

>>> import numpy as np
>>> import arviz as az
>>> x_data = np.random.normal(0, 1, 100)
>>> y_data = np.random.normal(2 + x_data * 0.5, 0.5, (2, 50, 100))
>>> az.plot_hdi(x_data, y_data)

Precalculate HDI interval per chain and plot separately:

.. plot::
:context: close-figs

>>> hdi_data = az.hdi(y_data, input_core_dims=[["draw"]])
>>> ax = az.plot_hdi(x_data, hdi_data=hdi_data[0], color="r", fill_kwargs={"alpha": .2})
>>> az.plot_hdi(x_data, hdi_data=hdi_data[1], color="k", ax=ax, fill_kwargs={"alpha": .2})

"""
if credible_interval:
hdi_prob = credible_interval_warning(credible_interval, hdi_prob)

if hdi_kwargs is None:
hdi_kwargs = {}
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
plot_kwargs.setdefault("color", color)
plot_kwargs["color"] = vectorized_to_hex(plot_kwargs.get("color", color))
plot_kwargs.setdefault("alpha", 0)

fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "hexbin")
fill_kwargs.setdefault("color", color)
fill_kwargs = matplotlib_kwarg_dealiaser(fill_kwargs, "fill_between")
fill_kwargs["color"] = vectorized_to_hex(fill_kwargs.get("color", color))
fill_kwargs.setdefault("alpha", 0.5)

x = np.asarray(x)
y = np.asarray(y)

x_shape = x.shape
y_shape = y.shape
if y_shape[-len(x_shape) :] != x_shape:
msg = "Dimension mismatch for x: {} and y: {}."
msg += " y-dimensions should be (chain, draw, *x.shape) or"
msg += " (draw, *x.shape)"
raise TypeError(msg.format(x_shape, y_shape))

if len(y_shape[: -len(x_shape)]) > 1:
new_shape = tuple([-1] + list(x_shape))
y = y.reshape(new_shape)

if hdi_prob is None:
hdi_prob = rcParams["stats.hdi_prob"]
else:
if not 1 >= hdi_prob > 0:
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")

hdi_ = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False)
if y is None and hdi_data is None:
raise ValueError("One of {y, hdi_data} is required")
if hdi_data is not None and y is not None:
warnings.warn("Both y and hdi_data arguments present, ignoring y")
elif hdi_data is not None:
hdi_prob = (
hdi_data.hdi.attrs.get("hdi_prob", np.nan) if hasattr(hdi_data, "hdi") else np.nan
)
if isinstance(hdi_data, Dataset):
data_vars = list(hdi_data.data_vars)
if len(data_vars) != 1:
raise ValueError(
"Found several variables in hdi_data. Only single variable Datasets are "
"supported."
)
hdi_data = hdi_data[data_vars[0]]
else:
y = np.asarray(y)
if hdi_prob is None:
hdi_prob = rcParams["stats.hdi_prob"]
else:
if not 1 >= hdi_prob > 0:
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)

hdi_shape = hdi_data.shape
if hdi_shape[:-1] != x_shape:
msg = (
"Dimension mismatch for x: {} and hdi: {}. Check the dimensions of y and"
"hdi_kwargs to make sure they are compatible"
)
raise TypeError(msg.format(x_shape, hdi_shape))

if smooth:
if smooth_kwargs is None:
Expand All @@ -111,12 +163,12 @@ def plot_hdi(
smooth_kwargs.setdefault("polyorder", 2)
x_data = np.linspace(x.min(), x.max(), 200)
x_data[0] = (x_data[0] + x_data[1]) / 2
hdi_interp = griddata(x, hdi_, x_data)
hdi_interp = griddata(x, hdi_data, x_data)
y_data = savgol_filter(hdi_interp, axis=0, **smooth_kwargs)
else:
idx = np.argsort(x)
x_data = x[idx]
y_data = hdi_[idx]
y_data = hdi_data[idx]

hdiplot_kwargs = dict(
ax=ax,
Expand All @@ -132,12 +184,11 @@ def plot_hdi(
backend = rcParams["plot.backend"]
backend = backend.lower()

# TODO: Add backend kwargs
plot = get_plotting_function("plot_hdi", "hdiplot", backend)
ax = plot(**hdiplot_kwargs)
return ax


def plot_hpd(*args, **kwargs): # noqa: D103
warnings.warn("plot_hdi has been deprecated, please use plot_hdi", DeprecationWarning)
warnings.warn("plot_hpd has been deprecated, please use plot_hdi", DeprecationWarning)
OriolAbril marked this conversation as resolved.
Show resolved Hide resolved
return plot_hdi(*args, **kwargs)
1 change: 1 addition & 0 deletions arviz/plots/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,6 +663,7 @@ def matplotlib_kwarg_dealiaser(args, kind, backend="matplotlib"):
"plot": mpl.lines.Line2D,
"hist": mpl.patches.Patch,
"hexbin": mpl.collections.PolyCollection,
"fill_between": mpl.collections.PolyCollection,
"hlines": mpl.collections.LineCollection,
"text": mpl.text.Text,
"contour": mpl.contour.ContourSet,
Expand Down
Loading