-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon No.8】 add gumbel distribution api #46255
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
✅ This PR's description meets the template requirements! |
@cxxly 麻烦您 review 一下。 |
处理CI失败问题,原则所有CI通过,才允许request review |
from paddle.distribution.uniform import Uniform | ||
from paddle.distribution.transformed_distribution import TransformedDistribution | ||
from paddle.distribution.transform import AffineTransform, ExpTransform | ||
from paddle.fluid import framework as framework |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考google style组织包的结构 https://google.github.io/styleguide/pyguide.html#s3.13-imports-formatting
Args: | ||
loc(int|float): The mean of normal distribution.The data type is int, float. | ||
scale(int|float): The std of normal distribution.The data type is int, float. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
int|float|Tensor,Tensor为必须支持数据类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Args: | ||
loc(int|float): The mean of normal distribution.The data type is int, float. | ||
scale(int|float): The std of normal distribution.The data type is int, float. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The location parameter of gumbel distribution .....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
from paddle.distribution import Gumbel | ||
|
||
# Define a single scalar Gumbel distribution. | ||
dist = Gumbel(loc=0., scale=1.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
示例代码中增加一些方法调用示例
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
f"Expected type of scale is Real|Variable, but got {type(scale)}" | ||
) | ||
self.loc, self.scale = paddle.broadcast_tensors([loc, scale]) | ||
finfo = np.finfo(type(self.loc)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果输入类型是 Real, 此处broadcast会报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
f"Expected type of scale is Real|Variable, but got {type(scale)}" | ||
) | ||
self.loc, self.scale = paddle.broadcast_tensors([loc, scale]) | ||
finfo = np.finfo(type(self.loc)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果 loc为Variable,此处np.finfo(type(self.loc))会报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
Tensor: The variance value. | ||
|
||
""" | ||
return math.pow(self.scale, 2) * math.pow(math.pi, 2) / 6 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
使用Paddle API运算,math.pow不支持Tensor类型
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
"""Mean of distribution | ||
|
||
The variance is | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mean is ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
Tensor: std value. | ||
|
||
""" | ||
return math.sqrt(self.variance) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
|
||
""" | ||
y = (self.loc - value) / self.scale | ||
return math.exp(y - math.exp(y)) / self.scale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
Tensor: log probability.The data type is same with value. | ||
|
||
""" | ||
return math.log(self.prob(value)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
python/paddle/distribution/gumbel.py
Outdated
Tensor: Shannon entropy of gumbel distribution. | ||
|
||
""" | ||
return math.log(self.scale) + 1 + np.euler_gamma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
非常感谢您的审查意见!我们会逐一修改完善。
增加动态图/静态图下相关测试用例 |
@cxxly 您好,针对上次您提出的review意见我们已经修改,劳烦再次审查一下。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
python/paddle/distribution/gumbel.py
Outdated
from paddle.distribution.gumbel import Gumbel | ||
|
||
# Gumbel distributed with loc=0, scale=1 | ||
dist = Gumbel(paddle.full([0.0]), paddle.full([1.0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.full([0.0])
will report error? use paddle.to_tensor(0.0)
or paddle.full([1], 0.0)
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle.full([1.0])
is also needed to correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
python/paddle/distribution/gumbel.py
Outdated
dist = Gumbel(paddle.full([0.0]), paddle.full([1.0])) | ||
dist.sample() | ||
# Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=True, [4.14814520]) | ||
value = paddle.full([0.5]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use paddle.to_tensor(0.5)
or paddle.full([1], 0.5)
or instead?
""" | ||
return paddle.log(self.prob(value)) | ||
|
||
def cdf(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in order to have rfc same with code, need to add this functioncdf
in rfc API 实现方案
According to your comments, we have fixed parameters and add cdf function in rfc(PaddlePaddle/community#298). Would you please review again ?@jeff41404 .Thanks. |
|
Sorry about that mistake. We have fixed that. Would you please review again ?@jeff41404 .Thanks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for docs
copyright随后再改吧
python/paddle/distribution/gumbel.py
Outdated
@@ -0,0 +1,242 @@ | |||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | |
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我们随后会跟进修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Ligoml 您好,我们对copyright也进行了修改,麻烦您再review一下,非常感谢。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe
新增gumbel分布API
设计文档:PaddlePaddle/community#254
中文api文档:PaddlePaddle/docs#5290