Skip to content

Commit

Permalink
Add weight_predictions function (#2147)
Browse files Browse the repository at this point in the history
* add weight_predictions

* clean, add checks and docstring

* checks

* add test

* update changelog

* update per comments

* update per comments
  • Loading branch information
aloctavodia authored Nov 12, 2022
1 parent 7e0691b commit 24e66c3
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 4 deletions.
5 changes: 3 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Change Log


## v0.x.x Unreleased

### New features
- Adds Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152))

* Add `weight_predictions` function to allow generation of weighted predictions from two or more InfereceData with `posterior_predictive` groups and a set of weights ([2147](https://github.com/arviz-devs/arviz/pull/2147))
- Add Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152

### Maintenance and fixes
- Fix dimension ordering for `plot_trace` with divergences ([2151](https://github.com/arviz-devs/arviz/pull/2151))
Expand Down
1 change: 1 addition & 0 deletions arviz/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"r2_score",
"summary",
"waic",
"weight_predictions",
"ELPDData",
"ess",
"rhat",
Expand Down
72 changes: 71 additions & 1 deletion arviz/stats/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
NO_GET_ARGS = True

from .. import _log
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data
from ..data import InferenceData, convert_to_dataset, convert_to_inference_data, extract
from ..rcparams import rcParams, ScaleKeyword, ICKeyword
from ..utils import Numba, _numba_var, _var_names, get_coords
from .density_utils import get_bins as _get_bins
Expand Down Expand Up @@ -49,6 +49,7 @@
"r2_score",
"summary",
"waic",
"weight_predictions",
"_calculate_ics",
]

Expand Down Expand Up @@ -2043,3 +2044,72 @@ def apply_test_function(
setattr(out, grp, out_group)

return out


def weight_predictions(idatas, weights=None):
"""
Generate weighted posterior predictive samples from a list of InferenceData
and a set of weights.
Parameters
---------
idatas : list[InferenceData]
List of :class:`arviz.InferenceData` objects containing the groups `posterior_predictive`
and `observed_data`. Observations should be the same for all InferenceData objects.
weights : array-like, optional
Individual weights for each model. Weights should be positive. If they do not sum up to 1,
they will be normalized. Default, same weight for each model.
Weights can be computed using many different methods including those in
:func:`arviz.compare`.
Returns
-------
idata: InferenceData
Output InferenceData object with the groups `posterior_predictive` and `observed_data`.
See Also
--------
compare : Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation
"""
if len(idatas) < 2:
raise ValueError("You should provide a list with at least two InferenceData objects")

if not all("posterior_predictive" in idata.groups() for idata in idatas):
raise ValueError(
"All the InferenceData objects must contain the `posterior_predictive` group"
)

if not all(idatas[0].observed_data.equals(idata.observed_data) for idata in idatas[1:]):
raise ValueError("The observed data should be the same for all InferenceData objects")

if weights is None:
weights = np.ones(len(idatas)) / len(idatas)
elif len(idatas) != len(weights):
raise ValueError(
"The number of weights should be the same as the number of InferenceData objects"
)

weights = np.array(weights, dtype=float)
weights /= weights.sum()

len_idatas = [
idata.posterior_predictive.dims["chain"] * idata.posterior_predictive.dims["draw"]
for idata in idatas
]

if not all(len_idatas):
raise ValueError("At least one of your idatas has 0 samples")

new_samples = (np.min(len_idatas) * weights).astype(int)

new_idatas = [
extract(idata, group="posterior_predictive", num_samples=samples).reset_coords()
for samples, idata in zip(new_samples, idatas)
]

weighted_samples = InferenceData(
posterior_predictive=xr.concat(new_idatas, dim="sample"),
observed_data=idatas[0].observed_data,
)

return weighted_samples
31 changes: 30 additions & 1 deletion arviz/tests/base_tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

import numpy as np
import pytest
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
from numpy.testing import (
assert_allclose,
assert_array_almost_equal,
assert_almost_equal,
assert_array_equal,
)
from scipy.special import logsumexp
from scipy.stats import linregress
from xarray import DataArray, Dataset
Expand All @@ -21,6 +26,7 @@
r2_score,
summary,
waic,
weight_predictions,
_calculate_ics,
)
from ...stats.stats import _gpinv
Expand Down Expand Up @@ -800,3 +806,26 @@ def test_apply_test_function_should_overwrite_error(centered_eight):
"""Test error when overwrite=False but out_name is already a present variable."""
with pytest.raises(ValueError, match="Should overwrite"):
apply_test_function(centered_eight, lambda y, theta: y, out_name_data="obs")


def test_weight_predictions():
idata0 = from_dict(
posterior_predictive={"a": np.random.normal(-1, 1, 1000)}, observed_data={"a": [1]}
)
idata1 = from_dict(
posterior_predictive={"a": np.random.normal(1, 1, 1000)}, observed_data={"a": [1]}
)

new = weight_predictions([idata0, idata1])
assert (
idata1.posterior_predictive.mean()
> new.posterior_predictive.mean()
> idata0.posterior_predictive.mean()
)
assert "posterior_predictive" in new
assert "observed_data" in new

new = weight_predictions([idata0, idata1], weights=[0.5, 0.5])
assert_almost_equal(new.posterior_predictive["a"].mean(), 0, decimal=1)
new = weight_predictions([idata0, idata1], weights=[0.9, 0.1])
assert_almost_equal(new.posterior_predictive["a"].mean(), -0.8, decimal=1)

0 comments on commit 24e66c3

Please sign in to comment.