diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 39553f6250..644225f7e2 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -582,8 +582,8 @@ def predictions_to_inference_data( ) if hasattr(idata_orig, "posterior"): assert idata_orig is not None - converter.nchains = idata_orig["posterior"].dims["chain"] - converter.ndraws = idata_orig["posterior"].dims["draw"] + converter.nchains = idata_orig["posterior"].sizes["chain"] + converter.ndraws = idata_orig["posterior"].sizes["draw"] else: aelem = next(iter(predictions.values())) converter.nchains, converter.ndraws = aelem.shape[:2] diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 4f6d312e29..f297c0eb57 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -156,8 +156,8 @@ def test_to_idata(self, data, eight_schools_params, chains, draws): } fails = check_multiple_attrs(test_dict, inference_data) assert not fails - chains = inference_data.posterior.dims["chain"] - draws = inference_data.posterior.dims["draw"] + chains = inference_data.posterior.sizes["chain"] + draws = inference_data.posterior.sizes["draw"] obs = inference_data.observed_data["obs"] assert inference_data.log_likelihood["obs"].shape == (chains, draws) + obs.shape @@ -177,7 +177,7 @@ def test_predictions_to_idata(self, data, eight_schools_params): assert not fails for key, ivalues in inference_data.predictions.items(): assert ( - len(ivalues["chain"]) == inference_data.posterior.dims["chain"] + len(ivalues["chain"]) == inference_data.posterior.sizes["chain"] ) # same chains as in posterior # check adding in place @@ -188,7 +188,7 @@ def test_predictions_to_idata(self, data, eight_schools_params): assert not fails for key, ivalues in inference_data.predictions.items(): assert ( - len(ivalues["chain"]) == inference_data.posterior.dims["chain"] + len(ivalues["chain"]) == inference_data.posterior.sizes["chain"] ) # same chains as in posterior def test_predictions_to_idata_new(self, data, eight_schools_params): @@ -241,10 +241,10 @@ def test_posterior_predictive_thinned(self, data): } fails = check_multiple_attrs(test_dict, idata) assert not fails - assert idata.posterior.dims["chain"] == 2 - assert idata.posterior.dims["draw"] == draws - assert idata.posterior_predictive.dims["chain"] == 2 - assert idata.posterior_predictive.dims["draw"] == draws / thin_by + assert idata.posterior.sizes["chain"] == 2 + assert idata.posterior.sizes["draw"] == draws + assert idata.posterior_predictive.sizes["chain"] == 2 + assert idata.posterior_predictive.sizes["draw"] == draws / thin_by assert np.allclose(idata.posterior["draw"], np.arange(draws)) assert np.allclose(idata.posterior_predictive["draw"], np.arange(draws, step=thin_by)) @@ -723,11 +723,11 @@ def test_save_warmup(self, save_warmup, chains, tune, draws): fails = check_multiple_attrs(test_dict, idata) assert not fails if hasattr(idata, "posterior"): - assert idata.posterior.dims["chain"] == chains - assert idata.posterior.dims["draw"] == draws + assert idata.posterior.sizes["chain"] == chains + assert idata.posterior.sizes["draw"] == draws if hasattr(idata, "warmup_posterior"): - assert idata.warmup_posterior.dims["chain"] == chains - assert idata.warmup_posterior.dims["draw"] == tune + assert idata.warmup_posterior.sizes["chain"] == chains + assert idata.warmup_posterior.sizes["draw"] == tune def test_save_warmup_issue_1208_after_3_9(self): with pm.Model(): @@ -757,8 +757,8 @@ def test_save_warmup_issue_1208_after_3_9(self): } fails = check_multiple_attrs(test_dict, idata) assert not fails - assert idata.posterior.dims["chain"] == 2 - assert idata.posterior.dims["draw"] == 200 + assert idata.posterior.sizes["chain"] == 2 + assert idata.posterior.sizes["draw"] == 200 # manually sliced trace triggers the same warning as <=3.8 with pytest.warns(UserWarning, match="Warmup samples"): @@ -771,5 +771,5 @@ def test_save_warmup_issue_1208_after_3_9(self): } fails = check_multiple_attrs(test_dict, idata) assert not fails - assert idata.posterior.dims["chain"] == 2 - assert idata.posterior.dims["draw"] == 30 + assert idata.posterior.sizes["chain"] == 2 + assert idata.posterior.sizes["draw"] == 30 diff --git a/tests/sampling/test_jax.py b/tests/sampling/test_jax.py index ed7c81fef8..8ec95552a0 100644 --- a/tests/sampling/test_jax.py +++ b/tests/sampling/test_jax.py @@ -433,8 +433,8 @@ def test_idata_contains_stats(sampler_name: str): stats = idata.get("sample_stats") assert stats is not None - n_chains = stats.dims["chain"] - n_draws = stats.dims["draw"] + n_chains = stats.sizes["chain"] + n_draws = stats.sizes["draw"] # Stats vars expected for both samplers expected_stat_vars = { diff --git a/tests/smc/test_smc.py b/tests/smc/test_smc.py index 81907739fe..e969b6a6e6 100644 --- a/tests/smc/test_smc.py +++ b/tests/smc/test_smc.py @@ -222,8 +222,8 @@ def test_return_datatype(self, chains): assert isinstance(idata, InferenceData) assert "sample_stats" in idata - assert idata.posterior.dims["chain"] == chains - assert idata.posterior.dims["draw"] == draws + assert idata.posterior.sizes["chain"] == chains + assert idata.posterior.sizes["draw"] == draws assert isinstance(mt, MultiTrace) assert mt.nchains == chains