From a299c5eab8aaaacbd5d1d2fc505e8839d3bfa473 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 26 Oct 2022 13:12:56 -0700 Subject: [PATCH] [typing] annotate jax.numpy reduction operations --- jax/_src/numpy/reductions.py | 247 +++++++++++------- jax/_src/ops/scatter.py | 2 +- jax/_src/scipy/special.py | 2 +- .../jax2tf/tests/shape_poly_test.py | 2 +- 4 files changed, 151 insertions(+), 102 deletions(-) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 99a22b62f6f2..49b70a4c3013 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -15,10 +15,11 @@ import builtins from functools import partial import operator -from typing import Optional, Tuple, Union +from typing import overload, Any, Callable, Optional, Sequence, Tuple, Union import warnings import numpy as np +from typing_extensions import Literal from jax import core from jax import lax @@ -27,6 +28,7 @@ from jax._src.numpy.ndarray import ndarray from jax._src.numpy.util import _broadcast_to, _check_arraylike, _complex_elem_type, _promote_dtypes_inexact, _promote_dtypes_numeric, _where, _wraps from jax._src.lax import lax as lax_internal +from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.util import canonicalize_axis as _canonicalize_axis, maybe_named_axis, prod as _prod @@ -34,16 +36,19 @@ _lax_const = lax_internal._const -def _asarray(a): +Axis = Union[None, int, Sequence[int]] + + +def _asarray(a: ArrayLike) -> Array: # simplified version of jnp.asarray() for local use. return a if isinstance(a, ndarray) else api.device_put(a) -def _isscalar(element): +def _isscalar(element: Any) -> bool: if hasattr(element, '__jax_array__'): element = element.__jax_array__() return dtypes.is_python_scalar(element) or np.isscalar(element) -def _moveaxis(a, source: int, destination: int): +def _moveaxis(a: ArrayLike, source: int, destination: int) -> Array: # simplified version of jnp.moveaxis() for local use. _check_arraylike("moveaxis", a) a = _asarray(a) @@ -53,15 +58,23 @@ def _moveaxis(a, source: int, destination: int): perm.insert(destination, source) return lax.transpose(a, perm) -def _upcast_f16(dtype): - if dtype in [np.float16, dtypes.bfloat16]: +def _upcast_f16(dtype: DTypeLike) -> DType: + if np.dtype(dtype) in [np.float16, dtypes.bfloat16]: return np.dtype('float32') - return dtype - -def _reduction(a, name, np_fun, op, init_val, has_identity=True, - preproc=None, bool_op=None, upcast_f16_for_computation=False, - axis=None, dtype=None, out=None, keepdims=False, initial=None, - where_=None, parallel_reduce=None, promote_integers=False): + return np.dtype(dtype) + +ReductionOp = Callable[[Any, Any], Any] + +def _reduction(a: ArrayLike, name: str, np_fun: Any, op: ReductionOp, init_val: ArrayLike, + *, has_identity: bool = True, + preproc: Optional[Callable[[ArrayLike], ArrayLike]] = None, + bool_op: Optional[ReductionOp] = None, + upcast_f16_for_computation: bool = False, + axis: Axis = None, dtype: DTypeLike = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where_: Optional[ArrayLike] = None, + parallel_reduce: Optional[Callable[..., ArrayLike]] = None, + promote_integers: bool = False) -> Array: bool_op = bool_op or op # Note: we must accept out=None as an argument, because numpy reductions delegate to # object methods. For example `np.sum(x)` will call `x.sum()` if the `sum()` method @@ -119,7 +132,7 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True, else: result = lax.reduce(a, init_val, op, dims) if initial is not None: - result = op(lax.convert_element_type(initial, a.dtype), result) + result = op(lax.convert_element_type(initial, _asarray(a).dtype), result) if keepdims: result = lax.expand_dims(result, pos_dims) return lax.convert_element_type(result, dtype or result_dtype) @@ -127,13 +140,13 @@ def _reduction(a, name, np_fun, op, init_val, has_identity=True, def _canonicalize_axis_allow_named(x, rank): return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name) -def _reduction_dims(a, axis): +def _reduction_dims(a: ArrayLike, axis: Axis): if axis is None: return (tuple(range(np.ndim(a))),) * 2 elif not isinstance(axis, (np.ndarray, tuple, list)): - axis = (axis,) + axis = (axis,) # type: ignore[assignment] canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a)) - for x in axis) + for x in axis) # type: ignore[union-attr] if len(canon_axis) != len(set(canon_axis)): raise ValueError(f"duplicate value in 'axis': {axis}") canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int)) @@ -142,7 +155,7 @@ def _reduction_dims(a, axis): else: return canon_axis, canon_axis -def _reduction_init_val(a, init_val): +def _reduction_init_val(a: ArrayLike, init_val: Any) -> np.ndarray: # This function uses np.* functions because lax pattern matches against the # specific concrete values of the reduction inputs. a_dtype = dtypes.canonicalize_dtype(dtypes.dtype(a)) @@ -155,16 +168,16 @@ def _reduction_init_val(a, init_val): sign, info = np.sign(init_val), dtypes.iinfo(a_dtype) return np.array(info.min if sign < 0 else info.max, dtype=a_dtype) -def _cast_to_bool(operand): +def _cast_to_bool(operand: ArrayLike) -> Array: with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=np.ComplexWarning) return lax.convert_element_type(operand, np.bool_) -def _cast_to_numeric(operand): +def _cast_to_numeric(operand: ArrayLike) -> Array: return _promote_dtypes_numeric(operand)[0] -def _ensure_optional_axes(x): +def _ensure_optional_axes(x: Axis) -> Axis: def force(x): if x is None: return None @@ -186,8 +199,10 @@ def force(x): @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) -def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None, promote_integers=True): +def _reduce_sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, keepdims: bool = False, + initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, + promote_integers: bool = True) -> Array: return _reduction(a, "sum", np.sum, lax.add, 0, preproc=_cast_to_numeric, bool_op=lax.bitwise_or, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, @@ -195,76 +210,85 @@ def _reduce_sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=Non promote_integers=promote_integers) @_wraps(np.sum, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) -def sum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None, promote_integers=True): +def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None, promote_integers: bool = True) -> Array: return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True) -def _reduce_prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None, promote_integers=True): +def _reduce_prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, keepdims: bool = False, + initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, + promote_integers: bool = True) -> Array: return _reduction(a, "prod", np.prod, lax.mul, 1, preproc=_cast_to_numeric, bool_op=lax.bitwise_and, upcast_f16_for_computation=True, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where_=where, promote_integers=promote_integers) @_wraps(np.prod, skip_params=['out'], extra_params=_PROMOTE_INTEGERS_DOC) -def prod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None, promote_integers=True): +def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, keepdims: bool = False, + initial: Optional[ArrayLike] = None, where: Optional[ArrayLike] = None, + promote_integers: bool = True) -> Array: return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, promote_integers=promote_integers) @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): +def _reduce_max(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "max", np.max, lax.max, -np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmax) @_wraps(np.max, skip_params=['out']) -def max(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): +def max(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): +def _reduce_min(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "min", np.min, lax.min, np.inf, has_identity=False, axis=axis, out=out, keepdims=keepdims, initial=initial, where_=where, parallel_reduce=lax.pmin) @_wraps(np.min, skip_params=['out']) -def min(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): +def min(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, initial=initial, where=where) @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): +def _reduce_all(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "all", np.all, lax.bitwise_and, True, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @_wraps(np.all, skip_params=['out']) -def all(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): +def all(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True) -def _reduce_any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): +def _reduce_any(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduction(a, "any", np.any, lax.bitwise_or, False, preproc=_cast_to_bool, axis=axis, out=out, keepdims=keepdims, where_=where) @_wraps(np.any, skip_params=['out']) -def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, *, where=None): +def any(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: Optional[ArrayLike] = None) -> Array: return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out, keepdims=keepdims, where=where) @@ -274,24 +298,28 @@ def any(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, alltrue = all sometrue = any -def _axis_size(a, axis): +def _axis_size(a: ArrayLike, axis: Union[int, Sequence[int]]): if not isinstance(axis, (tuple, list)): - axis = (axis,) + axis_seq: Sequence[int] = (axis,) # type: ignore[assignment] + else: + axis_seq = axis size = 1 a_shape = np.shape(a) - for a in axis: + for a in axis_seq: size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name)) return size @_wraps(np.mean, skip_params=['out']) -def mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=False, *, where=None): +def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, keepdims: bool = False, *, + where: Optional[ArrayLike] = None) -> Array: return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True) -def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=False, *, where=None): +def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, keepdims: bool = False, *, + where: Optional[ArrayLike] = None) -> Array: _check_arraylike("mean", a) lax_internal._check_user_dtype_supported(dtype, "mean") if out is not None: @@ -313,14 +341,23 @@ def _mean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, sum(a, axis, dtype=dtype, keepdims=keepdims, where=where), lax.convert_element_type(normalizer, dtype)) +@overload +def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, + returned: Literal[False] = False, keepdims: bool = False) -> Array: ... +@overload +def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, *, + returned: Literal[True], keepdims: bool = False) -> Array: ... +@overload +def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, + returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: ... @_wraps(np.average) -def average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, - returned=False, keepdims=False): +def average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, + returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims) @partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True) -def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None, - returned=False, keepdims=False): +def _average(a: ArrayLike, axis: Axis = None, weights: Optional[ArrayLike] = None, + returned: bool = False, keepdims: bool = False) -> Union[Array, Tuple[Array, Array]]: if weights is None: # Treat all weights as 1 _check_arraylike("average", a) a, = _promote_dtypes_inexact(a) @@ -330,7 +367,7 @@ def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None elif isinstance(axis, tuple): weights_sum = lax.full_like(avg, _prod(core.dimension_as_value(a.shape[d]) for d in axis)) else: - weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) + weights_sum = lax.full_like(avg, core.dimension_as_value(a.shape[axis])) # type: ignore[index] else: _check_arraylike("average", a, weights) a, weights = _promote_dtypes_inexact(a, weights) @@ -374,21 +411,23 @@ def _average(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, weights=None @_wraps(np.var, skip_params=['out']) -def var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): +def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: Optional[ArrayLike] = None) -> Array: return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): +def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: Optional[ArrayLike] = None) -> Array: _check_arraylike("var", a) lax_internal._check_user_dtype_supported(dtype, "var") if out is not None: raise NotImplementedError("The 'out' argument to jnp.var is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) - a = a.astype(computation_dtype) + a = _asarray(a).astype(computation_dtype) a_mean = mean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = lax.sub(a, a_mean) if dtypes.issubdtype(centered.dtype, np.complexfloating): @@ -406,11 +445,11 @@ def _var(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, normalizer = normalizer - ddof result = sum(centered, axis, keepdims=keepdims, where=where) - out = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) - return lax.convert_element_type(out, dtype) + result = lax.div(result, lax.convert_element_type(normalizer, result.dtype)) + return lax.convert_element_type(result, dtype) -def _var_promote_types(a_dtype, dtype): +def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike) -> Tuple[DType, DType]: if dtype: if (not dtypes.issubdtype(dtype, np.complexfloating) and dtypes.issubdtype(a_dtype, np.complexfloating)): @@ -428,18 +467,20 @@ def _var_promote_types(a_dtype, dtype): else: dtype = _complex_elem_type(a_dtype) computation_dtype = a_dtype - return _upcast_f16(computation_dtype), dtype + return _upcast_f16(computation_dtype), np.dtype(dtype) @_wraps(np.std, skip_params=['out']) -def std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): +def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: Optional[ArrayLike] = None) -> Array: return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, *, where=None): +def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: Optional[ArrayLike] = None) -> Array: _check_arraylike("std", a) lax_internal._check_user_dtype_supported(dtype, "std") if out is not None: @@ -448,13 +489,13 @@ def _std(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, @_wraps(np.ptp, skip_params=['out']) -def ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=False): +def ptp(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False) -> Array: return _ptp(a, _ensure_optional_axes(axis), out, keepdims) @partial(api.jit, static_argnames=('axis', 'keepdims')) -def _ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=False): +def _ptp(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False) -> Array: _check_arraylike("ptp", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.ptp is not supported.") @@ -465,15 +506,16 @@ def _ptp(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, @_wraps(np.count_nonzero) @partial(api.jit, static_argnames=('axis', 'keepdims')) -def count_nonzero(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, - keepdims=False): +def count_nonzero(a: ArrayLike, axis: Axis = None, + keepdims: bool = False) -> Array: _check_arraylike("count_nonzero", a) return sum(lax.ne(a, _lax_const(a, 0)), axis=axis, dtype=dtypes.canonicalize_dtype(np.int_), keepdims=keepdims) -def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, - axis=None, keepdims=None, **kwargs): +def _nan_reduction(a: ArrayLike, name: str, jnp_reduction: Callable[..., Array], + init_val: ArrayLike, nan_if_all_nan: bool, + axis: Axis = None, keepdims: bool = False, **kwargs) -> Array: _check_arraylike(name, a) if not dtypes.issubdtype(dtypes.dtype(a), np.inexact): return jnp_reduction(a, axis=axis, keepdims=keepdims, **kwargs) @@ -488,24 +530,27 @@ def _nan_reduction(a, name, jnp_reduction, init_val, nan_if_all_nan, @_wraps(np.nanmin, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) -def nanmin(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): +def nanmin(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: return _nan_reduction(a, 'nanmin', min, np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) @_wraps(np.nanmax, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'keepdims')) -def nanmax(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, out=None, - keepdims=None, initial=None, where=None): +def nanmax(a: ArrayLike, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: return _nan_reduction(a, 'nanmax', max, -np.inf, nan_if_all_nan=initial is None, axis=axis, out=out, keepdims=keepdims, initial=initial, where=where) @_wraps(np.nansum, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None): +def nansum(a: ArrayLike, axis: Axis = None, out: None = None, dtype: DTypeLike = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nansum', sum, 0, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, @@ -517,8 +562,9 @@ def nansum(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, @_wraps(np.nanprod, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=None, initial=None, where=None): +def nanprod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, + keepdims: bool = False, initial: Optional[ArrayLike] = None, + where: Optional[ArrayLike] = None) -> Array: lax_internal._check_user_dtype_supported(dtype, "nanprod") return _nan_reduction(a, 'nanprod', prod, 1, nan_if_all_nan=False, axis=axis, dtype=dtype, out=out, keepdims=keepdims, @@ -526,8 +572,8 @@ def nanprod(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, @_wraps(np.nanmean, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, keepdims=False, where=None): +def nanmean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, + keepdims: bool = False, where: Optional[ArrayLike] = None) -> Array: _check_arraylike("nanmean", a) lax_internal._check_user_dtype_supported(dtype, "nanmean") if out is not None: @@ -545,15 +591,16 @@ def nanmean(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, @_wraps(np.nanvar, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, where=None): +def nanvar(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, + ddof: int = 0, keepdims: bool = False, + where: Optional[ArrayLike] = None) -> Array: _check_arraylike("nanvar", a) lax_internal._check_user_dtype_supported(dtype, "nanvar") if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanvar is not supported.") computation_dtype, dtype = _var_promote_types(dtypes.dtype(a), dtype) - a = a.astype(computation_dtype) + a = _asarray(a).astype(computation_dtype) a_mean = nanmean(a, axis, dtype=computation_dtype, keepdims=True, where=where) centered = _where(lax_internal._isnan(a), 0, lax.sub(a, a_mean)) # double-where trick for gradients. @@ -569,14 +616,15 @@ def nanvar(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, result = sum(centered, axis, keepdims=keepdims, where=where) result = _where(normalizer_mask, np.nan, result) divisor = _where(normalizer_mask, 1, normalizer) - out = lax.div(result, lax.convert_element_type(divisor, result.dtype)) - return lax.convert_element_type(out, dtype) + result = lax.div(result, lax.convert_element_type(divisor, result.dtype)) + return lax.convert_element_type(result, dtype) @_wraps(np.nanstd, skip_params=['out']) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) -def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, - out=None, ddof=0, keepdims=False, where=None): +def nanstd(a: ArrayLike, axis: Axis = None, dtype: DTypeLike = None, out: None = None, + ddof: int = 0, keepdims: bool = False, + where: Optional[ArrayLike] = None) -> Array: _check_arraylike("nanstd", a) lax_internal._check_user_dtype_supported(dtype, "nanstd") if out is not None: @@ -584,17 +632,18 @@ def nanstd(a, axis: Optional[Union[int, Tuple[int, ...]]] = None, dtype=None, return lax.sqrt(nanvar(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) -def _make_cumulative_reduction(np_reduction, reduction, fill_nan=False, fill_value=0): +# TODO(jakevdp): use a protocol here for better typing? +def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], + fill_nan: bool = False, fill_value: ArrayLike = 0) -> Callable[..., Array]: @_wraps(np_reduction, skip_params=['out']) def cumulative_reduction(a, - axis: Optional[Union[int, Tuple[int, ...]]] = None, + axis: Axis = None, dtype=None, out=None): return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) @partial(api.jit, static_argnames=('axis', 'dtype')) - def _cumulative_reduction(a, - axis: Optional[Union[int, Tuple[int, ...]]] = None, - dtype=None, out=None): + def _cumulative_reduction(a: ArrayLike, axis: Axis = None, + dtype: DTypeLike = None, out: None = None) -> Array: _check_arraylike(np_reduction.__name__, a) if out is not None: raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " diff --git a/jax/_src/ops/scatter.py b/jax/_src/ops/scatter.py index 5a2381ddcf30..84be65f0476e 100644 --- a/jax/_src/ops/scatter.py +++ b/jax/_src/ops/scatter.py @@ -161,7 +161,7 @@ def _segment_update(name: str, segment_ids = jnp.asarray(segment_ids) dtype = data.dtype if num_segments is None: - num_segments = jnp.max(segment_ids) + 1 + num_segments = np.max(segment_ids) + 1 num_segments = core.concrete_or_error(int, num_segments, "segment_sum() `num_segments` argument.") if num_segments is not None and num_segments < 0: raise ValueError("num_segments must be non-negative.") diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 4454cc657998..298abf9cf3df 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -1053,7 +1053,7 @@ def sph_harm(m: Array, phi = jnp.array([phi]) if n_max is None: - n_max = jnp.max(n) + n_max = np.max(n) n_max = core.concrete_or_error( int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must ' 'be statically specified to use `sph_harm` within JAX transformations.') diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index d5c8ee739be8..03245a3eacf9 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1179,7 +1179,7 @@ def _make_harness(group_name: str, name: str, [RandArg((3, 4), _f32), RandArg((2, 3, 4), _f32)], poly_axes=[0, 1]), _make_harness("add_transpose", "", - jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=0) + x)), + jax.grad(lambda x: jnp.sum(jnp.sum(x, axis=0, keepdims=False) + x)), [RandArg((3, 4), _f32)], poly_axes=[0]), _make_harness("arange", "start",