Skip to content

Commit

Permalink
add errors for MAP and iid data, adapt tests
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 20, 2024
1 parent 5071e43 commit 4b3fc61
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 3 deletions.
3 changes: 3 additions & 0 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@ def map(
Returns:
The MAP estimate.
"""
raise NotImplementedError(
"MAP estimation is currently not working accurately for ScorePosterior."
)
return super().map(
x=x,
num_iter=num_iter,
Expand Down
3 changes: 3 additions & 0 deletions sbi/inference/potentials/score_based_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def gradient(
input=theta, condition=self.x_o, time=time
)
else:
raise NotImplementedError(
"Score accumulation for IID data is not yet fully implemented."
)
if self.prior is None:
raise ValueError(
"Prior must be provided when interpreting the data as IID."
Expand Down
13 changes: 10 additions & 3 deletions tests/linearGaussian_npse_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import List

import pytest
import torch
from torch import eye, ones, zeros
Expand Down Expand Up @@ -30,7 +32,7 @@
],
)
def test_c2st_npse_on_linearGaussian(
sde_type, num_dim: int, prior_str: str, sample_with: list[str]
sde_type, num_dim: int, prior_str: str, sample_with: List[str]
):
"""Test whether NPSE infers well a simple example with available ground truth."""

Expand Down Expand Up @@ -78,7 +80,7 @@ def test_c2st_npse_on_linearGaussian(
check_c2st(
samples,
target_samples,
alg=f"npse-{sde_type or "vp"}-{prior_str}-{num_dim}D-{method}",
alg=f"npse-{sde_type or 'vp'}-{prior_str}-{num_dim}D-{method}",
)

# Checks for log_prob()
Expand Down Expand Up @@ -157,7 +159,12 @@ def simulator(theta):
check_c2st(samples, target_samples, alg="npse_different_dims_and_resume_training")


@pytest.mark.xfail(reason="iid_bridge not working.")
@pytest.mark.xfail(
reason="iid_bridge not working.",
raises=NotImplementedError,
strict=True,
match="Score accumulation*",
)
@pytest.mark.parametrize("num_trials", [2, 10])
def test_npse_iid_inference(num_trials):
"""Test whether NPSE infers well a simple example with available ground truth."""
Expand Down

0 comments on commit 4b3fc61

Please sign in to comment.