Skip to content

Commit

Permalink
skip "event" dims in log likelihood group whenever necessary (#1429)
Browse files Browse the repository at this point in the history
* add option to skip "event" dims in log likelihood group

* Update docstring

Co-authored-by: Alexandre ANDORRA <[email protected]>

* cover dim reduction but keeping # of dims and add tests

* add skip_event_dims to stan and dict converters

* black

* update changelog

* remove trailing whitespace

Co-authored-by: Alexandre ANDORRA <[email protected]>
  • Loading branch information
OriolAbril and AlexAndorra authored Oct 26, 2020
1 parent 475e34c commit ead5c98
Show file tree
Hide file tree
Showing 11 changed files with 147 additions and 31 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Added group argument to summary ([1408](https://github.com/arviz-devs/arviz/pull/1408))
* Add `ref_line`, `bar`, `vlines` and `marker_vlines` kwargs to `plot_rank` ([1419](https://github.com/arviz-devs/arviz/pull/1419))
* Add observed argument to (un)plot observed data in `plot_ppc` ([1422](https://github.com/arviz-devs/arviz/pull/1422))
* Add support for named dims and coordinates with multivariate observations ([1429](https://github.com/arviz-devs/arviz/pull/1429))

### Maintenance and fixes
* prevent wrapping group names in InferenceData repr_html ([1407](https://github.com/arviz-devs/arviz/pull/1407))
Expand Down
69 changes: 50 additions & 19 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def wrapped(cls, *args, **kwargs):
return wrapped


def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=None):
def generate_dims_coords(
shape, var_name, dims=None, coords=None, default_dims=None, skip_event_dims=None
):
"""Generate default dimensions and coordinates for a variable.
Parameters
Expand All @@ -66,6 +68,7 @@ def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=N
when manipulating Monte Carlo traces, the ``default_dims`` would be
``["chain" , "draw"]`` which ArviZ uses as its own names for dimensions
of MCMC traces.
skip_event_dims : bool, default False
Returns
-------
Expand All @@ -78,26 +81,41 @@ def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=N
default_dims = []
if dims is None:
dims = []
if len([dim for dim in dims if dim not in default_dims]) > len(shape):
warnings.warn(
(
"In variable {var_name}, there are "
+ "more dims ({dims_len}) given than exist ({shape_len}). "
+ "Passed array should have shape ({defaults}*shape)"
).format(
var_name=var_name,
dims_len=len(dims),
shape_len=len(shape),
defaults=",".join(default_dims) + ", " if default_dims is not None else "",
),
UserWarning,
)
if skip_event_dims is None:
skip_event_dims = False

if coords is None:
coords = {}

coords = deepcopy(coords)
dims = deepcopy(dims)

ndims = len([dim for dim in dims if dim not in default_dims])
if ndims > len(shape):
if skip_event_dims:
dims = dims[: len(shape)]
else:
warnings.warn(
(
"In variable {var_name}, there are "
+ "more dims ({dims_len}) given than exist ({shape_len}). "
+ "Passed array should have shape ({defaults}*shape)"
).format(
var_name=var_name,
dims_len=len(dims),
shape_len=len(shape),
defaults=",".join(default_dims) + ", " if default_dims is not None else "",
),
UserWarning,
)
if skip_event_dims:
# this is needed in case the reduction keeps the dimension with size 1
for i, (dim, dim_size) in enumerate(zip(dims, shape)):
print(f"{i}, dim: {dim}, {dim_size} =? {len(coords.get(dim, []))}")
if (dim in coords) and (dim_size != len(coords[dim])):
dims = dims[:i]
break

for idx, dim_len in enumerate(shape):
if (len(dims) < idx + 1) or (dims[idx] is None):
dim_name = "{var_name}_dim_{idx}".format(var_name=var_name, idx=idx)
Expand All @@ -112,7 +130,7 @@ def generate_dims_coords(shape, var_name, dims=None, coords=None, default_dims=N
return dims, coords


def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None):
def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None, skip_event_dims=None):
"""Convert a numpy array to an xarray.DataArray.
The first two dimensions will be (chain, draw), and any remaining
Expand All @@ -134,6 +152,7 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None):
is the name of the dimension, the values are the index values.
dims : List(str)
A list of coordinate names for the variable
skip_event_dims : bool
Returns
-------
Expand All @@ -154,7 +173,12 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None):
)

dims, coords = generate_dims_coords(
shape, var_name, dims=dims, coords=coords, default_dims=default_dims
shape,
var_name,
dims=dims,
coords=coords,
default_dims=default_dims,
skip_event_dims=skip_event_dims,
)

# reversed order for default dims: 'chain', 'draw'
Expand All @@ -173,7 +197,9 @@ def numpy_to_data_array(ary, *, var_name="data", coords=None, dims=None):
return xr.DataArray(ary, coords=coords, dims=dims)


def dict_to_dataset(data, *, attrs=None, library=None, coords=None, dims=None):
def dict_to_dataset(
data, *, attrs=None, library=None, coords=None, dims=None, skip_event_dims=None
):
"""Convert a dictionary of numpy arrays to an xarray.Dataset.
Parameters
Expand All @@ -189,6 +215,11 @@ def dict_to_dataset(data, *, attrs=None, library=None, coords=None, dims=None):
dims : dict[str] -> list[str]
Dimensions of each variable. The keys are variable names, values are lists of
coordinates.
skip_event_dims : bool
If True, cut extra dims whenever present to match the shape of the data.
Necessary for PPLs which have the same name in both observed data and log
likelihood groups, to account for their different shapes when observations are
multivariate.
Returns
-------
Expand All @@ -205,7 +236,7 @@ def dict_to_dataset(data, *, attrs=None, library=None, coords=None, dims=None):
data_vars = {}
for key, values in data.items():
data_vars[key] = numpy_to_data_array(
values, var_name=key, coords=coords, dims=dims.get(key)
values, var_name=key, coords=coords, dims=dims.get(key), skip_event_dims=skip_event_dims
)
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))

Expand Down
8 changes: 6 additions & 2 deletions arviz/data/io_cmdstan.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,12 @@ def log_likelihood_to_xarray(self):
)
attrs = None
return (
dict_to_dataset(data, coords=self.coords, dims=self.dims, attrs=attrs),
dict_to_dataset(data_warmup, coords=self.coords, dims=self.dims, attrs=attrs),
dict_to_dataset(
data, coords=self.coords, dims=self.dims, attrs=attrs, skip_event_dims=True
),
dict_to_dataset(
data_warmup, coords=self.coords, dims=self.dims, attrs=attrs, skip_event_dims=True
),
)

def to_inference_data(self):
Expand Down
14 changes: 12 additions & 2 deletions arviz/data/io_cmdstanpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,19 @@ def log_likelihood_to_xarray(self):
)

return (
dict_to_dataset(data, library=self.cmdstanpy, coords=self.coords, dims=self.dims),
dict_to_dataset(
data_warmup, library=self.cmdstanpy, coords=self.coords, dims=self.dims
data,
library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
skip_event_dims=True,
),
dict_to_dataset(
data_warmup,
library=self.cmdstanpy,
coords=self.coords,
dims=self.dims,
skip_event_dims=True,
),
)

Expand Down
14 changes: 12 additions & 2 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,20 @@ def log_likelihood_to_xarray(self):

return (
dict_to_dataset(
data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
data,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
skip_event_dims=True,
),
dict_to_dataset(
data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
data_warmup,
library=None,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
skip_event_dims=True,
),
)

Expand Down
4 changes: 3 additions & 1 deletion arviz/data/io_numpyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,9 @@ def log_likelihood_to_xarray(self):
for obs_name, log_like in log_likelihood_dict.items():
shape = (self.nchains, self.ndraws) + log_like.shape[1:]
data[obs_name] = np.reshape(log_like.copy(), shape)
return dict_to_dataset(data, library=self.numpyro, dims=self.dims, coords=self.coords)
return dict_to_dataset(
data, library=self.numpyro, dims=self.dims, coords=self.coords, skip_event_dims=True
)

def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
"""Convert posterior_predictive or prediction samples to xarray."""
Expand Down
12 changes: 10 additions & 2 deletions arviz/data/io_pymc3.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,16 @@ def log_likelihood_to_xarray(self):
except TypeError:
warnings.warn(warn_msg)
return (
dict_to_dataset(data, library=self.pymc3, dims=self.dims, coords=self.coords),
dict_to_dataset(data_warmup, library=self.pymc3, dims=self.dims, coords=self.coords),
dict_to_dataset(
data, library=self.pymc3, dims=self.dims, coords=self.coords, skip_event_dims=True
),
dict_to_dataset(
data_warmup,
library=self.pymc3,
dims=self.dims,
coords=self.coords,
skip_event_dims=True,
),
)

def translate_posterior_predictive_dict_to_xarray(self, dct) -> xr.Dataset:
Expand Down
4 changes: 3 additions & 1 deletion arviz/data/io_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ def log_likelihood_to_xarray(self):
"Check your model vectorization or set log_likelihood=False"
)
return None
return dict_to_dataset(data, library=self.pyro, coords=self.coords, dims=self.dims)
return dict_to_dataset(
data, library=self.pyro, coords=self.coords, dims=self.dims, skip_event_dims=True
)

def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
"""Convert posterior_predictive or prediction samples to xarray."""
Expand Down
12 changes: 10 additions & 2 deletions arviz/data/io_pystan.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,16 @@ def log_likelihood_to_xarray(self):
}

return (
dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
dict_to_dataset(
data, library=self.pystan, coords=self.coords, dims=self.dims, skip_event_dims=True
),
dict_to_dataset(
data_warmup,
library=self.pystan,
coords=self.coords,
dims=self.dims,
skip_event_dims=True,
),
)

@requires("posterior")
Expand Down
22 changes: 22 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,20 @@ def test_dims_coords_extra_dims():
assert len(coords["xy"]) == 20


@pytest.mark.parametrize("shape", [(4, 20), (4, 20, 1)])
def test_dims_coords_skip_event_dims(shape):
coords = {"x": np.arange(4), "y": np.arange(20), "z": np.arange(5)}
dims, coords = generate_dims_coords(
shape, "name", dims=["x", "y", "z"], coords=coords, skip_event_dims=True
)
assert "x" in dims
assert "y" in dims
assert "z" not in dims
assert len(coords["x"]) == 4
assert len(coords["y"]) == 20
assert "z" not in coords


def test_make_attrs():
extra_attrs = {"key": "Value"}
attrs = make_attrs(attrs=extra_attrs)
Expand Down Expand Up @@ -898,6 +912,14 @@ def test_dict_to_dataset():
assert set(dataset.b.coords) == {"chain", "draw", "c"}


def test_dict_to_dataset_event_dims_error():
datadict = {"a": np.random.randn(1, 100, 10)}
coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
msg = "different number of dimensions on data and dims"
with pytest.raises(ValueError, match=msg):
convert_to_dataset(datadict, coords=coords, dims={"a": ["b", "c"]})


def test_convert_to_dataset_idempotent():
first = convert_to_dataset(np.random.randn(100))
second = convert_to_dataset(first)
Expand Down
18 changes: 18 additions & 0 deletions arviz/tests/external_tests/test_data_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,24 @@ def test_no_model_deprecation(self):
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

def test_multivariate_observations(self):
coords = {"direction": ["x", "y", "z"], "experiment": np.arange(20)}
data = np.random.multinomial(20, [0.2, 0.3, 0.5], size=20)
with pm.Model(coords=coords):
p = pm.Beta("p", 1, 1, shape=(3,))
pm.Multinomial("y", 20, p, dims=("experiment", "direction"), observed=data)
idata = pm.sample(draws=50, tune=100, return_inferencedata=True)
test_dict = {
"posterior": ["p"],
"sample_stats": ["lp"],
"log_likelihood": ["y"],
"observed_data": ["y"],
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert "direction" not in idata.log_likelihood.dims
assert "direction" in idata.observed_data.dims


class TestPyMC3WarmupHandling:
@pytest.mark.skipif(
Expand Down

0 comments on commit ead5c98

Please sign in to comment.