Skip to content

Commit

Permalink
Reorganised tests, introduced support_is_bounded
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Jan 28, 2022
1 parent c01208f commit 8fd59c5
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 49 deletions.
18 changes: 17 additions & 1 deletion sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,23 @@ def mcmc_transform(
)
has_support = False

# Prior with bounded support, e.g., uniform priors.
# If the distribution has a `support`, check if the support is bounded.
# If it is not bounded, we want to z-score the space. This is not done
# by `biject_to()`, so we have to deal with this case separately.
if has_support:
if hasattr(prior.support, "base_constraint"):
constraint = prior.support.base_constraint
else:
constraint = prior.support
if isinstance(constraint, constraints._Real):
support_is_bounded = False
else:
support_is_bounded = True
else:
support_is_bounded = False

# Prior with bounded support, e.g., uniform priors.
if has_support and support_is_bounded:
transform = biject_to(prior.support)
# For all other cases build affine transform with mean and std.
else:
Expand All @@ -531,6 +546,7 @@ def mcmc_transform(
prior_std = theta.std(dim=0).to(device)

transform = torch_tf.AffineTransform(loc=prior_mean, scale=prior_std)
print(transform, prior_mean)
else:
transform = torch_tf.identity_transform

Expand Down
2 changes: 1 addition & 1 deletion sbi/utils/user_input_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def check_prior_support(prior):
"""

try:
within_support(prior, prior.sample())
within_support(prior, prior.sample((1,)))
except NotImplementedError:
raise NotImplementedError(
"""The prior must implement the support property or allow to call
Expand Down
22 changes: 20 additions & 2 deletions sbi/utils/user_input_checks_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch import Tensor, float32
from torch.distributions import Distribution, constraints
from torch.distributions import Distribution, Independent, biject_to, constraints
from torch.distributions.constraints import cat, independent


class CustomPriorWrapper(Distribution):
Expand Down Expand Up @@ -325,8 +326,25 @@ def variance(self) -> Tensor:

@property
def support(self):
# return independent constraints for each distribution.
return constraints.cat([d.support for d in self.dists], dim=1)
# First, we remove all `independent` constraints. This applies to e.g.
# `MultivariateNormal`. An `independent` constraint returns a 1D `[True]`
# when `.support.check(sample)` is called, whereas distributions that are
# not `independent` (e.g. `Gamma`), return a 2D `[[True]]`. When such
# constraints would be combined with the `constraint.cat(..., dim=1)`, it
# fails because the `independent` constraint returned only a 1D `[True]`.
supports = []
for d in self.dists:
if isinstance(d.support, independent):
supports.append(d.support.base_constraint)
else:
supports.append(d.support)

# Wrap as `independent` in order to have the correct shape of the
# `log_abs_det`, i.e. summed over the parameter dimensions.
return independent(
cat(supports, dim=1, lengths=self.dims_per_dist),
reinterpreted_batch_ndims=1,
)


def build_support(
Expand Down
37 changes: 1 addition & 36 deletions tests/sbiutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import pytest
import torch
from torch import Tensor, eye, ones, zeros
from torch.distributions import Exponential, LogNormal, MultivariateNormal
from torch.distributions import MultivariateNormal
from torch.distributions.transforms import IndependentTransform, identity_transform

from sbi.inference import SNPE, SNPE_A
from sbi.inference.snpe.snpe_a import SNPE_A_MDN
from sbi.utils import (
BoxUniform,
MultipleIndependent,
get_kde,
mcmc_transform,
posterior_nn,
Expand Down Expand Up @@ -363,40 +362,6 @@ def log_prob(self, theta):
plt.show()


@pytest.mark.parametrize(
"prior, enable_transform",
(
(BoxUniform(zeros(5), ones(5)), True),
(BoxUniform(zeros(1), ones(1)), True),
(BoxUniform(zeros(5), ones(5)), False),
(MultivariateNormal(zeros(5), eye(5)), True),
(Exponential(rate=ones(1)), True),
(LogNormal(zeros(1), ones(1)), True),
(
MultipleIndependent(
[Exponential(rate=ones(1)), BoxUniform(zeros(5), ones(5))]
),
True,
),
),
)
def test_mcmc_transform(prior, enable_transform):
"""
Test whether the transform for MCMC returns the log_abs_det in the correct shape.
"""

num_samples = 1000
prior, _, _ = process_prior(prior)
tf = mcmc_transform(prior, enable_transform=enable_transform)

samples = prior.sample((num_samples,))
unconstrained_samples = tf(samples)
samples_original = tf.inv(unconstrained_samples)

log_abs_det = tf.log_abs_det_jacobian(samples_original, unconstrained_samples)
assert log_abs_det.shape == torch.Size([num_samples])


@pytest.mark.parametrize(
"transform",
(
Expand Down
58 changes: 49 additions & 9 deletions tests/transforms_test.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
from turtle import pd
import pytest
import torch

from torch.distributions import Uniform, MultivariateNormal, LogNormal
from torch import ones, zeros, eye
from torch.distributions import Uniform, MultivariateNormal, LogNormal, Exponential
from torch.distributions.transforms import (
AffineTransform,
ComposeTransform,
ExpTransform,
IndependentTransform,
SigmoidTransform,
)

from sbi.utils import BoxUniform, mcmc_transform, process_prior
from sbi.utils import BoxUniform, mcmc_transform, process_prior, MultipleIndependent
from tests.user_input_checks_test import UserNumpyUniform


@pytest.mark.parametrize(
"prior, target_transform",
(
(Uniform(-torch.ones(1), torch.ones(1)), ComposeTransform),
(BoxUniform(-torch.ones(2), torch.ones(2)), ComposeTransform),
(UserNumpyUniform(torch.zeros(2), torch.ones(2)), ComposeTransform),
(Uniform(-torch.ones(1), torch.ones(1)), SigmoidTransform),
(BoxUniform(-torch.ones(2), torch.ones(2)), SigmoidTransform),
(UserNumpyUniform(torch.zeros(2), torch.ones(2)), SigmoidTransform),
(MultivariateNormal(torch.zeros(2), torch.eye(2)), AffineTransform),
(LogNormal(loc=torch.zeros(1), scale=torch.ones(1)), AffineTransform),
(LogNormal(loc=torch.zeros(1), scale=torch.ones(1)), ExpTransform),
),
)
def test_transforms(prior, target_transform):
Expand All @@ -37,8 +38,47 @@ def test_transforms(prior, target_transform):
if isinstance(core_transform, IndependentTransform):
core_transform = core_transform.base_transform

assert isinstance(core_transform, target_transform)
if hasattr(core_transform, "parts"):
transform_to_inspect = core_transform.parts[0]
else:
transform_to_inspect = core_transform

assert isinstance(transform_to_inspect, target_transform)

samples = prior.sample((2,))
transformed_samples = transform(samples)
assert torch.allclose(samples, transform.inv(transformed_samples))


@pytest.mark.parametrize(
"prior, enable_transform",
(
(BoxUniform(zeros(5), ones(5)), True),
(BoxUniform(zeros(1), ones(1)), True),
(BoxUniform(zeros(5), ones(5)), False),
(MultivariateNormal(zeros(5), eye(5)), True),
(Exponential(rate=ones(1)), True),
(LogNormal(zeros(1), ones(1)), True),
(
MultipleIndependent(
[Exponential(rate=ones(1)), BoxUniform(zeros(5), ones(5))]
),
True,
),
),
)
def test_mcmc_transform(prior, enable_transform):
"""
Test whether the transform for MCMC returns the log_abs_det in the correct shape.
"""

num_samples = 1000
prior, _, _ = process_prior(prior)
tf = mcmc_transform(prior, enable_transform=enable_transform)

samples = prior.sample((num_samples,))
unconstrained_samples = tf(samples)
samples_original = tf.inv(unconstrained_samples)

log_abs_det = tf.log_abs_det_jacobian(samples_original, unconstrained_samples)
assert log_abs_det.shape == torch.Size([num_samples])

0 comments on commit 8fd59c5

Please sign in to comment.