From e53ed5341b25e4a69ef635b8fb69afcdf591b33f Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Tue, 20 Aug 2024 15:02:16 +0200 Subject: [PATCH] fix map error handling and tests. --- sbi/inference/posteriors/base_posterior.py | 5 ++--- sbi/utils/sbiutils.py | 4 ++-- tests/linearGaussian_npse_test.py | 3 ++- tests/posterior_nn_test.py | 8 +++++++- tests/sbc_test.py | 5 +---- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sbi/inference/posteriors/base_posterior.py b/sbi/inference/posteriors/base_posterior.py index 5eafe1bcf..4ea9eee98 100644 --- a/sbi/inference/posteriors/base_posterior.py +++ b/sbi/inference/posteriors/base_posterior.py @@ -52,9 +52,8 @@ def __init__( stacklevel=2, ) - if not isinstance(potential_fn, BasePotential) and not isinstance( - potential_fn, BasePotential - ): + # Wrap as `CallablePotentialWrapper` if `potential_fn` is a Callable. + if not isinstance(potential_fn, BasePotential): kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys()) for key in ["theta", "x_o"]: assert key in kwargs_of_callable, ( diff --git a/sbi/utils/sbiutils.py b/sbi/utils/sbiutils.py index dc3c1f392..b2053a44e 100644 --- a/sbi/utils/sbiutils.py +++ b/sbi/utils/sbiutils.py @@ -945,13 +945,13 @@ def gradient_ascent( try: optimize_inits.requires_grad_(False) # type: ignore gradient = potential_fn.gradient(optimize_inits) - except NotImplementedError: + except (NotImplementedError, AttributeError): optimize_inits.requires_grad_(True) # type: ignore probs = potential_fn(optimize_inits).squeeze() loss = probs.sum() loss.backward() gradient = optimize_inits.grad - assert gradient is Tensor, "Gradient must be a tensor." + assert isinstance(gradient, Tensor), "Gradient must be a tensor." # Update the parameters with gradient descent. # See https://discuss.pytorch.org/t/updatation-of-parameters-without-using-optimizer-step/34244/2 diff --git a/tests/linearGaussian_npse_test.py b/tests/linearGaussian_npse_test.py index 8d41dd5e9..0cb0e70e7 100644 --- a/tests/linearGaussian_npse_test.py +++ b/tests/linearGaussian_npse_test.py @@ -206,7 +206,8 @@ def test_npse_iid_inference(num_trials): @pytest.mark.slow @pytest.mark.xfail( - raises=AssertionError, reason="MAP optimization via score not working accurately." + raises=NotImplementedError, + reason="MAP optimization via score not working accurately.", ) def test_npse_map(): num_dim = 2 diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 10ecf5490..8f42a2dfb 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -32,7 +32,13 @@ ( 0, 1, - pytest.param(2, marks=pytest.mark.xfail(raises=AssertionError)), + pytest.param( + 2, + marks=pytest.mark.xfail( + raises=AssertionError, + reason=".log_prob() supports only batch size 1 for x_o.", + ), + ), ), ) def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool): diff --git a/tests/sbc_test.py b/tests/sbc_test.py index bbbce9332..2e2fa362f 100644 --- a/tests/sbc_test.py +++ b/tests/sbc_test.py @@ -13,10 +13,7 @@ from sbi.analysis import sbc_rank_plot from sbi.diagnostics import check_sbc, get_nltp, run_sbc from sbi.inference import NPSE, SNLE, SNPE -from sbi.simulators import linear_gaussian -from sbi.simulators.linear_gaussian import ( - linear_gaussian, -) +from sbi.simulators.linear_gaussian import linear_gaussian from sbi.utils import BoxUniform, MultipleIndependent from tests.test_utils import PosteriorPotential, TractablePosterior