-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.8】 add gumbel distribution api #46255
Changes from all commits
d4263c2
a29f111
f5d62e1
0b8faec
e6a8c1b
2791493
1541ecb
ef3dc50
ddcd86a
4e40718
fff33ad
517d053
cedc871
72cd09b
8e5bdc4
cc3f783
b6416cb
7f603a6
6a06603
06b9dc4
d83d484
db490e3
381d059
931f572
b95dc13
78b1b5b
67047a2
554a813
c786d25
c398fe3
c713d81
983a3f8
aea7df7
33c780b
da166ee
6891ad8
e10fd27
8d5a83c
4e328be
9d89aac
a0c357d
db1cbfd
c6a2292
c735592
38530db
fc57abe
4bab5d1
98f8ed6
33a83fc
f2fa6dc
a20a723
0289b74
fb972c3
2f017a0
0e2892d
3261480
1ecfcc6
6a9245e
e593fb4
8c57748
f7a0c36
444454e
3598ed5
017f66c
e7108f0
069cadf
e93ff40
6172d98
c957ab4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import numbers | ||
import math | ||
import numpy as np | ||
|
||
from paddle.distribution.transformed_distribution import TransformedDistribution | ||
from paddle.fluid import framework as framework | ||
|
||
|
||
class Gumbel(TransformedDistribution): | ||
r"""The Gumbel distribution with location `loc` and `scale` parameters. | ||
|
||
Mathematical details | ||
|
||
The probability density function (pdf) is | ||
|
||
.. math:: | ||
|
||
pdf(x; mu, sigma) = exp(-(x - mu) / sigma - exp(-(x - mu) / sigma)) / sigma | ||
|
||
|
||
In the above equation: | ||
|
||
* :math:`loc = \mu`: is the mean. | ||
* :math:`scale = \sigma`: is the std. | ||
|
||
Args: | ||
loc(int|float|tensor): The mean of gumbel distribution.The data type is int, float, tensor. | ||
scale(int|float|tensor): The std of gumbel distribution.The data type is int, float, tensor. | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. int|float|Tensor,Tensor为必须支持数据类型 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The location parameter of gumbel distribution ..... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
Examples: | ||
.. code-block:: python | ||
|
||
import paddle | ||
from paddle.distribution.gumbel import Gumbel | ||
|
||
# Gumbel distributed with loc=0, scale=1 | ||
dist = Gumbel(paddle.full([1], 0.0), paddle.full([1], 1.0)) | ||
dist.sample([2]) | ||
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[-0.27544352], [-0.64499271]]) | ||
value = paddle.full([1], 0.5) | ||
dist.prob(value) | ||
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.33070430]) | ||
dist.log_prob(value) | ||
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [-1.10653067]) | ||
dist.cdf(value) | ||
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [0.54523915]) | ||
dist.entropy() | ||
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [1.57721567]) | ||
dist.rsample([2]) | ||
# Tensor(shape=[2, 1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [[0.80463481], [0.91893655]]) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1)参考 API 文档规范书写示例代码,https://github.com/PaddlePaddle/docs/wiki/飞桨API文档书写规范,下述每个方法代码示例存在同样问题,请统一修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
""" | ||
|
||
def __init__(self, loc, scale): | ||
|
||
if not isinstance(loc, (numbers.Real, framework.Variable)): | ||
raise TypeError( | ||
f"Expected type of loc is Real|Variable, but got {type(loc)}") | ||
if not isinstance(scale, (numbers.Real, framework.Variable)): | ||
raise TypeError( | ||
f"Expected type of scale is Real|Variable, but got {type(scale)}" | ||
) | ||
|
||
if isinstance(loc, numbers.Real): | ||
loc = paddle.full(shape=(), fill_value=loc) | ||
|
||
if isinstance(scale, numbers.Real): | ||
scale = paddle.full(shape=(), fill_value=scale) | ||
|
||
if loc.shape != scale.shape: | ||
self.loc, self.scale = paddle.broadcast_tensors([loc, scale]) | ||
else: | ||
self.loc, self.scale = loc, scale | ||
|
||
finfo = np.finfo(dtype='float32') | ||
self.base_dist = paddle.distribution.Uniform( | ||
paddle.full_like(self.loc, float(finfo.tiny)), | ||
paddle.full_like(self.loc, float(1 - finfo.eps))) | ||
|
||
self.transforms = () | ||
|
||
super(Gumbel, self).__init__(self.base_dist, self.transforms) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 62-82行代码逻辑太过冗余,建议统一转成Tensor进行处理,类似如下写法(伪代码) if not isinstance(loc, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of loc is Real|Variable, but got {type(loc)}")
if not isinstance(scale, (numbers.Real, framework.Variable)):
raise TypeError(
f"Expected type of scale is Real|Variable, but got {type(scale)}"
)
if isinstance(loc, numbers.Real):
self.loc = paddle.full(shape=(), fill_value=loc)
if isinstance(scale, numbers.Real):
self.scale = paddle.full(shape=(), fill_value=scale)
if self.loc.shape != self.scale.shape:
self.loc, self.scale = paddle.broadcast_tensors([self.loc, self.scale])
self.base_dist = paddle.distribution.Uniform(
paddle.full_like(self.loc, float(finfo.tiny)),
paddle.full_like(self.loc, float(1 - finfo.eps)))
self.transforms = ....
super(Gumbel, self).__init__(...) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
@property | ||
def mean(self): | ||
"""Mean of distribution | ||
|
||
The mean is | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The mean is ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
.. math:: | ||
|
||
mean = \mu + \sigma * γ | ||
|
||
In the above equation: | ||
|
||
* :math:`loc = \mu`: is the location parameter. | ||
* :math:`scale = \sigma`: is the scale parameter. | ||
* :math:`γ`: is the euler's constant. | ||
|
||
Returns: | ||
Tensor: mean value. | ||
|
||
""" | ||
return self.loc + self.scale * np.euler_gamma | ||
|
||
@property | ||
def variance(self): | ||
"""Variance of distribution. | ||
|
||
The variance is | ||
|
||
.. math:: | ||
|
||
variance = \sigma^2 * \pi^2 / 6 | ||
|
||
In the above equation: | ||
|
||
* :math:`scale = \sigma`: is the scale parameter. | ||
|
||
Returns: | ||
Tensor: The variance value. | ||
|
||
""" | ||
temp = paddle.full(shape=self.loc.shape, | ||
fill_value=math.pi * math.pi, | ||
dtype=self.scale.dtype) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 1)to_tensor仅支持动态图,使用paddle.full 2) 数据类型和scale保持一致,不一定是float32 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
return paddle.pow(self.scale, 2) * temp / 6 | ||
|
||
@property | ||
def stddev(self): | ||
"""Standard deviation of distribution | ||
|
||
The standard deviation is | ||
|
||
.. math:: | ||
|
||
stddev = \sqrt{\sigma^2 * \pi^2 / 6} | ||
|
||
In the above equation: | ||
* :math:`scale = \sigma`: is the scale parameter. | ||
|
||
Returns: | ||
Tensor: std value | ||
""" | ||
return paddle.sqrt(self.variance) | ||
|
||
def prob(self, value): | ||
"""Probability density/mass function | ||
|
||
Args: | ||
value (Tensor): The input tensor. | ||
|
||
Returns: | ||
Tensor: probability.The data type is same with value. | ||
|
||
""" | ||
y = (self.loc - value) / self.scale | ||
|
||
return paddle.exp(y - paddle.exp(y)) / self.scale | ||
|
||
def log_prob(self, value): | ||
"""Log probability density/mass function. | ||
|
||
Args: | ||
value (Tensor): The input tensor. | ||
|
||
Returns: | ||
Tensor: log probability.The data type is same with value. | ||
|
||
""" | ||
return paddle.log(self.prob(value)) | ||
|
||
def cdf(self, value): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in order to have rfc same with code, need to add this function |
||
"""Cumulative distribution function. | ||
Args: | ||
value (Tensor): value to be evaluated. | ||
|
||
Returns: | ||
Tensor: cumulative probability of value. | ||
|
||
""" | ||
return paddle.exp(-paddle.exp(-(value - self.loc) / self.scale)) | ||
|
||
def entropy(self): | ||
"""Entropy of Gumbel distribution. | ||
|
||
Returns: | ||
Entropy of distribution. | ||
|
||
""" | ||
return paddle.log(self.scale) + 1 + np.euler_gamma | ||
|
||
def sample(self, shape): | ||
"""Sample from ``Gumbel``. | ||
|
||
Args: | ||
shape (Sequence[int], optional): The sample shape. Defaults to (). | ||
|
||
Returns: | ||
Tensor: A tensor with prepended dimensions shape.The data type is float32. | ||
|
||
""" | ||
with paddle.no_grad(): | ||
return self.rsample(shape) | ||
|
||
def rsample(self, shape): | ||
"""reparameterized sample | ||
Args: | ||
shape (Sequence[int]): 1D `int32`. Shape of the generated samples. | ||
|
||
Returns: | ||
Tensor: A tensor with prepended dimensions shape.The data type is float32. | ||
|
||
""" | ||
exp_trans = paddle.distribution.ExpTransform() | ||
affine_trans_1 = paddle.distribution.AffineTransform( | ||
paddle.full(shape=self.scale.shape, | ||
fill_value=0, | ||
dtype=self.loc.dtype), -paddle.ones_like(self.scale)) | ||
affine_trans_2 = paddle.distribution.AffineTransform( | ||
self.loc, -self.scale) | ||
|
||
return affine_trans_2.forward( | ||
exp_trans.inverse( | ||
affine_trans_1.forward( | ||
exp_trans.inverse(self._base.sample(shape))))) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,15 +62,19 @@ def __init__(self, base, transforms): | |
|
||
chain = transform.ChainTransform(transforms) | ||
base_shape = base.batch_shape + base.event_shape | ||
if len(base_shape) < chain._domain.event_rank: | ||
self._base = base | ||
self._transforms = transforms | ||
if not transforms: | ||
super(TransformedDistribution, | ||
self).__init__(base.batch_shape, base.event_shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个判断逻辑可以加在最前面 if not transforms:
....
else:
... |
||
return | ||
if len(base.batch_shape + base.event_shape) < chain._domain.event_rank: | ||
raise ValueError( | ||
f"'base' needs to have shape with size at least {chain._domain.event_rank}, but got {len(base_shape)}." | ||
f"'base' needs to have shape with size at least {chain._domain.event_rank}, bug got {len(base_shape)}." | ||
) | ||
if chain._domain.event_rank > len(base.event_shape): | ||
base = independent.Independent( | ||
(base, chain._domain.event_rank - len(base.event_shape))) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 删除此段代码原因? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
您好,我们CI-Coverage测试中发现这段代码没有覆盖到,导致该文件代码覆盖率极低,CI无法通过。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个代码是有用处的,不能因为覆盖率就把它删了;可以补充一些测试用例 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
好的,感谢您的建议 |
||
self._base = base | ||
self._transforms = transforms | ||
|
||
transformed_shape = chain.forward_shape(base.batch_shape + | ||
base.event_shape) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考google style组织包的结构 https://google.github.io/styleguide/pyguide.html#s3.13-imports-formatting