-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add multinomial probability distribution #38820
add multinomial probability distribution #38820
Conversation
Thanks for your contribution! |
Sorry to inform you that 7e37d2b's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
7e37d2b
to
b4d9a37
Compare
b4d9a37
to
3baa872
Compare
3f921de
to
154a928
Compare
154a928
to
4d2b27a
Compare
python/paddle/distribution/kl.py
Outdated
@@ -68,6 +68,11 @@ def kl_divergence(p, q): | |||
def register_kl(cls_p, cls_q): | |||
"""Decorator for register a KL divergence implemention function. | |||
|
|||
when call ``kl_divergence(p, q)`` , will search concrete implemention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意语法。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
python/paddle/distribution/beta.py
Outdated
@@ -37,8 +44,15 @@ class Beta(ExponentialFamily): | |||
|
|||
|
|||
Args: | |||
alpha (float|Tensor): alpha parameter of beta distribution, positive(>0). | |||
alpha (float|Tensor): alpha parameter of beta distribution, | |||
positive(>0), support broadcast semantic. when the parameter is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意英文句法,首字母大写等。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
@@ -200,7 +208,8 @@ def kl_divergence(self, other): | |||
if not in_dygraph_mode(): | |||
check_type(other, 'other', Categorical, 'kl_divergence') | |||
|
|||
logits = self.logits - nn.reduce_max(self.logits, dim=-1, keep_dim=True) | |||
logits = self.logits - \ | |||
nn.reduce_max(self.logits, dim=-1, keep_dim=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些推荐使用 paddle.max ,而不是用 nn 里面的函数。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这是之前同学遗留代码,我先更新下这部分代码,其余遗留在后续规划中,会统一更新
|
||
Args: | ||
concentration (Tensor): concentration parameter of dirichlet | ||
distribution | ||
distribution, also called :math:`\alpha`. when concentration over |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意句子首字母大写。英文句子用英文逗号,且后附空格。
|
||
Args: | ||
total_count (int): Number of trials. | ||
probs (Tensor): Probability of a trail falling into each category. Last |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注意拼写。 trial trail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
samples = self._dist.sample(sample_shape) | ||
sample_mean = samples.mean(axis=0) | ||
np.testing.assert_allclose( | ||
sample_mean, self._dist.mean, atol=0, rtol=0.20) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样子的 tolerance 可能会被 CI 系统认为不合理。是否写明一下这么做的原因。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
('value-int', 10, np.array([0.2, 0.3, 0.5]), np.array([2, 3, 5])), | ||
('value-multi-dim', 10, np.array([[0.3, 0.7], [0.5, 0.5]]), | ||
np.array([[4., 6], [8, 2]])), | ||
# ('value-sum-non-n', 10, np.array([0.5, 0.2, 0.3]), np.array([4,5,2])), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是否添加采多个样的 case. 比如 Batch shape 是 (), 而 sample shape 是 (2,) 类似这样的。
def setUp(self): | ||
self.prog = paddle.static.Program() | ||
self.exe = paddle.static.Executor() | ||
with paddle.static.program_guard(prog): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
最好写两个 program, 虽然这里并不创建参数。
4d2b27a
to
2a43a55
Compare
1547cc0
to
efdea2e
Compare
efdea2e
to
89ce615
Compare
89ce615
to
138fcd0
Compare
359cdd6
to
8654cf4
Compare
python/paddle/distribution/kl.py
Outdated
function registered by ``register_kl``, according to multi-dispatch pattern. | ||
If find the implemention function, it will return the result, or not will | ||
raise ``NotImplementError`` exception. User can register implemention | ||
funciton by the decorator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- implemention functions (plural forum);
- If an implemention function is found;
- ortherwise, it will raise a
NotImplementError
exception; - Users
- functions
python/paddle/distribution/kl.py
Outdated
@@ -167,7 +170,7 @@ def _kl_uniform_uniform(p, q): | |||
|
|||
@register_kl(ExponentialFamily, ExponentialFamily) | |||
def _kl_expfamily_expfamily(p, q): | |||
"""compute kl-divergence using `Bregman divergences` | |||
"""Compute kl-divergence using `Bregman divergences` | |||
https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use correct hyperlink format of rst.
python/paddle/distribution/kl.py
Outdated
@@ -205,5 +208,5 @@ def _kl_expfamily_expfamily(p, q): | |||
|
|||
|
|||
def _sum_rightmost(value, n): | |||
"""sum value along rightmost n dim""" | |||
"""Sum value along rightmost n dim""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elements ...(plural form)
dimensions.
python/paddle/distribution/beta.py
Outdated
@@ -37,8 +44,14 @@ class Beta(ExponentialFamily): | |||
|
|||
|
|||
Args: | |||
alpha (float|Tensor): alpha parameter of beta distribution, positive(>0). | |||
beta (float|Tensor): beta parameter of beta distribution, positive(>0). | |||
alpha (float|Tensor): Alpha parameter. It support broadcast semantic. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It supports
Sementics.
is a tensor
represents
distributions
a
concentration (Tensor): concentration parameter of dirichlet | ||
distribution | ||
concentration (Tensor): "Concentration" parameter of dirichlet | ||
distribution, also called :math:`\alpha`. When It's over one |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When it's.
concentration (Tensor): "Concentration" parameter of dirichlet | ||
distribution, also called :math:`\alpha`. When It's over one | ||
dimension, the last axis is parameter of distribution, | ||
``event_shape=concentration.shape[-1:]`` , other axes is batch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axes other than the last are condsiderd batch dimensions.
Args: | ||
total_count (int): Number of trials. | ||
probs (Tensor): Probability of a trial falling into each category. Last | ||
axis of probs indexes over categories, other axes index over batches. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last axis
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
TODO: refine documentation in next PR!
8654cf4
to
852300d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approve
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
APIs
Describe