Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-20] Add type annotations for python/paddle/distribution/lognormal.py #65843

Merged
merged 4 commits into from
Jul 12, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions python/paddle/distribution/lognormal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Sequence, Union

import paddle
from paddle.distribution.normal import Normal
from paddle.distribution.transform import ExpTransform
from paddle.distribution.transformed_distribution import TransformedDistribution

if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
from typing_extensions import TypeAlias

from paddle import Tensor
from paddle._typing import NestedSequence

_LognormalBoundary: TypeAlias = Union[
float,
Sequence[float],
NestedSequence[float],
npt.NDArray[Union[np.float32, np.float64]],
Tensor,
]


class LogNormal(TransformedDistribution):
r"""The LogNormal distribution with location `loc` and `scale` parameters.
Expand Down Expand Up @@ -88,15 +107,19 @@ class LogNormal(TransformedDistribution):
Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=True,
[0.34939718])
"""
loc: Tensor
scale: Tensor

def __init__(self, loc, scale):
def __init__(
self, loc: _LognormalBoundary, scale: _LognormalBoundary
) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是缺了类属性 loc scale

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抱歉,已加

self._base = Normal(loc=loc, scale=scale)
self.loc = self._base.loc
self.scale = self._base.scale
super().__init__(self._base, [ExpTransform()])

@property
def mean(self):
def mean(self) -> Tensor:
"""Mean of lognormal distribution.

Returns:
Expand All @@ -105,7 +128,7 @@ def mean(self):
return paddle.exp(self._base.mean + self._base.variance / 2)

@property
def variance(self):
def variance(self) -> Tensor:
"""Variance of lognormal distribution.

Returns:
Expand All @@ -115,7 +138,7 @@ def variance(self):
2 * self._base.mean + self._base.variance
)

def entropy(self):
def entropy(self) -> Tensor:
r"""Shannon entropy in nats.

The entropy is
Expand All @@ -135,7 +158,7 @@ def entropy(self):
"""
return self._base.entropy() + self._base.mean

def probs(self, value):
def probs(self, value: Tensor) -> Tensor:
"""Probability density/mass function.

Args:
Expand All @@ -147,7 +170,7 @@ def probs(self, value):
"""
return paddle.exp(self.log_prob(value))

def kl_divergence(self, other):
def kl_divergence(self, other: LogNormal) -> Tensor:
r"""The KL-divergence between two lognormal distributions.

The probability density function (pdf) is
Expand Down