Skip to content

Commit

Permalink
fix map error handling and tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 20, 2024
1 parent 4b3fc61 commit c0217b4
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 7 deletions.
5 changes: 2 additions & 3 deletions sbi/inference/posteriors/base_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ def __init__(
stacklevel=2,
)

if not isinstance(potential_fn, BasePotential) and not isinstance(
potential_fn, BasePotential
):
# Wrap as `CallablePotentialWrapper` if `potential_fn` is a Callable.
if not isinstance(potential_fn, BasePotential):
kwargs_of_callable = list(inspect.signature(potential_fn).parameters.keys())
for key in ["theta", "x_o"]:
assert key in kwargs_of_callable, (
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,13 +945,13 @@ def gradient_ascent(
try:
optimize_inits.requires_grad_(False) # type: ignore
gradient = potential_fn.gradient(optimize_inits)
except NotImplementedError:
except (NotImplementedError, AttributeError):
optimize_inits.requires_grad_(True) # type: ignore
probs = potential_fn(optimize_inits).squeeze()
loss = probs.sum()
loss.backward()
gradient = optimize_inits.grad
assert gradient is Tensor, "Gradient must be a tensor."
assert isinstance(gradient, Tensor), "Gradient must be a tensor."

# Update the parameters with gradient descent.
# See https://discuss.pytorch.org/t/updatation-of-parameters-without-using-optimizer-step/34244/2
Expand Down
3 changes: 2 additions & 1 deletion tests/linearGaussian_npse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ def test_npse_iid_inference(num_trials):

@pytest.mark.slow
@pytest.mark.xfail(
raises=AssertionError, reason="MAP optimization via score not working accurately."
raises=NotImplementedError,
reason="MAP optimization via score not working accurately.",
)
def test_npse_map():
num_dim = 2
Expand Down
8 changes: 7 additions & 1 deletion tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@
(
0,
1,
pytest.param(2, marks=pytest.mark.xfail(raises=AssertionError)),
pytest.param(
2,
marks=pytest.mark.xfail(
raises=AssertionError,
reason=".log_prob() supports only batch size 1 for x_o.",
),
),
),
)
def test_log_prob_with_different_x(snpe_method: type, x_o_batch_dim: bool):
Expand Down

0 comments on commit c0217b4

Please sign in to comment.