From ecce73327be3ddd867ab7ada1b7c82f27b14bddc Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 14:59:49 +0000 Subject: [PATCH 01/39] add LogNormal API --- python/paddle/distribution/__init__.py | 2 +- python/paddle/distribution/kl.py | 8 + python/paddle/distribution/lognormal.py | 166 ++++++++++++++ python/paddle/distribution/normal.py | 29 ++- python/paddle/distribution/transform.py | 4 +- .../distribution/transformed_distribution.py | 38 +++- .../test_distribution_lognormal.py | 211 ++++++++++++++++++ .../test_distribution_lognormal_static.py | 197 ++++++++++++++++ 8 files changed, 642 insertions(+), 13 deletions(-) create mode 100644 python/paddle/distribution/lognormal.py create mode 100644 python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py create mode 100644 python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 64d59b04864ba..41749a31a6f28 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -30,7 +30,7 @@ __all__ = [ # noqa 'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily', 'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl', - 'Independent', 'TransformedDistribution' + 'Independent', 'TransformedDistribution', 'LogNormal' ] __all__.extend(transform.__all__) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index c5ad3f04358dc..12019e1dcb4d1 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -21,9 +21,12 @@ from paddle.distribution.distribution import Distribution from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.normal import Normal +from paddle.distribution.lognormal import LogNormal from paddle.distribution.uniform import Uniform from paddle.fluid.framework import _non_static_mode, in_dygraph_mode + + __all__ = ["register_kl", "kl_divergence"] _REGISTER_TABLE = {} @@ -206,5 +209,10 @@ def _kl_expfamily_expfamily(p, q): return kl +@register_kl(LogNormal, LogNormal) +def _kl_normal_normal(p, q): + return p.base_dist.kl_divergence(q.base_dist) + + def _sum_rightmost(value, n): return value.sum(list(range(-n, 0))) if n > 0 else value diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py new file mode 100644 index 0000000000000..652ce59fcaea4 --- /dev/null +++ b/python/paddle/distribution/lognormal.py @@ -0,0 +1,166 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distribution import Normal, TransformedDistribution, ExpTransform + + +class LogNormal(TransformedDistribution): + r"""The Normal distribution with location `loc` and `scale` parameters. + + Mathematical details + + The probability density function (pdf) is + + .. math:: + pdf(x; \mu, \sigma) = \\frac{1}{\sigma x \sqrt{2\pi}}e^{(-\\frac{(ln(x) - \mu)^2}{2\sigma^2})} + pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} } + + In the above equation: + + * :math:`loc = \mu`: is the means of the underlying Normal distribution. + * :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. + + Args: + loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor. + scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor. + name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Examples: + .. code-block:: python + + import paddle + #from paddle.distribution import LogNormal + from lognormal import LogNormal + # Define a single scalar LogNormal distribution. + dist = LogNormal(loc=0., scale=3.) + # Define a batch of two scalar valued LogNormals. + # The underlying Normal of first has mean 1 and standard deviation 11, the underlying Normal of second 2 and 22. + dist = LogNormal(loc=[1., 2.], scale=[11., 22.]) + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + + # Define a batch of two scalar valued LogNormals. + # Their underlying Normal have mean 1, but different standard deviations. + dist = LogNormal(loc=1., scale=[11., 22.]) + + # Complete example + value_tensor = paddle.to_tensor([0.8], dtype="float32") + + lognormal_a = LogNormal([0.], [1.]) + lognormal_b = LogNormal([0.5], [2.]) + sample = lognormal_a.sample([2]) + # a random tensor created by normal distribution with shape: [2, 1] + entropy = lognormal_a.entropy() + # [1.4189385] with shape: [1] + lp = lognormal_a.log_prob(value_tensor) + # [-0.72069150] with shape: [1] + p = lognormal_a.probs(value_tensor) + # [0.48641577] with shape: [1] + kl = lognormal_a.kl_divergence(lognormal_b) + # [0.34939718] with shape: [1] + """ + + def __init__(self, loc, scale, name=None): + self.base_dist = Normal(loc=loc, scale=scale, name=name) + self.loc = self.base_dist.loc + self.scale = self.base_dist.scale + super(LogNormal, self).__init__(self.base_dist, [ExpTransform()]) + + @property + def mean(self): + """mean of lognormal distribuion. + + Returns: + Tensor: mean value. + """ + return paddle.exp(self.base_dist.mean + self.base_dist.variance / 2) + + @property + def variance(self): + """variance of lognormal distribution. + + Returns: + Tensor: variance value. + """ + return (paddle.expm1(self.base_dist.variance) * + paddle.exp(2 * self.base_dist.mean + self.base_dist.variance)) + + def entropy(self): + r"""Shannon entropy in nats. + + The entropy is + + .. math:: + + entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2) + \mu + + In the above equation: + + * :math:`scale = \sigma`: is the std. + + Returns: + Tensor: Shannon entropy of lognormal distribution.The data type is float32. + + """ + return self.base_dist.entropy() + self.base_dist.mean + + def probs(self, value): + """Probability density/mass function. + + Args: + value (Tensor): The input tensor. + + Returns: + Tensor: probability.The data type is same with value. + + """ + return paddle.exp(self.log_prob(value)) + + def kl_divergence(self, other): + r"""The KL-divergence between two lognormal distributions. + + The probability density function (pdf) is + + .. math:: + + KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio}) + + .. math:: + + ratio = \\frac{\sigma_0}{\sigma_1} + + .. math:: + + diff = \mu_1 - \mu_0 + + In the above equation: + + * :math:`loc = \mu_0`: is the mean of underlying Normal distribution. + * :math:`scale = \sigma_0`: is the std of underlying Normal distribution. + * :math:`loc = \mu_1`: is the mean of other underlying Normal distribution. + * :math:`scale = \sigma_1`: is the std of other underlying Normal distribution. + * :math:`ratio`: is the ratio of scales. + * :math:`diff`: is the difference between means. + + Args: + other (LogNormal): instance of LogNormal. + + Returns: + Tensor: kl-divergence between two lognormal distributions.The data type is float32. + + """ + return self.base_dist.kl_divergence(other.base_dist) + +# print(dist.prob(value_tensor)) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index f248e1a09273d..fedc20ec6e238 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -13,10 +13,8 @@ # limitations under the License. import math -import warnings import numpy as np -from paddle import _C_ops, _legacy_C_ops from paddle.distribution import distribution from paddle.fluid import core from paddle.fluid.data_feeder import (check_dtype, check_type, @@ -55,7 +53,7 @@ class Normal(distribution.Distribution): Examples: .. code-block:: python - + import paddle from paddle.distribution import Normal @@ -128,6 +126,14 @@ def __init__(self, loc, scale, name=None): self.scale = tensor.cast(self.scale, dtype=self.dtype) super(Normal, self).__init__(self.loc.shape) + @property + def mean(self): + return self.loc + + @property + def variance(self): + return self.scale.pow(2) + def sample(self, shape, seed=0): """Generate samples of the specified shape. @@ -163,13 +169,26 @@ def sample(self, shape, seed=0): else: output_shape = shape + batch_shape output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \ - (tensor.zeros(output_shape, dtype=self.dtype) + self.scale) + (tensor.zeros(output_shape, dtype=self.dtype) + self.scale) output = elementwise_add(output, self.loc, name=name) if self.all_arg_is_float: return nn.reshape(output, shape, name=name) else: return output + def rsample(self, shape, seed=0): + """Generate reparameterized samples of the specified shape. + + Args: + shape (list): 1D `int32`. Shape of the generated samples. + seed (int): Python integer number. + + Returns: + Tensor: A tensor with prepended dimensions shape.The data type is float32. + + """ + return self.sample(shape) + def entropy(self): r"""Shannon entropy in nats. @@ -248,7 +267,7 @@ def kl_divergence(self, other): .. math:: ratio = \\frac{\sigma_0}{\sigma_1} - + .. math:: diff = \mu_1 - \mu_0 diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index d7a512aade2e5..ffdcc08831c8b 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -142,7 +142,7 @@ def __call__(self, input): input, [self]) if isinstance(input, Transform): return ChainTransform([self, input]) - return self.forward(x) + return self.forward(input) def forward(self, x): """Forward transformation with mapping :math:`y = f(x)`. @@ -285,7 +285,7 @@ def _call_forward_log_det_jacobian(self, x): if hasattr(self, '_forward_log_det_jacobian'): return self._forward_log_det_jacobian(x) if hasattr(self, '_inverse_log_det_jacobian'): - return -self._inverse_log_det_jacobian(self.forward(y)) + return -self._inverse_log_det_jacobian(self.forward(x)) raise NotImplementedError( 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian' 'is implemented. One of them is required.') diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index ce386971e5fcc..e082f148a21a2 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -30,7 +30,7 @@ class TransformedDistribution(distribution.Distribution): Examples: .. code-block:: python - + import paddle from paddle.distribution import transformed_distribution @@ -58,7 +58,8 @@ def __init__(self, base, transforms): f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}." ) if not all(isinstance(t, transform.Transform) for t in transforms): - raise TypeError("All element of transforms must be Transform type.") + raise TypeError( + "All element of transforms must be Transform type.") chain = transform.ChainTransform(transforms) if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: @@ -76,8 +77,9 @@ def __init__(self, base, transforms): transformed_event_rank = chain._codomain.event_rank + \ max(len(base.event_shape)-chain._domain.event_rank, 0) super(TransformedDistribution, self).__init__( - transformed_shape[:len(transformed_shape) - transformed_event_rank], - transformed_shape[:len(transformed_shape) - transformed_event_rank]) + transformed_shape[:len(transformed_shape) - + transformed_event_rank], + transformed_shape[len(transformed_shape) - transformed_event_rank:]) def sample(self, shape=()): """Sample from ``TransformedDistribution``. @@ -93,6 +95,20 @@ def sample(self, shape=()): x = t.forward(x) return x + def rsample(self, shape=()): + """Reparameterized sample from ``TransformedDistribution``. + + Args: + shape (tuple, optional): The sample shape. Defaults to (). + + Returns: + [Tensor]: The sample result. + """ + x = self._base.rsample(shape) + for t in self._transforms: + x = t.forward(x) + return x + def log_prob(self, value): """The log probability evaluated at value. @@ -110,12 +126,24 @@ def log_prob(self, value): event_rank += t._domain.event_rank - t._codomain.event_rank log_prob = log_prob - \ _sum_rightmost(t.forward_log_det_jacobian( - x), event_rank-t._domain.event_rank) + x), event_rank - t._domain.event_rank) y = x log_prob += _sum_rightmost(self._base.log_prob(y), event_rank - len(self._base.event_shape)) return log_prob + def _monotonize_cdf(self, value): + """ + This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is + monotone increasing. + """ + sign = 1 + for t in self._transforms: + sign = sign * t.sign + if isinstance(sign, int) and sign == 1: + return value + return sign * (value - 0.5) + 0.5 + def _sum_rightmost(value, n): return value.sum(list(range(-n, 0))) if n > 0 else value diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py new file mode 100644 index 0000000000000..816ecab5a5cfc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -0,0 +1,211 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import unittest +import scipy.stats + +import numpy as np +import paddle + + +import config +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +from paddle.distribution import LogNormal +from test_distribution import DistributionNumpy +from paddle.distribution.kl import kl_divergence + +np.random.seed(2022) + + +class LogNormalNumpy(DistributionNumpy): + + def __init__(self, loc, scale): + self.loc = np.array(loc) + self.scale = np.array(scale) + if str(self.loc.dtype) not in ['float32', 'float64']: + self.loc = self.loc.astype('float32') + self.scale = self.scale.astype('float32') + + @property + def mean(self): + var = self.scale * self.scale + return np.exp(self.loc + var / 2) + + @property + def variance(self): + var = self.scale * self.scale + return (np.exp(var) - 1) * np.exp(2 * self.loc + var) + + def log_prob(self, value): + var = self.scale * self.scale + log_scale = np.log(self.scale) + return -((np.log(value) - self.loc) * + (np.log(value) - self.loc)) / (2. * var) - log_scale - math.log( + math.sqrt(2. * math.pi)) - np.log(value) + + def probs(self, value): + var = self.scale * self.scale + return np.exp(-1. * ((np.log(value) - self.loc) * (np.log(value) - self.loc)) / + (2. * var)) / (math.sqrt(2 * math.pi) * self.scale * value) + + def entropy(self): + return 0.5 + self.loc + 0.5 * np.log( + np.array(2. * math.pi).astype(self.loc.dtype)) + np.log(self.scale) + + def kl_divergence(self, other): + var_ratio = (self.scale / other.scale) + var_ratio = var_ratio * var_ratio + t1 = ((self.loc - other.loc) / other.scale) + t1 = (t1 * t1) + return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc', 'scale'), [ + ('float', xrand(), xrand()), + ('one-dim', xrand((3, )), xrand((3, ))), + ('multi-dim', xrand((5, 5)), xrand((5, 5))) + ]) +class LogNormalTest(unittest.TestCase): + + def setUp(self): + self._paddle_lognormal = LogNormal( + loc=paddle.to_tensor(self.loc), + scale=paddle.to_tensor(self.scale)) + self._np_lognormal = LogNormalNumpy(self.loc, self.scale) + + def test_mean(self): + mean = self._paddle_lognormal.mean + np_mean = self._np_lognormal.mean + self.assertEqual(mean.numpy().dtype, np_mean.dtype) + np.testing.assert_allclose(mean, + np_mean, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_variance(self): + var = self._paddle_lognormal.variance + np_var = self._np_lognormal.variance + self.assertEqual(var.numpy().dtype, np_var.dtype) + np.testing.assert_allclose(var, + np_var, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_entropy(self): + entropy = self._paddle_lognormal.entropy() + np_entropy = self._np_lognormal.entropy() + self.assertEqual(entropy.numpy().dtype, np_entropy.dtype) + np.testing.assert_allclose(entropy, + np_entropy, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_probs(self): + value = [np.random.rand(*self.scale.shape)] + + for v in value: + with paddle.fluid.dygraph.guard(self.place): + probs = self._paddle_lognormal.probs(paddle.to_tensor(v)) + np_probs = self._np_lognormal.probs(v) + np.testing.assert_allclose( + probs, + np_probs, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_log_prob(self): + value = [np.random.rand(*self.scale.shape)] + for v in value: + with paddle.fluid.dygraph.guard(self.place): + log_prob = self._paddle_lognormal.log_prob(paddle.to_tensor(v)) + np_log_prob = self._np_lognormal.log_prob(v) + np.testing.assert_allclose( + log_prob, + np_log_prob, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc', 'scale'), [ + ('sample1', xrand((2, )), xrand((2, ))), + ('sample2', xrand((5, )), xrand((5, ))) + ]) +class LogNormalTestSample(unittest.TestCase): + def test_sample(self): + self._paddle_lognormal = LogNormal( + loc=self.loc, + scale=self.scale) + shape = [8000] + samples = self._paddle_lognormal.sample(shape) + for i in range(len(self.scale)): + self.assertTrue(self._kstest( + self.loc[i], self.scale[i], samples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, + scipy.stats.lognorm(s=scale, scale=np.exp(loc)).cdf) + return ks < 0.02 + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc1', 'scale1', + 'loc2', 'scale2'), [ + ('one-dim', xrand((2, )), xrand((2, )), + xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((2, 2)), xrand((2, 2)), + xrand((2, 2)), xrand((2, 2))) + ]) +class TestLognormalKL(unittest.TestCase): + + def setUp(self): + self._paddle_lognormal = LogNormal( + loc=paddle.to_tensor(self.loc1), + scale=paddle.to_tensor(self.scale1)) + self._paddle_lognormal_other = LogNormal( + loc=paddle.to_tensor(self.loc2), + scale=paddle.to_tensor(self.scale2)) + + def test_kl_divergence(self): + kl1 = kl_divergence(self._paddle_lognormal, + self._paddle_lognormal_other) + kl2 = self._kl(self._paddle_lognormal, self._paddle_lognormal_other) + np.testing.assert_allclose( + kl1, + kl2, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + + def _kl(self, dist1, dist2): + loc1 = np.array(dist1.loc) + loc2 = np.array(dist2.loc) + scale1 = np.array(dist1.scale) + scale2 = np.array(dist2.scale) + var_ratio = (scale1 / scale2) + var_ratio = var_ratio * var_ratio + t1 = ((loc1 - loc2) / scale2) + t1 = (t1 * t1) + return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py new file mode 100644 index 0000000000000..493c8dab45ba1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -0,0 +1,197 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +import scipy.stats +import config + +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand + +from paddle.distribution import LogNormal +from test_distribution import DistributionNumpy +from test_distribution_lognormal import LogNormalNumpy +from paddle.distribution.kl import kl_divergence + +np.random.seed(2022) + +paddle.enable_static() + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc', 'scale'), [ + ('one-dim', xrand((2, )), + xrand((2, ))), + ('multi-dim', xrand((3, 3)), + xrand((3, 3))) + ]) +class TestLogNormal(unittest.TestCase): + + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + self._paddle_lognormal = LogNormal(loc=loc, scale=scale) + self._np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale) + self.sample_shape = [8000] + + mean = self._paddle_lognormal.mean + var = self._paddle_lognormal.variance + entropy = self._paddle_lognormal.entropy() + samples = self._paddle_lognormal.sample(self.sample_shape) + fetch_list = [mean, var, entropy, samples] + self.feeds = {'loc': self.loc, 'scale': self.scale} + + executor.run(startup_program) + [self.mean, self.var, self.entropy, self.samples] = executor.run( + main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_mean(self): + np_mean = self._np_lognormal.mean + self.assertEqual(str(self.mean.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.mean, + np_mean, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_var(self): + np_var = self._np_lognormal.variance + self.assertEqual(str(self.var.dtype).split('.')[-1], self.scale.dtype) + np.testing.assert_allclose(self.var, + np_var, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_entropy(self): + np_entropy = self._np_lognormal.entropy() + self.assertEqual(str(self.entropy.dtype).split('.') + [-1], self.scale.dtype) + np.testing.assert_allclose(self.entropy, + np_entropy, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc', 'scale'), [ + ('sample', xrand((5, )), + xrand((5, ))) + ]) +class LogNormalTestSample(unittest.TestCase): + + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, + self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + self.sample_shape = [8000] + self._paddle_lognormal = LogNormal(loc=loc, scale=scale) + self.mean = self._paddle_lognormal.mean + self.samples = self._paddle_lognormal.sample(self.sample_shape) + fetch_list = [self.mean, self.samples] + self.feeds = {'loc': self.loc, 'scale': self.scale} + + executor.run(startup_program) + [self.mean, self.samples] = executor.run( + main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_sample(self): + for i in range(len(self.scale)): + self.assertTrue(self._kstest( + self.loc[i], self.scale[i], self.samples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest( + samples, + scipy.stats.lognorm(s=scale, scale=np.exp(loc)).cdf) + return ks < 0.02 + + +@place(config.DEVICES) +@parameterize_cls( + (TEST_CASE_NAME, 'loc1', 'scale1', + 'loc2', 'scale2'), [ + ('one-dim', xrand((2, )), xrand((2, )), + xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((2, 2)), xrand((2, 2)), + xrand((2, 2)), xrand((2, 2))) + ]) +class TestLognormalKL(unittest.TestCase): + + def setUp(self): + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc1 = paddle.static.data('loc1', self.loc1.shape, + self.loc1.dtype) + scale1 = paddle.static.data('scale1', self.scale1.shape, + self.scale1.dtype) + loc2 = paddle.static.data('loc2', self.loc2.shape, + self.loc2.dtype) + scale2 = paddle.static.data('scale2', self.scale2.shape, + self.scale2.dtype) + self._paddle_lognormal = LogNormal(loc=loc1, scale=scale1) + self._paddle_lognormal_other = LogNormal(loc=loc2, scale=scale2) + self.kl1 = kl_divergence( + self._paddle_lognormal, self._paddle_lognormal_other) + self.kl2 = self._kl(self._paddle_lognormal, + self._paddle_lognormal_other) + fetch_list = [self.kl1, self.kl2] + self.feeds = {'loc1': self.loc1, 'scale1': self.scale1, + 'loc2': self.loc2, 'scale2': self.scale2} + + executor.run(startup_program) + [self.kl1, self.kl2] = executor.run( + main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_kl_divergence(self): + np.testing.assert_allclose(self.kl1, + self.kl2, + rtol=config.RTOL.get( + str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + + def _kl(self, dist1, dist2): + loc1 = dist1.loc + loc2 = dist2.loc + scale1 = (dist1.scale) + scale2 = (dist2.scale) + var_ratio = (scale1 / scale2) + var_ratio = var_ratio * var_ratio + t1 = ((loc1 - loc2) / scale2) + t1 = (t1 * t1) + return 0.5 * (var_ratio + t1 - 1 - np.log(var_ratio)) + + +if __name__ == '__main__': + unittest.main() From b3d17574987fa040c64dd688569df3b5f3d23e86 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 15:47:04 +0000 Subject: [PATCH 02/39] fix bug --- python/paddle/distribution/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 41749a31a6f28..bd70831a3bade 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -22,6 +22,7 @@ from paddle.distribution.kl import kl_divergence, register_kl from paddle.distribution.multinomial import Multinomial from paddle.distribution.normal import Normal +from paddle.distribution.lognormal import LogNormal from paddle.distribution.transform import * # noqa: F403 from paddle.distribution.transformed_distribution import \ TransformedDistribution From 211f42cb740da6c0a2aabfdd02a217a4d8ae502b Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 16:13:24 +0000 Subject: [PATCH 03/39] fix bug --- python/paddle/distribution/lognormal.py | 2 - .../test_distribution_lognormal.py | 69 ++++++++----------- 2 files changed, 28 insertions(+), 43 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 652ce59fcaea4..a1cd2d3a2f145 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -162,5 +162,3 @@ def kl_divergence(self, other): """ return self.base_dist.kl_divergence(other.base_dist) - -# print(dist.prob(value_tensor)) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 816ecab5a5cfc..628fcbee3c755 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import math import unittest import scipy.stats @@ -20,7 +19,6 @@ import numpy as np import paddle - import config from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand from paddle.distribution import LogNormal @@ -52,14 +50,16 @@ def variance(self): def log_prob(self, value): var = self.scale * self.scale log_scale = np.log(self.scale) - return -((np.log(value) - self.loc) * - (np.log(value) - self.loc)) / (2. * var) - log_scale - math.log( - math.sqrt(2. * math.pi)) - np.log(value) + return -( + (np.log(value) - self.loc) * + (np.log(value) - self.loc)) / (2. * var) - log_scale - math.log( + math.sqrt(2. * math.pi)) - np.log(value) def probs(self, value): var = self.scale * self.scale - return np.exp(-1. * ((np.log(value) - self.loc) * (np.log(value) - self.loc)) / - (2. * var)) / (math.sqrt(2 * math.pi) * self.scale * value) + return np.exp( + -1. * ((np.log(value) - self.loc) * (np.log(value) - self.loc)) / + (2. * var)) / (math.sqrt(2 * math.pi) * self.scale * value) def entropy(self): return 0.5 + self.loc + 0.5 * np.log( @@ -74,18 +74,15 @@ def kl_divergence(self, other): @place(config.DEVICES) -@parameterize_cls( - (TEST_CASE_NAME, 'loc', 'scale'), [ - ('float', xrand(), xrand()), - ('one-dim', xrand((3, )), xrand((3, ))), - ('multi-dim', xrand((5, 5)), xrand((5, 5))) - ]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), + [('float', xrand(), xrand()), + ('one-dim', xrand((3, )), xrand((3, ))), + ('multi-dim', xrand((5, 5)), xrand((5, 5)))]) class LogNormalTest(unittest.TestCase): def setUp(self): - self._paddle_lognormal = LogNormal( - loc=paddle.to_tensor(self.loc), - scale=paddle.to_tensor(self.scale)) + self._paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc), + scale=paddle.to_tensor(self.scale)) self._np_lognormal = LogNormalNumpy(self.loc, self.scale) def test_mean(self): @@ -142,21 +139,17 @@ def test_log_prob(self): @place(config.DEVICES) -@parameterize_cls( - (TEST_CASE_NAME, 'loc', 'scale'), [ - ('sample1', xrand((2, )), xrand((2, ))), - ('sample2', xrand((5, )), xrand((5, ))) - ]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample1', xrand( + (2, )), xrand((2, ))), ('sample2', xrand((5, )), xrand((5, )))]) class LogNormalTestSample(unittest.TestCase): + def test_sample(self): - self._paddle_lognormal = LogNormal( - loc=self.loc, - scale=self.scale) + self._paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) shape = [8000] samples = self._paddle_lognormal.sample(shape) for i in range(len(self.scale)): - self.assertTrue(self._kstest( - self.loc[i], self.scale[i], samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], samples[:, i])) def _kstest(self, loc, scale, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. @@ -168,19 +161,14 @@ def _kstest(self, loc, scale, samples): @place(config.DEVICES) @parameterize_cls( - (TEST_CASE_NAME, 'loc1', 'scale1', - 'loc2', 'scale2'), [ - ('one-dim', xrand((2, )), xrand((2, )), - xrand((2, )), xrand((2, ))), - ('multi-dim', xrand((2, 2)), xrand((2, 2)), - xrand((2, 2)), xrand((2, 2))) - ]) + (TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))]) class TestLognormalKL(unittest.TestCase): def setUp(self): - self._paddle_lognormal = LogNormal( - loc=paddle.to_tensor(self.loc1), - scale=paddle.to_tensor(self.scale1)) + self._paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc1), + scale=paddle.to_tensor(self.scale1)) self._paddle_lognormal_other = LogNormal( loc=paddle.to_tensor(self.loc2), scale=paddle.to_tensor(self.scale2)) @@ -189,11 +177,10 @@ def test_kl_divergence(self): kl1 = kl_divergence(self._paddle_lognormal, self._paddle_lognormal_other) kl2 = self._kl(self._paddle_lognormal, self._paddle_lognormal_other) - np.testing.assert_allclose( - kl1, - kl2, - rtol=config.RTOL.get(str(self.scale1.dtype)), - atol=config.ATOL.get(str(self.scale1.dtype))) + np.testing.assert_allclose(kl1, + kl2, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) def _kl(self, dist1, dist2): loc1 = np.array(dist1.loc) From 6a95d568b47214a22d87207ddbb26b2746fcb6c1 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 16:33:50 +0000 Subject: [PATCH 04/39] fix bug --- python/paddle/distribution/lognormal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index a1cd2d3a2f145..0bc7ebb1c13e2 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -13,7 +13,8 @@ # limitations under the License. import paddle -from paddle.distribution import Normal, TransformedDistribution, ExpTransform +from paddle.distribution import TransformedDistribution, ExpTransform +from paddle.distribution.normal import Normal class LogNormal(TransformedDistribution): From acadbb8ee93f2db694a0f43f1aad2f39e0d828f4 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 16:42:51 +0000 Subject: [PATCH 05/39] fix bug --- python/paddle/distribution/kl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 550384cee8600..759f846607cb0 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -25,8 +25,6 @@ from paddle.distribution.uniform import Uniform from paddle.fluid.framework import _non_static_mode, in_dygraph_mode - - __all__ = ["register_kl", "kl_divergence"] _REGISTER_TABLE = {} @@ -98,7 +96,7 @@ def decorator(f): def _dispatch(cls_p, cls_q): - """Multiple dispatch into concrete implement function""" + """Multiple dispatch into concrete implement function.""" # find all matched super class pair of p and q matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE From 7af5a399fe3977aa63d413d300f8403300d29529 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 16:48:50 +0000 Subject: [PATCH 06/39] fix bug --- python/paddle/distribution/__init__.py | 1 + python/paddle/distribution/lognormal.py | 1 + .../distribution/transformed_distribution.py | 11 ++- .../test_distribution_lognormal_static.py | 80 ++++++++----------- 4 files changed, 40 insertions(+), 53 deletions(-) diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index bd70831a3bade..871044360e8f4 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from paddle.distribution import transform + from paddle.distribution.beta import Beta from paddle.distribution.categorical import Categorical from paddle.distribution.dirichlet import Dirichlet diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 0bc7ebb1c13e2..b3a6b744e9ba7 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -13,6 +13,7 @@ # limitations under the License. import paddle + from paddle.distribution import TransformedDistribution, ExpTransform from paddle.distribution.normal import Normal diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index 4e9e7336cd03e..b15ccd44a05ba 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -20,8 +20,8 @@ class TransformedDistribution(distribution.Distribution): - r""" - Applies a sequence of Transforms to a base distribution. + r""" + Applies a sequence of Transforms to a base distribution. Args: base (Distribution): The base distribution. @@ -31,11 +31,11 @@ class TransformedDistribution(distribution.Distribution): .. code-block:: python - import paddle + import paddle from paddle.distribution import transformed_distribution d = transformed_distribution.TransformedDistribution( - paddle.distribution.Normal(0., 1.), + paddle.distribution.Normal(0., 1.), [paddle.distribution.AffineTransform(paddle.to_tensor(1.), paddle.to_tensor(2.))] ) @@ -58,8 +58,7 @@ def __init__(self, base, transforms): f"Expected type of 'transforms' is Sequence[Transform] or Chain, but got {type(transforms)}." ) if not all(isinstance(t, transform.Transform) for t in transforms): - raise TypeError( - "All element of transforms must be Transform type.") + raise TypeError("All element of transforms must be Transform type.") chain = transform.ChainTransform(transforms) if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 493c8dab45ba1..d8f881da0a4a4 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -31,13 +31,8 @@ @place(config.DEVICES) -@parameterize_cls( - (TEST_CASE_NAME, 'loc', 'scale'), [ - ('one-dim', xrand((2, )), - xrand((2, ))), - ('multi-dim', xrand((3, 3)), - xrand((3, 3))) - ]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('one-dim', xrand( + (2, )), xrand((2, ))), ('multi-dim', xrand((3, 3)), xrand((3, 3)))]) class TestLogNormal(unittest.TestCase): def setUp(self): @@ -60,10 +55,10 @@ def setUp(self): self.feeds = {'loc': self.loc, 'scale': self.scale} executor.run(startup_program) - [self.mean, self.var, self.entropy, self.samples] = executor.run( - main_program, - feed=self.feeds, - fetch_list=fetch_list) + [self.mean, self.var, self.entropy, + self.samples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) def test_mean(self): np_mean = self._np_lognormal.mean @@ -83,8 +78,8 @@ def test_var(self): def test_entropy(self): np_entropy = self._np_lognormal.entropy() - self.assertEqual(str(self.entropy.dtype).split('.') - [-1], self.scale.dtype) + self.assertEqual( + str(self.entropy.dtype).split('.')[-1], self.scale.dtype) np.testing.assert_allclose(self.entropy, np_entropy, rtol=config.RTOL.get(str(self.scale.dtype)), @@ -92,11 +87,8 @@ def test_entropy(self): @place(config.DEVICES) -@parameterize_cls( - (TEST_CASE_NAME, 'loc', 'scale'), [ - ('sample', xrand((5, )), - xrand((5, ))) - ]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (5, )), xrand((5, )))]) class LogNormalTestSample(unittest.TestCase): def setUp(self): @@ -104,8 +96,7 @@ def setUp(self): main_program = paddle.static.Program() executor = paddle.static.Executor(self.place) with paddle.static.program_guard(main_program, startup_program): - loc = paddle.static.data('loc', self.loc.shape, - self.loc.dtype) + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) self.sample_shape = [8000] @@ -116,15 +107,14 @@ def setUp(self): self.feeds = {'loc': self.loc, 'scale': self.scale} executor.run(startup_program) - [self.mean, self.samples] = executor.run( - main_program, - feed=self.feeds, - fetch_list=fetch_list) + [self.mean, self.samples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) def test_sample(self): for i in range(len(self.scale)): - self.assertTrue(self._kstest( - self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) def _kstest(self, loc, scale, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. @@ -136,13 +126,9 @@ def _kstest(self, loc, scale, samples): @place(config.DEVICES) @parameterize_cls( - (TEST_CASE_NAME, 'loc1', 'scale1', - 'loc2', 'scale2'), [ - ('one-dim', xrand((2, )), xrand((2, )), - xrand((2, )), xrand((2, ))), - ('multi-dim', xrand((2, 2)), xrand((2, 2)), - xrand((2, 2)), xrand((2, 2))) - ]) + (TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))]) class TestLognormalKL(unittest.TestCase): def setUp(self): @@ -150,35 +136,35 @@ def setUp(self): main_program = paddle.static.Program() executor = paddle.static.Executor(self.place) with paddle.static.program_guard(main_program, startup_program): - loc1 = paddle.static.data('loc1', self.loc1.shape, - self.loc1.dtype) + loc1 = paddle.static.data('loc1', self.loc1.shape, self.loc1.dtype) scale1 = paddle.static.data('scale1', self.scale1.shape, self.scale1.dtype) - loc2 = paddle.static.data('loc2', self.loc2.shape, - self.loc2.dtype) + loc2 = paddle.static.data('loc2', self.loc2.shape, self.loc2.dtype) scale2 = paddle.static.data('scale2', self.scale2.shape, self.scale2.dtype) self._paddle_lognormal = LogNormal(loc=loc1, scale=scale1) self._paddle_lognormal_other = LogNormal(loc=loc2, scale=scale2) - self.kl1 = kl_divergence( - self._paddle_lognormal, self._paddle_lognormal_other) + self.kl1 = kl_divergence(self._paddle_lognormal, + self._paddle_lognormal_other) self.kl2 = self._kl(self._paddle_lognormal, self._paddle_lognormal_other) fetch_list = [self.kl1, self.kl2] - self.feeds = {'loc1': self.loc1, 'scale1': self.scale1, - 'loc2': self.loc2, 'scale2': self.scale2} + self.feeds = { + 'loc1': self.loc1, + 'scale1': self.scale1, + 'loc2': self.loc2, + 'scale2': self.scale2 + } executor.run(startup_program) - [self.kl1, self.kl2] = executor.run( - main_program, - feed=self.feeds, - fetch_list=fetch_list) + [self.kl1, self.kl2] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) def test_kl_divergence(self): np.testing.assert_allclose(self.kl1, self.kl2, - rtol=config.RTOL.get( - str(self.scale1.dtype)), + rtol=config.RTOL.get(str(self.scale1.dtype)), atol=config.ATOL.get(str(self.scale1.dtype))) def _kl(self, dist1, dist2): From 5c92dce48ce36962f5353591ce366cfc847d808b Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 17:36:13 +0000 Subject: [PATCH 07/39] fix bug --- python/paddle/distribution/__init__.py | 1 - python/paddle/distribution/lognormal.py | 1 - python/paddle/distribution/normal.py | 2 +- python/paddle/distribution/transform.py | 4 ++-- python/paddle/distribution/transformed_distribution.py | 4 ++-- .../distribution/test_distribution_lognormal_static.py | 1 - 6 files changed, 5 insertions(+), 8 deletions(-) diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 871044360e8f4..bd70831a3bade 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -13,7 +13,6 @@ # limitations under the License. from paddle.distribution import transform - from paddle.distribution.beta import Beta from paddle.distribution.categorical import Categorical from paddle.distribution.dirichlet import Dirichlet diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index b3a6b744e9ba7..0bc7ebb1c13e2 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -13,7 +13,6 @@ # limitations under the License. import paddle - from paddle.distribution import TransformedDistribution, ExpTransform from paddle.distribution.normal import Normal diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index a4222c09b31ab..9b9d063de8e21 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -13,8 +13,8 @@ # limitations under the License. import math - import numpy as np + from paddle.distribution import distribution from paddle.fluid import core from paddle.fluid.data_feeder import (check_dtype, check_type, diff --git a/python/paddle/distribution/transform.py b/python/paddle/distribution/transform.py index cd9f45a068be4..ff2e13f94acf9 100644 --- a/python/paddle/distribution/transform.py +++ b/python/paddle/distribution/transform.py @@ -1133,8 +1133,8 @@ def _forward(self, x): offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) z = F.sigmoid(x - offset.log()) z_cumprod = (1 - z).cumprod(-1) - return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \ - F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1) + return F.pad(z, [0] * 2 * (len(x.shape) - 1) + [0, 1], value=1) * \ + F.pad(z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1) def _inverse(self, y): y_crop = y[..., :-1] diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index b15ccd44a05ba..6f03a39da12cb 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -63,7 +63,7 @@ def __init__(self, base, transforms): chain = transform.ChainTransform(transforms) if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: raise ValueError( - f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}." + f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base.batch_shape + base.event_shape)}." ) if chain._domain.event_rank > len(base.event_shape): base = independent.Independent( @@ -74,7 +74,7 @@ def __init__(self, base, transforms): transformed_shape = chain.forward_shape(base.batch_shape + base.event_shape) transformed_event_rank = chain._codomain.event_rank + \ - max(len(base.event_shape)-chain._domain.event_rank, 0) + max(len(base.event_shape) - chain._domain.event_rank, 0) super(TransformedDistribution, self).__init__( transformed_shape[:len(transformed_shape) - transformed_event_rank], transformed_shape[len(transformed_shape) - transformed_event_rank:]) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index d8f881da0a4a4..2157e21115a58 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -17,7 +17,6 @@ import numpy as np import scipy.stats import config - from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand from paddle.distribution import LogNormal From 6ea8cf5e516d63827939cad81ab4e21a77af528c Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 22 Sep 2022 18:18:58 +0000 Subject: [PATCH 08/39] fix bug --- python/paddle/distribution/lognormal.py | 3 ++- .../unittests/distribution/test_distribution_lognormal.py | 2 +- .../distribution/test_distribution_lognormal_static.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 0bc7ebb1c13e2..75cbab3790338 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -13,7 +13,8 @@ # limitations under the License. import paddle -from paddle.distribution import TransformedDistribution, ExpTransform +from paddle.distribution.transform import ExpTransform +from paddle.distribution.transformed_distribution import TransformedDistribution from paddle.distribution.normal import Normal diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 628fcbee3c755..13e6f5601289f 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -21,7 +21,7 @@ import config from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand -from paddle.distribution import LogNormal +from paddle.distribution.lognormal import LogNormal from test_distribution import DistributionNumpy from paddle.distribution.kl import kl_divergence diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 2157e21115a58..298d635612397 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -19,7 +19,7 @@ import config from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand -from paddle.distribution import LogNormal +from paddle.distribution.lognormal import LogNormal from test_distribution import DistributionNumpy from test_distribution_lognormal import LogNormalNumpy from paddle.distribution.kl import kl_divergence From 6f98baf36d1ff5c7dcce4267be60e61b366cdf09 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 01:20:56 +0000 Subject: [PATCH 09/39] fix bug --- .../test_distribution_lognormal.py | 18 ++++++++++++------ .../test_distribution_lognormal_static.py | 12 ++++++------ 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 13e6f5601289f..7a2fbac66d55c 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -81,6 +81,7 @@ def kl_divergence(self, other): class LogNormalTest(unittest.TestCase): def setUp(self): + paddle.disable_static() self._paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc), scale=paddle.to_tensor(self.scale)) self._np_lognormal = LogNormalNumpy(self.loc, self.scale) @@ -141,15 +142,19 @@ def test_log_prob(self): @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample1', xrand( (2, )), xrand((2, ))), ('sample2', xrand((5, )), xrand((5, )))]) -class LogNormalTestSample(unittest.TestCase): +class TestLogNormalSample(unittest.TestCase): - def test_sample(self): + def setUp(self): + paddle.disable_static() self._paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) - shape = [8000] - samples = self._paddle_lognormal.sample(shape) + self.shape = [8000] + self.samples = self._paddle_lognormal.sample(self.shape) + + def test_sample(self): + for i in range(len(self.scale)): self.assertTrue( - self._kstest(self.loc[i], self.scale[i], samples[:, i])) + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) def _kstest(self, loc, scale, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. @@ -164,9 +169,10 @@ def _kstest(self, loc, scale, samples): (TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'), [('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))), ('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))]) -class TestLognormalKL(unittest.TestCase): +class TestLogNormalKL(unittest.TestCase): def setUp(self): + paddle.disable_static() self._paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc1), scale=paddle.to_tensor(self.scale1)) self._paddle_lognormal_other = LogNormal( diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 298d635612397..1db6a287985bd 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -17,17 +17,14 @@ import numpy as np import scipy.stats import config -from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand from paddle.distribution.lognormal import LogNormal -from test_distribution import DistributionNumpy from test_distribution_lognormal import LogNormalNumpy from paddle.distribution.kl import kl_divergence np.random.seed(2022) -paddle.enable_static() - @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('one-dim', xrand( @@ -35,6 +32,7 @@ class TestLogNormal(unittest.TestCase): def setUp(self): + paddle.enable_static() startup_program = paddle.static.Program() main_program = paddle.static.Program() executor = paddle.static.Executor(self.place) @@ -88,9 +86,10 @@ def test_entropy(self): @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( (5, )), xrand((5, )))]) -class LogNormalTestSample(unittest.TestCase): +class TestLogNormalSample(unittest.TestCase): def setUp(self): + paddle.enable_static() startup_program = paddle.static.Program() main_program = paddle.static.Program() executor = paddle.static.Executor(self.place) @@ -128,9 +127,10 @@ def _kstest(self, loc, scale, samples): (TEST_CASE_NAME, 'loc1', 'scale1', 'loc2', 'scale2'), [('one-dim', xrand((2, )), xrand((2, )), xrand((2, )), xrand((2, ))), ('multi-dim', xrand((2, 2)), xrand((2, 2)), xrand((2, 2)), xrand((2, 2)))]) -class TestLognormalKL(unittest.TestCase): +class TestLogNormalKL(unittest.TestCase): def setUp(self): + paddle.enable_static() startup_program = paddle.static.Program() main_program = paddle.static.Program() executor = paddle.static.Executor(self.place) From 124d8f95eb188d239e59d7b100ca4d09923b872f Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 01:55:32 +0000 Subject: [PATCH 10/39] fix bug --- python/paddle/distribution/lognormal.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 75cbab3790338..9bfe1f0795cb7 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -44,7 +44,6 @@ class LogNormal(TransformedDistribution): import paddle #from paddle.distribution import LogNormal - from lognormal import LogNormal # Define a single scalar LogNormal distribution. dist = LogNormal(loc=0., scale=3.) # Define a batch of two scalar valued LogNormals. From 0d79cc2ed06b602de0ec9e69bed742d31111d462 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 02:13:10 +0000 Subject: [PATCH 11/39] fix bug --- python/paddle/distribution/lognormal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 9bfe1f0795cb7..fd7fdaa793f52 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -43,7 +43,8 @@ class LogNormal(TransformedDistribution): .. code-block:: python import paddle - #from paddle.distribution import LogNormal + from paddle.distribution import LogNormal + # Define a single scalar LogNormal distribution. dist = LogNormal(loc=0., scale=3.) # Define a batch of two scalar valued LogNormals. From b02255b43163518b79538a619530b8f3f45ad27c Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 03:50:21 +0000 Subject: [PATCH 12/39] fix bug --- python/paddle/distribution/lognormal.py | 3 +-- .../paddle/distribution/transformed_distribution.py | 12 ------------ 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index fd7fdaa793f52..24a52e24b671d 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -26,8 +26,7 @@ class LogNormal(TransformedDistribution): The probability density function (pdf) is .. math:: - pdf(x; \mu, \sigma) = \\frac{1}{\sigma x \sqrt{2\pi}}e^{(-\\frac{(ln(x) - \mu)^2}{2\sigma^2})} - pdf(x; \mu, \sigma) = \\frac{1}{Z}e^{\\frac {-0.5 (x - \mu)^2} {\sigma^2} } + pdf(x; \mu, \sigma) = \frac{1}{\sigma x \sqrt{2\pi}}e^{(-\frac{(ln(x) - \mu)^2}{2\sigma^2})} In the above equation: diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index 6f03a39da12cb..221a13a067f05 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -130,18 +130,6 @@ def log_prob(self, value): event_rank - len(self._base.event_shape)) return log_prob - def _monotonize_cdf(self, value): - """ - This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is - monotone increasing. - """ - sign = 1 - for t in self._transforms: - sign = sign * t.sign - if isinstance(sign, int) and sign == 1: - return value - return sign * (value - 0.5) + 0.5 - def _sum_rightmost(value, n): return value.sum(list(range(-n, 0))) if n > 0 else value From 83e6d769a9c2a9f0b3ed4a518e18b515103e852b Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 03:57:45 +0000 Subject: [PATCH 13/39] add comment --- python/paddle/distribution/lognormal.py | 4 ++-- python/paddle/distribution/normal.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 24a52e24b671d..d345158e330ab 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -81,7 +81,7 @@ def __init__(self, loc, scale, name=None): @property def mean(self): - """mean of lognormal distribuion. + """Mean of lognormal distribuion. Returns: Tensor: mean value. @@ -90,7 +90,7 @@ def mean(self): @property def variance(self): - """variance of lognormal distribution. + """Variance of lognormal distribution. Returns: Tensor: variance value. diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 9b9d063de8e21..2e41550e951d4 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -128,10 +128,20 @@ def __init__(self, loc, scale, name=None): @property def mean(self): + """Mean of multinomial distribuion. + + Returns: + Tensor: mean value. + """ return self.loc @property def variance(self): + """Variance of lognormal distribution. + + Returns: + Tensor: variance value. + """ return self.scale.pow(2) def sample(self, shape, seed=0): From fb23b198bc761678a7028cc5aa802eee051ae83f Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 04:14:29 +0000 Subject: [PATCH 14/39] fix bug --- python/paddle/distribution/normal.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 2e41550e951d4..f3df0ff94de5c 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -276,11 +276,7 @@ def kl_divergence(self, other): .. math:: -<<<<<<< HEAD ratio = \\frac{\sigma_0}{\sigma_1} -======= - ratio = \frac{\sigma_0}{\sigma_1} ->>>>>>> f778470061589ba4c396d7a2b56f2f94819ceb39 .. math:: From 1fa84b0ea264e4282e45e65836a8b9c8f50378cd Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 05:02:12 +0000 Subject: [PATCH 15/39] fix docs --- python/paddle/distribution/lognormal.py | 6 +++--- python/paddle/distribution/normal.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index d345158e330ab..b78dee56c77df 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -105,7 +105,7 @@ def entropy(self): .. math:: - entropy(\sigma) = 0.5 \\log (2 \pi e \sigma^2) + \mu + entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2) + \mu In the above equation: @@ -136,11 +136,11 @@ def kl_divergence(self, other): .. math:: - KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\\frac{diff}{\sigma_1})^2 - 1 - 2 \\ln {ratio}) + KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio}) .. math:: - ratio = \\frac{\sigma_0}{\sigma_1} + ratio = \frac{\sigma_0}{\sigma_1} .. math:: diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index f3df0ff94de5c..44af005cc3e80 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -276,7 +276,7 @@ def kl_divergence(self, other): .. math:: - ratio = \\frac{\sigma_0}{\sigma_1} + ratio = \frac{\sigma_0}{\sigma_1} .. math:: From 558e784b62ba3fb37a6d1c8956f5553c7a12dbd7 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 06:53:58 +0000 Subject: [PATCH 16/39] fix bug --- .../unittests/distribution/test_distribution_lognormal.py | 6 +++--- .../distribution/test_distribution_lognormal_static.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 7a2fbac66d55c..292ec6e98933a 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -140,14 +140,14 @@ def test_log_prob(self): @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample1', xrand( - (2, )), xrand((2, ))), ('sample2', xrand((5, )), xrand((5, )))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, )))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): paddle.disable_static() self._paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) - self.shape = [8000] + self.shape = [9000] self.samples = self._paddle_lognormal.sample(self.shape) def test_sample(self): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 1db6a287985bd..7a52e6db7e00e 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -85,7 +85,7 @@ def test_entropy(self): @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( - (5, )), xrand((5, )))]) + (4, )), xrand((4, )))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): @@ -97,7 +97,7 @@ def setUp(self): loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) - self.sample_shape = [8000] + self.sample_shape = [9000] self._paddle_lognormal = LogNormal(loc=loc, scale=scale) self.mean = self._paddle_lognormal.mean self.samples = self._paddle_lognormal.sample(self.sample_shape) From e32f57192279959bc8eaf2a8557a134aa9c39f6a Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 09:21:04 +0000 Subject: [PATCH 17/39] fix bug --- python/paddle/distribution/normal.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 44af005cc3e80..33abb8730fd6c 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -13,8 +13,10 @@ # limitations under the License. import math -import numpy as np +import warnings +import numpy as np +from paddle import _C_ops, _legacy_C_ops from paddle.distribution import distribution from paddle.fluid import core from paddle.fluid.data_feeder import (check_dtype, check_type, From 9224e4a333b2722dab8bf0dff6097137c98cc611 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 10:03:10 +0000 Subject: [PATCH 18/39] fix bug --- .../distribution/test_distribution_lognormal.py | 2 +- .../test_distribution_lognormal_static.py | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 292ec6e98933a..e03b40732a763 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -147,7 +147,7 @@ class TestLogNormalSample(unittest.TestCase): def setUp(self): paddle.disable_static() self._paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) - self.shape = [9000] + self.shape = [100000] self.samples = self._paddle_lognormal.sample(self.shape) def test_sample(self): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 7a52e6db7e00e..11f5c204f4e68 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -97,17 +97,16 @@ def setUp(self): loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) - self.sample_shape = [9000] + self.sample_shape = [100000] self._paddle_lognormal = LogNormal(loc=loc, scale=scale) - self.mean = self._paddle_lognormal.mean self.samples = self._paddle_lognormal.sample(self.sample_shape) - fetch_list = [self.mean, self.samples] + fetch_list = [self.samples] self.feeds = {'loc': self.loc, 'scale': self.scale} executor.run(startup_program) - [self.mean, self.samples] = executor.run(main_program, - feed=self.feeds, - fetch_list=fetch_list) + [self.samples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) def test_sample(self): for i in range(len(self.scale)): From 37eeb2d4d85fc86de2fd0a21dce598bd0458bcdf Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Fri, 23 Sep 2022 15:36:38 +0000 Subject: [PATCH 19/39] add test --- python/paddle/distribution/kl.py | 2 +- python/paddle/distribution/lognormal.py | 13 ++++- .../test_distribution_lognormal.py | 37 ++++++++++++-- .../test_distribution_lognormal_static.py | 48 ++++++++++++++----- .../distribution/test_distribution_normal.py | 9 ++-- ...t_distribution_transformed_distribution.py | 7 +++ 6 files changed, 93 insertions(+), 23 deletions(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 03d68044d02f0..50de613b89c54 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -214,7 +214,7 @@ def _kl_expfamily_expfamily(p, q): @register_kl(LogNormal, LogNormal) -def _kl_normal_normal(p, q): +def _kl_lognormal_lognormal(p, q): return p.base_dist.kl_divergence(q.base_dist) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index b78dee56c77df..e6c8589b921a1 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -21,6 +21,15 @@ class LogNormal(TransformedDistribution): r"""The Normal distribution with location `loc` and `scale` parameters. + .. math:: + + X \sim Normal(\mu, \sigma) + + Y = exp(X) \sim LogNormal(\mu, \sigma) + + + :math:`Normal(\mu, \sigma)` is the underlying distribution of :math:`LogNormal(\mu, \sigma)` + Mathematical details The probability density function (pdf) is @@ -148,8 +157,8 @@ def kl_divergence(self, other): In the above equation: - * :math:`loc = \mu_0`: is the mean of underlying Normal distribution. - * :math:`scale = \sigma_0`: is the std of underlying Normal distribution. + * :math:`loc = \mu_0`: is the mean of current underlying Normal distribution. + * :math:`scale = \sigma_0`: is the std of current underlying Normal distribution. * :math:`loc = \mu_1`: is the mean of other underlying Normal distribution. * :math:`scale = \sigma_1`: is the std of other underlying Normal distribution. * :math:`ratio`: is the ratio of scales. diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index e03b40732a763..eca545c43194b 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -140,21 +140,50 @@ def test_log_prob(self): @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( - (4, )), xrand((4, )))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), + [('sample', xrand( + (4, ), min=0, max=1), xrand((4, ), min=0.01, max=1))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): paddle.disable_static() self._paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) - self.shape = [100000] - self.samples = self._paddle_lognormal.sample(self.shape) + n = 80000 + self.sample_shape = [n] + self.rsample_shape = [n] + self.samples = self._paddle_lognormal.sample(self.sample_shape) + self.rsamples = self._paddle_lognormal.rsample(self.rsample_shape) def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, + self._paddle_lognormal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(samples_var, + self._paddle_lognormal.variance, + rtol=0.1, + atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, + self._paddle_lognormal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(rsamples_var, + self._paddle_lognormal.variance, + rtol=0.1, + atol=0) for i in range(len(self.scale)): + self.assertEqual(self.samples[:, i].shape, self.sample_shape) + self.assertEqual(self.rsamples[:, i].shape, self.rsample_shape) self.assertTrue( self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) def _kstest(self, loc, scale, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 11f5c204f4e68..122983814d723 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -42,18 +42,15 @@ def setUp(self): self.scale.dtype) self._paddle_lognormal = LogNormal(loc=loc, scale=scale) self._np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale) - self.sample_shape = [8000] - mean = self._paddle_lognormal.mean var = self._paddle_lognormal.variance entropy = self._paddle_lognormal.entropy() - samples = self._paddle_lognormal.sample(self.sample_shape) - fetch_list = [mean, var, entropy, samples] + fetch_list = [mean, var, entropy] self.feeds = {'loc': self.loc, 'scale': self.scale} executor.run(startup_program) - [self.mean, self.var, self.entropy, - self.samples] = executor.run(main_program, + [self.mean, self.var, + self.entropy] = executor.run(main_program, feed=self.feeds, fetch_list=fetch_list) @@ -84,8 +81,9 @@ def test_entropy(self): @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( - (4, )), xrand((4, )))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), + [('sample', xrand( + (4, ), min=0, max=1), xrand((4, ), min=0.01, max=1))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): @@ -97,21 +95,45 @@ def setUp(self): loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) - self.sample_shape = [100000] + n = 80000 + self.sample_shape = [n] + self.rsample_shape = [n] self._paddle_lognormal = LogNormal(loc=loc, scale=scale) + self.mean = self._paddle_lognormal.mean + self.variance = self._paddle_lognormal.variance self.samples = self._paddle_lognormal.sample(self.sample_shape) - fetch_list = [self.samples] + self.rsamples = self._paddle_lognormal.rsample(self.rsample_shape) + fetch_list = [self.mean, self.variance, self.samples, self.rsamples] self.feeds = {'loc': self.loc, 'scale': self.scale} executor.run(startup_program) - [self.samples] = executor.run(main_program, - feed=self.feeds, - fetch_list=fetch_list) + [self.mean, self.variance, self.samples, + self.rsamples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(rsamples_var, + self.variance, + rtol=0.1, + atol=0) + for i in range(len(self.scale)): + self.assertEqual(self.samples[:, i].shape, (self.sample_shape[0], )) + self.assertEqual(self.rsamples[:, i].shape, + (self.rsample_shape[0], )) self.assertTrue( self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) def _kstest(self, loc, scale, samples): # Uses the Kolmogorov-Smirnov test for goodness of fit. diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py index 5023905caa744..5cabe0fa488f2 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py @@ -115,7 +115,7 @@ def init_static_data(self, batch_size, dims): dtype='float32') def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6): - sample, entropy, log_prob, probs, kl = fetch_list + sample, rsample, entropy, log_prob, probs, kl = fetch_list np_normal = NormalNumpy(self.loc_np, self.scale_np) np_sample = np_normal.sample([sample_shape]) @@ -133,6 +133,7 @@ def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6): log_tolerance = 1e-4 np.testing.assert_equal(sample.shape, np_sample.shape) + np.testing.assert_equal(rsample.shape, np_sample.shape) np.testing.assert_allclose(entropy, np_entropy, rtol=tolerance, @@ -155,13 +156,14 @@ def test_normal_distribution_dygraph(self, sample_shape=7, tolerance=1e-6): normal = Normal(self.dynamic_loc, self.dynamic_scale) sample = normal.sample([sample_shape]).numpy() + rsample = normal.rsample([sample_shape]).numpy() entropy = normal.entropy().numpy() log_prob = normal.log_prob(self.dynamic_values).numpy() probs = normal.probs(self.dynamic_values).numpy() other_normal = Normal(self.dynamic_other_loc, self.dynamic_other_scale) kl = normal.kl_divergence(other_normal).numpy() - fetch_list = [sample, entropy, log_prob, probs, kl] + fetch_list = [sample, rsample, entropy, log_prob, probs, kl] self.compare_with_numpy(fetch_list) def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6): @@ -170,6 +172,7 @@ def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6): normal = Normal(self.static_loc, self.static_scale) sample = normal.sample([sample_shape]) + rsample = normal.rsample([sample_shape]) entropy = normal.entropy() log_prob = normal.log_prob(self.static_values) probs = normal.probs(self.static_values) @@ -177,7 +180,7 @@ def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6): self.static_other_scale) kl = normal.kl_divergence(other_normal) - fetch_list = [sample, entropy, log_prob, probs, kl] + fetch_list = [sample, rsample, entropy, log_prob, probs, kl] feed_vars = { 'loc': self.loc_np, diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py index c47250195daab..c448e407b9dee 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py @@ -61,6 +61,13 @@ def test_sample(self): self.assertEqual(tuple(data.shape), expected_shape) self.assertEqual(data.dtype, self.base.loc.dtype) + def test_rsample(self): + shape = [5, 10, 8] + expected_shape = (5, 10, 8) + data = self._t.rsample(shape) + self.assertEqual(tuple(data.shape), expected_shape) + self.assertEqual(data.dtype, self.base.loc.dtype) + if __name__ == '__main__': unittest.main() From f700bd1468037351d6af0ed3c1b945994b90067a Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 00:03:32 +0000 Subject: [PATCH 20/39] add test --- python/paddle/distribution/normal.py | 5 +- .../test_distribution_lognormal.py | 35 ++++++++--- .../test_distribution_lognormal_static.py | 61 ++++++++++++------- 3 files changed, 68 insertions(+), 33 deletions(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 33abb8730fd6c..ee4118415b161 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -180,8 +180,9 @@ def sample(self, shape, seed=0): return output else: output_shape = shape + batch_shape - output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \ - (tensor.zeros(output_shape, dtype=self.dtype) + self.scale) + output = nn.gaussian_random( + output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * ( + tensor.zeros(output_shape, dtype=self.dtype) + self.scale) output = elementwise_add(output, self.loc, name=name) if self.all_arg_is_float: return nn.reshape(output, shape, name=name) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index eca545c43194b..460ed23f0cf5a 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -13,6 +13,7 @@ # limitations under the License. import math +from pprint import pprint import unittest import scipy.stats @@ -21,6 +22,7 @@ import config from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +from paddle.distribution.normal import Normal from paddle.distribution.lognormal import LogNormal from test_distribution import DistributionNumpy from paddle.distribution.kl import kl_divergence @@ -202,18 +204,33 @@ class TestLogNormalKL(unittest.TestCase): def setUp(self): paddle.disable_static() - self._paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc1), - scale=paddle.to_tensor(self.scale1)) - self._paddle_lognormal_other = LogNormal( - loc=paddle.to_tensor(self.loc2), - scale=paddle.to_tensor(self.scale2)) + self.ln_a = LogNormal(loc=paddle.to_tensor(self.loc1), + scale=paddle.to_tensor(self.scale1)) + self.ln_b = LogNormal(loc=paddle.to_tensor(self.loc2), + scale=paddle.to_tensor(self.scale2)) + self.normal_a = Normal(loc=paddle.to_tensor(self.loc1), + scale=paddle.to_tensor(self.scale1)) + self.normal_b = Normal(loc=paddle.to_tensor(self.loc2), + scale=paddle.to_tensor(self.scale2)) def test_kl_divergence(self): - kl1 = kl_divergence(self._paddle_lognormal, - self._paddle_lognormal_other) - kl2 = self._kl(self._paddle_lognormal, self._paddle_lognormal_other) + kl0 = self.ln_a.kl_divergence(self.ln_b) + kl1 = kl_divergence(self.ln_a, self.ln_b) + kl_normal = kl_divergence(self.normal_a, self.normal_b) + kl_formula = self._kl(self.ln_a, self.ln_b) + + self.assertEqual(tuple(kl0.shape), self.scale1.shape) + self.assertEqual(tuple(kl1.shape), self.scale1.shape) + np.testing.assert_allclose(kl0, + kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) np.testing.assert_allclose(kl1, - kl2, + kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + np.testing.assert_allclose(kl_normal, + kl_formula, rtol=config.RTOL.get(str(self.scale1.dtype)), atol=config.ATOL.get(str(self.scale1.dtype))) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 122983814d723..183f653022fde 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -19,6 +19,7 @@ import config from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +from paddle.distribution.normal import Normal from paddle.distribution.lognormal import LogNormal from test_distribution_lognormal import LogNormalNumpy from paddle.distribution.kl import kl_divergence @@ -40,11 +41,11 @@ def setUp(self): loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) - self._paddle_lognormal = LogNormal(loc=loc, scale=scale) + self.ln_a = LogNormal(loc=loc, scale=scale) self._np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale) - mean = self._paddle_lognormal.mean - var = self._paddle_lognormal.variance - entropy = self._paddle_lognormal.entropy() + mean = self.ln_a.mean + var = self.ln_a.variance + entropy = self.ln_a.entropy() fetch_list = [mean, var, entropy] self.feeds = {'loc': self.loc, 'scale': self.scale} @@ -98,11 +99,11 @@ def setUp(self): n = 80000 self.sample_shape = [n] self.rsample_shape = [n] - self._paddle_lognormal = LogNormal(loc=loc, scale=scale) - self.mean = self._paddle_lognormal.mean - self.variance = self._paddle_lognormal.variance - self.samples = self._paddle_lognormal.sample(self.sample_shape) - self.rsamples = self._paddle_lognormal.rsample(self.rsample_shape) + self.ln_a = LogNormal(loc=loc, scale=scale) + self.mean = self.ln_a.mean + self.variance = self.ln_a.variance + self.samples = self.ln_a.sample(self.sample_shape) + self.rsamples = self.ln_a.rsample(self.rsample_shape) fetch_list = [self.mean, self.variance, self.samples, self.rsamples] self.feeds = {'loc': self.loc, 'scale': self.scale} @@ -127,9 +128,9 @@ def test_sample(self): atol=0) for i in range(len(self.scale)): - self.assertEqual(self.samples[:, i].shape, (self.sample_shape[0], )) + self.assertEqual(self.samples[:, i].shape, tuple(self.sample_shape)) self.assertEqual(self.rsamples[:, i].shape, - (self.rsample_shape[0], )) + tuple(self.rsample_shape)) self.assertTrue( self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) self.assertTrue( @@ -162,13 +163,18 @@ def setUp(self): loc2 = paddle.static.data('loc2', self.loc2.shape, self.loc2.dtype) scale2 = paddle.static.data('scale2', self.scale2.shape, self.scale2.dtype) - self._paddle_lognormal = LogNormal(loc=loc1, scale=scale1) - self._paddle_lognormal_other = LogNormal(loc=loc2, scale=scale2) - self.kl1 = kl_divergence(self._paddle_lognormal, - self._paddle_lognormal_other) - self.kl2 = self._kl(self._paddle_lognormal, - self._paddle_lognormal_other) - fetch_list = [self.kl1, self.kl2] + + self.ln_a = LogNormal(loc=loc1, scale=scale1) + self.ln_b = LogNormal(loc=loc2, scale=scale2) + self.normal_a = Normal(loc=loc1, scale=scale1) + self.normal_b = Normal(loc=loc2, scale=scale2) + + self.kl0 = self.ln_a.kl_divergence(self.ln_b) + self.kl1 = kl_divergence(self.ln_a, self.ln_b) + self.kl_normal = kl_divergence(self.normal_a, self.normal_b) + self.kl_formula = self._kl(self.ln_a, self.ln_b) + + fetch_list = [self.kl0, self.kl1, self.kl_normal, self.kl_formula] self.feeds = { 'loc1': self.loc1, 'scale1': self.scale1, @@ -177,13 +183,24 @@ def setUp(self): } executor.run(startup_program) - [self.kl1, self.kl2] = executor.run(main_program, - feed=self.feeds, - fetch_list=fetch_list) + [self.kl0, self.kl1, self.kl_normal, + self.kl_formula] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) def test_kl_divergence(self): + np.testing.assert_allclose(self.kl0, + self.kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + np.testing.assert_allclose(self.kl1, - self.kl2, + self.kl_formula, + rtol=config.RTOL.get(str(self.scale1.dtype)), + atol=config.ATOL.get(str(self.scale1.dtype))) + + np.testing.assert_allclose(self.kl_normal, + self.kl_formula, rtol=config.RTOL.get(str(self.scale1.dtype)), atol=config.ATOL.get(str(self.scale1.dtype))) From acb6d4ee3a900a381c0ae3f9ea16fe6800c67795 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 06:39:03 +0000 Subject: [PATCH 21/39] change the args type of Normal sample --- python/paddle/distribution/kl.py | 2 +- python/paddle/distribution/lognormal.py | 18 +++--- python/paddle/distribution/normal.py | 15 +++-- .../test_distribution_lognormal.py | 55 ++++++++++--------- .../test_distribution_lognormal_static.py | 39 +++++++------ 5 files changed, 64 insertions(+), 65 deletions(-) diff --git a/python/paddle/distribution/kl.py b/python/paddle/distribution/kl.py index 50de613b89c54..80a093ad8b491 100644 --- a/python/paddle/distribution/kl.py +++ b/python/paddle/distribution/kl.py @@ -215,7 +215,7 @@ def _kl_expfamily_expfamily(p, q): @register_kl(LogNormal, LogNormal) def _kl_lognormal_lognormal(p, q): - return p.base_dist.kl_divergence(q.base_dist) + return p._base.kl_divergence(q._base) def _sum_rightmost(value, n): diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index e6c8589b921a1..3f372cf37a4c5 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -83,10 +83,10 @@ class LogNormal(TransformedDistribution): """ def __init__(self, loc, scale, name=None): - self.base_dist = Normal(loc=loc, scale=scale, name=name) - self.loc = self.base_dist.loc - self.scale = self.base_dist.scale - super(LogNormal, self).__init__(self.base_dist, [ExpTransform()]) + self._base = Normal(loc=loc, scale=scale, name=name) + self.loc = self._base.loc + self.scale = self._base.scale + super(LogNormal, self).__init__(self._base, [ExpTransform()]) @property def mean(self): @@ -95,7 +95,7 @@ def mean(self): Returns: Tensor: mean value. """ - return paddle.exp(self.base_dist.mean + self.base_dist.variance / 2) + return paddle.exp(self._base.mean + self._base.variance / 2) @property def variance(self): @@ -104,8 +104,8 @@ def variance(self): Returns: Tensor: variance value. """ - return (paddle.expm1(self.base_dist.variance) * - paddle.exp(2 * self.base_dist.mean + self.base_dist.variance)) + return (paddle.expm1(self._base.variance) * + paddle.exp(2 * self._base.mean + self._base.variance)) def entropy(self): r"""Shannon entropy in nats. @@ -124,7 +124,7 @@ def entropy(self): Tensor: Shannon entropy of lognormal distribution.The data type is float32. """ - return self.base_dist.entropy() + self.base_dist.mean + return self._base.entropy() + self._base.mean def probs(self, value): """Probability density/mass function. @@ -171,4 +171,4 @@ def kl_divergence(self, other): Tensor: kl-divergence between two lognormal distributions.The data type is float32. """ - return self.base_dist.kl_divergence(other.base_dist) + return self._base.kl_divergence(other._base) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index ee4118415b161..60e791e062d41 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -146,11 +146,11 @@ def variance(self): """ return self.scale.pow(2) - def sample(self, shape, seed=0): + def sample(self, shape=(), seed=0): """Generate samples of the specified shape. Args: - shape (list): 1D `int32`. Shape of the generated samples. + shape (Sequence[int], optional): Sample shape. seed (int): Python integer number. Returns: @@ -158,12 +158,11 @@ def sample(self, shape, seed=0): """ if not _non_static_mode(): - check_type(shape, 'shape', (list), 'sample') + check_type(shape, 'shape', (tuple), 'sample') check_type(seed, 'seed', (int), 'sample') - batch_shape = list((self.loc + self.scale).shape) + batch_shape = tuple((self.loc + self.scale).shape) name = self.name + '_sample' - if self.batch_size_unknown: output_shape = shape + batch_shape zero_tmp = tensor.fill_constant_batch_size_like( @@ -189,11 +188,11 @@ def sample(self, shape, seed=0): else: return output - def rsample(self, shape, seed=0): + def rsample(self, shape=(), seed=0): """Generate reparameterized samples of the specified shape. Args: - shape (list): 1D `int32`. Shape of the generated samples. + shape (Sequence[int], optional): Sample shape. seed (int): Python integer number. Returns: @@ -220,7 +219,7 @@ def entropy(self): """ name = self.name + '_entropy' - batch_shape = list((self.loc + self.scale).shape) + batch_shape = tuple((self.loc + self.scale).shape) zero_tmp = tensor.fill_constant_batch_size_like(self.loc + self.scale, batch_shape, self.dtype, 0.) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 460ed23f0cf5a..25feaac13d0b7 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -13,7 +13,6 @@ # limitations under the License. import math -from pprint import pprint import unittest import scipy.stats @@ -27,8 +26,6 @@ from test_distribution import DistributionNumpy from paddle.distribution.kl import kl_divergence -np.random.seed(2022) - class LogNormalNumpy(DistributionNumpy): @@ -84,13 +81,13 @@ class LogNormalTest(unittest.TestCase): def setUp(self): paddle.disable_static() - self._paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc), - scale=paddle.to_tensor(self.scale)) - self._np_lognormal = LogNormalNumpy(self.loc, self.scale) + self.paddle_lognormal = LogNormal(loc=paddle.to_tensor(self.loc), + scale=paddle.to_tensor(self.scale)) + self.np_lognormal = LogNormalNumpy(self.loc, self.scale) def test_mean(self): - mean = self._paddle_lognormal.mean - np_mean = self._np_lognormal.mean + mean = self.paddle_lognormal.mean + np_mean = self.np_lognormal.mean self.assertEqual(mean.numpy().dtype, np_mean.dtype) np.testing.assert_allclose(mean, np_mean, @@ -98,8 +95,8 @@ def test_mean(self): atol=config.ATOL.get(str(self.scale.dtype))) def test_variance(self): - var = self._paddle_lognormal.variance - np_var = self._np_lognormal.variance + var = self.paddle_lognormal.variance + np_var = self.np_lognormal.variance self.assertEqual(var.numpy().dtype, np_var.dtype) np.testing.assert_allclose(var, np_var, @@ -107,8 +104,8 @@ def test_variance(self): atol=config.ATOL.get(str(self.scale.dtype))) def test_entropy(self): - entropy = self._paddle_lognormal.entropy() - np_entropy = self._np_lognormal.entropy() + entropy = self.paddle_lognormal.entropy() + np_entropy = self.np_lognormal.entropy() self.assertEqual(entropy.numpy().dtype, np_entropy.dtype) np.testing.assert_allclose(entropy, np_entropy, @@ -120,8 +117,8 @@ def test_probs(self): for v in value: with paddle.fluid.dygraph.guard(self.place): - probs = self._paddle_lognormal.probs(paddle.to_tensor(v)) - np_probs = self._np_lognormal.probs(v) + probs = self.paddle_lognormal.probs(paddle.to_tensor(v)) + np_probs = self.np_lognormal.probs(v) np.testing.assert_allclose( probs, np_probs, @@ -132,8 +129,8 @@ def test_log_prob(self): value = [np.random.rand(*self.scale.shape)] for v in value: with paddle.fluid.dygraph.guard(self.place): - log_prob = self._paddle_lognormal.log_prob(paddle.to_tensor(v)) - np_log_prob = self._np_lognormal.log_prob(v) + log_prob = self.paddle_lognormal.log_prob(paddle.to_tensor(v)) + np_log_prob = self.np_lognormal.log_prob(v) np.testing.assert_allclose( log_prob, np_log_prob, @@ -149,39 +146,43 @@ class TestLogNormalSample(unittest.TestCase): def setUp(self): paddle.disable_static() - self._paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) + self.paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) n = 80000 - self.sample_shape = [n] - self.rsample_shape = [n] - self.samples = self._paddle_lognormal.sample(self.sample_shape) - self.rsamples = self._paddle_lognormal.rsample(self.rsample_shape) + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.samples = self.paddle_lognormal.sample(self.sample_shape) + self.rsamples = self.paddle_lognormal.rsample(self.rsample_shape) def test_sample(self): samples_mean = self.samples.mean(axis=0) samples_var = self.samples.var(axis=0) np.testing.assert_allclose(samples_mean, - self._paddle_lognormal.mean, + self.paddle_lognormal.mean, rtol=0.1, atol=0) np.testing.assert_allclose(samples_var, - self._paddle_lognormal.variance, + self.paddle_lognormal.variance, rtol=0.1, atol=0) rsamples_mean = self.rsamples.mean(axis=0) rsamples_var = self.rsamples.var(axis=0) np.testing.assert_allclose(rsamples_mean, - self._paddle_lognormal.mean, + self.paddle_lognormal.mean, rtol=0.1, atol=0) np.testing.assert_allclose(rsamples_var, - self._paddle_lognormal.variance, + self.paddle_lognormal.variance, rtol=0.1, atol=0) + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, + list(self.sample_shape + batch_shape)) + self.assertEqual(self.rsamples.shape, + list(self.rsample_shape + batch_shape)) + for i in range(len(self.scale)): - self.assertEqual(self.samples[:, i].shape, self.sample_shape) - self.assertEqual(self.rsamples[:, i].shape, self.rsample_shape) self.assertTrue( self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) self.assertTrue( diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 183f653022fde..820f1c21c0e7d 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -24,8 +24,6 @@ from test_distribution_lognormal import LogNormalNumpy from paddle.distribution.kl import kl_divergence -np.random.seed(2022) - @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('one-dim', xrand( @@ -41,11 +39,11 @@ def setUp(self): loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) - self.ln_a = LogNormal(loc=loc, scale=scale) - self._np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale) - mean = self.ln_a.mean - var = self.ln_a.variance - entropy = self.ln_a.entropy() + self.paddle_lognormal = LogNormal(loc=loc, scale=scale) + self.np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale) + mean = self.paddle_lognormal.mean + var = self.paddle_lognormal.variance + entropy = self.paddle_lognormal.entropy() fetch_list = [mean, var, entropy] self.feeds = {'loc': self.loc, 'scale': self.scale} @@ -56,7 +54,7 @@ def setUp(self): fetch_list=fetch_list) def test_mean(self): - np_mean = self._np_lognormal.mean + np_mean = self.np_lognormal.mean self.assertEqual(str(self.mean.dtype).split('.')[-1], self.scale.dtype) np.testing.assert_allclose(self.mean, np_mean, @@ -64,7 +62,7 @@ def test_mean(self): atol=config.ATOL.get(str(self.scale.dtype))) def test_var(self): - np_var = self._np_lognormal.variance + np_var = self.np_lognormal.variance self.assertEqual(str(self.var.dtype).split('.')[-1], self.scale.dtype) np.testing.assert_allclose(self.var, np_var, @@ -72,7 +70,7 @@ def test_var(self): atol=config.ATOL.get(str(self.scale.dtype))) def test_entropy(self): - np_entropy = self._np_lognormal.entropy() + np_entropy = self.np_lognormal.entropy() self.assertEqual( str(self.entropy.dtype).split('.')[-1], self.scale.dtype) np.testing.assert_allclose(self.entropy, @@ -97,13 +95,13 @@ def setUp(self): scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) n = 80000 - self.sample_shape = [n] - self.rsample_shape = [n] - self.ln_a = LogNormal(loc=loc, scale=scale) - self.mean = self.ln_a.mean - self.variance = self.ln_a.variance - self.samples = self.ln_a.sample(self.sample_shape) - self.rsamples = self.ln_a.rsample(self.rsample_shape) + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.paddle_lognormal = LogNormal(loc=loc, scale=scale) + self.mean = self.paddle_lognormal.mean + self.variance = self.paddle_lognormal.variance + self.samples = self.paddle_lognormal.sample(self.sample_shape) + self.rsamples = self.paddle_lognormal.rsample(self.rsample_shape) fetch_list = [self.mean, self.variance, self.samples, self.rsamples] self.feeds = {'loc': self.loc, 'scale': self.scale} @@ -127,10 +125,11 @@ def test_sample(self): rtol=0.1, atol=0) + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, self.sample_shape + batch_shape) + self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape) + for i in range(len(self.scale)): - self.assertEqual(self.samples[:, i].shape, tuple(self.sample_shape)) - self.assertEqual(self.rsamples[:, i].shape, - tuple(self.rsample_shape)) self.assertTrue( self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) self.assertTrue( From 920e4fcf02dc741c1e7a0f12476df177238a6698 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 07:25:12 +0000 Subject: [PATCH 22/39] fix bug --- python/paddle/distribution/lognormal.py | 6 +++--- python/paddle/distribution/normal.py | 4 ++-- python/paddle/distribution/transformed_distribution.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 3f372cf37a4c5..36c10015f57b2 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -59,7 +59,7 @@ class LogNormal(TransformedDistribution): # The underlying Normal of first has mean 1 and standard deviation 11, the underlying Normal of second 2 and 22. dist = LogNormal(loc=[1., 2.], scale=[11., 22.]) # Get 3 samples, returning a 3 x 2 tensor. - dist.sample([3]) + dist.sample((3, )) # Define a batch of two scalar valued LogNormals. # Their underlying Normal have mean 1, but different standard deviations. @@ -70,8 +70,8 @@ class LogNormal(TransformedDistribution): lognormal_a = LogNormal([0.], [1.]) lognormal_b = LogNormal([0.5], [2.]) - sample = lognormal_a.sample([2]) - # a random tensor created by normal distribution with shape: [2, 1] + sample = lognormal_a.sample((2, )) + # a random tensor created by lognormal distribution with shape: [2, 1] entropy = lognormal_a.entropy() # [1.4189385] with shape: [1] lp = lognormal_a.log_prob(value_tensor) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 60e791e062d41..6f1ed51d45620 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -65,7 +65,7 @@ class Normal(distribution.Distribution): # The first has mean 1 and standard deviation 11, the second 2 and 22. dist = Normal(loc=[1., 2.], scale=[11., 22.]) # Get 3 samples, returning a 3 x 2 tensor. - dist.sample([3]) + dist.sample((3, )) # Define a batch of two scalar valued Normals. # Both have mean 1, but different standard deviations. @@ -76,7 +76,7 @@ class Normal(distribution.Distribution): normal_a = Normal([0.], [1.]) normal_b = Normal([0.5], [2.]) - sample = normal_a.sample([2]) + sample = normal_a.sample((2, )) # a random tensor created by normal distribution with shape: [2, 1] entropy = normal_a.entropy() # [1.4189385] with shape: [1] diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index 221a13a067f05..38becc4d37320 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -83,7 +83,7 @@ def sample(self, shape=()): """Sample from ``TransformedDistribution``. Args: - shape (tuple, optional): The sample shape. Defaults to (). + shape ((Sequence[int], optional): The sample shape. Defaults to (). Returns: [Tensor]: The sample result. @@ -97,7 +97,7 @@ def rsample(self, shape=()): """Reparameterized sample from ``TransformedDistribution``. Args: - shape (tuple, optional): The sample shape. Defaults to (). + shape ((Sequence[int], optional): The sample shape. Defaults to (). Returns: [Tensor]: The sample result. From a463709a15c05c33d222d64a797a9a7d9d43aced Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 08:28:56 +0000 Subject: [PATCH 23/39] fix bug --- python/paddle/distribution/normal.py | 14 +++++++++++--- .../distribution/test_distribution_lognormal.py | 8 ++++---- .../test_distribution_lognormal_static.py | 10 ++++++---- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 6f1ed51d45620..f2960a4f1842d 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -25,6 +25,10 @@ from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, elementwise_mul, elementwise_sub, nn, ops, tensor) +try: + from collections.abc import Iterable +except: + from collections import Iterable class Normal(distribution.Distribution): @@ -157,12 +161,16 @@ def sample(self, shape=(), seed=0): Tensor, A tensor with prepended dimensions shape.The data type is float32. """ + if not isinstance(shape, Iterable): + raise TypeError('sample shape must be Iterable object.') + if not _non_static_mode(): - check_type(shape, 'shape', (tuple), 'sample') check_type(seed, 'seed', (int), 'sample') - batch_shape = tuple((self.loc + self.scale).shape) + shape = list(shape) + batch_shape = list((self.loc + self.scale).shape) name = self.name + '_sample' + if self.batch_size_unknown: output_shape = shape + batch_shape zero_tmp = tensor.fill_constant_batch_size_like( @@ -219,7 +227,7 @@ def entropy(self): """ name = self.name + '_entropy' - batch_shape = tuple((self.loc + self.scale).shape) + batch_shape = list((self.loc + self.scale).shape) zero_tmp = tensor.fill_constant_batch_size_like(self.loc + self.scale, batch_shape, self.dtype, 0.) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 25feaac13d0b7..39357b64c0f49 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -75,8 +75,8 @@ def kl_divergence(self, other): @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('float', xrand(), xrand()), - ('one-dim', xrand((3, )), xrand((3, ))), - ('multi-dim', xrand((5, 5)), xrand((5, 5)))]) + ('one-dim', xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((3, 3)), xrand((3, 3)))]) class LogNormalTest(unittest.TestCase): def setUp(self): @@ -141,13 +141,13 @@ def test_log_prob(self): @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( - (4, ), min=0, max=1), xrand((4, ), min=0.01, max=1))]) + (4, ), min=0, max=1), xrand((4, ), min=0, max=1))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): paddle.disable_static() self.paddle_lognormal = LogNormal(loc=self.loc, scale=self.scale) - n = 80000 + n = 100000 self.sample_shape = (n, ) self.rsample_shape = (n, ) self.samples = self.paddle_lognormal.sample(self.sample_shape) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 820f1c21c0e7d..e2251f54083df 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -26,8 +26,10 @@ @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('one-dim', xrand( - (2, )), xrand((2, ))), ('multi-dim', xrand((3, 3)), xrand((3, 3)))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), + [('float', xrand(), xrand()), + ('one-dim', xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((3, 3)), xrand((3, 3)))]) class TestLogNormal(unittest.TestCase): def setUp(self): @@ -82,7 +84,7 @@ def test_entropy(self): @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( - (4, ), min=0, max=1), xrand((4, ), min=0.01, max=1))]) + (4, ), min=0, max=1), xrand((4, ), min=0, max=1))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): @@ -94,7 +96,7 @@ def setUp(self): loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) - n = 80000 + n = 100000 self.sample_shape = (n, ) self.rsample_shape = (n, ) self.paddle_lognormal = LogNormal(loc=loc, scale=scale) From d88eb598e19ffdbee2fbc08492fff1786a1cfabb Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 08:31:54 +0000 Subject: [PATCH 24/39] fix bug --- python/paddle/distribution/normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index f2960a4f1842d..6c002fb9ede9a 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -207,7 +207,7 @@ def rsample(self, shape=(), seed=0): Tensor: A tensor with prepended dimensions shape.The data type is float32. """ - return self.sample(shape) + return self.sample(shape, seed) def entropy(self): r"""Shannon entropy in nats. From 062c64fde16c03b1f3eb22083e06bc0ae565f226 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 09:00:22 +0000 Subject: [PATCH 25/39] fix bug --- python/paddle/distribution/normal.py | 60 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 6c002fb9ede9a..e717d3739c641 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -60,36 +60,36 @@ class Normal(distribution.Distribution): Examples: .. code-block:: python - import paddle - from paddle.distribution import Normal - - # Define a single scalar Normal distribution. - dist = Normal(loc=0., scale=3.) - # Define a batch of two scalar valued Normals. - # The first has mean 1 and standard deviation 11, the second 2 and 22. - dist = Normal(loc=[1., 2.], scale=[11., 22.]) - # Get 3 samples, returning a 3 x 2 tensor. - dist.sample((3, )) - - # Define a batch of two scalar valued Normals. - # Both have mean 1, but different standard deviations. - dist = Normal(loc=1., scale=[11., 22.]) - - # Complete example - value_tensor = paddle.to_tensor([0.8], dtype="float32") - - normal_a = Normal([0.], [1.]) - normal_b = Normal([0.5], [2.]) - sample = normal_a.sample((2, )) - # a random tensor created by normal distribution with shape: [2, 1] - entropy = normal_a.entropy() - # [1.4189385] with shape: [1] - lp = normal_a.log_prob(value_tensor) - # [-1.2389386] with shape: [1] - p = normal_a.probs(value_tensor) - # [0.28969154] with shape: [1] - kl = normal_a.kl_divergence(normal_b) - # [0.34939718] with shape: [1] + import paddle + from paddle.distribution import Normal + + # Define a single scalar Normal distribution. + dist = Normal(loc=0., scale=3.) + # Define a batch of two scalar valued Normals. + # The first has mean 1 and standard deviation 11, the second 2 and 22. + dist = Normal(loc=[1., 2.], scale=[11., 22.]) + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + + # Define a batch of two scalar valued Normals. + # Both have mean 1, but different standard deviations. + dist = Normal(loc=1., scale=[11., 22.]) + + # Complete example + value_tensor = paddle.to_tensor([0.8], dtype="float32") + + normal_a = Normal([0.], [1.]) + normal_b = Normal([0.5], [2.]) + sample = normal_a.sample([2]) + # a random tensor created by normal distribution with shape: [2, 1] + entropy = normal_a.entropy() + # [1.4189385] with shape: [1] + lp = normal_a.log_prob(value_tensor) + # [-1.2389386] with shape: [1] + p = normal_a.probs(value_tensor) + # [0.28969154] with shape: [1] + kl = normal_a.kl_divergence(normal_b) + # [0.34939718] with shape: [1] """ def __init__(self, loc, scale, name=None): From a649645121ad2e31a73df8b2d930a5047d050820 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 11:38:48 +0000 Subject: [PATCH 26/39] add test --- .../test_distribution_lognormal.py | 45 +++++++-------- .../test_distribution_lognormal_static.py | 57 ++++++++++++------- 2 files changed, 57 insertions(+), 45 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 39357b64c0f49..17f9c42c17480 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -73,10 +73,9 @@ def kl_divergence(self, other): @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), - [('float', xrand(), xrand()), - ('one-dim', xrand((2, )), xrand((2, ))), - ('multi-dim', xrand((3, 3)), xrand((3, 3)))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale', 'value'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((3, 3)), xrand((3, 3)), xrand((3, 3)))]) class LogNormalTest(unittest.TestCase): def setUp(self): @@ -113,29 +112,25 @@ def test_entropy(self): atol=config.ATOL.get(str(self.scale.dtype))) def test_probs(self): - value = [np.random.rand(*self.scale.shape)] - - for v in value: - with paddle.fluid.dygraph.guard(self.place): - probs = self.paddle_lognormal.probs(paddle.to_tensor(v)) - np_probs = self.np_lognormal.probs(v) - np.testing.assert_allclose( - probs, - np_probs, - rtol=config.RTOL.get(str(self.scale.dtype)), - atol=config.ATOL.get(str(self.scale.dtype))) + with paddle.fluid.dygraph.guard(self.place): + probs = self.paddle_lognormal.probs(paddle.to_tensor(self.value)) + np_probs = self.np_lognormal.probs(self.value) + np.testing.assert_allclose( + probs, + np_probs, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) def test_log_prob(self): - value = [np.random.rand(*self.scale.shape)] - for v in value: - with paddle.fluid.dygraph.guard(self.place): - log_prob = self.paddle_lognormal.log_prob(paddle.to_tensor(v)) - np_log_prob = self.np_lognormal.log_prob(v) - np.testing.assert_allclose( - log_prob, - np_log_prob, - rtol=config.RTOL.get(str(self.scale.dtype)), - atol=config.ATOL.get(str(self.scale.dtype))) + with paddle.fluid.dygraph.guard(self.place): + log_prob = self.paddle_lognormal.log_prob( + paddle.to_tensor(self.value)) + np_log_prob = self.np_lognormal.log_prob(self.value) + np.testing.assert_allclose( + log_prob, + np_log_prob, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) @place(config.DEVICES) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index e2251f54083df..897cdcb185155 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -26,10 +26,9 @@ @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), - [('float', xrand(), xrand()), - ('one-dim', xrand((2, )), xrand((2, ))), - ('multi-dim', xrand((3, 3)), xrand((3, 3)))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale', 'value'), + [('one-dim', xrand((2, )), xrand((2, )), xrand((2, ))), + ('multi-dim', xrand((3, 3)), xrand((3, 3)), xrand((3, 3)))]) class TestLogNormal(unittest.TestCase): def setUp(self): @@ -41,19 +40,23 @@ def setUp(self): loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) scale = paddle.static.data('scale', self.scale.shape, self.scale.dtype) + value = paddle.static.data('value', self.value.shape, + self.value.dtype) self.paddle_lognormal = LogNormal(loc=loc, scale=scale) self.np_lognormal = LogNormalNumpy(loc=self.loc, scale=self.scale) mean = self.paddle_lognormal.mean var = self.paddle_lognormal.variance entropy = self.paddle_lognormal.entropy() - fetch_list = [mean, var, entropy] - self.feeds = {'loc': self.loc, 'scale': self.scale} + probs = self.paddle_lognormal.probs(value) + log_prob = self.paddle_lognormal.log_prob(value) + fetch_list = [mean, var, entropy, probs, log_prob] + self.feeds = {'loc': self.loc, 'scale': self.scale, 'value': self.value} executor.run(startup_program) - [self.mean, self.var, - self.entropy] = executor.run(main_program, - feed=self.feeds, - fetch_list=fetch_list) + [self.mean, self.var, self.entropy, self.probs, + self.log_prob] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) def test_mean(self): np_mean = self.np_lognormal.mean @@ -80,6 +83,20 @@ def test_entropy(self): rtol=config.RTOL.get(str(self.scale.dtype)), atol=config.ATOL.get(str(self.scale.dtype))) + def test_probs(self): + np_probs = self.np_lognormal.probs(self.value) + np.testing.assert_allclose(self.probs, + np_probs, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + + def test_log_prob(self): + np_log_prob = self.np_lognormal.log_prob(self.value) + np.testing.assert_allclose(self.log_prob, + np_log_prob, + rtol=config.RTOL.get(str(self.scale.dtype)), + atol=config.ATOL.get(str(self.scale.dtype))) + @place(config.DEVICES) @parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), @@ -100,11 +117,11 @@ def setUp(self): self.sample_shape = (n, ) self.rsample_shape = (n, ) self.paddle_lognormal = LogNormal(loc=loc, scale=scale) - self.mean = self.paddle_lognormal.mean - self.variance = self.paddle_lognormal.variance - self.samples = self.paddle_lognormal.sample(self.sample_shape) - self.rsamples = self.paddle_lognormal.rsample(self.rsample_shape) - fetch_list = [self.mean, self.variance, self.samples, self.rsamples] + mean = self.paddle_lognormal.mean + variance = self.paddle_lognormal.variance + samples = self.paddle_lognormal.sample(self.sample_shape) + rsamples = self.paddle_lognormal.rsample(self.rsample_shape) + fetch_list = [mean, variance, samples, rsamples] self.feeds = {'loc': self.loc, 'scale': self.scale} executor.run(startup_program) @@ -170,12 +187,12 @@ def setUp(self): self.normal_a = Normal(loc=loc1, scale=scale1) self.normal_b = Normal(loc=loc2, scale=scale2) - self.kl0 = self.ln_a.kl_divergence(self.ln_b) - self.kl1 = kl_divergence(self.ln_a, self.ln_b) - self.kl_normal = kl_divergence(self.normal_a, self.normal_b) - self.kl_formula = self._kl(self.ln_a, self.ln_b) + kl0 = self.ln_a.kl_divergence(self.ln_b) + kl1 = kl_divergence(self.ln_a, self.ln_b) + kl_normal = kl_divergence(self.normal_a, self.normal_b) + kl_formula = self._kl(self.ln_a, self.ln_b) - fetch_list = [self.kl0, self.kl1, self.kl_normal, self.kl_formula] + fetch_list = [kl0, kl1, kl_normal, kl_formula] self.feeds = { 'loc1': self.loc1, 'scale1': self.scale1, From 552515828c750b0a027f9ecf5ac231c131bfc2ce Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 12:18:14 +0000 Subject: [PATCH 27/39] add test --- .../unittests/distribution/test_distribution_lognormal.py | 5 ++--- .../distribution/test_distribution_lognormal_static.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 17f9c42c17480..b560c8766a2aa 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -134,9 +134,8 @@ def test_log_prob(self): @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), - [('sample', xrand( - (4, ), min=0, max=1), xrand((4, ), min=0, max=1))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, ), min=0, max=1))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 897cdcb185155..0e7cf876ece65 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -99,9 +99,8 @@ def test_log_prob(self): @place(config.DEVICES) -@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), - [('sample', xrand( - (4, ), min=0, max=1), xrand((4, ), min=0, max=1))]) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, ), min=0, max=1))]) class TestLogNormalSample(unittest.TestCase): def setUp(self): From 19776eeb55e05fd2a2c30329bf07c6737d87c233 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 14:42:16 +0000 Subject: [PATCH 28/39] format --- python/paddle/distribution/normal.py | 60 ++++++++++++++-------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index e717d3739c641..63c56c16ed1f6 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -60,36 +60,36 @@ class Normal(distribution.Distribution): Examples: .. code-block:: python - import paddle - from paddle.distribution import Normal - - # Define a single scalar Normal distribution. - dist = Normal(loc=0., scale=3.) - # Define a batch of two scalar valued Normals. - # The first has mean 1 and standard deviation 11, the second 2 and 22. - dist = Normal(loc=[1., 2.], scale=[11., 22.]) - # Get 3 samples, returning a 3 x 2 tensor. - dist.sample([3]) - - # Define a batch of two scalar valued Normals. - # Both have mean 1, but different standard deviations. - dist = Normal(loc=1., scale=[11., 22.]) - - # Complete example - value_tensor = paddle.to_tensor([0.8], dtype="float32") - - normal_a = Normal([0.], [1.]) - normal_b = Normal([0.5], [2.]) - sample = normal_a.sample([2]) - # a random tensor created by normal distribution with shape: [2, 1] - entropy = normal_a.entropy() - # [1.4189385] with shape: [1] - lp = normal_a.log_prob(value_tensor) - # [-1.2389386] with shape: [1] - p = normal_a.probs(value_tensor) - # [0.28969154] with shape: [1] - kl = normal_a.kl_divergence(normal_b) - # [0.34939718] with shape: [1] + import paddle + from paddle.distribution import Normal + + # Define a single scalar Normal distribution. + dist = Normal(loc=0., scale=3.) + # Define a batch of two scalar valued Normals. + # The first has mean 1 and standard deviation 11, the second 2 and 22. + dist = Normal(loc=[1., 2.], scale=[11., 22.]) + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + + # Define a batch of two scalar valued Normals. + # Both have mean 1, but different standard deviations. + dist = Normal(loc=1., scale=[11., 22.]) + + # Complete example + value_tensor = paddle.to_tensor([0.8], dtype="float32") + + normal_a = Normal([0.], [1.]) + normal_b = Normal([0.5], [2.]) + sample = normal_a.sample([2]) + # a random tensor created by normal distribution with shape: [2, 1] + entropy = normal_a.entropy() + # [1.4189385] with shape: [1] + lp = normal_a.log_prob(value_tensor) + # [-1.2389386] with shape: [1] + p = normal_a.probs(value_tensor) + # [0.28969154] with shape: [1] + kl = normal_a.kl_divergence(normal_b) + # [0.34939718] with shape: [1] """ def __init__(self, loc, scale, name=None): From 6f176c661f363ba57600fe6fc71ee1e5830c4f54 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 16:28:28 +0000 Subject: [PATCH 29/39] add comment --- python/paddle/distribution/normal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 63c56c16ed1f6..889e81d08ef43 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -154,7 +154,7 @@ def sample(self, shape=(), seed=0): """Generate samples of the specified shape. Args: - shape (Sequence[int], optional): Sample shape. + shape (Sequence[int], optional): Shape of the generated samples. seed (int): Python integer number. Returns: From 0f8d1e6dda832eb35572debd90f31499b94b2993 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sat, 24 Sep 2022 16:32:43 +0000 Subject: [PATCH 30/39] add comment --- python/paddle/distribution/normal.py | 2 +- python/paddle/distribution/transformed_distribution.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 889e81d08ef43..87701cd4bacf2 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -200,7 +200,7 @@ def rsample(self, shape=(), seed=0): """Generate reparameterized samples of the specified shape. Args: - shape (Sequence[int], optional): Sample shape. + shape (Sequence[int], optional): Shape of the generated samples. seed (int): Python integer number. Returns: diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index 38becc4d37320..c0ba50c83f68b 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -83,7 +83,7 @@ def sample(self, shape=()): """Sample from ``TransformedDistribution``. Args: - shape ((Sequence[int], optional): The sample shape. Defaults to (). + shape (Sequence[int], optional): The sample shape. Defaults to (). Returns: [Tensor]: The sample result. @@ -97,7 +97,7 @@ def rsample(self, shape=()): """Reparameterized sample from ``TransformedDistribution``. Args: - shape ((Sequence[int], optional): The sample shape. Defaults to (). + shape (Sequence[int], optional): The sample shape. Defaults to (). Returns: [Tensor]: The sample result. From 8be803e135c3b4628bc467591b61d87387dfcbfd Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Mon, 26 Sep 2022 10:29:33 +0000 Subject: [PATCH 31/39] add comment --- python/paddle/distribution/lognormal.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 36c10015f57b2..34a349761ef63 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -118,7 +118,8 @@ def entropy(self): In the above equation: - * :math:`scale = \sigma`: is the std. + * :math:`loc = \mu`: is the mean of the underlying Normal distribution. + * :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. Returns: Tensor: Shannon entropy of lognormal distribution.The data type is float32. From 73e2b4f63dbac329c96eeb0a4956b331509362cf Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Mon, 26 Sep 2022 10:52:57 +0000 Subject: [PATCH 32/39] add comment --- python/paddle/distribution/lognormal.py | 18 +++++++++--------- python/paddle/distribution/normal.py | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 34a349761ef63..3f97693bee9bd 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -43,8 +43,8 @@ class LogNormal(TransformedDistribution): * :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. Args: - loc(int|float|list|tuple|numpy.ndarray|Tensor): The mean of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor. - scale(int|float|list|tuple|numpy.ndarray|Tensor): The std of normal distribution.The data type is int, float, list, numpy.ndarray or Tensor. + loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of normal distribution. + scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of normal distribution. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Examples: @@ -122,7 +122,7 @@ def entropy(self): * :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. Returns: - Tensor: Shannon entropy of lognormal distribution.The data type is float32. + Tensor: Shannon entropy of lognormal distribution. """ return self._base.entropy() + self._base.mean @@ -134,7 +134,7 @@ def probs(self, value): value (Tensor): The input tensor. Returns: - Tensor: probability.The data type is same with value. + Tensor: probability.The data type is same with :attr:`value` . """ return paddle.exp(self.log_prob(value)) @@ -158,10 +158,10 @@ def kl_divergence(self, other): In the above equation: - * :math:`loc = \mu_0`: is the mean of current underlying Normal distribution. - * :math:`scale = \sigma_0`: is the std of current underlying Normal distribution. - * :math:`loc = \mu_1`: is the mean of other underlying Normal distribution. - * :math:`scale = \sigma_1`: is the std of other underlying Normal distribution. + * :math:`loc = \mu_0`: is the means of current underlying Normal distribution. + * :math:`scale = \sigma_0`: is the stddevs of current underlying Normal distribution. + * :math:`loc = \mu_1`: is the means of other underlying Normal distribution. + * :math:`scale = \sigma_1`: is the stddevs of other underlying Normal distribution. * :math:`ratio`: is the ratio of scales. * :math:`diff`: is the difference between means. @@ -169,7 +169,7 @@ def kl_divergence(self, other): other (LogNormal): instance of LogNormal. Returns: - Tensor: kl-divergence between two lognormal distributions.The data type is float32. + Tensor: kl-divergence between two lognormal distributions. """ return self._base.kl_divergence(other._base) diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 87701cd4bacf2..3f42ee1414fc0 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -243,7 +243,7 @@ def log_prob(self, value): value (Tensor): The input tensor. Returns: - Tensor: log probability.The data type is same with value. + Tensor: log probability.The data type is same with :attr:`value` . """ name = self.name + '_log_prob' @@ -263,7 +263,7 @@ def probs(self, value): value (Tensor): The input tensor. Returns: - Tensor, probability. The data type is same with value. + Tensor, probability. The data type is same with :attr:`value` . """ name = self.name + '_probs' From 0f94b9a1e4df210e9999d9870deddfecbe4f7d97 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Wed, 28 Sep 2022 01:11:54 +0000 Subject: [PATCH 33/39] format code --- python/paddle/distribution/__init__.py | 2 +- python/paddle/distribution/lognormal.py | 2 +- .../distribution/test_distribution_lognormal.py | 11 +++++------ .../test_distribution_lognormal_static.py | 13 ++++++------- .../test_distribution_multinomial_static.py | 4 ++-- 5 files changed, 15 insertions(+), 17 deletions(-) diff --git a/python/paddle/distribution/__init__.py b/python/paddle/distribution/__init__.py index 8a41028bd81a0..0e77febe55191 100644 --- a/python/paddle/distribution/__init__.py +++ b/python/paddle/distribution/__init__.py @@ -20,9 +20,9 @@ from paddle.distribution.exponential_family import ExponentialFamily from paddle.distribution.independent import Independent from paddle.distribution.kl import kl_divergence, register_kl +from paddle.distribution.lognormal import LogNormal from paddle.distribution.multinomial import Multinomial from paddle.distribution.normal import Normal -from paddle.distribution.lognormal import LogNormal from paddle.distribution.transform import * # noqa: F403 from paddle.distribution.transformed_distribution import \ TransformedDistribution diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 3f97693bee9bd..1286205814f0d 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -13,9 +13,9 @@ # limitations under the License. import paddle +from paddle.distribution.normal import Normal from paddle.distribution.transform import ExpTransform from paddle.distribution.transformed_distribution import TransformedDistribution -from paddle.distribution.normal import Normal class LogNormal(TransformedDistribution): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index b560c8766a2aa..17a1b772920bc 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -13,18 +13,17 @@ # limitations under the License. import math -import unittest -import scipy.stats +import config import numpy as np import paddle - -import config -from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +import scipy.stats +import unittest +from paddle.distribution.kl import kl_divergence from paddle.distribution.normal import Normal from paddle.distribution.lognormal import LogNormal from test_distribution import DistributionNumpy -from paddle.distribution.kl import kl_divergence +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand class LogNormalNumpy(DistributionNumpy): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 0e7cf876ece65..192616f220806 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -import paddle +import config import numpy as np +import paddle import scipy.stats -import config - -from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand -from paddle.distribution.normal import Normal +import unittest +from paddle.distribution.kl import kl_divergence from paddle.distribution.lognormal import LogNormal +from paddle.distribution.normal import Normal +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand from test_distribution_lognormal import LogNormalNumpy -from paddle.distribution.kl import kl_divergence @place(config.DEVICES) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py index f9beb6b7702f8..56341d7fc0ef8 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_multinomial_static.py @@ -154,13 +154,13 @@ def setUp(self): self.main_program = paddle.static.Program() self.executor = paddle.static.Executor(self.place) - with paddle.static.program_guard(main_program, startup_program): + with paddle.static.program_guard(self.main_program, startup_program): probs = paddle.static.data('probs', self.probs.shape, self.probs.dtype) dist = paddle.distribution.Multinomial(self.total_count, probs) self.feed = {'probs': self.probs} - executor.run(startup_program) + self.executor.run(startup_program) def TestInit(self): with self.assertRaises(ValueError): From 7713ad89cc15ce8fd17ee55a7cbe7367b3fcd5bd Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Wed, 28 Sep 2022 14:35:17 +0000 Subject: [PATCH 34/39] fix bug --- python/paddle/distribution/lognormal.py | 6 +- python/paddle/distribution/normal.py | 11 +- .../distribution/transformed_distribution.py | 5 +- .../test_distribution_lognormal.py | 2 +- .../distribution/test_distribution_normal.py | 134 ++++++++++++++++-- 5 files changed, 140 insertions(+), 18 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 1286205814f0d..6f466f7948566 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -19,7 +19,7 @@ class LogNormal(TransformedDistribution): - r"""The Normal distribution with location `loc` and `scale` parameters. + r"""The LogNormal distribution with location `loc` and `scale` parameters. .. math:: @@ -43,8 +43,8 @@ class LogNormal(TransformedDistribution): * :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution. Args: - loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of normal distribution. - scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of normal distribution. + loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of the underlying Normal distribution. + scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of the underlying Normal distribution. name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Examples: diff --git a/python/paddle/distribution/normal.py b/python/paddle/distribution/normal.py index 3f42ee1414fc0..33e36fbe72dac 100644 --- a/python/paddle/distribution/normal.py +++ b/python/paddle/distribution/normal.py @@ -16,6 +16,7 @@ import warnings import numpy as np +import paddle from paddle import _C_ops, _legacy_C_ops from paddle.distribution import distribution from paddle.fluid import core @@ -196,18 +197,22 @@ def sample(self, shape=(), seed=0): else: return output - def rsample(self, shape=(), seed=0): + def rsample(self, shape=()): """Generate reparameterized samples of the specified shape. Args: shape (Sequence[int], optional): Shape of the generated samples. - seed (int): Python integer number. Returns: Tensor: A tensor with prepended dimensions shape.The data type is float32. """ - return self.sample(shape, seed) + if not isinstance(shape, Iterable): + raise TypeError('sample shape must be Iterable object.') + + shape = self._extend_shape(tuple(shape)) + eps = paddle.normal(shape=shape) + return (self.loc + eps * self.scale) def entropy(self): r"""Shannon entropy in nats. diff --git a/python/paddle/distribution/transformed_distribution.py b/python/paddle/distribution/transformed_distribution.py index c0ba50c83f68b..da0e5908f0ce1 100644 --- a/python/paddle/distribution/transformed_distribution.py +++ b/python/paddle/distribution/transformed_distribution.py @@ -61,9 +61,10 @@ def __init__(self, base, transforms): raise TypeError("All element of transforms must be Transform type.") chain = transform.ChainTransform(transforms) - if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: + base_shape = base.batch_shape + base.event_shape + if len(base_shape) < chain._domain.event_rank: raise ValueError( - f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base.batch_shape + base.event_shape)}." + f"'base' needs to have shape with size at least {chain._domain.event_rank}, but got {len(base_shape)}." ) if chain._domain.event_rank > len(base.event_shape): base = independent.Independent( diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 17a1b772920bc..7e270172c631f 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -22,8 +22,8 @@ from paddle.distribution.kl import kl_divergence from paddle.distribution.normal import Normal from paddle.distribution.lognormal import LogNormal -from test_distribution import DistributionNumpy from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +from test_distribution import DistributionNumpy class LogNormalNumpy(DistributionNumpy): diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py index 5cabe0fa488f2..ff96d7489a61a 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py @@ -13,15 +13,17 @@ # limitations under the License. import math -import unittest +import config import numpy as np import paddle from paddle import fluid from paddle.distribution import * from paddle.fluid import layers - +from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +import scipy.stats from test_distribution import DistributionNumpy +import unittest np.random.seed(2022) @@ -115,7 +117,7 @@ def init_static_data(self, batch_size, dims): dtype='float32') def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6): - sample, rsample, entropy, log_prob, probs, kl = fetch_list + sample, entropy, log_prob, probs, kl = fetch_list np_normal = NormalNumpy(self.loc_np, self.scale_np) np_sample = np_normal.sample([sample_shape]) @@ -131,9 +133,7 @@ def compare_with_numpy(self, fetch_list, sample_shape=7, tolerance=1e-6): # There is a loss of accuracy in this conversion. # So set the tolerance from 1e-6 to 1e-4. log_tolerance = 1e-4 - np.testing.assert_equal(sample.shape, np_sample.shape) - np.testing.assert_equal(rsample.shape, np_sample.shape) np.testing.assert_allclose(entropy, np_entropy, rtol=tolerance, @@ -156,14 +156,13 @@ def test_normal_distribution_dygraph(self, sample_shape=7, tolerance=1e-6): normal = Normal(self.dynamic_loc, self.dynamic_scale) sample = normal.sample([sample_shape]).numpy() - rsample = normal.rsample([sample_shape]).numpy() entropy = normal.entropy().numpy() log_prob = normal.log_prob(self.dynamic_values).numpy() probs = normal.probs(self.dynamic_values).numpy() other_normal = Normal(self.dynamic_other_loc, self.dynamic_other_scale) kl = normal.kl_divergence(other_normal).numpy() - fetch_list = [sample, rsample, entropy, log_prob, probs, kl] + fetch_list = [sample, entropy, log_prob, probs, kl] self.compare_with_numpy(fetch_list) def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6): @@ -172,7 +171,6 @@ def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6): normal = Normal(self.static_loc, self.static_scale) sample = normal.sample([sample_shape]) - rsample = normal.rsample([sample_shape]) entropy = normal.entropy() log_prob = normal.log_prob(self.static_values) probs = normal.probs(self.static_values) @@ -180,7 +178,7 @@ def test_normal_distribution_static(self, sample_shape=7, tolerance=1e-6): self.static_other_scale) kl = normal.kl_divergence(other_normal) - fetch_list = [sample, rsample, entropy, log_prob, probs, kl] + fetch_list = [sample, entropy, log_prob, probs, kl] feed_vars = { 'loc': self.loc_np, @@ -502,5 +500,123 @@ def init_static_data(self, batch_size, dims): dtype='float32') +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, )))]) +class TestNormalSampleDygraph(unittest.TestCase): + + def setUp(self): + paddle.disable_static() + self.paddle_normal = Normal(loc=self.loc, scale=self.scale) + n = 100000 + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.samples = self.paddle_normal.sample(self.sample_shape) + self.rsamples = self.paddle_normal.rsample(self.rsample_shape) + + def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, + self.paddle_normal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(samples_var, + self.paddle_normal.variance, + rtol=0.1, + atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, + self.paddle_normal.mean, + rtol=0.1, + atol=0) + np.testing.assert_allclose(rsamples_var, + self.paddle_normal.variance, + rtol=0.1, + atol=0) + + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, + list(self.sample_shape + batch_shape)) + self.assertEqual(self.rsamples.shape, + list(self.rsample_shape + batch_shape)) + + for i in range(len(self.scale)): + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest(samples, + scipy.stats.norm(loc=loc, scale=scale).cdf) + return ks < 0.02 + + +@place(config.DEVICES) +@parameterize_cls((TEST_CASE_NAME, 'loc', 'scale'), [('sample', xrand( + (4, )), xrand((4, )))]) +class TestNormalSampleStaic(unittest.TestCase): + + def setUp(self): + paddle.enable_static() + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + executor = paddle.static.Executor(self.place) + with paddle.static.program_guard(main_program, startup_program): + loc = paddle.static.data('loc', self.loc.shape, self.loc.dtype) + scale = paddle.static.data('scale', self.scale.shape, + self.scale.dtype) + n = 100000 + self.sample_shape = (n, ) + self.rsample_shape = (n, ) + self.paddle_lognormal = Normal(loc=loc, scale=scale) + mean = self.paddle_lognormal.mean + variance = self.paddle_lognormal.variance + samples = self.paddle_lognormal.sample(self.sample_shape) + rsamples = self.paddle_lognormal.rsample(self.rsample_shape) + fetch_list = [mean, variance, samples, rsamples] + self.feeds = {'loc': self.loc, 'scale': self.scale} + + executor.run(startup_program) + [self.mean, self.variance, self.samples, + self.rsamples] = executor.run(main_program, + feed=self.feeds, + fetch_list=fetch_list) + + def test_sample(self): + samples_mean = self.samples.mean(axis=0) + samples_var = self.samples.var(axis=0) + np.testing.assert_allclose(samples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(samples_var, self.variance, rtol=0.1, atol=0) + + rsamples_mean = self.rsamples.mean(axis=0) + rsamples_var = self.rsamples.var(axis=0) + np.testing.assert_allclose(rsamples_mean, self.mean, rtol=0.1, atol=0) + np.testing.assert_allclose(rsamples_var, + self.variance, + rtol=0.1, + atol=0) + + batch_shape = (self.loc + self.scale).shape + self.assertEqual(self.samples.shape, self.sample_shape + batch_shape) + self.assertEqual(self.rsamples.shape, self.rsample_shape + batch_shape) + + for i in range(len(self.scale)): + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.samples[:, i])) + self.assertTrue( + self._kstest(self.loc[i], self.scale[i], self.rsamples[:, i])) + + def _kstest(self, loc, scale, samples): + # Uses the Kolmogorov-Smirnov test for goodness of fit. + ks, _ = scipy.stats.kstest(samples, + scipy.stats.norm(loc=loc, scale=scale).cdf) + return ks < 0.02 + + if __name__ == '__main__': unittest.main() From 16765f7e4ff60b9336de7d4221c5c3e644ff23bb Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Wed, 28 Sep 2022 14:48:14 +0000 Subject: [PATCH 35/39] fix bug --- .../unittests/distribution/test_distribution_normal.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py index ff96d7489a61a..925a49439e776 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py @@ -573,11 +573,11 @@ def setUp(self): n = 100000 self.sample_shape = (n, ) self.rsample_shape = (n, ) - self.paddle_lognormal = Normal(loc=loc, scale=scale) - mean = self.paddle_lognormal.mean - variance = self.paddle_lognormal.variance - samples = self.paddle_lognormal.sample(self.sample_shape) - rsamples = self.paddle_lognormal.rsample(self.rsample_shape) + self.paddle_normal = Normal(loc=loc, scale=scale) + mean = self.paddle_normal.mean + variance = self.paddle_normal.variance + samples = self.paddle_normal.sample(self.sample_shape) + rsamples = self.paddle_normal.rsample(self.rsample_shape) fetch_list = [mean, variance, samples, rsamples] self.feeds = {'loc': self.loc, 'scale': self.scale} From f1bf7111768c1913f8fe9b06069d3c8d962001a1 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Wed, 28 Sep 2022 16:23:41 +0000 Subject: [PATCH 36/39] fix bug --- .../distribution/test_distribution_transformed_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py index c448e407b9dee..15fd94117f008 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_transformed_distribution.py @@ -63,7 +63,7 @@ def test_sample(self): def test_rsample(self): shape = [5, 10, 8] - expected_shape = (5, 10, 8) + expected_shape = (5, 10, 8, 1) data = self._t.rsample(shape) self.assertEqual(tuple(data.shape), expected_shape) self.assertEqual(data.dtype, self.base.loc.dtype) From 9f3a98b5f187a1d32e1bcd545965d72a8f0f0e48 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Thu, 29 Sep 2022 04:09:41 +0000 Subject: [PATCH 37/39] add comment --- python/paddle/distribution/lognormal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 6f466f7948566..f0c516f7f473a 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -28,7 +28,7 @@ class LogNormal(TransformedDistribution): Y = exp(X) \sim LogNormal(\mu, \sigma) - :math:`Normal(\mu, \sigma)` is the underlying distribution of :math:`LogNormal(\mu, \sigma)` + Due to LogNormal distribution is based on the transformation of Normal distribution, we call that :math:`Normal(\mu, \sigma)` is the underlying distribution of :math:`LogNormal(\mu, \sigma)` Mathematical details From 4745c39a7723c5815dea8f3c1fbe01097e7deb09 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sun, 9 Oct 2022 06:23:31 +0000 Subject: [PATCH 38/39] remove name parameter for LogNormal --- python/paddle/distribution/lognormal.py | 11 ++++++----- .../distribution/test_distribution_lognormal.py | 6 +++--- .../test_distribution_lognormal_static.py | 5 +++-- .../distribution/test_distribution_normal.py | 2 +- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index f0c516f7f473a..683f550f95ea6 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import paddle from paddle.distribution.normal import Normal from paddle.distribution.transform import ExpTransform -from paddle.distribution.transformed_distribution import TransformedDistribution +from paddle.distribution.transformed_distribution import \ + TransformedDistribution + +import paddle class LogNormal(TransformedDistribution): @@ -45,7 +47,6 @@ class LogNormal(TransformedDistribution): Args: loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of the underlying Normal distribution. scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of the underlying Normal distribution. - name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Examples: .. code-block:: python @@ -82,8 +83,8 @@ class LogNormal(TransformedDistribution): # [0.34939718] with shape: [1] """ - def __init__(self, loc, scale, name=None): - self._base = Normal(loc=loc, scale=scale, name=name) + def __init__(self, loc, scale): + self._base = Normal(loc=loc, scale=scale) self.loc = self._base.loc self.scale = self._base.scale super(LogNormal, self).__init__(self._base, [ExpTransform()]) diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index 7e270172c631f..ce80355aeba13 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -13,15 +13,15 @@ # limitations under the License. import math +import unittest import config import numpy as np -import paddle import scipy.stats -import unittest +import paddle from paddle.distribution.kl import kl_divergence -from paddle.distribution.normal import Normal from paddle.distribution.lognormal import LogNormal +from paddle.distribution.normal import Normal from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand from test_distribution import DistributionNumpy diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index 192616f220806..fdb724b50ee7e 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import config import numpy as np -import paddle import scipy.stats -import unittest +import paddle from paddle.distribution.kl import kl_divergence from paddle.distribution.lognormal import LogNormal from paddle.distribution.normal import Normal diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py index 925a49439e776..1873ac7efa6b5 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_normal.py @@ -13,6 +13,7 @@ # limitations under the License. import math +import unittest import config import numpy as np @@ -23,7 +24,6 @@ from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand import scipy.stats from test_distribution import DistributionNumpy -import unittest np.random.seed(2022) From c97a9fe5b4ef5b54dd992dd8b5860dc972232458 Mon Sep 17 00:00:00 2001 From: MayYouBeProsperous Date: Sun, 9 Oct 2022 06:48:22 +0000 Subject: [PATCH 39/39] organize imports --- python/paddle/distribution/lognormal.py | 3 +-- .../unittests/distribution/test_distribution_lognormal.py | 3 ++- .../distribution/test_distribution_lognormal_static.py | 3 ++- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/distribution/lognormal.py b/python/paddle/distribution/lognormal.py index 683f550f95ea6..b171e1ecbc61e 100644 --- a/python/paddle/distribution/lognormal.py +++ b/python/paddle/distribution/lognormal.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import paddle from paddle.distribution.normal import Normal from paddle.distribution.transform import ExpTransform from paddle.distribution.transformed_distribution import \ TransformedDistribution -import paddle - class LogNormal(TransformedDistribution): r"""The LogNormal distribution with location `loc` and `scale` parameters. diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py index ce80355aeba13..a7c97047505c3 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal.py @@ -17,12 +17,13 @@ import config import numpy as np -import scipy.stats import paddle from paddle.distribution.kl import kl_divergence from paddle.distribution.lognormal import LogNormal from paddle.distribution.normal import Normal from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +import scipy.stats + from test_distribution import DistributionNumpy diff --git a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py index fdb724b50ee7e..75a9e497f34b7 100644 --- a/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py +++ b/python/paddle/fluid/tests/unittests/distribution/test_distribution_lognormal_static.py @@ -16,12 +16,13 @@ import config import numpy as np -import scipy.stats import paddle from paddle.distribution.kl import kl_divergence from paddle.distribution.lognormal import LogNormal from paddle.distribution.normal import Normal from parameterize import TEST_CASE_NAME, parameterize_cls, place, xrand +import scipy.stats + from test_distribution_lognormal import LogNormalNumpy