Skip to content

Commit

Permalink
fix slow tests and add missing MPS features.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Mar 7, 2024
1 parent 3c80e0b commit 0df847d
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 36 deletions.
17 changes: 14 additions & 3 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Optional
from warnings import warn

import torch
from pyknos.nflows import distributions as distributions_
from pyknos.nflows import flows, transforms
from pyknos.nflows.nn import nets
Expand Down Expand Up @@ -180,7 +181,7 @@ def build_maf(
# Combine transforms.
transform = transforms.CompositeTransform(transform_list)

distribution = distributions_.StandardNormal((x_numel,))
distribution = get_base_dist(x_numel, **kwargs)
neural_net = flows.Flow(transform, distribution, embedding_net)

return neural_net
Expand Down Expand Up @@ -293,7 +294,7 @@ def build_maf_rqs(
# Combine transforms.
transform = transforms.CompositeTransform(transform_list)

distribution = distributions_.StandardNormal((x_numel,))
distribution = get_base_dist(x_numel, **kwargs)

Check warning on line 297 in sbi/neural_nets/flow.py

View check run for this annotation

Codecov / codecov/patch

sbi/neural_nets/flow.py#L297

Added line #L297 was not covered by tests
neural_net = flows.Flow(transform, distribution, embedding_net)

return neural_net
Expand Down Expand Up @@ -411,7 +412,7 @@ def mask_in_layer(i):
standardizing_net(batch_y, structured_y), embedding_net
)

distribution = distributions_.StandardNormal((x_numel,))
distribution = get_base_dist(x_numel, **kwargs)

# Combine transforms.
transform = transforms.CompositeTransform(transform_list)
Expand Down Expand Up @@ -480,3 +481,13 @@ def __call__(self, inputs: Tensor, context: Tensor, *args, **kwargs) -> Tensor:
Spline parameters.
"""
return self.spline_predictor(context)


def get_base_dist(
num_dims: int, dtype: torch.dtype = torch.float32, **kwargs
) -> distributions_.Distribution:
"""Returns the base distribution for a flow with given float dtype."""

base = distributions_.StandardNormal((num_dims,))
base._log_z = base._log_z.to(dtype)
return base
4 changes: 3 additions & 1 deletion tests/analysis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from sbi.analysis import ActiveSubspace, conditional_corrcoeff, conditional_pairplot
from sbi.inference import SNPE
from sbi.utils import BoxUniform, get_1d_marginal_peaks_from_kde
from sbi.utils.torchutils import process_device


@pytest.mark.slow
@pytest.mark.gpu
@pytest.mark.parametrize("device", ["cpu", "cuda:0"])
@pytest.mark.parametrize("device", ["cpu", "gpu"])
def test_analysis_modules(device: str) -> None:
"""Tests sensitivity analysis and conditional posterior utils on GPU and CPU.
Expand All @@ -18,6 +19,7 @@ def test_analysis_modules(device: str) -> None:
device: Which device to run the inference on.
"""
num_dim = 3
device = process_device(device)
prior = BoxUniform(
low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim), device=device
)
Expand Down
43 changes: 15 additions & 28 deletions tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from contextlib import nullcontext
from typing import Optional, Tuple
from typing import Tuple

import pytest
import torch
Expand Down Expand Up @@ -110,12 +110,16 @@ def simulator(theta):

if method in [SNPE_A, SNPE_C]:
kwargs = dict(
density_estimator=utils.posterior_nn(model=model, num_transforms=2)
density_estimator=utils.posterior_nn(
model=model, num_transforms=2, dtype=torch.float32
)
)
train_kwargs = dict(force_first_round_loss=True)
elif method == SNLE:
kwargs = dict(
density_estimator=utils.likelihood_nn(model=model, num_transforms=2)
density_estimator=utils.likelihood_nn(
model=model, num_transforms=2, dtype=torch.float32
)
)
train_kwargs = dict()
elif method in (SNRE_A, SNRE_B, SNRE_C):
Expand Down Expand Up @@ -145,7 +149,7 @@ def simulator(theta):
sample_with="mcmc",
mcmc_method=sampling_method,
mcmc_parameters=dict(
thin=5,
thin=10 if sampling_method == "slice_np_vectorized" else 1,
num_chains=10 if sampling_method == "slice_np_vectorized" else 1,
),
)
Expand Down Expand Up @@ -183,27 +187,6 @@ def simulator(theta):
proposals[-1].potential(samples)


@pytest.mark.gpu
@pytest.mark.parametrize(
"device_input, device_target",
[
("cpu", "cpu"),
("cuda", "cuda:0"),
("cuda:0", "cuda:0"),
pytest.param("cuda:42", None, marks=pytest.mark.xfail),
pytest.param("qwerty", None, marks=pytest.mark.xfail),
],
)
def test_process_device(device_input: str, device_target: Optional[str]) -> None:
"""Test process_device with different device combinations."""
device_output = process_device(device_input)
assert device_output == device_target, (
f"Failure when processing device '{device_input}': "
f"result should have been '{device_target}' and is "
f"instead '{device_output}'"
)


@pytest.mark.gpu
@pytest.mark.parametrize("device_datum", ["cpu", "gpu"])
@pytest.mark.parametrize("device_embedding_net", ["cpu", "gpu"])
Expand Down Expand Up @@ -359,7 +342,7 @@ def test_embedding_nets_integration_training_device(
data_device = process_device(data_device)
training_device = process_device(training_device)

samples_per_round = 32
samples_per_round = 64
num_rounds = 2

x_o = torch.ones((1, x_dim))
Expand Down Expand Up @@ -447,8 +430,12 @@ def test_embedding_nets_integration_training_device(

posterior = inference.build_posterior(
density_estimator_train,
mcmc_method="slice_np_vectorized",
mcmc_parameters=dict(thin=5, num_chains=10, warmup_steps=10),
**{}
if inference_method == SNPE_A
else dict(
mcmc_method="slice_np_vectorized",
mcmc_parameters=dict(thin=10, num_chains=20, warmup_steps=10),
),
)
proposal = posterior.set_default_x(x_o)

Expand Down
2 changes: 1 addition & 1 deletion tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_api_snle_multiple_trials_and_rounds_map(num_dim: int, prior_str: str):
x_o = zeros((num_trials, num_dim))
posterior = inference.build_posterior(
mcmc_method="slice_np_vectorized",
mcmc_parameters=dict(num_chains=10, thin=5, warmup_steps=10),
mcmc_parameters=dict(num_chains=10, thin=10, warmup_steps=10),
).set_default_x(x_o)
posterior.sample(sample_shape=(num_samples,))
proposals.append(posterior)
Expand Down
2 changes: 1 addition & 1 deletion tests/linearGaussian_snpe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ def simulator(theta):
proposal=restricted_prior,
**mcmc_parameters,
)
cond_samples = mcmc_posterior.sample((num_conditional_samples,))
cond_samples = mcmc_posterior.sample((num_conditional_samples,), x=x_o)

_ = analysis.pairplot(
cond_samples,
Expand Down
6 changes: 4 additions & 2 deletions tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from sbi.utils import BoxUniform, likelihood_nn, mcmc_transform
from sbi.utils.conditional_density_utils import ConditionedPotential
from sbi.utils.torchutils import atleast_2d
from sbi.utils.torchutils import atleast_2d, process_device
from sbi.utils.user_input_checks_utils import MultipleIndependent
from tests.test_utils import check_c2st

Expand Down Expand Up @@ -46,9 +46,11 @@ def mixed_simulator(theta, stimulus_condition=2.0):


@pytest.mark.gpu
@pytest.mark.parametrize("device", ("cpu", "cuda"))
@pytest.mark.parametrize("device", ("cpu", "gpu"))
def test_mnle_on_device(device):
"""Test MNLE API on device."""

device = process_device(device)
# Generate mixed data.
num_simulations = 100
mcmc_method = "slice"
Expand Down

0 comments on commit 0df847d

Please sign in to comment.