From 5c8586ecfdd8a590b831bb5f07814953da4230d8 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 25 Sep 2023 14:44:27 +0100 Subject: [PATCH] MAINT Clean up leftover `Array = Any` aliases in jax/_src/**.py I had to revert to using `Any` for `RaggedAxis.ragged_axes` because pytype found more latent type errors, which require the understanding of ragedness and dynamic shapes internals to fix properly. --- jax/BUILD | 1 + jax/_src/interpreters/batching.py | 10 +- jax/_src/lax/ann.py | 5 +- jax/_src/lax/control_flow/for_loop.py | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/lax/convolution.py | 22 ++-- jax/_src/lax/windowed_reductions.py | 5 +- jax/_src/nn/functions.py | 165 +++++++++++++++----------- jax/_src/nn/initializers.py | 11 +- jax/_src/ops/scatter.py | 4 +- jax/_src/scipy/optimize/_lbfgs.py | 5 +- jax/_src/state/types.py | 3 +- 12 files changed, 128 insertions(+), 107 deletions(-) diff --git a/jax/BUILD b/jax/BUILD index cad6d7d9acb4..06be18b9e98a 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -666,6 +666,7 @@ pytype_strict_library( ":core", ":effects", ":pretty_printer", + ":typing", ":util", ], ) diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index 84e616a62d63..843f29fe395f 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -33,12 +33,12 @@ from jax._src.interpreters import partial_eval as pe from jax._src.tree_util import (tree_unflatten, tree_flatten, register_pytree_node) +from jax._src.typing import Array from jax._src.util import (unzip2, unzip3, safe_map, safe_zip, split_list, canonicalize_axis, moveaxis, as_hashable_function, curry, memoize, weakref_lru_cache) -Array = Any map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -116,7 +116,7 @@ class RaggedAxis: # For each axis, we store its index and the corresponding segment lengths. # For example, the jumble i:(Fin 3) => f32[lens1.i, 7, lens2.i] # would be represented with ragged_axes = [(1, lens1), (3, lens2)] - ragged_axes: tuple[tuple[int, Array], ...] + ragged_axes: tuple[tuple[int, Any], ...] @property def size(self): @@ -148,8 +148,10 @@ def _sorted_ragged_axis(stacked_axis, ragged_axes): return RaggedAxis(stacked_axis, tuple(sorted(ragged_axes, key=lambda p: p[0]))) def make_batch_axis( - ndim: int, stacked_axis: int, ragged_axes: list[tuple[int, Array]] - ) -> int | RaggedAxis: + ndim: int, + stacked_axis: int, + ragged_axes: list[tuple[int, Array | core.Var]], +) -> int | RaggedAxis: if ragged_axes: canonical = [(canonicalize_axis(ax, ndim), sz) for ax, sz in ragged_axes] return _sorted_ragged_axis(canonicalize_axis(stacked_axis, ndim), canonical) diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index ac1b5aec103c..22abf2c6b0a4 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -70,7 +70,6 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): """ from functools import partial -from typing import Any import numpy as np @@ -88,9 +87,7 @@ def pmap_mips(qy, db, db_offset, db_size, k, recall_target): from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func from jax._src.lib.mlir.dialects import hlo - - -Array = Any +from jax._src.typing import Array def approx_max_k(operand: Array, diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 4a93b8b68ba2..d4635a8660b7 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -39,6 +39,7 @@ from jax._src.state import primitives as state_primitives from jax._src.state import utils as state_utils from jax._src.state import types as state_types +from jax._src.typing import Array from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip, split_list, split_dict) from jax._src.lax.control_flow import loops @@ -53,7 +54,6 @@ S = TypeVar('S') T = TypeVar('T') class Ref(Generic[T]): pass -Array = Any ref_set = state_primitives.ref_set ref_get = state_primitives.ref_get diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index bf118b83c817..823d7c7e2933 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -52,6 +52,7 @@ from jax._src.state import discharge as state_discharge from jax._src.numpy.ufuncs import logaddexp from jax._src.traceback_util import api_boundary +from jax._src.typing import Array from jax._src.util import (partition_list, safe_map, safe_zip, split_list, unzip2, weakref_lru_cache, merge_lists) import numpy as np @@ -64,7 +65,6 @@ zip = safe_zip T = TypeVar('T') -Array = Any BooleanNumeric = Any # A bool, or a Boolean array. ### Helper functions diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index b533508dc3dd..0aed39ce9533 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import builtins from collections.abc import Sequence from functools import partial import operator -from typing import Any, NamedTuple, Optional, Union +from typing import NamedTuple, Optional, Union import numpy as np @@ -28,14 +27,9 @@ from jax._src.interpreters import mlir from jax._src.lax import lax from jax._src.lib.mlir.dialects import hlo +from jax._src.typing import Array, DTypeLike -_max = builtins.max - -Array = Any -DType = Any -Shape = core.Shape - class ConvDimensionNumbers(NamedTuple): """Describes batch, spatial, and feature dimensions of a convolution. @@ -62,7 +56,7 @@ def conv_general_dilated( dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, feature_group_count: int = 1, batch_group_count: int = 1, precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DType] = None) -> Array: + preferred_element_type: Optional[DTypeLike] = None) -> Array: """General n-dimensional convolution operator, with optional dilation. Wraps XLA's `Conv @@ -174,7 +168,7 @@ def conv_general_dilated( def conv(lhs: Array, rhs: Array, window_strides: Sequence[int], padding: str, precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DType] = None) -> Array: + preferred_element_type: Optional[DTypeLike] = None) -> Array: """Convenience wrapper around `conv_general_dilated`. Args: @@ -204,7 +198,7 @@ def conv_with_general_padding(lhs: Array, rhs: Array, lhs_dilation: Optional[Sequence[int]], rhs_dilation: Optional[Sequence[int]], precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DType] = None) -> Array: + preferred_element_type: Optional[DTypeLike] = None) -> Array: """Convenience wrapper around `conv_general_dilated`. Args: @@ -256,7 +250,7 @@ def _conv_transpose_padding(k, s, padding): else: pad_a = int(np.ceil(pad_len / 2)) elif padding == 'VALID': - pad_len = k + s - 2 + _max(k - s, 0) + pad_len = k + s - 2 + max(k - s, 0) pad_a = k - 1 else: raise ValueError('Padding mode must be `SAME` or `VALID`.') @@ -277,7 +271,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], dimension_numbers: ConvGeneralDilatedDimensionNumbers = None, transpose_kernel: bool = False, precision: lax.PrecisionLike = None, - preferred_element_type: Optional[DType] = None) -> Array: + preferred_element_type: Optional[DTypeLike] = None) -> Array: """Convenience wrapper for calculating the N-d convolution "transpose". This function directly calculates a fractionally strided conv rather than @@ -343,7 +337,7 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int], if transpose_kernel: # flip spatial dims and swap input / output channel axes rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:]) - rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1]) + rhs = rhs.swapaxes(dn.rhs_spec[0], dn.rhs_spec[1]) return conv_general_dilated(lhs, rhs, one, pads, strides, rhs_dilation, dn, precision=precision, preferred_element_type=preferred_element_type) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 1ac143b23692..e7fe2a604fd7 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -14,7 +14,7 @@ from collections.abc import Sequence from functools import partial -from typing import Any, Callable, Optional, Union +from typing import Callable, Optional, Union import warnings import numpy as np @@ -36,12 +36,11 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.ufuncs import logaddexp +from jax._src.typing import Array map = util.safe_map zip = util.safe_zip -Array = Any - def reduce_window(operand, init_value, computation: Callable, window_dimensions: core.Shape, window_strides: Sequence[int], diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 69b05d4b0496..62b87412e8cb 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -28,15 +28,16 @@ from jax._src import dtypes from jax._src import util from jax._src.core import AxisName +from jax._src.numpy import util as numpy_util +from jax._src.typing import Array, ArrayLike from jax._src.ops.special import logsumexp as _logsumexp -Array = Any # activations @custom_jvp @jax.jit -def relu(x: Array) -> Array: +def relu(x: ArrayLike) -> Array: r"""Rectified linear unit activation function. Computes the element-wise function: @@ -72,7 +73,7 @@ def relu(x: Array) -> Array: relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0))) @jax.jit -def softplus(x: Array) -> Array: +def softplus(x: ArrayLike) -> Array: r"""Softplus activation function. Computes the element-wise function @@ -86,7 +87,7 @@ def softplus(x: Array) -> Array: return jnp.logaddexp(x, 0) @jax.jit -def soft_sign(x: Array) -> Array: +def soft_sign(x: ArrayLike) -> Array: r"""Soft-sign activation function. Computes the element-wise function @@ -97,10 +98,12 @@ def soft_sign(x: Array) -> Array: Args: x : input array """ - return x / (jnp.abs(x) + 1) + numpy_util.check_arraylike("soft_sign", x) + x_arr = jnp.asarray(x) + return x_arr / (jnp.abs(x_arr) + 1) @jax.jit -def sigmoid(x: Array) -> Array: +def sigmoid(x: ArrayLike) -> Array: r"""Sigmoid activation function. Computes the element-wise function: @@ -121,7 +124,7 @@ def sigmoid(x: Array) -> Array: return lax.logistic(x) @jax.jit -def silu(x: Array) -> Array: +def silu(x: ArrayLike) -> Array: r"""SiLU (a.k.a. swish) activation function. Computes the element-wise function: @@ -140,12 +143,14 @@ def silu(x: Array) -> Array: See also: :func:`sigmoid` """ - return x * sigmoid(x) + numpy_util.check_arraylike("silu", x) + x_arr = jnp.asarray(x) + return x_arr * sigmoid(x_arr) swish = silu @jax.jit -def log_sigmoid(x: Array) -> Array: +def log_sigmoid(x: ArrayLike) -> Array: r"""Log-sigmoid activation function. Computes the element-wise function: @@ -162,10 +167,12 @@ def log_sigmoid(x: Array) -> Array: See also: :func:`sigmoid` """ - return -softplus(-x) + numpy_util.check_arraylike("log_sigmoid", x) + x_arr = jnp.asarray(x) + return -softplus(-x_arr) @jax.jit -def elu(x: Array, alpha: Array = 1.0) -> Array: +def elu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array: r"""Exponential linear unit activation function. Computes the element-wise function: @@ -186,11 +193,14 @@ def elu(x: Array, alpha: Array = 1.0) -> Array: See also: :func:`selu` """ - safe_x = jnp.where(x > 0, 0., x) - return jnp.where(x > 0, x, alpha * jnp.expm1(safe_x)) + numpy_util.check_arraylike("elu", x) + x_arr = jnp.asarray(x) + return jnp.where(x_arr > 0, + x_arr, + alpha * jnp.expm1(jnp.where(x_arr > 0, 0., x_arr))) @jax.jit -def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array: +def leaky_relu(x: ArrayLike, negative_slope: ArrayLike = 1e-2) -> Array: r"""Leaky rectified linear unit activation function. Computes the element-wise function: @@ -213,10 +223,12 @@ def leaky_relu(x: Array, negative_slope: Array = 1e-2) -> Array: See also: :func:`relu` """ - return jnp.where(x >= 0, x, negative_slope * x) + numpy_util.check_arraylike("leaky_relu", x) + x_arr = jnp.asarray(x) + return jnp.where(x_arr >= 0, x_arr, negative_slope * x_arr) @jax.jit -def hard_tanh(x: Array) -> Array: +def hard_tanh(x: ArrayLike) -> Array: r"""Hard :math:`\mathrm{tanh}` activation function. Computes the element-wise function: @@ -234,10 +246,12 @@ def hard_tanh(x: Array) -> Array: Returns: An array. """ - return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x)) + numpy_util.check_arraylike("hard_tanh", x) + x_arr = jnp.asarray(x) + return jnp.where(x_arr > 1, 1, jnp.where(x_arr < -1, -1, x_arr)) @jax.jit -def celu(x: Array, alpha: Array = 1.0) -> Array: +def celu(x: ArrayLike, alpha: ArrayLike = 1.0) -> Array: r"""Continuously-differentiable exponential linear unit activation. Computes the element-wise function: @@ -262,7 +276,7 @@ def celu(x: Array, alpha: Array = 1.0) -> Array: return jnp.maximum(x, 0.0) + alpha * jnp.expm1(jnp.minimum(x, 0.0) / alpha) @jax.jit -def selu(x: Array) -> Array: +def selu(x: ArrayLike) -> Array: r"""Scaled exponential linear unit activation. Computes the element-wise function: @@ -295,7 +309,7 @@ def selu(x: Array) -> Array: # TODO(phawkins): this jit was found to change numerics in a test. Debug this. # @partial(jax.jit, static_argnames=("approximate",)) -def gelu(x: Array, approximate: bool = True) -> Array: +def gelu(x: ArrayLike, approximate: bool = True) -> Array: r"""Gaussian error linear unit activation function. If ``approximate=False``, computes the element-wise function: @@ -317,20 +331,18 @@ def gelu(x: Array, approximate: bool = True) -> Array: x : input array approximate: whether to use the approximate or exact formulation. """ - - # Promote to nearest float-like dtype. - x = x.astype(dtypes.to_inexact_dtype(x.dtype)) + [x_arr] = numpy_util.promote_args_inexact("gelu", x) if approximate: - sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype) - cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x + 0.044715 * (x ** 3)))) - return x * cdf + sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x_arr.dtype) + cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3)))) + return x_arr * cdf else: - sqrt_2 = np.sqrt(2).astype(x.dtype) - return jnp.array(x * (lax.erf(x / sqrt_2) + 1) / 2, dtype=x.dtype) + sqrt_2 = np.sqrt(2).astype(x_arr.dtype) + return jnp.array(x_arr * (lax.erf(x_arr / sqrt_2) + 1) / 2, dtype=x_arr.dtype) @partial(jax.jit, static_argnames=("axis",)) -def glu(x: Array, axis: int = -1) -> Array: +def glu(x: ArrayLike, axis: int = -1) -> Array: r"""Gated linear unit activation function. Computes the function: @@ -353,9 +365,11 @@ def glu(x: Array, axis: int = -1) -> Array: See also: :func:`sigmoid` """ - size = x.shape[axis] + numpy_util.check_arraylike("glu", x) + x_arr = jnp.asarray(x) + size = x_arr.shape[axis] assert size % 2 == 0, "axis size must be divisible by 2" - x1, x2 = jnp.split(x, 2, axis) + x1, x2 = jnp.split(x_arr, 2, axis) return x1 * sigmoid(x2) # other functions @@ -364,10 +378,10 @@ def glu(x: Array, axis: int = -1) -> Array: @partial(jax.jit, static_argnames=("axis",)) -def log_softmax(x: Array, +def log_softmax(x: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = -1, - where: Optional[Array] = None, - initial: Optional[Array] = None) -> Array: + where: Optional[ArrayLike] = None, + initial: Optional[ArrayLike] = None) -> Array: r"""Log-Softmax function. Computes the logarithm of the :code:`softmax` function, which rescales @@ -391,8 +405,10 @@ def log_softmax(x: Array, See also: :func:`softmax` """ - x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) - shifted = x - lax.stop_gradient(x_max) + numpy_util.check_arraylike("log_softmax", x) + x_arr = jnp.asarray(x) + x_max = jnp.max(x_arr, axis, where=where, initial=initial, keepdims=True) + shifted = x_arr - lax.stop_gradient(x_max) shifted_logsumexp = jnp.log( jnp.sum(jnp.exp(shifted), axis, where=where, keepdims=True)) result = shifted - shifted_logsumexp @@ -403,10 +419,10 @@ def log_softmax(x: Array, # TODO(phawkins): this jit was found to change numerics in a test. Debug this. #@partial(jax.jit, static_argnames=("axis",)) -def softmax(x: Array, +def softmax(x: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = -1, - where: Optional[Array] = None, - initial: Optional[Array] = None) -> Array: + where: Optional[ArrayLike] = None, + initial: Optional[ArrayLike] = None) -> Array: r"""Softmax function. Computes the function which rescales elements to the range :math:`[0, 1]` @@ -431,17 +447,20 @@ def softmax(x: Array, :func:`log_softmax` """ if jax.config.jax_softmax_custom_jvp: - return _softmax(x, axis, where, initial) + # mypy is confused by the `functools.partial` application in the definition + # of `_softmax` and incorrectly concludes that `_softmax` returns + # `ReturnValue` -- the unsubstituted type parameter of `custom_jvp`. + return _softmax(x, axis, where, initial) # type: ignore[return-value] else: return _softmax_deprecated(x, axis, where, initial) # TODO(mattjj): replace softmax with _softmax when deprecation flag is removed @partial(jax.custom_jvp, nondiff_argnums=(1,)) def _softmax( - x, + x: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = -1, - where: Optional[Array] = None, - initial: Optional[Array] = None) -> Array: + where: Optional[ArrayLike] = None, + initial: Optional[ArrayLike] = None) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) unnormalized = jnp.exp(x - x_max) result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) @@ -455,7 +474,11 @@ def _softmax_jvp(axis, primals, tangents): y = _softmax(x, axis, where, initial) return y, y * (x_dot - (y * x_dot).sum(axis, where=where, keepdims=True)) -def _softmax_deprecated(x, axis, where, initial): +def _softmax_deprecated( + x: ArrayLike, + axis: Optional[Union[int, tuple[int, ...]]] = -1, + where: Optional[ArrayLike] = None, + initial: Optional[ArrayLike] = None) -> Array: x_max = jnp.max(x, axis, where=where, initial=initial, keepdims=True) unnormalized = jnp.exp(x - lax.stop_gradient(x_max)) result = unnormalized / jnp.sum(unnormalized, axis, where=where, keepdims=True) @@ -465,13 +488,15 @@ def _softmax_deprecated(x, axis, where, initial): @partial(jax.jit, static_argnames=("axis",)) -def standardize(x: Array, +def standardize(x: ArrayLike, axis: Optional[Union[int, tuple[int, ...]]] = -1, - mean: Optional[Array] = None, - variance: Optional[Array] = None, - epsilon: Array = 1e-5, - where: Optional[Array] = None) -> Array: + mean: Optional[ArrayLike] = None, + variance: Optional[ArrayLike] = None, + epsilon: ArrayLike = 1e-5, + where: Optional[ArrayLike] = None) -> Array: r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" + numpy_util.check_arraylike("standardize", x) + numpy_util.check_arraylike_or_none("standardize", mean, variance, where) if mean is None: mean = jnp.mean(x, axis, keepdims=True, where=where) if variance is None: @@ -481,43 +506,45 @@ def standardize(x: Array, # when used in neural network normalization layers variance = jnp.mean( jnp.square(x), axis, keepdims=True, where=where) - jnp.square(mean) - return (x - mean) * lax.rsqrt(variance + epsilon) + return jnp.subtract(x, jnp.asarray(mean)) * lax.rsqrt(jnp.asarray(variance) + epsilon) -def normalize(x: Array, - axis: Optional[Union[int, tuple[int, ...]]] = -1, - mean: Optional[Array] = None, - variance: Optional[Array] = None, - epsilon: Array = 1e-5, - where: Optional[Array] = None) -> Array: +def normalize(x: ArrayLike, + axis: Optional[Union[int, tuple[int, ...]]] = -1, + mean: Optional[ArrayLike] = None, + variance: Optional[ArrayLike] = None, + epsilon: ArrayLike = 1e-5, + where: Optional[ArrayLike] = None) -> Array: r"""Normalizes an array by subtracting ``mean`` and dividing by :math:`\sqrt{\mathrm{variance}}`.""" warnings.warn("jax.nn.normalize will be deprecated. Use jax.nn.standardize instead.", DeprecationWarning) return standardize(x, axis, mean, variance, epsilon, where) +# TODO(slebedev): Change the type of `x` to `ArrayLike`. @partial(jax.jit, static_argnames=("num_classes", "dtype", "axis")) -def _one_hot(x: Array, num_classes: int, *, +def _one_hot(x: Any, num_classes: int, *, dtype: Any, axis: Union[int, AxisName]) -> Array: num_classes = core.concrete_dim_or_error( num_classes, "The error arose in jax.nn.one_hot argument `num_classes`.") dtype = dtypes.canonicalize_dtype(dtype) - x = jnp.asarray(x) + x_arr = jnp.asarray(x) try: - output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1) + output_pos_axis = util.canonicalize_axis(axis, x_arr.ndim + 1) except TypeError: axis_size = lax.psum(1, axis) if num_classes != axis_size: raise ValueError(f"Expected num_classes to match the size of axis {axis}, " f"but {num_classes} != {axis_size}") from None axis_idx = lax.axis_index(axis) - return jnp.asarray(x == axis_idx, dtype=dtype) + return jnp.asarray(x_arr == axis_idx, dtype=dtype) axis = operator.index(axis) # type: ignore[arg-type] - lhs = lax.expand_dims(x, (axis,)) - rhs_shape = [1] * x.ndim + lhs = lax.expand_dims(x_arr, (axis,)) + rhs_shape = [1] * x_arr.ndim rhs_shape.insert(output_pos_axis, num_classes) - rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis) + rhs = lax.broadcasted_iota(x_arr.dtype, rhs_shape, output_pos_axis) return jnp.asarray(lhs == rhs, dtype=dtype) -def one_hot(x: Array, num_classes: int, *, +# TODO(slebedev): Change the type of `x` to `ArrayLike`. +def one_hot(x: Any, num_classes: int, *, dtype: Any = jnp.float_, axis: Union[int, AxisName] = -1) -> Array: """One-hot encodes the given indices. @@ -550,7 +577,7 @@ def one_hot(x: Array, num_classes: int, *, @jax.custom_jvp @jax.jit -def relu6(x: Array) -> Array: +def relu6(x: ArrayLike) -> Array: r"""Rectified Linear Unit 6 activation function. Computes the element-wise function @@ -582,7 +609,7 @@ def relu6(x: Array) -> Array: lax.select((x > 0) & (x < 6), g, lax.full_like(g, 0))) @jax.jit -def hard_sigmoid(x: Array) -> Array: +def hard_sigmoid(x: ArrayLike) -> Array: r"""Hard Sigmoid activation function. Computes the element-wise function @@ -602,7 +629,7 @@ def hard_sigmoid(x: Array) -> Array: return relu6(x + 3.) / 6. @jax.jit -def hard_silu(x: Array) -> Array: +def hard_silu(x: ArrayLike) -> Array: r"""Hard SiLU (swish) activation function Computes the element-wise function @@ -622,6 +649,8 @@ def hard_silu(x: Array) -> Array: See also: :func:`hard_sigmoid` """ - return x * hard_sigmoid(x) + numpy_util.check_arraylike("hard_silu", x) + x_arr = jnp.asarray(x) + return x_arr * hard_sigmoid(x_arr) hard_swish = hard_silu diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index add524b73b0b..297c330773a8 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -23,18 +23,17 @@ import numpy as np -import jax import jax.numpy as jnp from jax import lax from jax import random from jax._src import core from jax._src import dtypes +from jax._src.typing import Array, ArrayLike from jax._src.util import set_module export = set_module('jax.nn.initializers') -KeyArray = jax.Array -Array = Any +KeyArray = Array # TODO: Import or define these to match # https://github.com/numpy/numpy/blob/main/numpy/typing/_dtype_like.py. DTypeLikeFloat = Any @@ -48,7 +47,7 @@ class Initializer(Protocol): def __call__(key: KeyArray, shape: core.Shape, dtype: DTypeLikeInexact = jnp.float_) -> Array: - ... + raise NotImplementedError @export def zeros(key: KeyArray, @@ -82,7 +81,7 @@ def ones(key: KeyArray, return jnp.ones(shape, dtypes.canonicalize_dtype(dtype)) @export -def constant(value: Array, +def constant(value: ArrayLike, dtype: DTypeLikeInexact = jnp.float_ ) -> Initializer: """Builds an initializer that returns arrays full of a constant ``value``. @@ -240,7 +239,7 @@ def _complex_uniform(key: KeyArray, theta = 2 * jnp.pi * random.uniform(key_theta, shape, real_dtype).astype(dtype) return r * jnp.exp(1j * theta) -def _complex_truncated_normal(key: KeyArray, upper: Array, +def _complex_truncated_normal(key: KeyArray, upper: ArrayLike, shape: Union[Sequence[int], core.NamedShape], dtype: DTypeLikeInexact) -> Array: """ diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 0fa832aac155..47cba94c59d6 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -16,7 +16,7 @@ from collections.abc import Sequence import sys -from typing import Any, Callable, Optional, Union +from typing import Callable, Optional, Union import warnings import numpy as np @@ -31,9 +31,9 @@ from jax._src.numpy import lax_numpy as jnp from jax._src.numpy import reductions from jax._src.numpy.util import check_arraylike, promote_dtypes +from jax._src.typing import Array -Array = Any if sys.version_info >= (3, 10): from types import EllipsisType SingleIndex = Union[None, int, slice, Sequence[int], Array, EllipsisType] diff --git a/jax/_src/scipy/optimize/_lbfgs.py b/jax/_src/scipy/optimize/_lbfgs.py index 44862eee146b..b50c54ec1700 100644 --- a/jax/_src/scipy/optimize/_lbfgs.py +++ b/jax/_src/scipy/optimize/_lbfgs.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. """The Limited-Memory Broyden-Fletcher-Goldfarb-Shanno minimization algorithm.""" -from typing import Any, Callable, NamedTuple, Optional, Union +from typing import Callable, NamedTuple, Optional, Union from functools import partial import jax import jax.numpy as jnp from jax import lax from jax._src.scipy.optimize.line_search import line_search +from jax._src.typing import Array + _dot = partial(jnp.dot, precision=lax.Precision.HIGHEST) -Array = Any class LBFGSResults(NamedTuple): """Results from L-BFGS optimization diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index be6c39177e8a..e7845f41a615 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -23,14 +23,13 @@ from jax._src import effects from jax._src import pretty_printer as pp from jax._src.util import safe_map, safe_zip +from jax._src.typing import Array ## JAX utilities map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip -Array = Any - _ref_effect_color = pp.Color.GREEN class RefEffect(effects.JaxprInputEffect):