Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dataset.sizes instead of dataset.dims #7057

Merged
merged 1 commit into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,8 @@
)
if hasattr(idata_orig, "posterior"):
assert idata_orig is not None
converter.nchains = idata_orig["posterior"].dims["chain"]
converter.ndraws = idata_orig["posterior"].dims["draw"]
converter.nchains = idata_orig["posterior"].sizes["chain"]
converter.ndraws = idata_orig["posterior"].sizes["draw"]

Check warning on line 586 in pymc/backends/arviz.py

View check run for this annotation

Codecov / codecov/patch

pymc/backends/arviz.py#L585-L586

Added lines #L585 - L586 were not covered by tests
else:
aelem = next(iter(predictions.values()))
converter.nchains, converter.ndraws = aelem.shape[:2]
Expand Down
32 changes: 16 additions & 16 deletions tests/backends/test_arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ def test_to_idata(self, data, eight_schools_params, chains, draws):
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails
chains = inference_data.posterior.dims["chain"]
draws = inference_data.posterior.dims["draw"]
chains = inference_data.posterior.sizes["chain"]
draws = inference_data.posterior.sizes["draw"]
obs = inference_data.observed_data["obs"]
assert inference_data.log_likelihood["obs"].shape == (chains, draws) + obs.shape

Expand All @@ -177,7 +177,7 @@ def test_predictions_to_idata(self, data, eight_schools_params):
assert not fails
for key, ivalues in inference_data.predictions.items():
assert (
len(ivalues["chain"]) == inference_data.posterior.dims["chain"]
len(ivalues["chain"]) == inference_data.posterior.sizes["chain"]
) # same chains as in posterior

# check adding in place
Expand All @@ -188,7 +188,7 @@ def test_predictions_to_idata(self, data, eight_schools_params):
assert not fails
for key, ivalues in inference_data.predictions.items():
assert (
len(ivalues["chain"]) == inference_data.posterior.dims["chain"]
len(ivalues["chain"]) == inference_data.posterior.sizes["chain"]
) # same chains as in posterior

def test_predictions_to_idata_new(self, data, eight_schools_params):
Expand Down Expand Up @@ -241,10 +241,10 @@ def test_posterior_predictive_thinned(self, data):
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == draws
assert idata.posterior_predictive.dims["chain"] == 2
assert idata.posterior_predictive.dims["draw"] == draws / thin_by
assert idata.posterior.sizes["chain"] == 2
assert idata.posterior.sizes["draw"] == draws
assert idata.posterior_predictive.sizes["chain"] == 2
assert idata.posterior_predictive.sizes["draw"] == draws / thin_by
assert np.allclose(idata.posterior["draw"], np.arange(draws))
assert np.allclose(idata.posterior_predictive["draw"], np.arange(draws, step=thin_by))

Expand Down Expand Up @@ -723,11 +723,11 @@ def test_save_warmup(self, save_warmup, chains, tune, draws):
fails = check_multiple_attrs(test_dict, idata)
assert not fails
if hasattr(idata, "posterior"):
assert idata.posterior.dims["chain"] == chains
assert idata.posterior.dims["draw"] == draws
assert idata.posterior.sizes["chain"] == chains
assert idata.posterior.sizes["draw"] == draws
if hasattr(idata, "warmup_posterior"):
assert idata.warmup_posterior.dims["chain"] == chains
assert idata.warmup_posterior.dims["draw"] == tune
assert idata.warmup_posterior.sizes["chain"] == chains
assert idata.warmup_posterior.sizes["draw"] == tune

def test_save_warmup_issue_1208_after_3_9(self):
with pm.Model():
Expand Down Expand Up @@ -757,8 +757,8 @@ def test_save_warmup_issue_1208_after_3_9(self):
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 200
assert idata.posterior.sizes["chain"] == 2
assert idata.posterior.sizes["draw"] == 200

# manually sliced trace triggers the same warning as <=3.8
with pytest.warns(UserWarning, match="Warmup samples"):
Expand All @@ -771,5 +771,5 @@ def test_save_warmup_issue_1208_after_3_9(self):
}
fails = check_multiple_attrs(test_dict, idata)
assert not fails
assert idata.posterior.dims["chain"] == 2
assert idata.posterior.dims["draw"] == 30
assert idata.posterior.sizes["chain"] == 2
assert idata.posterior.sizes["draw"] == 30
4 changes: 2 additions & 2 deletions tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ def test_idata_contains_stats(sampler_name: str):

stats = idata.get("sample_stats")
assert stats is not None
n_chains = stats.dims["chain"]
n_draws = stats.dims["draw"]
n_chains = stats.sizes["chain"]
n_draws = stats.sizes["draw"]

# Stats vars expected for both samplers
expected_stat_vars = {
Expand Down
4 changes: 2 additions & 2 deletions tests/smc/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ def test_return_datatype(self, chains):

assert isinstance(idata, InferenceData)
assert "sample_stats" in idata
assert idata.posterior.dims["chain"] == chains
assert idata.posterior.dims["draw"] == draws
assert idata.posterior.sizes["chain"] == chains
assert idata.posterior.sizes["draw"] == draws

assert isinstance(mt, MultiTrace)
assert mt.nchains == chains
Expand Down
Loading