Skip to content

Commit

Permalink
MAINT Clean up leftover Array = Any aliases in jax/_src/**.py
Browse files Browse the repository at this point in the history
This seemingly minor change uncovered a few inconsistencies in existing code.
I silenced the new type errors via ``type: ignore` and added a TODO to
investigate later.
  • Loading branch information
superbobry committed Sep 25, 2023
1 parent c478282 commit 737e6a2
Show file tree
Hide file tree
Showing 11 changed files with 45 additions and 49 deletions.
11 changes: 7 additions & 4 deletions jax/_src/interpreters/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import (tree_unflatten, tree_flatten,
register_pytree_node)
from jax._src.typing import Array
from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list,
canonicalize_axis, moveaxis, as_hashable_function,
curry, memoize, weakref_lru_cache)


Array = Any
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Expand Down Expand Up @@ -172,7 +172,8 @@ def shape_as_bdim(
# This assumes that there is only one binder in the data_shape.
ragged_axes = [(i, size.lengths) for i, size in enumerate(data_shape)
if isinstance(size, IndexedAxisSize)]
return make_batch_axis(len(data_shape), stacked_axis, ragged_axes)
# TODO(slebedev): This is a genuine type error, Array | T is not <: Array.
return make_batch_axis(len(data_shape), stacked_axis, ragged_axes) # type: ignore[arg-type]


def _update_annotation(
Expand Down Expand Up @@ -242,7 +243,8 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt:
(d, ias), = ((i, sz) # type: ignore
for i, sz in enumerate(x.aval.elt_ty.shape)
if type(sz) is IndexedAxisSize)
batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)])
# TODO(slebedev): This is a genuine type error, Array | T is not <: Array.
batch_axis = make_batch_axis(x.data.ndim, 0, [(d+1, ias.lengths)]) # type: ignore[list-item]
return BatchTracer(trace, x.data, batch_axis) # type: ignore
elif isinstance(spec, int) or spec is None:
spec = spec and canonicalize_axis(spec, len(np.shape(x)))
Expand Down Expand Up @@ -695,7 +697,8 @@ def canonicalize_segment_lengths(d: RaggedAxis) -> RaggedAxis:
id(core.get_referent(segment_lengths)),
(segment_lengths, pe.DBIdx(len(axis_map))))
new_ragged_axes.append((ragged_axis, dbidx))
return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes))
# TODO(slebedev): This is a geniune type error, pe.DBIdx is not <: Array.
return RaggedAxis(d.stacked_axis, tuple(new_ragged_axes)) # type: ignore[arg-type]
new_dims = [canonicalize_segment_lengths(d)
if isinstance(d, RaggedAxis) else d for d in dims]
segment_lens = [s for s, _ in axis_map.values()]
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/lax/ann.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
"""

from functools import partial
from typing import Any

import numpy as np

Expand All @@ -88,9 +87,7 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target):
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import func
from jax._src.lib.mlir.dialects import hlo


Array = Any
from jax._src.typing import Array


def approx_max_k(operand: Array,
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src.state import primitives as state_primitives
from jax._src.state import utils as state_utils
from jax._src.state import types as state_types
from jax._src.typing import Array
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict)
from jax._src.lax.control_flow import loops
Expand All @@ -53,7 +54,6 @@
S = TypeVar('S')
T = TypeVar('T')
class Ref(Generic[T]): pass
Array = Any

ref_set = state_primitives.ref_set
ref_get = state_primitives.ref_get
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from jax._src.state import discharge as state_discharge
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.typing import Array
from jax._src.util import (partition_list, safe_map, safe_zip, split_list,
unzip2, weakref_lru_cache, merge_lists)
import numpy as np
Expand All @@ -64,7 +65,6 @@
zip = safe_zip

T = TypeVar('T')
Array = Any
BooleanNumeric = Any # A bool, or a Boolean array.

### Helper functions
Expand Down
22 changes: 8 additions & 14 deletions jax/_src/lax/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import builtins
from collections.abc import Sequence
from functools import partial
import operator
from typing import Any, NamedTuple, Optional, Union
from typing import NamedTuple, Optional, Union

import numpy as np

Expand All @@ -28,14 +27,9 @@
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lib.mlir.dialects import hlo
from jax._src.typing import Array, DTypeLike


_max = builtins.max

Array = Any
DType = Any
Shape = core.Shape

class ConvDimensionNumbers(NamedTuple):
"""Describes batch, spatial, and feature dimensions of a convolution.
Expand All @@ -62,7 +56,7 @@ def conv_general_dilated(
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
feature_group_count: int = 1, batch_group_count: int = 1,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""General n-dimensional convolution operator, with optional dilation.
Wraps XLA's `Conv
Expand Down Expand Up @@ -174,7 +168,7 @@ def conv_general_dilated(

def conv(lhs: Array, rhs: Array, window_strides: Sequence[int],
padding: str, precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
Expand Down Expand Up @@ -204,7 +198,7 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
lhs_dilation: Optional[Sequence[int]],
rhs_dilation: Optional[Sequence[int]],
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper around `conv_general_dilated`.
Args:
Expand Down Expand Up @@ -256,7 +250,7 @@ def _conv_transpose_padding(k, s, padding):
else:
pad_a = int(np.ceil(pad_len / 2))
elif padding == 'VALID':
pad_len = k + s - 2 + _max(k - s, 0)
pad_len = k + s - 2 + max(k - s, 0)
pad_a = k - 1
else:
raise ValueError('Padding mode must be `SAME` or `VALID`.')
Expand All @@ -277,7 +271,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
dimension_numbers: ConvGeneralDilatedDimensionNumbers = None,
transpose_kernel: bool = False,
precision: lax.PrecisionLike = None,
preferred_element_type: Optional[DType] = None) -> Array:
preferred_element_type: Optional[DTypeLike] = None) -> Array:
"""Convenience wrapper for calculating the N-d convolution "transpose".
This function directly calculates a fractionally strided conv rather than
Expand Down Expand Up @@ -343,7 +337,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
if transpose_kernel:
# flip spatial dims and swap input / output channel axes
rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) # type: ignore[assignment]
return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn,
precision=precision,
preferred_element_type=preferred_element_type)
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/windowed_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.typing import Array

map = util.safe_map
zip = util.safe_zip

Array = Any


def reduce_window(operand, init_value, computation: Callable,
window_dimensions: core.Shape, window_strides: Sequence[int],
Expand Down
30 changes: 17 additions & 13 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
from jax._src.typing import Array, ArrayLike
from jax._src.ops.special import logsumexp as _logsumexp

Array = Any

# activations

Expand Down Expand Up @@ -165,7 +165,7 @@ def log_sigmoid(x: Array) -> Array:
return -softplus(-x)

@jax.jit
def elu(x: Array, alpha: Array = 1.0) -> Array:
def elu(x: Array, alpha: ArrayLike = 1.0) -> Array:
r"""Exponential linear unit activation function.
Computes the element-wise function:
Expand All @@ -190,7 +190,7 @@ def elu(x: Array, alpha: Array = 1.0) -> Array:
return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x))

@jax.jit
def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array:
def leaky_relu(x: Array, negative_slope: ArrayLike = 1e-2) -> Array:
r"""Leaky rectified linear unit activation function.
Computes the element-wise function:
Expand Down Expand Up @@ -237,7 +237,7 @@ def hard_tanh(x: Array) -> Array:
return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))

@jax.jit
def celu(x: Array, alpha: Array = 1.0) -> Array:
def celu(x: Array, alpha: ArrayLike = 1.0) -> Array:
r"""Continuously-differentiable exponential linear unit activation.
Computes the element-wise function:
Expand Down Expand Up @@ -431,14 +431,14 @@ def softmax(x: Array,
:func:`log_softmax`
"""
if jax.config.jax_softmax_custom_jvp:
return _softmax(x, axis, where, initial)
return _softmax(x, axis, where, initial) # type: ignore[return-value]
else:
return _softmax_deprecated(x, axis, where, initial)

# TODO(mattjj): replace softmax with _softmax when deprecation flag is removed
@partial(jax.custom_jvp, nondiff_argnums=(1,))
def _softmax(
x,
x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
Expand All @@ -455,7 +455,11 @@ def _softmax_jvp(axis, primals, tangents):
y = _softmax(x, axis, where, initial)
return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True))

def _softmax_deprecated(x, axis, where, initial):
def _softmax_deprecated(
x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
where: Optional[Array] = None,
initial: Optional[Array] = None) -> Array:
x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True)
unnormalized = jnp.exp(x - lax.stop_gradient(x_max))
result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True)
Expand All @@ -469,7 +473,7 @@ def standardize(x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
epsilon: ArrayLike = 1e-5,
where: Optional[Array] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
if mean is None:
Expand All @@ -484,11 +488,11 @@ def standardize(x: Array,
return (x - mean) * lax.rsqrt(variance + epsilon)

def normalize(x: Array,
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: Array = 1e-5,
where: Optional[Array] = None) -> Array:
axis: Optional[Union[int, tuple[int, ...]]] = -1,
mean: Optional[Array] = None,
variance: Optional[Array] = None,
epsilon: ArrayLike = 1e-5,
where: Optional[Array] = None) -> Array:
r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`."""
warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning)
return standardize(x, axis, mean, variance, epsilon, where)
Expand Down
7 changes: 3 additions & 4 deletions jax/_src/nn/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,17 @@

import numpy as np

import jax
import jax.numpy as jnp
from jax import lax
from jax import random
from jax._src import core
from jax._src import dtypes
from jax._src.typing import Array, ArrayLike
from jax._src.util import set_module

export = set_module('jax.nn.initializers')

KeyArray = jax.Array
Array = Any
KeyArray = Array
# TODO: Import or define these to match
# https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py.
DTypeLikeFloat = Any
Expand Down Expand Up @@ -240,7 +239,7 @@ def _complex_uniform(key: KeyArray,
theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype)
return r * jnp.exp(1j * theta)

def _complex_truncated_normal(key: KeyArray, upper: Array,
def _complex_truncated_normal(key: KeyArray, upper: ArrayLike,
shape: Union[Sequence[int], core.NamedShape],
dtype: DTypeLikeInexact) -> Array:
"""
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from collections.abc import Sequence
import sys
from typing import Any, Callable, Optional, Union
from typing import Callable, Optional, Union
import warnings

import numpy as np
Expand All @@ -31,9 +31,9 @@
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import reductions
from jax._src.numpy.util import check_arraylike, promote_dtypes
from jax._src.typing import Array


Array = Any
if sys.version_info >= (3, 10):
from types import EllipsisType
SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType]
Expand Down
5 changes: 3 additions & 2 deletions jax/_src/scipy/optimize/_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm."""
from typing import Any, Callable, NamedTuple, Optional, Union
from typing import Callable, NamedTuple, Optional, Union
from functools import partial

import jax
import jax.numpy as jnp
from jax import lax
from jax._src.scipy.optimize.line_search import line_search
from jax._src.typing import Array


_dot = partial(jnp.dot, precision=lax.Precision.HIGHEST)


Array = Any

class LBFGSResults(NamedTuple):
"""Results from L-BFGS optimization
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/state/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@
from jax._src import effects
from jax._src import pretty_printer as pp
from jax._src.util import safe_map, safe_zip
from jax._src.typing import Array

## JAX utilities

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

Array = Any

_ref_effect_color = pp.Color.GREEN

class RefEffect(effects.JaxprInputEffect):
Expand Down

0 comments on commit 737e6a2

Please sign in to comment.