Skip to content

Commit

Permalink
add ExponentialFamily and Dirichlet probability distribution (#38445)
Browse files Browse the repository at this point in the history
* extend Distribution baseclass for supporting multivariant distribution and prob method

* add ExponentialFamily base class and entropy using Bregman divergence

* add dirichlet probability distribution
  • Loading branch information
cxxly authored Dec 30, 2021
1 parent c5bf09b commit 00cddf0
Show file tree
Hide file tree
Showing 11 changed files with 822 additions and 6 deletions.
16 changes: 12 additions & 4 deletions python/paddle/distribution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential_family import ExponentialFamily
from .normal import Normal
from .uniform import Uniform

__all__ = ['Categorical', 'Distribution', 'Normal', 'Uniform']
__all__ = [ #noqa
'Categorical',
'Distribution',
'Normal', 'Uniform',
'ExponentialFamily',
'Dirichlet'
]
162 changes: 162 additions & 0 deletions python/paddle/distribution/dirichlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from ..fluid.data_feeder import check_variable_and_dtype
from ..fluid.framework import in_dygraph_mode
from ..fluid.layer_helper import LayerHelper
from .exponential_family import ExponentialFamily


class Dirichlet(ExponentialFamily):
r"""
Dirichlet distribution with parameter concentration
The Dirichlet distribution is defined over the `(k-1)-simplex` using a
positive, lenght-k vector concentration(`k > 1`).
The Dirichlet is identically the Beta distribution when `k = 2`.
The probability density function (pdf) is
.. math::
f(x_1,...,x_k; \alpha_1,...,\alpha_k) = \frac{1}{B(\alpha)} \prod_{i=1}^{k}x_i^{\alpha_i-1}
The normalizing constant is the multivariate beta function.
Args:
concentration (Tensor): concentration parameter of dirichlet
distribution
Examples:
.. code-block:: python
import paddle
dirichlet = paddle.distribution.Dirichlet(paddle.to_tensor([1., 2., 3.]))
print(dirichlet.entropy())
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [-1.24434423])
print(dirichlet.prob(paddle.to_tensor([.3, .5, .6])))
# Tensor(shape=[1], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [10.80000114])
"""

def __init__(self, concentration):
if concentration.dim() < 1:
raise ValueError(
"`concentration` parameter must be at least one dimensional")

self.concentration = concentration
super(Dirichlet, self).__init__(concentration.shape[:-1],
concentration.shape[-1:])

@property
def mean(self):
"""mean of Dirichelt distribution.
Returns:
mean value of distribution.
"""
return self.concentration / self.concentration.sum(-1, keepdim=True)

@property
def variance(self):
"""variance of Dirichlet distribution.
Returns:
variance value of distribution.
"""
concentration0 = self.concentration.sum(-1, keepdim=True)
return (self.concentration * (concentration0 - self.concentration)) / (
concentration0.pow(2) * (concentration0 + 1))

def sample(self, shape=()):
"""sample from dirichlet distribution.
Args:
shape (Tensor, optional): sample shape. Defaults to empty tuple.
"""
shape = shape if isinstance(shape, tuple) else tuple(shape)
return _dirichlet(self.concentration.expand(self._extend_shape(shape)))

def prob(self, value):
"""Probability density function(pdf) evaluated at value.
Args:
value (Tensor): value to be evaluated.
Returns:
pdf evaluated at value.
"""
return paddle.exp(self.log_prob(value))

def log_prob(self, value):
"""log of probability densitiy function.
Args:
value (Tensor): value to be evaluated.
"""
return ((paddle.log(value) * (self.concentration - 1.0)
).sum(-1) + paddle.lgamma(self.concentration.sum(-1)) -
paddle.lgamma(self.concentration).sum(-1))

def entropy(self):
"""entropy of Dirichlet distribution.
Returns:
entropy of distribution.
"""
concentration0 = self.concentration.sum(-1)
k = self.concentration.shape[-1]
return (paddle.lgamma(self.concentration).sum(-1) -
paddle.lgamma(concentration0) -
(k - concentration0) * paddle.digamma(concentration0) - (
(self.concentration - 1.0
) * paddle.digamma(self.concentration)).sum(-1))

@property
def _natural_parameters(self):
return (self.concentration, )

def _log_normalizer(self, x):
return x.lgamma().sum(-1) - paddle.lgamma(x.sum(-1))


def _dirichlet(concentration, name=None):
raise NotImplementedError


# op_type = 'dirichlet'

# check_variable_and_dtype(concentration, 'concentration',
# ['float32', 'float64'], op_type)

# if in_dygraph_mode():
# return paddle._C_ops.dirichlet(concentration)

# else:
# helper = LayerHelper(op_type, **locals())
# out = helper.create_variable_for_type_inference(
# dtype=concentration.dtype)
# helper.append_op(
# type=op_type,
# inputs={"Alpha": concentration},
# outputs={'Out': out},
# attrs={})
# return out
47 changes: 45 additions & 2 deletions python/paddle/distribution/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,34 @@ class Distribution(object):
implemented in specific distributions.
"""

def __init__(self):
def __init__(self, batch_shape=(), event_shape=()):

self._batch_shape = batch_shape if isinstance(
batch_shape, tuple) else tuple(batch_shape)
self._event_shape = event_shape if isinstance(
event_shape, tuple) else tuple(event_shape)

super(Distribution, self).__init__()

def sample(self):
@property
def batch_shape(self):
"""Returns batch shape of distribution
Returns:
Tensor: batch shape
"""
return self._batch_shape

@property
def event_shape(self):
"""Returns event shape of distribution
Returns:
Tensor: event shape
"""
return self._event_shape

def sample(self, shape=()):
"""Sampling from the distribution."""
raise NotImplementedError

Expand All @@ -57,6 +81,14 @@ def kl_divergence(self, other):
"""The KL-divergence between self distributions and other."""
raise NotImplementedError

def prob(self, value):
"""Probability density/mass function evaluated at value.
Args:
value (Tensor): value which will be evaluated
"""
raise NotImplementedError

def log_prob(self, value):
"""Log probability density/mass function."""
raise NotImplementedError
Expand All @@ -65,6 +97,17 @@ def probs(self, value):
"""Probability density/mass function."""
raise NotImplementedError

def _extend_shape(self, sample_shape):
"""compute shape of the sample
Args:
sample_shape (Tensor): sample shape
Returns:
Tensor: generated sample data shape
"""
return sample_shape + self._batch_shape + self._event_shape

def _validate_args(self, *args):
"""
Argument validation for distribution args
Expand Down
60 changes: 60 additions & 0 deletions python/paddle/distribution/exponential_family.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from ..fluid.framework import in_dygraph_mode
from .distribution import Distribution


class ExponentialFamily(Distribution):
""" Base class for exponential family distribution.
"""

@property
def _natural_parameters(self):
raise NotImplementedError

def _log_normalizer(self):
raise NotImplementedError

@property
def _mean_carrier_measure(self):
raise NotImplementedError

def entropy(self):
"""caculate entropy use `bregman divergence`
https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf
"""
entropy_value = -self._mean_carrier_measure

natural_parameters = []
for parameter in self._natural_parameters:
parameter = parameter.detach()
parameter.stop_gradient = False
natural_parameters.append(parameter)

log_norm = self._log_normalizer(*natural_parameters)

if in_dygraph_mode():
grads = paddle.grad(
log_norm.sum(), natural_parameters, create_graph=True)
else:
grads = paddle.static.gradients(log_norm.sum(), natural_parameters)

entropy_value += log_norm
for p, g in zip(natural_parameters, grads):
entropy_value -= p * g

return entropy_value
Loading

0 comments on commit 00cddf0

Please sign in to comment.