Skip to content

Commit

Permalink
add full integration testwith embedders
Browse files Browse the repository at this point in the history
  • Loading branch information
LouisRouillard committed Jan 28, 2022
1 parent c5b94ab commit 593d0cd
Showing 1 changed file with 71 additions and 21 deletions.
92 changes: 71 additions & 21 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
prepare_for_sbi,
check_embedding_net_device,
)
from sbi.utils.get_nn_models import posterior_nn
from sbi.utils.get_nn_models import posterior_nn, likelihood_nn, classifier_nn


@pytest.mark.slow
Expand Down Expand Up @@ -279,18 +279,20 @@ def test_train_with_different_data_and_training_device(

inference = inference_method(
prior,
density_estimator=(
"mdn_snpe_a"
if inference_method == SNPE_A
else "resnet"
**(
dict(classifier="resnet")
if inference_method in [SNRE_A, SNRE_B]
else "maf"
else dict(
density_estimator=(
"mdn_snpe_a" if inference_method == SNPE_A else "maf"
)
)
),
show_progress_bars=False,
device=training_device,
)

theta, x = simulate_for_sbi(simulator, prior, 5)
theta, x = simulate_for_sbi(simulator, prior, 32)
theta, x = theta.to(data_device), x.to(data_device)
x_o = torch.zeros(x.shape[1])
inference = inference.append_simulations(theta, x)
Expand All @@ -307,50 +309,98 @@ def test_train_with_different_data_and_training_device(


@pytest.mark.gpu
@pytest.mark.parametrize("inference_method", [SNPE_C])
@pytest.mark.parametrize("inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNLE])
@pytest.mark.parametrize("prior_device", ("cpu", "cuda"))
@pytest.mark.parametrize("embedding_net_device", ("cpu", "cuda"))
@pytest.mark.parametrize("data_device", ("cpu", "cuda"))
@pytest.mark.parametrize("training_device", ("cpu", "cuda"))
def test_custom_density_estimator_training_device(
def test_embedding_nets_integration_training_device(
inference_method: NeuralInference,
prior_device: str,
embedding_net_device: str,
data_device: str,
training_device: str,
) -> None:

# add warnings checks and other methods
# add other methods

D_theta = 2
D_x = 3
samples_per_round = 100
samples_per_round = 32
num_rounds = 2

x_o = torch.ones((1, D_x))

prior = utils.BoxUniform(
low=-torch.ones((D_theta,)), high=torch.ones((D_theta,)), device=prior_device
)
embedding_net = nn.Linear(in_features=D_x, out_features=2).to(embedding_net_device)
density_estimator = posterior_nn(model="maf", embedding_net=embedding_net)

with pytest.raises(Exception) if prior_device != training_device else nullcontext():
inference = inference_method(
prior=prior, density_estimator=density_estimator, device=training_device
if inference_method in [SNRE_A, SNRE_B]:
embedding_net_theta = nn.Linear(in_features=D_theta, out_features=2).to(
embedding_net_device
)
embedding_net_x = nn.Linear(in_features=D_x, out_features=2).to(
embedding_net_device
)
nn_kwargs = dict(
classifier=classifier_nn(
model="resnet",
embedding_net_x=embedding_net_x,
embedding_net_theta=embedding_net_theta,
hidden_features=4,
)
)
elif inference_method == SNLE:
embedding_net = nn.Linear(in_features=D_theta, out_features=2).to(
embedding_net_device
)
nn_kwargs = dict(
density_estimator=likelihood_nn(
model="maf",
embedding_net=embedding_net,
hidden_features=4,
num_transforms=2,
)
)
else:
embedding_net = nn.Linear(in_features=D_x, out_features=2).to(
embedding_net_device
)
nn_kwargs = dict(
density_estimator=posterior_nn(
model="mdn_snpe_a" if inference_method == SNPE_A else "maf",
embedding_net=embedding_net,
hidden_features=4,
num_transforms=2,
)
)

with pytest.raises(Exception) if prior_device != training_device else nullcontext():
inference = inference_method(prior=prior, **nn_kwargs, device=training_device)

if prior_device != training_device:
return
pytest.xfail("We do not correct the case of invalid prior device")

theta = prior.sample((samples_per_round,)).to(data_device)

proposal = prior
for _ in range(num_rounds):
X = MultivariateNormal(torch.zeros((D_x,)), torch.eye(D_x)).sample(
(samples_per_round,)
for round_idx in range(num_rounds):
X = (
MultivariateNormal(torch.zeros((D_x,)), torch.eye(D_x))
.sample((samples_per_round,))
.to(data_device)
)

density_estimator_train = inference.append_simulations(theta, X).train()
with pytest.warns(
UserWarning
) if data_device != training_device else nullcontext():
density_estimator_append = inference.append_simulations(theta, X)

with pytest.warns(UserWarning) if (round_idx == 0) and (
embedding_net_device != training_device
) else nullcontext():
density_estimator_train = density_estimator_append.train(max_num_epochs=2)

posterior = inference.build_posterior(density_estimator_train)
proposal = posterior.set_default_x(x_o)
theta = proposal.sample((samples_per_round,))

0 comments on commit 593d0cd

Please sign in to comment.