Skip to content

Commit

Permalink
Fix error in dataset_to_point_list when chain, draw are not the lea…
Browse files Browse the repository at this point in the history
…ding dims (pymc-devs#7180)
  • Loading branch information
OriolAbril authored and mkusnetsov committed Oct 26, 2024
1 parent 38d86fc commit 30ac872
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pymc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,10 @@ def dataset_to_point_list(
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
num_sample_dims = len(sample_dims)
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
stacked_dict = {
vn: da.transpose(*sample_dims, ...).values.reshape((-1, *da.shape[num_sample_dims:]))
for vn, da in ds.items()
vn: da.values.reshape((-1, *da.shape[num_sample_dims:]))
for vn, da in transposed_dict.items()
}
points = [
{vn: stacked_dict[vn][i, ...] for vn in var_names}
Expand Down
10 changes: 10 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,16 @@ def test_dataset_to_point_list(input_type):
assert isinstance(pl[0]["A"], np.ndarray)


def test_transposed_dataset_to_point_list():
ds = xarray.Dataset()
ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain"))
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
assert isinstance(pl, list)
assert len(pl) == 6
assert isinstance(pl[0], dict)
assert isinstance(pl[0]["A"], np.ndarray)


def test_dataset_to_point_list_str_key():
# Check that non-str keys are caught
ds = xarray.Dataset()
Expand Down

0 comments on commit 30ac872

Please sign in to comment.