diff --git a/pymc/util.py b/pymc/util.py index 799d92b1b4..8388a8ed49 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -249,9 +249,10 @@ def dataset_to_point_list( raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.") num_sample_dims = len(sample_dims) stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims} + transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()} stacked_dict = { - vn: da.transpose(*sample_dims, ...).values.reshape((-1, *da.shape[num_sample_dims:])) - for vn, da in ds.items() + vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) + for vn, da in transposed_dict.items() } points = [ {vn: stacked_dict[vn][i, ...] for vn in var_names} diff --git a/tests/test_util.py b/tests/test_util.py index e984fdbd1c..61d916249e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -170,6 +170,16 @@ def test_dataset_to_point_list(input_type): assert isinstance(pl[0]["A"], np.ndarray) +def test_transposed_dataset_to_point_list(): + ds = xarray.Dataset() + ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain")) + pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + assert isinstance(pl, list) + assert len(pl) == 6 + assert isinstance(pl[0], dict) + assert isinstance(pl[0]["A"], np.ndarray) + + def test_dataset_to_point_list_str_key(): # Check that non-str keys are caught ds = xarray.Dataset()