Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explicitly import distributions from torch #3333

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
all: docs test

install: FORCE
pip install -e .[dev,profile]
pip install -e .[dev,profile] --config-settings editable_mode=strict

uninstall: FORCE
pip uninstall pyro-ppl
Expand Down
84 changes: 84 additions & 0 deletions pyro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,51 @@
# SPDX-License-Identifier: Apache-2.0

import pyro.distributions.torch_patch # noqa F403

# Import both * to get new distributions and explicitly to help mypy.
from pyro.distributions.torch import * # noqa F403
from pyro.distributions.torch import (
Bernoulli,
Beta,
Binomial,
Categorical,
Cauchy,
Chi2,
ContinuousBernoulli,
Dirichlet,
ExponentialFamily,
Exponential,
FisherSnedecor,
Gamma,
Geometric,
Gumbel,
HalfCauchy,
HalfNormal,
Independent,
Kumaraswamy,
Laplace,
LKJCholesky,
LogNormal,
LogisticNormal,
LowRankMultivariateNormal,
MixtureSameFamily,
Multinomial,
MultivariateNormal,
NegativeBinomial,
Normal,
OneHotCategorical,
OneHotCategoricalStraightThrough,
Pareto,
Poisson,
RelaxedBernoulli,
RelaxedOneHotCategorical,
StudentT,
TransformedDistribution,
Uniform,
VonMises,
Weibull,
Wishart,
)

# isort: split

Expand Down Expand Up @@ -99,7 +143,13 @@
"AVFMultivariateNormal",
"AffineBeta",
"AsymmetricLaplace",
"Bernoulli",
"Beta",
"BetaBinomial",
"Binomial",
"Categorical",
"Cauchy",
"Chi2",
"CoalescentRateLikelihood",
"CoalescentTimes",
"CoalescentTimesWithRate",
Expand All @@ -108,43 +158,71 @@
"ConditionalTransform",
"ConditionalTransformModule",
"ConditionalTransformedDistribution",
"ContinuousBernoulli",
"Delta",
"Dirichlet",
"DirichletMultinomial",
"DiscreteHMM",
"Distribution",
"Empirical",
"ExpandedDistribution",
"Exponential",
"ExponentialFamily",
"ExtendedBetaBinomial",
"ExtendedBinomial",
"FisherSnedecor",
"FoldedDistribution",
"Gamma",
"GammaGaussianHMM",
"GammaPoisson",
"GaussianHMM",
"GaussianMRF",
"GaussianScaleMixture",
"Geometric",
"GroupedNormalNormal",
"Gumbel",
"HalfCauchy",
"HalfNormal",
"ImproperUniform",
"Independent",
"IndependentHMM",
"InverseGamma",
"Kumaraswamy",
"LKJ",
"LKJCholesky",
"LKJCorrCholesky",
"Laplace",
"LinearHMM",
"LogNormal",
"LogNormalNegativeBinomial",
"Logistic",
"LogisticNormal",
"LowRankMultivariateNormal",
"MaskedDistribution",
"MaskedMixture",
"MixtureOfDiagNormals",
"MixtureOfDiagNormalsSharedCovariance",
"MixtureSameFamily",
"Multinomial",
"MultivariateNormal",
"MultivariateStudentT",
"NanMaskedMultivariateNormal",
"NanMaskedNormal",
"NegativeBinomial",
"Normal",
"OMTMultivariateNormal",
"OneHotCategorical",
"OneHotCategoricalStraightThrough",
"OneOneMatching",
"OneTwoMatching",
"OrderedLogistic",
"Pareto",
"Poisson",
"ProjectedNormal",
"Rejector",
"RelaxedBernoulli",
"RelaxedBernoulliStraightThrough",
"RelaxedOneHotCategorical",
"RelaxedOneHotCategoricalStraightThrough",
"SineBivariateVonMises",
"SineSkewed",
Expand All @@ -153,11 +231,17 @@
"SoftLaplace",
"SpanningTree",
"Stable",
"StudentT",
"TorchDistribution",
"TransformModule",
"TransformedDistribution",
"TruncatedPolyaGamma",
"Uniform",
"Unit",
"VonMises",
"VonMises3D",
"Weibull",
"Wishart",
"ZeroInflatedDistribution",
"ZeroInflatedNegativeBinomial",
"ZeroInflatedPoisson",
Expand Down
49 changes: 47 additions & 2 deletions pyro/distributions/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,52 @@ def _cat_docstrings(*docstrings):
return result


# Programmatically load all distributions from PyTorch.
__all__ = []
# Add static imports to help mypy.
__all__ = [ # noqa: F822
"Bernoulli",
"Beta",
"Binomial",
"Categorical",
"Cauchy",
"Chi2",
"ContinuousBernoulli",
"Dirichlet",
"ExponentialFamily",
"Exponential",
"FisherSnedecor",
"Gamma",
"Geometric",
"Gumbel",
"HalfCauchy",
"HalfNormal",
"Independent",
"Kumaraswamy",
"Laplace",
"LKJCholesky",
"LogNormal",
"LogisticNormal",
"LowRankMultivariateNormal",
"MixtureSameFamily",
"Multinomial",
"MultivariateNormal",
"NegativeBinomial",
"Normal",
"OneHotCategorical",
"OneHotCategoricalStraightThrough",
"Pareto",
"Poisson",
"RelaxedBernoulli",
"RelaxedOneHotCategorical",
"StudentT",
"TransformedDistribution",
"Uniform",
"VonMises",
"Weibull",
"Wishart",
]

# Programmatically load all distributions from PyTorch,
# updating __all__ to include any new distributions.
for _name, _Dist in torch.distributions.__dict__.items():
if not isinstance(_Dist, type):
continue
Expand All @@ -372,6 +416,7 @@ def _cat_docstrings(*docstrings):
)
_PyroDist.__doc__ = _cat_docstrings(_PyroDist.__doc__, _Dist.__doc__)
__all__.append(_name)
__all__ = sorted(set(__all__))


# Create sphinx documentation.
Expand Down
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@
long_description=long_description,
long_description_content_type="text/markdown",
packages=find_packages(include=["pyro", "pyro.*"]),
package_data={"pyro.distributions": ["*.cpp"]},
package_data={
"pyro": ["py.typed"],
"pyro.distributions": ["*.cpp"],
},
author="Uber AI Labs",
url="http://pyro.ai",
project_urls={
Expand Down
5 changes: 3 additions & 2 deletions tests/infer/test_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def guide():
if reparameterized and has_rsample is not False:
# pathwise gradient estimator
expected_grads = {
"scale": -(-z * (z - loc) + (x - z) * (z - loc) + 1).sum(0, keepdim=True)
/ scale,
"scale": (
-(-z * (z - loc) + (x - z) * (z - loc) + 1).sum(0, keepdim=True) / scale
),
"loc": -(-z + (x - z)),
}
else:
Expand Down
Loading