Skip to content

Commit

Permalink
[JAX] Fix incorrect type annotations.
Browse files Browse the repository at this point in the history
An upcoming change to JAX will teach pytype more accurate types for functions in the jax.numpy module. This reveals a number of type errors in downstream users of JAX. In particular, pytype is able to infer `jax.Array` accurately as a type in many more cases.

PiperOrigin-RevId: 556042684
  • Loading branch information
hawkinsp authored and DistraxDev committed Aug 14, 2023
1 parent 0984e67 commit 7b74cee
Show file tree
Hide file tree
Showing 8 changed files with 22 additions and 12 deletions.
4 changes: 2 additions & 2 deletions distrax/_src/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Bernoulli distribution."""

from typing import Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

import chex
from distrax._src.distributions import distribution
Expand Down Expand Up @@ -44,7 +44,7 @@ class Bernoulli(distribution.Distribution):
def __init__(self,
logits: Optional[Numeric] = None,
probs: Optional[Numeric] = None,
dtype: jnp.dtype = int):
dtype: Union[jnp.dtype, type[Any]] = int):
"""Initializes a Bernoulli distribution.
Args:
Expand Down
4 changes: 2 additions & 2 deletions distrax/_src/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""Categorical distribution."""

from typing import Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

import chex
from distrax._src.distributions import distribution
Expand All @@ -38,7 +38,7 @@ class Categorical(distribution.Distribution):
def __init__(self,
logits: Optional[Array] = None,
probs: Optional[Array] = None,
dtype: jnp.dtype = int):
dtype: Union[jnp.dtype, type[Any]] = int):
"""Initializes a Categorical distribution.
Args:
Expand Down
4 changes: 3 additions & 1 deletion distrax/_src/distributions/epsilon_greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Epsilon-Greedy distributions with respect to a set of preferences."""

from typing import Any, Union

import chex
from distrax._src.distributions import categorical
from distrax._src.distributions import distribution
Expand Down Expand Up @@ -46,7 +48,7 @@ class EpsilonGreedy(categorical.Categorical):
def __init__(self,
preferences: Array,
epsilon: float,
dtype: jnp.dtype = int):
dtype: Union[jnp.dtype, type[Any]] = int):
"""Initializes an EpsilonGreedy distribution.
Args:
Expand Down
6 changes: 5 additions & 1 deletion distrax/_src/distributions/greedy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Greedy distributions with respect to a set of preferences."""

from typing import Any, Union

import chex
from distrax._src.distributions import categorical
from distrax._src.distributions import distribution
Expand All @@ -37,7 +39,9 @@ class Greedy(categorical.Categorical):
all other indices will be assigned a probability of zero.
"""

def __init__(self, preferences: Array, dtype: jnp.dtype = int):
def __init__(
self, preferences: Array, dtype: Union[jnp.dtype, type[Any]] = int
):
"""Initializes a Greedy distribution.
Args:
Expand Down
6 changes: 3 additions & 3 deletions distrax/_src/distributions/multinomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import functools
import operator

from typing import Tuple, Optional, Union
from typing import Any, Tuple, Optional, Union

import chex
from distrax._src.distributions import distribution
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(self,
total_count: Numeric,
logits: Optional[Array] = None,
probs: Optional[Array] = None,
dtype: jnp.dtype = int):
dtype: Union[jnp.dtype, type[Any]] = int):
"""Initializes a Multinomial distribution.
Args:
Expand Down Expand Up @@ -162,7 +162,7 @@ def _sample_n(self, key: PRNGKey, n: int) -> Array:

@staticmethod
def _sample_n_scalar(
key: PRNGKey, total_count: int, n: int, logits: Array,
key: PRNGKey, total_count: Union[int, Array], n: int, logits: Array,
dtype: jnp.dtype) -> Array:
"""Sample method for a Multinomial with integer `total_count`."""

Expand Down
5 changes: 3 additions & 2 deletions distrax/_src/distributions/one_hot_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================
"""OneHotCategorical distribution."""

from typing import Optional, Tuple
from typing import Any, Optional, Tuple, Union

import chex
from distrax._src.distributions import categorical
Expand All @@ -40,7 +40,7 @@ class OneHotCategorical(categorical.Categorical):
def __init__(self,
logits: Optional[Array] = None,
probs: Optional[Array] = None,
dtype: jnp.dtype = int):
dtype: Union[jnp.dtype, type[Any]] = int):
"""Initializes a OneHotCategorical distribution.
Args:
Expand Down Expand Up @@ -80,6 +80,7 @@ def prob(self, value: EventT) -> Array:
def mode(self) -> Array:
"""Calculates the mode."""
preferences = self._probs if self._logits is None else self._logits
assert preferences is not None
greedy_index = jnp.argmax(preferences, axis=-1)
return jax.nn.one_hot(greedy_index, self.num_categories).astype(self._dtype)

Expand Down
4 changes: 3 additions & 1 deletion distrax/_src/distributions/softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""Softmax distribution."""

from typing import Any, Union

import chex
from distrax._src.distributions import categorical
from distrax._src.distributions import distribution
Expand All @@ -35,7 +37,7 @@ class Softmax(categorical.Categorical):
def __init__(self,
logits: Array,
temperature: float = 1.,
dtype: jnp.dtype = int):
dtype: Union[jnp.dtype, type[Any]] = int):
"""Initializes a Softmax distribution.
Args:
Expand Down
1 change: 1 addition & 0 deletions distrax/_src/distributions/transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def dtype(self) -> jnp.dtype:
"""See `Distribution.dtype`."""
if self._dtype is None:
self._infer_shapes_and_dtype()
assert self._dtype is not None # By _infer_shapes_and_dtype()
return self._dtype

@property
Expand Down

0 comments on commit 7b74cee

Please sign in to comment.