Skip to content

Commit

Permalink
refactor iid tutorial; separate mnle tutorial.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 7, 2023
1 parent 4d7d4da commit 8958650
Show file tree
Hide file tree
Showing 3 changed files with 1,733 additions and 4 deletions.
24 changes: 20 additions & 4 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@
from pyro.distributions import InverseGamma
from torch.distributions import Beta, Binomial, Gamma

from sbi.inference import MNLE, MCMCPosterior, likelihood_estimator_based_potential
from sbi.inference import (
MNLE,
MCMCPosterior,
likelihood_estimator_based_potential,
)
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.likelihood_based_potential import (
MixedLikelihoodBasedPotential,
)
from sbi.utils import BoxUniform, likelihood_nn, mcmc_transform
from sbi.utils.conditional_density_utils import ConditionedPotential
from sbi.utils.torchutils import atleast_2d
from sbi.utils.user_input_checks_utils import MultipleIndependent
from tests.test_utils import check_c2st
Expand All @@ -21,7 +29,10 @@ def test_mnle_on_device(device):
num_simulations = 100
theta = torch.rand(num_simulations, 2)
x = torch.cat(
(torch.rand(num_simulations, 1), torch.randint(0, 2, (num_simulations, 1))),
(
torch.rand(num_simulations, 1),
torch.randint(0, 2, (num_simulations, 1)),
),
dim=1,
).to(device)

Expand All @@ -41,7 +52,10 @@ def test_mnle_api(sampler):
num_simulations = 100
theta = torch.rand(num_simulations, 2)
x = torch.cat(
(torch.rand(num_simulations, 1), torch.randint(0, 2, (num_simulations, 1))),
(
torch.rand(num_simulations, 1),
torch.randint(0, 2, (num_simulations, 1)),
),
dim=1,
)

Expand Down Expand Up @@ -89,7 +103,9 @@ def mixed_simulator(theta):

# Sample choices and rts independently.
choices = Binomial(probs=ps).sample()
rts = InverseGamma(concentration=2 * torch.ones_like(beta), rate=beta).sample()
rts = InverseGamma(
concentration=2 * torch.ones_like(beta), rate=beta
).sample()

return torch.cat((rts, choices), dim=1)

Expand Down
898 changes: 898 additions & 0 deletions tutorials/14_iid_data_and_permutation_invariant_embeddings.ipynb

Large diffs are not rendered by default.

815 changes: 815 additions & 0 deletions tutorials/17_SBI_for_models_of_decision_making.ipynb

Large diffs are not rendered by default.

0 comments on commit 8958650

Please sign in to comment.