diff --git a/Makefile b/Makefile index f0585be7dc..b531c2a515 100644 --- a/Makefile +++ b/Makefile @@ -3,7 +3,7 @@ all: docs test install: FORCE - pip install -e .[dev,profile] + pip install -e .[dev,profile] --config-settings editable_mode=strict uninstall: FORCE pip uninstall pyro-ppl diff --git a/pyro/distributions/__init__.py b/pyro/distributions/__init__.py index b648d0d66a..edfe6e85f4 100644 --- a/pyro/distributions/__init__.py +++ b/pyro/distributions/__init__.py @@ -2,8 +2,57 @@ # SPDX-License-Identifier: Apache-2.0 import pyro.distributions.torch_patch # noqa F403 + +# Import * to get the latest upstream distributions. from pyro.distributions.torch import * # noqa F403 +# Additionally try to import explicitly to help mypy static analysis. +try: + from pyro.distributions.torch import ( + Bernoulli, + Beta, + Binomial, + Categorical, + Cauchy, + Chi2, + ContinuousBernoulli, + Dirichlet, + Exponential, + ExponentialFamily, + FisherSnedecor, + Gamma, + Geometric, + Gumbel, + HalfCauchy, + HalfNormal, + Independent, + Kumaraswamy, + Laplace, + LKJCholesky, + LogisticNormal, + LogNormal, + LowRankMultivariateNormal, + MixtureSameFamily, + Multinomial, + MultivariateNormal, + NegativeBinomial, + Normal, + OneHotCategorical, + OneHotCategoricalStraightThrough, + Pareto, + Poisson, + RelaxedBernoulli, + RelaxedOneHotCategorical, + StudentT, + TransformedDistribution, + Uniform, + VonMises, + Weibull, + Wishart, + ) +except ImportError: + pass + # isort: split from pyro.distributions.affine_beta import AffineBeta @@ -99,7 +148,13 @@ "AVFMultivariateNormal", "AffineBeta", "AsymmetricLaplace", + "Bernoulli", + "Beta", "BetaBinomial", + "Binomial", + "Categorical", + "Cauchy", + "Chi2", "CoalescentRateLikelihood", "CoalescentTimes", "CoalescentTimesWithRate", @@ -108,43 +163,71 @@ "ConditionalTransform", "ConditionalTransformModule", "ConditionalTransformedDistribution", + "ContinuousBernoulli", "Delta", + "Dirichlet", "DirichletMultinomial", "DiscreteHMM", "Distribution", "Empirical", "ExpandedDistribution", + "Exponential", + "ExponentialFamily", "ExtendedBetaBinomial", "ExtendedBinomial", + "FisherSnedecor", "FoldedDistribution", + "Gamma", "GammaGaussianHMM", "GammaPoisson", "GaussianHMM", "GaussianMRF", "GaussianScaleMixture", + "Geometric", "GroupedNormalNormal", + "Gumbel", + "HalfCauchy", + "HalfNormal", "ImproperUniform", + "Independent", "IndependentHMM", "InverseGamma", + "Kumaraswamy", "LKJ", + "LKJCholesky", "LKJCorrCholesky", + "Laplace", "LinearHMM", + "LogNormal", "LogNormalNegativeBinomial", "Logistic", + "LogisticNormal", + "LowRankMultivariateNormal", "MaskedDistribution", "MaskedMixture", "MixtureOfDiagNormals", "MixtureOfDiagNormalsSharedCovariance", + "MixtureSameFamily", + "Multinomial", + "MultivariateNormal", "MultivariateStudentT", "NanMaskedMultivariateNormal", "NanMaskedNormal", + "NegativeBinomial", + "Normal", "OMTMultivariateNormal", + "OneHotCategorical", + "OneHotCategoricalStraightThrough", "OneOneMatching", "OneTwoMatching", "OrderedLogistic", + "Pareto", + "Poisson", "ProjectedNormal", "Rejector", + "RelaxedBernoulli", "RelaxedBernoulliStraightThrough", + "RelaxedOneHotCategorical", "RelaxedOneHotCategoricalStraightThrough", "SineBivariateVonMises", "SineSkewed", @@ -153,11 +236,17 @@ "SoftLaplace", "SpanningTree", "Stable", + "StudentT", "TorchDistribution", "TransformModule", + "TransformedDistribution", "TruncatedPolyaGamma", + "Uniform", "Unit", + "VonMises", "VonMises3D", + "Weibull", + "Wishart", "ZeroInflatedDistribution", "ZeroInflatedNegativeBinomial", "ZeroInflatedPoisson", @@ -171,4 +260,5 @@ # Import all torch distributions from `pyro.distributions.torch_distribution` __all__.extend(torch_dists) +__all__[:] = sorted(set(__all__)) del torch_dists diff --git a/pyro/distributions/constraints.py b/pyro/distributions/constraints.py index 3f8026f2e0..7e6d3072bf 100644 --- a/pyro/distributions/constraints.py +++ b/pyro/distributions/constraints.py @@ -1,18 +1,50 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +# Import * to get the latest upstream constraints. from torch.distributions.constraints import * # noqa F403 +# Additionally try to import explicitly to help mypy static analysis. +try: + from torch.distributions.constraints import ( + Constraint, + boolean, + cat, + corr_cholesky, + dependent, + dependent_property, + greater_than, + greater_than_eq, + half_open_interval, + independent, + integer_interval, + interval, + is_dependent, + less_than, + lower_cholesky, + lower_triangular, + multinomial, + nonnegative, + nonnegative_integer, + one_hot, + positive, + positive_definite, + positive_integer, + positive_semidefinite, + real, + real_vector, + simplex, + square, + stack, + symmetric, + unit_interval, + ) +except ImportError: + pass + # isort: split import torch -from torch.distributions.constraints import ( - Constraint, - independent, - lower_cholesky, - positive, - positive_definite, -) from torch.distributions.constraints import __all__ as torch_constraints @@ -129,19 +161,50 @@ def check(self, value): corr_cholesky_constraint = corr_cholesky # noqa: F405 DEPRECATED __all__ = [ + "Constraint", + "boolean", + "cat", + "corr_cholesky", "corr_cholesky_constraint", "corr_matrix", + "dependent", + "dependent_property", + "greater_than", + "greater_than_eq", + "half_open_interval", + "independent", "integer", + "integer_interval", + "interval", + "is_dependent", + "less_than", + "lower_cholesky", + "lower_triangular", + "multinomial", + "nonnegative", + "nonnegative_integer", + "one_hot", "ordered_vector", + "positive", + "positive_definite", + "positive_integer", "positive_ordered_vector", + "positive_semidefinite", + "real", + "real_vector", + "simplex", "softplus_lower_cholesky", "softplus_positive", "sphere", + "square", + "stack", + "symmetric", + "unit_interval", "unit_lower_cholesky", ] __all__.extend(torch_constraints) -__all__ = sorted(set(__all__)) +__all__[:] = sorted(set(__all__)) del torch_constraints diff --git a/pyro/distributions/torch.py b/pyro/distributions/torch.py index 902602de1a..2f3f255d97 100644 --- a/pyro/distributions/torch.py +++ b/pyro/distributions/torch.py @@ -346,8 +346,52 @@ def _cat_docstrings(*docstrings): return result -# Programmatically load all distributions from PyTorch. -__all__ = [] +# Add static imports to help mypy. +__all__ = [ # noqa: F822 + "Bernoulli", + "Beta", + "Binomial", + "Categorical", + "Cauchy", + "Chi2", + "ContinuousBernoulli", + "Dirichlet", + "ExponentialFamily", + "Exponential", + "FisherSnedecor", + "Gamma", + "Geometric", + "Gumbel", + "HalfCauchy", + "HalfNormal", + "Independent", + "Kumaraswamy", + "Laplace", + "LKJCholesky", + "LogNormal", + "LogisticNormal", + "LowRankMultivariateNormal", + "MixtureSameFamily", + "Multinomial", + "MultivariateNormal", + "NegativeBinomial", + "Normal", + "OneHotCategorical", + "OneHotCategoricalStraightThrough", + "Pareto", + "Poisson", + "RelaxedBernoulli", + "RelaxedOneHotCategorical", + "StudentT", + "TransformedDistribution", + "Uniform", + "VonMises", + "Weibull", + "Wishart", +] + +# Programmatically load all distributions from PyTorch, +# updating __all__ to include any new distributions. for _name, _Dist in torch.distributions.__dict__.items(): if not isinstance(_Dist, type): continue @@ -372,6 +416,7 @@ def _cat_docstrings(*docstrings): ) _PyroDist.__doc__ = _cat_docstrings(_PyroDist.__doc__, _Dist.__doc__) __all__.append(_name) +__all__ = sorted(set(__all__)) # Create sphinx documentation. diff --git a/pyro/distributions/transforms/__init__.py b/pyro/distributions/transforms/__init__.py index d2a2382974..89375afab4 100644 --- a/pyro/distributions/transforms/__init__.py +++ b/pyro/distributions/transforms/__init__.py @@ -1,16 +1,39 @@ # Copyright (c) 2017-2019 Uber Technologies, Inc. # SPDX-License-Identifier: Apache-2.0 +# Import * to get the latest upstream transforms. from torch.distributions.transforms import * # noqa F403 +# Additionally try to import explicitly to help mypy static analysis. +try: + from torch.distributions.transforms import ( + AbsTransform, + AffineTransform, + CatTransform, + ComposeTransform, + # CorrCholeskyTransform, # Use Pyro's version below. + CumulativeDistributionTransform, + ExpTransform, + IndependentTransform, + LowerCholeskyTransform, + PositiveDefiniteTransform, + PowerTransform, + ReshapeTransform, + SigmoidTransform, + SoftmaxTransform, + # SoftplusTransform, # Use Pyro's version below. + StackTransform, + StickBreakingTransform, + TanhTransform, + Transform, + identity_transform, + ) +except ImportError: + pass + # isort: split from torch.distributions import biject_to, transform_to -from torch.distributions.transforms import ( - ComposeTransform, - ExpTransform, - LowerCholeskyTransform, -) from torch.distributions.transforms import __all__ as torch_transforms from .. import constraints @@ -150,12 +173,15 @@ def iterated(repeats, base_fn, *args, **kwargs): __all__ = [ - "iterated", + "AbsTransform", "AffineAutoregressive", "AffineCoupling", + "AffineTransform", "BatchNorm", "BlockAutoregressive", + "CatTransform", "CholeskyTransform", + "ComposeTransform", "ComposeTransformModule", "ConditionalAffineAutoregressive", "ConditionalAffineCoupling", @@ -167,15 +193,20 @@ def iterated(repeats, base_fn, *args, **kwargs): "ConditionalRadial", "ConditionalSpline", "ConditionalSplineAutoregressive", + "CorrCholeskyTransform", "CorrLCholeskyTransform", "CorrMatrixCholeskyTransform", + "CumulativeDistributionTransform", "DiscreteCosineTransform", "ELUTransform", + "ExpTransform", "GeneralizedChannelPermute", "HaarTransform", "Householder", + "IndependentTransform", "LeakyReLUTransform", "LowerCholeskyAffine", + "LowerCholeskyTransform", "MatrixExponential", "NeuralAutoregressive", "Normalize", @@ -183,15 +214,24 @@ def iterated(repeats, base_fn, *args, **kwargs): "Permute", "Planar", "Polynomial", + "PositiveDefiniteTransform", "PositivePowerTransform", + "PowerTransform", "Radial", + "ReshapeTransform", + "SigmoidTransform", "SimplexToOrderedTransform", + "SoftmaxTransform", "SoftplusLowerCholeskyTransform", "SoftplusTransform", "Spline", "SplineAutoregressive", "SplineCoupling", + "StackTransform", + "StickBreakingTransform", "Sylvester", + "TanhTransform", + "Transform", "affine_autoregressive", "affine_coupling", "batchnorm", @@ -209,6 +249,8 @@ def iterated(repeats, base_fn, *args, **kwargs): "elu", "generalized_channel_permute", "householder", + "identity_transform", + "iterated", "leaky_relu", "matrix_exponential", "neural_autoregressive", @@ -223,4 +265,5 @@ def iterated(repeats, base_fn, *args, **kwargs): ] __all__.extend(torch_transforms) +__all__[:] = sorted(set(__all__)) del torch_transforms diff --git a/setup.py b/setup.py index e8b075d146..c2470e07b3 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,10 @@ long_description=long_description, long_description_content_type="text/markdown", packages=find_packages(include=["pyro", "pyro.*"]), - package_data={"pyro.distributions": ["*.cpp"]}, + package_data={ + "pyro": ["py.typed"], + "pyro.distributions": ["*.cpp"], + }, author="Uber AI Labs", url="http://pyro.ai", project_urls={ diff --git a/tests/infer/test_gradient.py b/tests/infer/test_gradient.py index 69501cf561..f6bd6f3024 100644 --- a/tests/infer/test_gradient.py +++ b/tests/infer/test_gradient.py @@ -94,8 +94,9 @@ def guide(): if reparameterized and has_rsample is not False: # pathwise gradient estimator expected_grads = { - "scale": -(-z * (z - loc) + (x - z) * (z - loc) + 1).sum(0, keepdim=True) - / scale, + "scale": ( + -(-z * (z - loc) + (x - z) * (z - loc) + 1).sum(0, keepdim=True) / scale + ), "loc": -(-z + (x - z)), } else: