diff --git a/distrax/_src/distributions/bernoulli.py b/distrax/_src/distributions/bernoulli.py index 662e764..6806346 100644 --- a/distrax/_src/distributions/bernoulli.py +++ b/distrax/_src/distributions/bernoulli.py @@ -14,7 +14,7 @@ # ============================================================================== """Bernoulli distribution.""" -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import chex from distrax._src.distributions import distribution @@ -44,7 +44,7 @@ class Bernoulli(distribution.Distribution): def __init__(self, logits: Optional[Numeric] = None, probs: Optional[Numeric] = None, - dtype: jnp.dtype = int): + dtype: jnp.dtype | type[Any] = int): """Initializes a Bernoulli distribution. Args: diff --git a/distrax/_src/distributions/categorical.py b/distrax/_src/distributions/categorical.py index 2ed8be4..55416b2 100644 --- a/distrax/_src/distributions/categorical.py +++ b/distrax/_src/distributions/categorical.py @@ -14,7 +14,7 @@ # ============================================================================== """Categorical distribution.""" -from typing import Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import chex from distrax._src.distributions import distribution @@ -38,7 +38,7 @@ class Categorical(distribution.Distribution): def __init__(self, logits: Optional[Array] = None, probs: Optional[Array] = None, - dtype: jnp.dtype = int): + dtype: jnp.dtype | type[Any] = int): """Initializes a Categorical distribution. Args: diff --git a/distrax/_src/distributions/epsilon_greedy.py b/distrax/_src/distributions/epsilon_greedy.py index 14b077b..d738bac 100644 --- a/distrax/_src/distributions/epsilon_greedy.py +++ b/distrax/_src/distributions/epsilon_greedy.py @@ -14,6 +14,8 @@ # ============================================================================== """Epsilon-Greedy distributions with respect to a set of preferences.""" +from typing import Any + import chex from distrax._src.distributions import categorical from distrax._src.distributions import distribution @@ -46,7 +48,7 @@ class EpsilonGreedy(categorical.Categorical): def __init__(self, preferences: Array, epsilon: float, - dtype: jnp.dtype = int): + dtype: jnp.dtype | type[Any] = int): """Initializes an EpsilonGreedy distribution. Args: diff --git a/distrax/_src/distributions/greedy.py b/distrax/_src/distributions/greedy.py index e5a8873..0bff51d 100644 --- a/distrax/_src/distributions/greedy.py +++ b/distrax/_src/distributions/greedy.py @@ -14,6 +14,8 @@ # ============================================================================== """Greedy distributions with respect to a set of preferences.""" +from typing import Any + import chex from distrax._src.distributions import categorical from distrax._src.distributions import distribution @@ -37,7 +39,7 @@ class Greedy(categorical.Categorical): all other indices will be assigned a probability of zero. """ - def __init__(self, preferences: Array, dtype: jnp.dtype = int): + def __init__(self, preferences: Array, dtype: jnp.dtype | type[Any] = int): """Initializes a Greedy distribution. Args: diff --git a/distrax/_src/distributions/multinomial.py b/distrax/_src/distributions/multinomial.py index 3df2908..08ea0c6 100644 --- a/distrax/_src/distributions/multinomial.py +++ b/distrax/_src/distributions/multinomial.py @@ -17,7 +17,7 @@ import functools import operator -from typing import Tuple, Optional, Union +from typing import Any, Tuple, Optional, Union import chex from distrax._src.distributions import distribution @@ -46,7 +46,7 @@ def __init__(self, total_count: Numeric, logits: Optional[Array] = None, probs: Optional[Array] = None, - dtype: jnp.dtype = int): + dtype: jnp.dtype | type[Any] = int): """Initializes a Multinomial distribution. Args: @@ -162,7 +162,7 @@ def _sample_n(self, key: PRNGKey, n: int) -> Array: @staticmethod def _sample_n_scalar( - key: PRNGKey, total_count: int, n: int, logits: Array, + key: PRNGKey, total_count: int | Array, n: int, logits: Array, dtype: jnp.dtype) -> Array: """Sample method for a Multinomial with integer `total_count`.""" diff --git a/distrax/_src/distributions/one_hot_categorical.py b/distrax/_src/distributions/one_hot_categorical.py index 75004e5..f30bcf3 100644 --- a/distrax/_src/distributions/one_hot_categorical.py +++ b/distrax/_src/distributions/one_hot_categorical.py @@ -14,7 +14,7 @@ # ============================================================================== """OneHotCategorical distribution.""" -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import chex from distrax._src.distributions import categorical @@ -40,7 +40,7 @@ class OneHotCategorical(categorical.Categorical): def __init__(self, logits: Optional[Array] = None, probs: Optional[Array] = None, - dtype: jnp.dtype = int): + dtype: jnp.dtype | type[Any] = int): """Initializes a OneHotCategorical distribution. Args: @@ -80,6 +80,7 @@ def prob(self, value: EventT) -> Array: def mode(self) -> Array: """Calculates the mode.""" preferences = self._probs if self._logits is None else self._logits + assert preferences is not None greedy_index = jnp.argmax(preferences, axis=-1) return jax.nn.one_hot(greedy_index, self.num_categories).astype(self._dtype) diff --git a/distrax/_src/distributions/softmax.py b/distrax/_src/distributions/softmax.py index 86ded3f..5fc3b6b 100644 --- a/distrax/_src/distributions/softmax.py +++ b/distrax/_src/distributions/softmax.py @@ -14,6 +14,8 @@ # ============================================================================== """Softmax distribution.""" +from typing import Any + import chex from distrax._src.distributions import categorical from distrax._src.distributions import distribution @@ -35,7 +37,7 @@ class Softmax(categorical.Categorical): def __init__(self, logits: Array, temperature: float = 1., - dtype: jnp.dtype = int): + dtype: jnp.dtype | type[Any] = int): """Initializes a Softmax distribution. Args: diff --git a/distrax/_src/distributions/transformed.py b/distrax/_src/distributions/transformed.py index f066a28..538b59a 100644 --- a/distrax/_src/distributions/transformed.py +++ b/distrax/_src/distributions/transformed.py @@ -139,6 +139,7 @@ def dtype(self) -> jnp.dtype: """See `Distribution.dtype`.""" if self._dtype is None: self._infer_shapes_and_dtype() + assert self._dtype is not None # By _infer_shapes_and_dtype() return self._dtype @property