diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index a44a3d138..8570ac8d2 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -34,7 +34,7 @@ prepare_for_sbi, check_embedding_net_device, ) -from sbi.utils.get_nn_models import posterior_nn +from sbi.utils.get_nn_models import posterior_nn, likelihood_nn, classifier_nn @pytest.mark.slow @@ -279,18 +279,20 @@ def test_train_with_different_data_and_training_device( inference = inference_method( prior, - density_estimator=( - "mdn_snpe_a" - if inference_method == SNPE_A - else "resnet" + **( + dict(classifier="resnet") if inference_method in [SNRE_A, SNRE_B] - else "maf" + else dict( + density_estimator=( + "mdn_snpe_a" if inference_method == SNPE_A else "maf" + ) + ) ), show_progress_bars=False, device=training_device, ) - theta, x = simulate_for_sbi(simulator, prior, 5) + theta, x = simulate_for_sbi(simulator, prior, 32) theta, x = theta.to(data_device), x.to(data_device) x_o = torch.zeros(x.shape[1]) inference = inference.append_simulations(theta, x) @@ -307,12 +309,12 @@ def test_train_with_different_data_and_training_device( @pytest.mark.gpu -@pytest.mark.parametrize("inference_method", [SNPE_C]) +@pytest.mark.parametrize("inference_method", [SNPE_A, SNPE_C, SNRE_A, SNRE_B, SNLE]) @pytest.mark.parametrize("prior_device", ("cpu", "cuda")) @pytest.mark.parametrize("embedding_net_device", ("cpu", "cuda")) @pytest.mark.parametrize("data_device", ("cpu", "cuda")) @pytest.mark.parametrize("training_device", ("cpu", "cuda")) -def test_custom_density_estimator_training_device( +def test_embedding_nets_integration_training_device( inference_method: NeuralInference, prior_device: str, embedding_net_device: str, @@ -320,11 +322,11 @@ def test_custom_density_estimator_training_device( training_device: str, ) -> None: - # add warnings checks and other methods + # add other methods D_theta = 2 D_x = 3 - samples_per_round = 100 + samples_per_round = 32 num_rounds = 2 x_o = torch.ones((1, D_x)) @@ -332,25 +334,73 @@ def test_custom_density_estimator_training_device( prior = utils.BoxUniform( low=-torch.ones((D_theta,)), high=torch.ones((D_theta,)), device=prior_device ) - embedding_net = nn.Linear(in_features=D_x, out_features=2).to(embedding_net_device) - density_estimator = posterior_nn(model="maf", embedding_net=embedding_net) - with pytest.raises(Exception) if prior_device != training_device else nullcontext(): - inference = inference_method( - prior=prior, density_estimator=density_estimator, device=training_device + if inference_method in [SNRE_A, SNRE_B]: + embedding_net_theta = nn.Linear(in_features=D_theta, out_features=2).to( + embedding_net_device + ) + embedding_net_x = nn.Linear(in_features=D_x, out_features=2).to( + embedding_net_device + ) + nn_kwargs = dict( + classifier=classifier_nn( + model="resnet", + embedding_net_x=embedding_net_x, + embedding_net_theta=embedding_net_theta, + hidden_features=4, + ) + ) + elif inference_method == SNLE: + embedding_net = nn.Linear(in_features=D_theta, out_features=2).to( + embedding_net_device + ) + nn_kwargs = dict( + density_estimator=likelihood_nn( + model="maf", + embedding_net=embedding_net, + hidden_features=4, + num_transforms=2, + ) ) + else: + embedding_net = nn.Linear(in_features=D_x, out_features=2).to( + embedding_net_device + ) + nn_kwargs = dict( + density_estimator=posterior_nn( + model="mdn_snpe_a" if inference_method == SNPE_A else "maf", + embedding_net=embedding_net, + hidden_features=4, + num_transforms=2, + ) + ) + + with pytest.raises(Exception) if prior_device != training_device else nullcontext(): + inference = inference_method(prior=prior, **nn_kwargs, device=training_device) + if prior_device != training_device: - return + pytest.xfail("We do not correct the case of invalid prior device") theta = prior.sample((samples_per_round,)).to(data_device) proposal = prior - for _ in range(num_rounds): - X = MultivariateNormal(torch.zeros((D_x,)), torch.eye(D_x)).sample( - (samples_per_round,) + for round_idx in range(num_rounds): + X = ( + MultivariateNormal(torch.zeros((D_x,)), torch.eye(D_x)) + .sample((samples_per_round,)) + .to(data_device) ) - density_estimator_train = inference.append_simulations(theta, X).train() + with pytest.warns( + UserWarning + ) if data_device != training_device else nullcontext(): + density_estimator_append = inference.append_simulations(theta, X) + + with pytest.warns(UserWarning) if (round_idx == 0) and ( + embedding_net_device != training_device + ) else nullcontext(): + density_estimator_train = density_estimator_append.train(max_num_epochs=2) + posterior = inference.build_posterior(density_estimator_train) proposal = posterior.set_default_x(x_o) theta = proposal.sample((samples_per_round,))