Skip to content

Commit

Permalink
[Typing][B-29] Add type annotations for `python/paddle/distribution/v…
Browse files Browse the repository at this point in the history
…ariable.py` (#65620)
  • Loading branch information
NKNaN authored Jul 2, 2024
1 parent f6ae216 commit 6d59406
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 18 deletions.
19 changes: 13 additions & 6 deletions python/paddle/distribution/constraint.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,45 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

import paddle

if TYPE_CHECKING:
from paddle import Tensor


class Constraint:
"""Constraint condition for random variable."""

def __call__(self, value):
def __call__(self, value: Tensor) -> Tensor:
raise NotImplementedError


class Real(Constraint):
def __call__(self, value):
def __call__(self, value: Tensor) -> Tensor:
return value == value


class Range(Constraint):
def __init__(self, lower, upper):
def __init__(self, lower: Tensor, upper: Tensor) -> None:
self._lower = lower
self._upper = upper
super().__init__()

def __call__(self, value):
def __call__(self, value: Tensor) -> Tensor:
return self._lower <= value <= self._upper


class Positive(Constraint):
def __call__(self, value):
def __call__(self, value: Tensor) -> Tensor:
return value >= 0.0


class Simplex(Constraint):
def __call__(self, value):
def __call__(self, value: Tensor) -> Tensor:
return paddle.all(value >= 0, axis=-1) and (
(value.sum(-1) - 1).abs() < 1e-6
)
Expand Down
38 changes: 26 additions & 12 deletions python/paddle/distribution/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence

import paddle
from paddle.distribution import constraint

if TYPE_CHECKING:
from paddle import Tensor
from paddle.distribution.constraint import Constraint


class Variable:
"""Random variable of probability distribution.
Expand All @@ -24,32 +32,38 @@ class Variable:
event_rank (int): The rank of event dimensions.
"""

def __init__(self, is_discrete=False, event_rank=0, constraint=None):
def __init__(
self,
is_discrete: bool = False,
event_rank: int = 0,
constraint: Constraint | None = None,
) -> None:
self._is_discrete = is_discrete
self._event_rank = event_rank
self._constraint = constraint

@property
def is_discrete(self):
def is_discrete(self) -> bool:
return self._is_discrete

@property
def event_rank(self):
def event_rank(self) -> int:
return self._event_rank

def constraint(self, value):
def constraint(self, value: Tensor) -> Tensor:
"""Check whether the 'value' meet the constraint conditions of this
random variable."""
assert self._constraint is not None
return self._constraint(value)


class Real(Variable):
def __init__(self, event_rank=0):
def __init__(self, event_rank: int = 0) -> None:
super().__init__(False, event_rank, constraint.real)


class Positive(Variable):
def __init__(self, event_rank=0):
def __init__(self, event_rank: int = 0) -> None:
super().__init__(False, event_rank, constraint.positive)


Expand All @@ -62,14 +76,14 @@ class Independent(Variable):
reinterpreted.
"""

def __init__(self, base, reinterpreted_batch_rank):
def __init__(self, base: Variable, reinterpreted_batch_rank: int) -> None:
self._base = base
self._reinterpreted_batch_rank = reinterpreted_batch_rank
super().__init__(
base.is_discrete, base.event_rank + reinterpreted_batch_rank
)

def constraint(self, value):
def constraint(self, value: Tensor) -> Tensor:
ret = self._base.constraint(value)
if ret.dim() < self._reinterpreted_batch_rank:
raise ValueError(
Expand All @@ -81,22 +95,22 @@ def constraint(self, value):


class Stack(Variable):
def __init__(self, vars, axis=0):
def __init__(self, vars: Sequence[Variable], axis: int = 0) -> None:
self._vars = vars
self._axis = axis

@property
def is_discrete(self):
def is_discrete(self) -> bool:
return any(var.is_discrete for var in self._vars)

@property
def event_rank(self):
def event_rank(self) -> int:
rank = max(var.event_rank for var in self._vars)
if self._axis + rank < 0:
rank += 1
return rank

def constraint(self, value):
def constraint(self, value: Tensor) -> Tensor:
if not (-value.dim() <= self._axis < value.dim()):
raise ValueError(
f'Input dimensions {value.dim()} should be grater than stack '
Expand Down

0 comments on commit 6d59406

Please sign in to comment.