Skip to content

Commit

Permalink
test: conditioning on experimental conditions with mnle.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 7, 2023
1 parent 8958650 commit e7f07bb
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 49 deletions.
28 changes: 27 additions & 1 deletion sbi/inference/snle/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,29 @@ def __init__(
kwargs = del_entries(locals(), entries=("self", "__class__"))
super().__init__(**kwargs)

def train(
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
discard_prior_samples: bool = False,
retrain_from_scratch: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[Dict] = None,
) -> MixedDensityEstimator:
density_estimator = super().train(
**del_entries(locals(), entries=("self", "__class__"))
)
assert isinstance(
density_estimator, MixedDensityEstimator
), f"""Internal net must be of type
MixedDensityEstimator but is {type(density_estimator)}."""
return density_estimator

def build_posterior(
self,
density_estimator: Optional[TorchModule] = None,
Expand Down Expand Up @@ -128,7 +151,10 @@ def build_posterior(
), f"""net must be of type MixedDensityEstimator but is {type
(likelihood_estimator)}."""

potential_fn, theta_transform = mixed_likelihood_estimator_based_potential(
(
potential_fn,
theta_transform,
) = mixed_likelihood_estimator_based_potential(
likelihood_estimator=likelihood_estimator, prior=prior, x_o=None
)

Expand Down
6 changes: 2 additions & 4 deletions tests/linearGaussian_snre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,7 @@ def test_api_sre_on_linearGaussian(num_dim: int, SNRE: RatioEstimator):
prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))

simulator, prior = prepare_for_sbi(diagonal_linear_gaussian, prior)
inference = SNRE(
classifier="resnet",
show_progress_bars=False,
)
inference = SNRE(classifier="resnet", show_progress_bars=False)

theta, x = simulate_for_sbi(simulator, prior, 1000, simulation_batch_size=50)
ratio_estimator = inference.append_simulations(theta, x).train(max_num_epochs=5)
Expand All @@ -70,6 +67,7 @@ def test_api_sre_on_linearGaussian(num_dim: int, SNRE: RatioEstimator):
num_chains=2,
)
posterior.sample(sample_shape=(10,))
posterior.map(num_iter=1)


@pytest.mark.parametrize("SNRE", (SNRE_B, SNRE_C))
Expand Down
165 changes: 121 additions & 44 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@

import pytest
import torch
from numpy import isin
from pyro.distributions import InverseGamma
from torch.distributions import Beta, Binomial, Gamma
from torch.distributions import Beta, Binomial, Categorical, Gamma

from sbi.inference import (
MNLE,
MCMCPosterior,
likelihood_estimator_based_potential,
)
from sbi.inference import MNLE, MCMCPosterior, likelihood_estimator_based_potential
from sbi.inference.posteriors.rejection_posterior import RejectionPosterior
from sbi.inference.posteriors.vi_posterior import VIPosterior
from sbi.inference.potentials.base_potential import BasePotential
from sbi.inference.potentials.likelihood_based_potential import (
MixedLikelihoodBasedPotential,
Expand All @@ -22,6 +21,28 @@
from tests.test_utils import check_c2st


# toy simulator for mixed data
def mixed_simulator(theta, stimulus_condition=2.0):
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]

# Sample choices and rts independently.
choices = Binomial(probs=ps).sample()
rts = InverseGamma(
concentration=stimulus_condition * torch.ones_like(beta), rate=beta
).sample()

return torch.cat((rts, choices), dim=1)


mcmc_kwargs = dict(
num_chains=10,
warmup_steps=100,
method="slice_np_vectorized",
init_strategy="proposal",
)


@pytest.mark.gpu
@pytest.mark.parametrize("device", ("cpu", "cuda"))
def test_mnle_on_device(device):
Expand Down Expand Up @@ -63,49 +84,38 @@ def test_mnle_api(sampler):
prior = BoxUniform(torch.zeros(2), torch.ones(2))
x_o = x[0]
# Build estimator manually.
density_estimator = likelihood_nn(model="mnle", **dict(tail_bound=2.0))
density_estimator = likelihood_nn(model="mnle")
trainer = MNLE(density_estimator=density_estimator)
mnle = trainer.append_simulations(theta, x).train(max_num_epochs=1)
trainer.append_simulations(theta, x).train(max_num_epochs=5)

# Test different samplers.
posterior = trainer.build_posterior(prior=prior, sample_with=sampler)
posterior.set_default_x(x_o)
if sampler == "vi":
posterior.train()
posterior.sample((1,), show_progress_bars=False)

# MNLE should work with the default potential as well.
potential_fn, parameter_transform = likelihood_estimator_based_potential(
mnle, prior, x_o
)
posterior = MCMCPosterior(
potential_fn,
proposal=prior,
theta_transform=parameter_transform,
init_strategy="proposal",
)
posterior.sample((1,), show_progress_bars=False)
if isinstance(posterior, VIPosterior):
posterior.train().sample((1,))
elif isinstance(posterior, RejectionPosterior):
posterior.sample((1,))
else:
posterior.sample(
(1,),
num_chains=2,
warmup_steps=1,
method="slice_np_vectorized",
init_strategy="proposal",
thin=1,
)


@pytest.mark.slow
@pytest.mark.parametrize(
"sampler",
(
"mcmc",
"rejection",
# "vi", # Failing because of transformed space dimension mismatch.
),
)
@pytest.mark.parametrize("sampler", ("mcmc", "rejection", "vi"))
def test_mnle_accuracy(sampler):
def mixed_simulator(theta):
# Extract parameters
beta, ps = theta[:, :1], theta[:, 1:]

# Sample choices and rts independently.
choices = Binomial(probs=ps).sample()
rts = InverseGamma(
concentration=2 * torch.ones_like(beta), rate=beta
).sample()
rts = InverseGamma(concentration=1 * torch.ones_like(beta), rate=beta).sample()

return torch.cat((rts, choices), dim=1)

Expand All @@ -127,13 +137,6 @@ def mixed_simulator(theta):
trainer.append_simulations(theta, x).train()
posterior = trainer.build_posterior()

mcmc_kwargs = dict(
num_chains=10,
warmup_steps=100,
method="slice_np_vectorized",
init_strategy="proposal",
)

for num_trials in [10]:
theta_o = prior.sample((1,))
x_o = mixed_simulator(theta_o.repeat(num_trials, 1))
Expand All @@ -154,7 +157,7 @@ def mixed_simulator(theta):

mnle_posterior_samples = posterior.sample(
sample_shape=(num_samples,),
show_progress_bars=False,
show_progress_bars=True,
**mcmc_kwargs if sampler == "mcmc" else {},
)

Expand All @@ -170,9 +173,11 @@ class PotentialFunctionProvider(BasePotential):

allow_iid_x = True # type: ignore

def __init__(self, prior, x_o, device="cpu"):
def __init__(self, prior, x_o, concentration_scaling=1.0, device="cpu"):
super().__init__(prior, x_o, device)

self.concentration_scaling = concentration_scaling

def __call__(self, theta, track_gradients: bool = True):
theta = atleast_2d(theta)

Expand All @@ -195,7 +200,8 @@ def iid_likelihood(self, theta: torch.Tensor) -> torch.Tensor:
lp_rts = torch.stack(
[
InverseGamma(
concentration=2 * torch.ones_like(beta_i), rate=beta_i
concentration=self.concentration_scaling * torch.ones_like(beta_i),
rate=beta_i,
).log_prob(self.x_o[:, :1])
for beta_i in theta[:, :1]
],
Expand All @@ -207,3 +213,74 @@ def iid_likelihood(self, theta: torch.Tensor) -> torch.Tensor:
)

return joint_likelihood.sum(0)


@pytest.mark.slow
def test_mnle_with_experiment_conditions():
def sim_wrapper(theta):
# simulate with experiment conditions
return mixed_simulator(theta[:, :2], theta[:, 2:] + 1)

proposal = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
Categorical(probs=torch.ones(1, 3)),
],
validate_args=False,
)

num_simulations = 10000
num_samples = 1000
theta = proposal.sample((num_simulations,))
x = sim_wrapper(theta)
assert x.shape == (num_simulations, 2)

num_trials = 10
theta_o = proposal.sample((1,))
theta_o[0, 2] = 2.0 # set condition to 2 as in original simulator.
x_o = sim_wrapper(theta_o.repeat(num_trials, 1))

# MNLE
trainer = MNLE(proposal)
estimator = trainer.append_simulations(theta, x).train(max_num_epochs=1)

potential_fn = MixedLikelihoodBasedPotential(estimator, proposal, x_o)

conditioned_potential_fn = ConditionedPotential(
potential_fn, condition=theta_o, dims_to_sample=[0, 1], allow_iid_x=True
)

# True posterior samples
prior = MultipleIndependent(
[
Gamma(torch.tensor([1.0]), torch.tensor([0.5])),
Beta(torch.tensor([2.0]), torch.tensor([2.0])),
],
validate_args=False,
)
prior_transform = mcmc_transform(prior)
true_posterior_samples = MCMCPosterior(
PotentialFunctionProvider(
prior,
atleast_2d(x_o),
concentration_scaling=float(theta_o[0, 2])
+ 1.0, # add one because the sim_wrapper adds one (see above)
),
theta_transform=prior_transform,
proposal=prior,
**mcmc_kwargs,
).sample((num_samples,), x=x_o)

mcmc_posterior = MCMCPosterior(
potential_fn=conditioned_potential_fn,
theta_transform=prior_transform,
proposal=prior,
)
cond_samples = mcmc_posterior.sample((num_samples,), x=x_o)

check_c2st(
cond_samples,
true_posterior_samples,
alg="MNLE with experiment conditions",
)

0 comments on commit e7f07bb

Please sign in to comment.