-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add ExponentialFamily and Dirichlet probability distribution #38445
add ExponentialFamily and Dirichlet probability distribution #38445
Conversation
Thanks for your contribution! |
ece7935
to
b9bdcfe
Compare
b9bdcfe
to
2119cb1
Compare
2119cb1
to
fd01cc0
Compare
fd01cc0
to
343bfce
Compare
343bfce
to
4e954e1
Compare
@@ -57,6 +81,14 @@ def kl_divergence(self, other): | |||
"""The KL-divergence between self distributions and other.""" | |||
raise NotImplementedError | |||
|
|||
def prob(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the difference between 'def prob' and 'def probs', why do we need 'def prob' ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been discussed in design document. Currently, Paddle's pdf and log pdf method named probs and log_prob. It's not consistent, so we add prob method and probs will be deprecated in future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can probs() be deprecated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, will be deprecated in future.
# inputs={"Alpha": concentration}, | ||
# outputs={'Out': out}, | ||
# attrs={}) | ||
# return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these annotations are expected?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why _dirichlet is commented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pr need other pr dirichlet op #38244 to be merged, or not unittest will failed. _dirichlet function will be uncomment in next pr.
"""Returns event shape of distribution | ||
|
||
Returns: | ||
Tensor: event shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think the return type of the function is a Tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not a Tensor. It's doc mistake, and will be fixed.
rtol=config.RTOL.get(config.DEFAULT_DTYPE), | ||
atol=config.ATOL.get(config.DEFAULT_DTYPE)) | ||
|
||
def test_entropy_expection(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be fixed in next pr.
|
||
natural_parameters = [] | ||
for parameter in self._natural_parameters: | ||
parameter = parameter.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
detach is not viable in static mode.
Why it is implemented like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it's available in static mode.
('test-dirichlet-dist', | ||
paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand())))]) | ||
class TestExponentialFamilyException(unittest.TestCase): | ||
def test_entropy_expection(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exception
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will be fixed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
TODO: fix them in next PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
TODO:Fix API Docs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG API
PR types
New features
PR changes
APIs
Describe
Examples