Skip to content

Commit

Permalink
extend hdi functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril committed Jun 14, 2020
1 parent 032429b commit f6a7b55
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 21 deletions.
51 changes: 33 additions & 18 deletions arviz/plots/hdiplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@

def plot_hdi(
x,
y,
y=None,
hdi_prob=None,
hdi_data=None,
color="C1",
circular=False,
smooth=True,
smooth_kwargs=None,
fill_kwargs=None,
plot_kwargs=None,
hdi_kwargs=None,
ax=None,
backend=None,
backend_kwargs=None,
Expand All @@ -34,8 +36,11 @@ def plot_hdi(
----------
x : array-like
Values to plot.
y : array-like
y : array-like, optional
Values from which to compute the hdi. Assumed shape (chain, draw, \*shape).
Only optional if hdi_data is present.
hdi_data : array_like, optional
Hdi values to use.
hdi_prob : float, optional
Probability for the highest density interval. Defaults to 0.94.
color : str
Expand Down Expand Up @@ -73,6 +78,8 @@ def plot_hdi(
if credible_interval:
hdi_prob = credible_interval_warning(credible_interval, hdi_prob)

if hdi_kwargs is None:
hdi_kwargs = {}
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
plot_kwargs.setdefault("color", color)
plot_kwargs.setdefault("alpha", 0)
Expand All @@ -82,23 +89,31 @@ def plot_hdi(
fill_kwargs.setdefault("alpha", 0.5)

x = np.asarray(x)
y = np.asarray(y)

x_shape = x.shape
y_shape = y.shape
if y_shape[-len(x_shape) :] != x_shape:
msg = "Dimension mismatch for x: {} and y: {}."
msg += " y-dimensions should be (chain, draw, *x.shape) or"
msg += " (draw, *x.shape)"
raise TypeError(msg.format(x_shape, y_shape))

if hdi_prob is None:
hdi_prob = rcParams["stats.hdi_prob"]


if y is None and hdi_data is None:
raise ValueError("One of {y, hdi_data} is required")
elif hdi_data is not None and y is not None:
warnings.warn("Both y and hdi_data arguments present, ignoring y")
elif y is not None:
y = np.asarray(y)
if hdi_prob is None:
hdi_prob = rcParams["stats.hdi_prob"]
else:
if not 1 >= hdi_prob > 0:
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)
else:
if not 1 >= hdi_prob > 0:
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
hdi_prob = hdi_data.hdi.attrs.get("hdi_prob", np.nan)

hdi_ = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False)
hdi_shape = hdi_data.shape
if hdi_shape[:-1] != x_shape:
msg = (
"Dimension mismatch for x: {} and hdi: {}. Check the dimensions of y and"
"hdi_kwargs to make sure they are compatible"
)
raise TypeError(msg.format(x_shape, hdi_shape))

if smooth:
if smooth_kwargs is None:
Expand All @@ -107,12 +122,12 @@ def plot_hdi(
smooth_kwargs.setdefault("polyorder", 2)
x_data = np.linspace(x.min(), x.max(), 200)
x_data[0] = (x_data[0] + x_data[1]) / 2
hdi_interp = griddata(x, hdi_, x_data)
hdi_interp = griddata(x, hdi_data, x_data)
y_data = savgol_filter(hdi_interp, axis=0, **smooth_kwargs)
else:
idx = np.argsort(x)
x_data = x[idx]
y_data = hdi_[idx]
y_data = hdi_data[idx]

hdiplot_kwargs = dict(
ax=ax,
Expand Down
5 changes: 2 additions & 3 deletions arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,9 +484,8 @@ def hdi(
var_names = _var_names(var_names, ary, filter_vars)
ary = ary[var_names] if var_names else ary

hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords(
hdi=["lower", "higher"]
)
hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob))
hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords({"hdi": hdi_coord})
hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data
return hdi_data.x.values if isarray else hdi_data

Expand Down

0 comments on commit f6a7b55

Please sign in to comment.