Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.8】 add gumbel distribution api #46255

Merged
merged 69 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
69 commits
Select commit Hold shift + click to select a range
d4263c2
init gumbel api
PureNatural Sep 18, 2022
a29f111
commit: update test file
dasenCoding Sep 18, 2022
f5d62e1
fix:bug
PureNatural Sep 19, 2022
0b8faec
update Gumbel API
dasenCoding Sep 28, 2022
e6a8c1b
upgrade distribution/gumbel.py
dasenCoding Oct 4, 2022
2791493
add tests/test_distribution_gumbel.py
dasenCoding Oct 4, 2022
1541ecb
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 4, 2022
ef3dc50
fix:code style
PureNatural Oct 4, 2022
ddcd86a
fix:code style
PureNatural Oct 4, 2022
4e40718
fix:code style
PureNatural Oct 4, 2022
fff33ad
fix:code style
PureNatural Oct 4, 2022
517d053
fix bug
dasenCoding Oct 5, 2022
cedc871
fix:code style
dasenCoding Oct 5, 2022
72cd09b
fix:code style
PureNatural Oct 5, 2022
8e5bdc4
fix:rollback uniform
PureNatural Oct 5, 2022
cc3f783
fix:delete invalid code
PureNatural Oct 5, 2022
b6416cb
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 9, 2022
7f603a6
fix:bug and add static test
PureNatural Oct 9, 2022
6a06603
fix:code style
PureNatural Oct 9, 2022
06b9dc4
fix:code style
PureNatural Oct 9, 2022
d83d484
fix:delete init transforms
PureNatural Oct 9, 2022
db490e3
fix:bug
PureNatural Oct 9, 2022
381d059
fix:bug
PureNatural Oct 9, 2022
931f572
fix:code style
PureNatural Oct 9, 2022
b95dc13
fix:code style
PureNatural Oct 9, 2022
78b1b5b
fix:add transforms
PureNatural Oct 9, 2022
67047a2
fix:code style
PureNatural Oct 9, 2022
554a813
fix:code style
PureNatural Oct 9, 2022
c786d25
fix:bug
PureNatural Oct 9, 2022
c398fe3
fix:bug
PureNatural Oct 9, 2022
c713d81
fix:code style
PureNatural Oct 9, 2022
983a3f8
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 9, 2022
aea7df7
fix:code style
PureNatural Oct 9, 2022
33c780b
fix:bug
PureNatural Oct 9, 2022
da166ee
fix:code style
PureNatural Oct 9, 2022
6891ad8
fix:code style
PureNatural Oct 9, 2022
e10fd27
fix:bug for gumbel.py / add:judge transforms'len for transformed_dist…
dasenCoding Oct 10, 2022
8d5a83c
Merge branch 'gumbel_api' of https://github.com/PureNatural/Paddle in…
dasenCoding Oct 10, 2022
4e328be
update gumbel.py
dasenCoding Oct 11, 2022
9d89aac
fix:bug for test_distribution_gumbel.py
dasenCoding Oct 11, 2022
a0c357d
fix:bug for test_distribution_gumbel_static.py
dasenCoding Oct 11, 2022
db1cbfd
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
c6a2292
fix:code style
PureNatural Oct 11, 2022
c735592
fix:code style
PureNatural Oct 11, 2022
38530db
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
fc57abe
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
4bab5d1
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 11, 2022
98f8ed6
fix:coverage
PureNatural Oct 11, 2022
33a83fc
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 12, 2022
f2fa6dc
fix:bug
PureNatural Oct 12, 2022
a20a723
fix:bug
PureNatural Oct 12, 2022
0289b74
fix:code style
PureNatural Oct 12, 2022
fb972c3
fix:bug
PureNatural Oct 12, 2022
2f017a0
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 12, 2022
0e2892d
delete:no use package for gumbel.py
dasenCoding Oct 12, 2022
3261480
add:coverage transforms'len judge for test_distribution_gumbel.py
dasenCoding Oct 12, 2022
1ecfcc6
fix:code style for test_distribution_gumbel.py
dasenCoding Oct 12, 2022
6a9245e
fix:coverage
PureNatural Oct 12, 2022
e593fb4
fix:code style
PureNatural Oct 12, 2022
8c57748
fix:code style
PureNatural Oct 12, 2022
f7a0c36
fix:code style
PureNatural Oct 12, 2022
444454e
fix:code style
PureNatural Oct 12, 2022
3598ed5
fix:code style
PureNatural Oct 12, 2022
017f66c
Merge branch 'PaddlePaddle:develop' into gumbel_api
PureNatural Oct 14, 2022
e7108f0
fix:en doc
PureNatural Oct 14, 2022
069cadf
Merge branch 'gumbel_api' of github.com:PureNatural/Paddle into gumbe…
PureNatural Oct 14, 2022
e93ff40
fix:param
PureNatural Oct 14, 2022
6172d98
fix:copyright
PureNatural Oct 16, 2022
c957ab4
fixSample; test=document_fix
dasenCoding Oct 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/paddle/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from paddle.distribution.categorical import Categorical
from paddle.distribution.dirichlet import Dirichlet
from paddle.distribution.distribution import Distribution
from paddle.distribution.gumbel import Gumbel
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.independent import Independent
from paddle.distribution.kl import kl_divergence, register_kl
Expand All @@ -32,7 +33,7 @@
__all__ = [ # noqa
'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily',
'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl',
'Independent', 'TransformedDistribution', 'Laplace', 'LogNormal'
'Independent', 'TransformedDistribution', 'Laplace', 'LogNormal', 'Gumbel'
]

__all__.extend(transform.__all__)
242 changes: 242 additions & 0 deletions python/paddle/distribution/gumbel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import numbers
import math
import numpy as np

from paddle.distribution.transformed_distribution import TransformedDistribution
from paddle.fluid import framework as framework
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



class Gumbel(TransformedDistribution):
r"""The Gumbel distribution with location `loc` and `scale` parameters.

Mathematical details

The probability density function (pdf) is

.. math::

pdf(x; mu, sigma) = exp(-(x - mu) / sigma - exp(-(x - mu) / sigma)) / sigma


In the above equation:

* :math:`loc = \mu`: is the mean.
* :math:`scale = \sigma`: is the std.

Args:
loc(int|float|tensor): The mean of gumbel distribution.The data type is int, float, tensor.
scale(int|float|tensor): The std of gumbel distribution.The data type is int, float, tensor.

Copy link
Contributor

@cxxly cxxly Sep 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int|float|Tensor,Tensor为必须支持数据类型

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The location parameter of gumbel distribution .....

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Examples:
.. code-block:: python

import paddle
from paddle.distribution.gumbel import Gumbel

# Gumbel distributed with loc=0, scale=1
dist = Gumbel(paddle.full([1], 0.0), paddle.full([1], 1.0))
dist.sample([2])
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[-0.27544352], [-0.64499271]])
value = paddle.full([1], 0.5)
dist.prob(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.33070430])
dist.log_prob(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [-1.10653067])
dist.cdf(value)
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.54523915])
dist.entropy()
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567])
dist.rsample([2])
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]])

Copy link
Contributor

@cxxly cxxly Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1)参考 API 文档规范书写示例代码,https://github.com/PaddlePaddle/docs/wiki/飞桨API文档书写规范,下述每个方法代码示例存在同样问题,请统一修改
2)建议每个方法内代码示例,统一写到此处

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""

def __init__(self, loc, scale):

if not isinstance(loc, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of loc is Real|Variable, but got {type(loc)}")
if not isinstance(scale, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of scale is Real|Variable, but got {type(scale)}"
)

if isinstance(loc, numbers.Real):
loc = paddle.full(shape=(), fill_value=loc)

if isinstance(scale, numbers.Real):
scale = paddle.full(shape=(), fill_value=scale)

if loc.shape != scale.shape:
self.loc, self.scale = paddle.broadcast_tensors([loc, scale])
else:
self.loc, self.scale = loc, scale

finfo = np.finfo(dtype='float32')
self.base_dist = paddle.distribution.Uniform(
paddle.full_like(self.loc, float(finfo.tiny)),
paddle.full_like(self.loc, float(1 - finfo.eps)))

self.transforms = ()

super(Gumbel, self).__init__(self.base_dist, self.transforms)

Copy link
Contributor

@cxxly cxxly Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

62-82行代码逻辑太过冗余,建议统一转成Tensor进行处理,类似如下写法(伪代码)

if not isinstance(loc, (numbers.Real, framework.Variable)):
    raise TypeError(
        f"Expected type of loc is Real|Variable, but got {type(loc)}")
if not isinstance(scale, (numbers.Real, framework.Variable)):
    raise TypeError(
        f"Expected type of scale is Real|Variable, but got {type(scale)}"
    )

if isinstance(loc, numbers.Real):
    self.loc = paddle.full(shape=(), fill_value=loc)
if isinstance(scale, numbers.Real):
    self.scale = paddle.full(shape=(), fill_value=scale)

if self.loc.shape != self.scale.shape:
    self.loc, self.scale = paddle.broadcast_tensors([self.loc, self.scale])

self.base_dist = paddle.distribution.Uniform(
                 paddle.full_like(self.loc, float(finfo.tiny)),
                 paddle.full_like(self.loc, float(1 - finfo.eps)))

self.transforms = ....
super(Gumbel, self).__init__(...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@property
def mean(self):
"""Mean of distribution

The mean is

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mean is ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.. math::

mean = \mu + \sigma * γ

In the above equation:

* :math:`loc = \mu`: is the location parameter.
* :math:`scale = \sigma`: is the scale parameter.
* :math:`γ`: is the euler's constant.

Returns:
Tensor: mean value.

"""
return self.loc + self.scale * np.euler_gamma

@property
def variance(self):
"""Variance of distribution.

The variance is

.. math::

variance = \sigma^2 * \pi^2 / 6

In the above equation:

* :math:`scale = \sigma`: is the scale parameter.

Returns:
Tensor: The variance value.

"""
temp = paddle.full(shape=self.loc.shape,
fill_value=math.pi * math.pi,
dtype=self.scale.dtype)

Copy link
Contributor

@cxxly cxxly Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1)to_tensor仅支持动态图,使用paddle.full 2) 数据类型和scale保持一致,不一定是float32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return paddle.pow(self.scale, 2) * temp / 6

@property
def stddev(self):
"""Standard deviation of distribution

The standard deviation is

.. math::

stddev = \sqrt{\sigma^2 * \pi^2 / 6}

In the above equation:
* :math:`scale = \sigma`: is the scale parameter.

Returns:
Tensor: std value
"""
return paddle.sqrt(self.variance)

def prob(self, value):
"""Probability density/mass function

Args:
value (Tensor): The input tensor.

Returns:
Tensor: probability.The data type is same with value.

"""
y = (self.loc - value) / self.scale

return paddle.exp(y - paddle.exp(y)) / self.scale

def log_prob(self, value):
"""Log probability density/mass function.

Args:
value (Tensor): The input tensor.

Returns:
Tensor: log probability.The data type is same with value.

"""
return paddle.log(self.prob(value))

def cdf(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in order to have rfc same with code, need to add this functioncdf in rfc API 实现方案

"""Cumulative distribution function.
Args:
value (Tensor): value to be evaluated.

Returns:
Tensor: cumulative probability of value.

"""
return paddle.exp(-paddle.exp(-(value - self.loc) / self.scale))

def entropy(self):
"""Entropy of Gumbel distribution.

Returns:
Entropy of distribution.

"""
return paddle.log(self.scale) + 1 + np.euler_gamma

def sample(self, shape):
"""Sample from ``Gumbel``.

Args:
shape (Sequence[int], optional): The sample shape. Defaults to ().

Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.

"""
with paddle.no_grad():
return self.rsample(shape)

def rsample(self, shape):
"""reparameterized sample
Args:
shape (Sequence[int]): 1D `int32`. Shape of the generated samples.

Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.

"""
exp_trans = paddle.distribution.ExpTransform()
affine_trans_1 = paddle.distribution.AffineTransform(
paddle.full(shape=self.scale.shape,
fill_value=0,
dtype=self.loc.dtype), -paddle.ones_like(self.scale))
affine_trans_2 = paddle.distribution.AffineTransform(
self.loc, -self.scale)

return affine_trans_2.forward(
exp_trans.inverse(
affine_trans_1.forward(
exp_trans.inverse(self._base.sample(shape)))))
12 changes: 8 additions & 4 deletions python/paddle/distribution/transformed_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,19 @@ def __init__(self, base, transforms):

chain = transform.ChainTransform(transforms)
base_shape = base.batch_shape + base.event_shape
if len(base_shape) < chain._domain.event_rank:
self._base = base
self._transforms = transforms
if not transforms:
super(TransformedDistribution,
self).__init__(base.batch_shape, base.event_shape)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个判断逻辑可以加在最前面

if not transforms:
    ....
else:
    ...

return
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank:
raise ValueError(
f"'base' needs to have shape with size at least {chain._domain.event_rank}, but got {len(base_shape)}."
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}."
)
if chain._domain.event_rank > len(base.event_shape):
base = independent.Independent(
(base, chain._domain.event_rank - len(base.event_shape)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除此段代码原因?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除此段代码原因?

您好,我们CI-Coverage测试中发现这段代码没有覆盖到,导致该文件代码覆盖率极低,CI无法通过。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个代码是有用处的,不能因为覆盖率就把它删了;可以补充一些测试用例

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个代码是有用处的,不能因为覆盖率就把它删了;可以补充一些测试用例

好的,感谢您的建议

self._base = base
self._transforms = transforms

transformed_shape = chain.forward_shape(base.batch_shape +
base.event_shape)
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/distribution/uniform.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def __init__(self, low, high, name=None):
self.low = tensor.cast(self.low, dtype=self.dtype)
self.high = tensor.cast(self.high, dtype=self.dtype)

super(Uniform, self).__init__(self.low.shape)

def sample(self, shape, seed=0):
"""Generate samples of the specified shape.

Expand Down
Loading