Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pytorch distribution.support.check() fails for TransformedDistributions #738

Open
rdgao opened this issue Sep 5, 2022 · 3 comments
Open
Labels
blocked Something is in the way of fixing this. Refer to it in the issue bug Something isn't working

Comments

@rdgao
Copy link
Contributor

rdgao commented Sep 5, 2022

This is not really an sbi bug but a pytorch one, but good to be aware anyway if using custom prior distributions:

in pytorch TransformedDistributions, support is computed as: support = self.transforms[-1].codomain, which ignores any constraints of the base distribution and any intermediate transformations, which is problematic if the base distribution is, e.g., Uniform. This results in accepting posterior samples that are outside of prior bounds when checked in sbiutils.within_support(). This is a known issue in pytorch.

To fix it, one can either:

  1. manually reassigning transforms[-1].codomain to be, e.g., the transformed bounds of the Uniform distribution (hacky but fast), or
  2. in sbiutils.within_support(), call .log_prob() to detect if there are samples that raises an error, which happens if it's outside of the distribution's support (works in general, but ugly and slow).

@michaeldeistler anything else to add?

@michaeldeistler
Copy link
Contributor

Thanks for creating this!

Just to re-iterate: this error will occur, e.g., during .train() of SNPE-C (second round), when prior.log_prob() is called. This gives an error because it is evaluated outside of the prior support.

@michaeldeistler
Copy link
Contributor

michaeldeistler commented Sep 5, 2022

And to reproduce:

from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform

base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, AffineTransform(zeros(1), ones(1)))
dist.support.check(100*ones(1))
# -> returns tensor([True])

@rdgao
Copy link
Contributor Author

rdgao commented Sep 5, 2022

and interestingly, if you try to set the codomain based on the suggested solution above (option 1), via
dist.transforms[-1].codomain = interval(0, 1)

it throws an error, perhaps appropriately: AttributeError: can't set attribute

another example, where setting the codomain is possible:

from torch import ones, zeros
from torch.distributions import AffineTransform, TransformedDistribution, Uniform, ExpTransform

base = Uniform(zeros(1), ones(1))
dist = TransformedDistribution(base, ExpTransform())
print(dist.support.check(100*ones(1)))

from torch.distributions.constraints import interval
dist.transforms[-1].codomain = interval(0, 1)
print(dist.support.check(100*ones(1)))

@michaeldeistler michaeldeistler added bug Something isn't working blocked Something is in the way of fixing this. Refer to it in the issue labels Sep 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
blocked Something is in the way of fixing this. Refer to it in the issue bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants