Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon No.10】新增 LogNormal API #46426

Merged
merged 44 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
ecce733
add LogNormal API
MayYouBeProsperous Sep 22, 2022
7ee7b73
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
MayYouBeProsperous Sep 22, 2022
b3d1757
fix bug
MayYouBeProsperous Sep 22, 2022
211f42c
fix bug
MayYouBeProsperous Sep 22, 2022
6a95d56
fix bug
MayYouBeProsperous Sep 22, 2022
acadbb8
fix bug
MayYouBeProsperous Sep 22, 2022
7af5a39
fix bug
MayYouBeProsperous Sep 22, 2022
5c92dce
fix bug
MayYouBeProsperous Sep 22, 2022
6ea8cf5
fix bug
MayYouBeProsperous Sep 22, 2022
6f98baf
fix bug
MayYouBeProsperous Sep 23, 2022
124d8f9
fix bug
MayYouBeProsperous Sep 23, 2022
0d79cc2
fix bug
MayYouBeProsperous Sep 23, 2022
b02255b
fix bug
MayYouBeProsperous Sep 23, 2022
83e6d76
add comment
MayYouBeProsperous Sep 23, 2022
57c06a8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
MayYouBeProsperous Sep 23, 2022
fb23b19
fix bug
MayYouBeProsperous Sep 23, 2022
1fa84b0
fix docs
MayYouBeProsperous Sep 23, 2022
558e784
fix bug
MayYouBeProsperous Sep 23, 2022
63a3e61
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
MayYouBeProsperous Sep 23, 2022
e32f571
fix bug
MayYouBeProsperous Sep 23, 2022
30228fb
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
MayYouBeProsperous Sep 23, 2022
9224e4a
fix bug
MayYouBeProsperous Sep 23, 2022
37eeb2d
add test
MayYouBeProsperous Sep 23, 2022
f700bd1
add test
MayYouBeProsperous Sep 24, 2022
acb6d4e
change the args type of Normal sample
MayYouBeProsperous Sep 24, 2022
920e4fc
fix bug
MayYouBeProsperous Sep 24, 2022
a463709
fix bug
MayYouBeProsperous Sep 24, 2022
d88eb59
fix bug
MayYouBeProsperous Sep 24, 2022
062c64f
fix bug
MayYouBeProsperous Sep 24, 2022
e7077b0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
MayYouBeProsperous Sep 24, 2022
a649645
add test
MayYouBeProsperous Sep 24, 2022
5525158
add test
MayYouBeProsperous Sep 24, 2022
19776ee
format
MayYouBeProsperous Sep 24, 2022
6f176c6
add comment
MayYouBeProsperous Sep 24, 2022
0f8d1e6
add comment
MayYouBeProsperous Sep 24, 2022
8be803e
add comment
MayYouBeProsperous Sep 26, 2022
73e2b4f
add comment
MayYouBeProsperous Sep 26, 2022
0f94b9a
format code
MayYouBeProsperous Sep 28, 2022
7713ad8
fix bug
MayYouBeProsperous Sep 28, 2022
16765f7
fix bug
MayYouBeProsperous Sep 28, 2022
f1bf711
fix bug
MayYouBeProsperous Sep 28, 2022
9f3a98b
add comment
MayYouBeProsperous Sep 29, 2022
4745c39
remove name parameter for LogNormal
MayYouBeProsperous Oct 9, 2022
c97a9fe
organize imports
MayYouBeProsperous Oct 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/paddle/distribution/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.independent import Independent
from paddle.distribution.kl import kl_divergence, register_kl
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.multinomial import Multinomial
from paddle.distribution.normal import Normal
from paddle.distribution.transform import * # noqa: F403
Expand All @@ -31,7 +32,7 @@
__all__ = [ # noqa
'Beta', 'Categorical', 'Dirichlet', 'Distribution', 'ExponentialFamily',
'Multinomial', 'Normal', 'Uniform', 'kl_divergence', 'register_kl',
'Independent', 'TransformedDistribution', 'Laplace'
'Independent', 'TransformedDistribution', 'Laplace', 'LogNormal'
]

__all__.extend(transform.__all__)
8 changes: 7 additions & 1 deletion python/paddle/distribution/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.distribution.distribution import Distribution
from paddle.distribution.exponential_family import ExponentialFamily
from paddle.distribution.normal import Normal
from paddle.distribution.lognormal import LogNormal
from paddle.distribution.uniform import Uniform
from paddle.distribution.laplace import Laplace
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
Expand Down Expand Up @@ -96,7 +97,7 @@ def decorator(f):


def _dispatch(cls_p, cls_q):
"""Multiple dispatch into concrete implement function"""
"""Multiple dispatch into concrete implement function."""

# find all matched super class pair of p and q
matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE
Expand Down Expand Up @@ -212,5 +213,10 @@ def _kl_expfamily_expfamily(p, q):
return kl


@register_kl(LogNormal, LogNormal)
def _kl_lognormal_lognormal(p, q):
return p._base.kl_divergence(q._base)


def _sum_rightmost(value, n):
return value.sum(list(range(-n, 0))) if n > 0 else value
175 changes: 175 additions & 0 deletions python/paddle/distribution/lognormal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.distribution.normal import Normal
from paddle.distribution.transform import ExpTransform
from paddle.distribution.transformed_distribution import TransformedDistribution


class LogNormal(TransformedDistribution):
r"""The LogNormal distribution with location `loc` and `scale` parameters.

.. math::

X \sim Normal(\mu, \sigma)

Y = exp(X) \sim LogNormal(\mu, \sigma)


Due to LogNormal distribution is based on the transformation of Normal distribution, we call that :math:`Normal(\mu, \sigma)` is the underlying distribution of :math:`LogNormal(\mu, \sigma)`

Mathematical details

The probability density function (pdf) is

.. math::
pdf(x; \mu, \sigma) = \frac{1}{\sigma x \sqrt{2\pi}}e^{(-\frac{(ln(x) - \mu)^2}{2\sigma^2})}

In the above equation:

* :math:`loc = \mu`: is the means of the underlying Normal distribution.
* :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LogNormal

Copy link
Contributor Author

@MayYouBeProsperous MayYouBeProsperous Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是指LogNormal的基础分布


Args:
loc(int|float|list|tuple|numpy.ndarray|Tensor): The means of the underlying Normal distribution.
scale(int|float|list|tuple|numpy.ndarray|Tensor): The stddevs of the underlying Normal distribution.
name(str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要加name参数吗?之前好像说是不用的?@cxxly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不用,name参数是为了追踪静态图下运行过程,每个方法也要处理,是一个通用逻辑,后面我会统一处理。此处可以先删除

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

嗯嗯好的,那辛苦把这个参数去掉吧 @MayYouBeProsperous

Copy link
Contributor Author

@MayYouBeProsperous MayYouBeProsperous Oct 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ligoml @cxxly 已经删除~ 麻烦再次review


Examples:
.. code-block:: python

import paddle
from paddle.distribution import LogNormal

# Define a single scalar LogNormal distribution.
dist = LogNormal(loc=0., scale=3.)
# Define a batch of two scalar valued LogNormals.
# The underlying Normal of first has mean 1 and standard deviation 11, the underlying Normal of second 2 and 22.
dist = LogNormal(loc=[1., 2.], scale=[11., 22.])
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample((3, ))

# Define a batch of two scalar valued LogNormals.
# Their underlying Normal have mean 1, but different standard deviations.
dist = LogNormal(loc=1., scale=[11., 22.])

# Complete example
value_tensor = paddle.to_tensor([0.8], dtype="float32")

lognormal_a = LogNormal([0.], [1.])
lognormal_b = LogNormal([0.5], [2.])
sample = lognormal_a.sample((2, ))
# a random tensor created by lognormal distribution with shape: [2, 1]
entropy = lognormal_a.entropy()
# [1.4189385] with shape: [1]
lp = lognormal_a.log_prob(value_tensor)
# [-0.72069150] with shape: [1]
p = lognormal_a.probs(value_tensor)
# [0.48641577] with shape: [1]
kl = lognormal_a.kl_divergence(lognormal_b)
# [0.34939718] with shape: [1]
"""

def __init__(self, loc, scale, name=None):
self._base = Normal(loc=loc, scale=scale, name=name)
self.loc = self._base.loc
self.scale = self._base.scale
super(LogNormal, self).__init__(self._base, [ExpTransform()])

@property
def mean(self):
"""Mean of lognormal distribuion.

Returns:
Tensor: mean value.
"""
return paddle.exp(self._base.mean + self._base.variance / 2)

@property
def variance(self):
"""Variance of lognormal distribution.

Returns:
Tensor: variance value.
"""
return (paddle.expm1(self._base.variance) *
paddle.exp(2 * self._base.mean + self._base.variance))

def entropy(self):
r"""Shannon entropy in nats.

The entropy is

.. math::

entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2) + \mu

In the above equation:

* :math:`loc = \mu`: is the mean of the underlying Normal distribution.
* :math:`scale = \sigma`: is the stddevs of the underlying Normal distribution.

Returns:
Tensor: Shannon entropy of lognormal distribution.

"""
return self._base.entropy() + self._base.mean

def probs(self, value):
"""Probability density/mass function.

Args:
value (Tensor): The input tensor.

Returns:
Tensor: probability.The data type is same with :attr:`value` .

"""
return paddle.exp(self.log_prob(value))

def kl_divergence(self, other):
r"""The KL-divergence between two lognormal distributions.

The probability density function (pdf) is

.. math::

KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = 0.5 (ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio})

.. math::

ratio = \frac{\sigma_0}{\sigma_1}

.. math::

diff = \mu_1 - \mu_0

In the above equation:

* :math:`loc = \mu_0`: is the means of current underlying Normal distribution.
* :math:`scale = \sigma_0`: is the stddevs of current underlying Normal distribution.
* :math:`loc = \mu_1`: is the means of other underlying Normal distribution.
* :math:`scale = \sigma_1`: is the stddevs of other underlying Normal distribution.
* :math:`ratio`: is the ratio of scales.
* :math:`diff`: is the difference between means.

Args:
other (LogNormal): instance of LogNormal.

Returns:
Tensor: kl-divergence between two lognormal distributions.

"""
return self._base.kl_divergence(other._base)
58 changes: 51 additions & 7 deletions python/paddle/distribution/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import warnings

import numpy as np
import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.distribution import distribution
from paddle.fluid import core
Expand All @@ -25,6 +26,10 @@
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div,
elementwise_mul, elementwise_sub, nn, ops,
tensor)
try:
from collections.abc import Iterable
except:
from collections import Iterable


class Normal(distribution.Distribution):
Expand Down Expand Up @@ -128,21 +133,42 @@ def __init__(self, loc, scale, name=None):
self.scale = tensor.cast(self.scale, dtype=self.dtype)
super(Normal, self).__init__(self.loc.shape)

def sample(self, shape, seed=0):
@property
def mean(self):
"""Mean of multinomial distribuion.

Returns:
Tensor: mean value.
"""
return self.loc

@property
def variance(self):
"""Variance of lognormal distribution.

Returns:
Tensor: variance value.
"""
return self.scale.pow(2)

def sample(self, shape=(), seed=0):
"""Generate samples of the specified shape.

Args:
shape (list): 1D `int32`. Shape of the generated samples.
shape (Sequence[int], optional): Shape of the generated samples.
seed (int): Python integer number.

Returns:
Tensor, A tensor with prepended dimensions shape.The data type is float32.

"""
if not isinstance(shape, Iterable):
raise TypeError('sample shape must be Iterable object.')

if not _non_static_mode():
check_type(shape, 'shape', (list), 'sample')
check_type(seed, 'seed', (int), 'sample')

shape = list(shape)
batch_shape = list((self.loc + self.scale).shape)
name = self.name + '_sample'

Expand All @@ -162,14 +188,32 @@ def sample(self, shape, seed=0):
return output
else:
output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \
(tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = nn.gaussian_random(
output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * (
tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float:
return nn.reshape(output, shape, name=name)
else:
return output

def rsample(self, shape=()):
"""Generate reparameterized samples of the specified shape.

Args:
shape (Sequence[int], optional): Shape of the generated samples.

Returns:
Tensor: A tensor with prepended dimensions shape.The data type is float32.

"""
if not isinstance(shape, Iterable):
raise TypeError('sample shape must be Iterable object.')

shape = self._extend_shape(tuple(shape))
eps = paddle.normal(shape=shape)
return (self.loc + eps * self.scale)

Copy link
Contributor

@cxxly cxxly Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rsample不能直接调用sample,sample不支持反向,需要通过标准正态分布重参数化实现。如果目前Paddle现有功能确实无法支持实现Normal sample的重参数化,可以暂时raise NotImplementedError

Copy link
Contributor Author

@MayYouBeProsperous MayYouBeProsperous Sep 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新的commit补充了rsample的实现,但是我不确定实现是否正确,想请教一下,应该用什么方法验证rsample支持反向呢?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

param = paddle.rand(...)
param.stop_gradient = False
d = paddle.distribution.xxx(param = param)
y = d.rsample(...)
paddle.grad(y, param)

可以新提交个PR进行验证,补充相应测试用例

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,谢谢您

def entropy(self):
r"""Shannon entropy in nats.

Expand Down Expand Up @@ -204,7 +248,7 @@ def log_prob(self, value):
value (Tensor): The input tensor.

Returns:
Tensor: log probability.The data type is same with value.
Tensor: log probability.The data type is same with :attr:`value` .

"""
name = self.name + '_log_prob'
Expand All @@ -224,7 +268,7 @@ def probs(self, value):
value (Tensor): The input tensor.

Returns:
Tensor, probability. The data type is same with value.
Tensor, probability. The data type is same with :attr:`value` .

"""
name = self.name + '_probs'
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/distribution/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __call__(self, input):
input, [self])
if isinstance(input, Transform):
return ChainTransform([self, input])
return self.forward(x)
return self.forward(input)

def forward(self, x):
"""Forward transformation with mapping :math:`y = f(x)`.
Expand Down Expand Up @@ -285,7 +285,7 @@ def _call_forward_log_det_jacobian(self, x):
if hasattr(self, '_forward_log_det_jacobian'):
return self._forward_log_det_jacobian(x)
if hasattr(self, '_inverse_log_det_jacobian'):
return -self._inverse_log_det_jacobian(self.forward(y))
return -self._inverse_log_det_jacobian(self.forward(x))
raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian'
'is implemented. One of them is required.')
Expand Down Expand Up @@ -1133,8 +1133,8 @@ def _forward(self, x):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \
F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1)
return F.pad(z, [0] * 2 * (len(x.shape) - 1) + [0, 1], value=1) * \
F.pad(z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1)

def _inverse(self, y):
y_crop = y[..., :-1]
Expand Down
Loading