Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ExponentialFamily and Dirichlet probability distribution #38445

Merged
merged 3 commits into from
Dec 30, 2021

Conversation

cxxly
Copy link
Contributor

@cxxly cxxly commented Dec 24, 2021

PR types

New features

PR changes

APIs

Describe

  • extend Distribution baseclass for supporting multivariant distribution and prob method
  • add ExponentialFamily base class and entropy using Bregman divergence
  • add Dirichlet probability distribution, including mean, variance, sample, entropy, prob, log_prob method

Examples

import paddle
# dirichlet distribution
dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))
print(dirichlet.entropy())

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@cxxly cxxly force-pushed the distribution-expfamily branch 2 times, most recently from ece7935 to b9bdcfe Compare December 27, 2021 04:42
@cxxly cxxly changed the title add probability distribution base class Distribution and ExponentialFamily add ExponentialFamily and Dirichlet probability distribution Dec 27, 2021
@cxxly cxxly force-pushed the distribution-expfamily branch from b9bdcfe to 2119cb1 Compare December 27, 2021 09:22
@cxxly cxxly force-pushed the distribution-expfamily branch from 2119cb1 to fd01cc0 Compare December 27, 2021 10:01
@cxxly cxxly force-pushed the distribution-expfamily branch from fd01cc0 to 343bfce Compare December 27, 2021 11:22
@cxxly cxxly force-pushed the distribution-expfamily branch from 343bfce to 4e954e1 Compare December 27, 2021 11:52
@@ -57,6 +81,14 @@ def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError

def prob(self, value):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the difference between 'def prob' and 'def probs', why do we need 'def prob' ?

Copy link
Contributor Author

@cxxly cxxly Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been discussed in design document. Currently, Paddle's pdf and log pdf method named probs and log_prob. It's not consistent, so we add prob method and probs will be deprecated in future.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can probs() be deprecated?

Copy link
Contributor Author

@cxxly cxxly Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, will be deprecated in future.

# inputs={"Alpha": concentration},
# outputs={'Out': out},
# attrs={})
# return out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these annotations are expected?

Copy link

@iclementine iclementine Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why _dirichlet is commented?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pr need other pr dirichlet op #38244 to be merged, or not unittest will failed. _dirichlet function will be uncomment in next pr.

"""Returns event shape of distribution

Returns:
Tensor: event shape

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the return type of the function is a Tensor.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not a Tensor. It's doc mistake, and will be fixed.

rtol=config.RTOL.get(config.DEFAULT_DTYPE),
atol=config.ATOL.get(config.DEFAULT_DTYPE))

def test_entropy_expection(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be fixed in next pr.


natural_parameters = []
for parameter in self._natural_parameters:
parameter = parameter.detach()
Copy link

@iclementine iclementine Dec 28, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

detach is not viable in static mode.

Why it is implemented like this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it's available in static mode.

('test-dirichlet-dist',
paddle.distribution.Dirichlet(paddle.to_tensor(config.xrand())))])
class TestExponentialFamilyException(unittest.TestCase):
def test_entropy_expection(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exception

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will be fixed.

Copy link

@iclementine iclementine left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

TODO: fix them in next PR.

Copy link
Contributor

@jeff41404 jeff41404 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
TODO:Fix API Docs

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG API

@iclementine iclementine merged commit 00cddf0 into PaddlePaddle:develop Dec 30, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants