diff --git a/sbi/neural_nets/flow.py b/sbi/neural_nets/flow.py index 03ca1dda6..33a9443aa 100644 --- a/sbi/neural_nets/flow.py +++ b/sbi/neural_nets/flow.py @@ -6,6 +6,7 @@ from typing import Optional from warnings import warn +import torch from pyknos.nflows import distributions as distributions_ from pyknos.nflows import flows, transforms from pyknos.nflows.nn import nets @@ -180,7 +181,7 @@ def build_maf( # Combine transforms. transform = transforms.CompositeTransform(transform_list) - distribution = distributions_.StandardNormal((x_numel,)) + distribution = get_base_dist(x_numel, **kwargs) neural_net = flows.Flow(transform, distribution, embedding_net) return neural_net @@ -293,7 +294,7 @@ def build_maf_rqs( # Combine transforms. transform = transforms.CompositeTransform(transform_list) - distribution = distributions_.StandardNormal((x_numel,)) + distribution = get_base_dist(x_numel, **kwargs) neural_net = flows.Flow(transform, distribution, embedding_net) return neural_net @@ -411,7 +412,7 @@ def mask_in_layer(i): standardizing_net(batch_y, structured_y), embedding_net ) - distribution = distributions_.StandardNormal((x_numel,)) + distribution = get_base_dist(x_numel, **kwargs) # Combine transforms. transform = transforms.CompositeTransform(transform_list) @@ -480,3 +481,13 @@ def __call__(self, inputs: Tensor, context: Tensor, *args, **kwargs) -> Tensor: Spline parameters. """ return self.spline_predictor(context) + + +def get_base_dist( + num_dims: int, dtype: torch.dtype = torch.float32, **kwargs +) -> distributions_.Distribution: + """Returns the base distribution for a flow with given float dtype.""" + + base = distributions_.StandardNormal((num_dims,)) + base._log_z = base._log_z.to(dtype) + return base diff --git a/tests/analysis_test.py b/tests/analysis_test.py index 13f1c1351..b69e654c7 100644 --- a/tests/analysis_test.py +++ b/tests/analysis_test.py @@ -4,11 +4,12 @@ from sbi.analysis import ActiveSubspace, conditional_corrcoeff, conditional_pairplot from sbi.inference import SNPE from sbi.utils import BoxUniform, get_1d_marginal_peaks_from_kde +from sbi.utils.torchutils import process_device @pytest.mark.slow @pytest.mark.gpu -@pytest.mark.parametrize("device", ["cpu", "cuda:0"]) +@pytest.mark.parametrize("device", ["cpu", "gpu"]) def test_analysis_modules(device: str) -> None: """Tests sensitivity analysis and conditional posterior utils on GPU and CPU. @@ -18,6 +19,7 @@ def test_analysis_modules(device: str) -> None: device: Which device to run the inference on. """ num_dim = 3 + device = process_device(device) prior = BoxUniform( low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim), device=device ) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index 531e0b17a..41f673dea 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -4,7 +4,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Optional, Tuple +from typing import Tuple import pytest import torch @@ -110,12 +110,16 @@ def simulator(theta): if method in [SNPE_A, SNPE_C]: kwargs = dict( - density_estimator=utils.posterior_nn(model=model, num_transforms=2) + density_estimator=utils.posterior_nn( + model=model, num_transforms=2, dtype=torch.float32 + ) ) train_kwargs = dict(force_first_round_loss=True) elif method == SNLE: kwargs = dict( - density_estimator=utils.likelihood_nn(model=model, num_transforms=2) + density_estimator=utils.likelihood_nn( + model=model, num_transforms=2, dtype=torch.float32 + ) ) train_kwargs = dict() elif method in (SNRE_A, SNRE_B, SNRE_C): @@ -145,7 +149,7 @@ def simulator(theta): sample_with="mcmc", mcmc_method=sampling_method, mcmc_parameters=dict( - thin=5, + thin=10 if sampling_method == "slice_np_vectorized" else 1, num_chains=10 if sampling_method == "slice_np_vectorized" else 1, ), ) @@ -183,27 +187,6 @@ def simulator(theta): proposals[-1].potential(samples) -@pytest.mark.gpu -@pytest.mark.parametrize( - "device_input, device_target", - [ - ("cpu", "cpu"), - ("cuda", "cuda:0"), - ("cuda:0", "cuda:0"), - pytest.param("cuda:42", None, marks=pytest.mark.xfail), - pytest.param("qwerty", None, marks=pytest.mark.xfail), - ], -) -def test_process_device(device_input: str, device_target: Optional[str]) -> None: - """Test process_device with different device combinations.""" - device_output = process_device(device_input) - assert device_output == device_target, ( - f"Failure when processing device '{device_input}': " - f"result should have been '{device_target}' and is " - f"instead '{device_output}'" - ) - - @pytest.mark.gpu @pytest.mark.parametrize("device_datum", ["cpu", "gpu"]) @pytest.mark.parametrize("device_embedding_net", ["cpu", "gpu"]) @@ -359,7 +342,7 @@ def test_embedding_nets_integration_training_device( data_device = process_device(data_device) training_device = process_device(training_device) - samples_per_round = 32 + samples_per_round = 64 num_rounds = 2 x_o = torch.ones((1, x_dim)) @@ -447,8 +430,12 @@ def test_embedding_nets_integration_training_device( posterior = inference.build_posterior( density_estimator_train, - mcmc_method="slice_np_vectorized", - mcmc_parameters=dict(thin=5, num_chains=10, warmup_steps=10), + **{} + if inference_method == SNPE_A + else dict( + mcmc_method="slice_np_vectorized", + mcmc_parameters=dict(thin=10, num_chains=20, warmup_steps=10), + ), ) proposal = posterior.set_default_x(x_o) diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 3a51d5283..92f440e32 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -71,7 +71,7 @@ def test_api_snle_multiple_trials_and_rounds_map(num_dim: int, prior_str: str): x_o = zeros((num_trials, num_dim)) posterior = inference.build_posterior( mcmc_method="slice_np_vectorized", - mcmc_parameters=dict(num_chains=10, thin=5, warmup_steps=10), + mcmc_parameters=dict(num_chains=10, thin=10, warmup_steps=10), ).set_default_x(x_o) posterior.sample(sample_shape=(num_samples,)) proposals.append(posterior) diff --git a/tests/linearGaussian_snpe_test.py b/tests/linearGaussian_snpe_test.py index bb05c6f5f..527c41265 100644 --- a/tests/linearGaussian_snpe_test.py +++ b/tests/linearGaussian_snpe_test.py @@ -573,7 +573,7 @@ def simulator(theta): proposal=restricted_prior, **mcmc_parameters, ) - cond_samples = mcmc_posterior.sample((num_conditional_samples,)) + cond_samples = mcmc_posterior.sample((num_conditional_samples,), x=x_o) _ = analysis.pairplot( cond_samples, diff --git a/tests/mnle_test.py b/tests/mnle_test.py index 92cbad7dc..b2af74793 100644 --- a/tests/mnle_test.py +++ b/tests/mnle_test.py @@ -15,7 +15,7 @@ ) from sbi.utils import BoxUniform, likelihood_nn, mcmc_transform from sbi.utils.conditional_density_utils import ConditionedPotential -from sbi.utils.torchutils import atleast_2d +from sbi.utils.torchutils import atleast_2d, process_device from sbi.utils.user_input_checks_utils import MultipleIndependent from tests.test_utils import check_c2st @@ -46,9 +46,11 @@ def mixed_simulator(theta, stimulus_condition=2.0): @pytest.mark.gpu -@pytest.mark.parametrize("device", ("cpu", "cuda")) +@pytest.mark.parametrize("device", ("cpu", "gpu")) def test_mnle_on_device(device): """Test MNLE API on device.""" + + device = process_device(device) # Generate mixed data. num_simulations = 100 mcmc_method = "slice"