diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index a6f4e1f7b8e8..0eca75fbbb23 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -33,7 +33,7 @@ _broadcast_to, check_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) from jax._src.lax import lax as lax_internal -from jax._src.typing import Array, ArrayLike, DType, DTypeLike +from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg from jax._src.util import ( canonicalize_axis as _canonicalize_axis, maybe_named_axis, NumpyComplexWarning) @@ -432,14 +432,22 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, @implements(np.var, skip_params=['out']) def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array: + # TODO(vfdev-5): deprecated 2024-5-16, remove after deprecation expires. + if not isinstance(ddof, DeprecatedArg): + warnings.warn( + "The ddof argument of jax.numpy.var is deprecated and setting it " + "will soon raise an error. To avoid an error in the future, and to " + "suppress this warning, please use the correction argument instead.", + DeprecationWarning, stacklevel=2) + correction = ddof + return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("var", a) dtypes.check_user_dtype_supported(dtype, "var") @@ -465,7 +473,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, else: normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=computation_dtype, keepdims=keepdims) - normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype)) + normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype)) result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where) return lax.div(result, normalizer).astype(dtype) @@ -493,14 +501,22 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy @implements(np.std, skip_params=['out']) def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, - where: ArrayLike | None = None) -> Array: - return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array: + # TODO(vfdev-5): deprecated 2024-5-16, remove after deprecation expires. + if not isinstance(ddof, DeprecatedArg): + warnings.warn( + "The ddof argument of jax.numpy.var is deprecated and setting it " + "will soon raise an error. To avoid an error in the future, and to " + "suppress this warning, please use the correction argument instead.", + DeprecationWarning, stacklevel=2) + correction = ddof + return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims, where=where) @partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims')) def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, - out: None = None, ddof: int = 0, keepdims: bool = False, *, + out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: check_arraylike("std", a) dtypes.check_user_dtype_supported(dtype, "std") @@ -508,7 +524,7 @@ def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}") if out is not None: raise NotImplementedError("The 'out' argument to jnp.std is not supported.") - return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where)) + return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where)) @implements(np.ptp, skip_params=['out']) diff --git a/jax/experimental/array_api/_statistical_functions.py b/jax/experimental/array_api/_statistical_functions.py index c34fb1fc3af4..25e60ca130f7 100644 --- a/jax/experimental/array_api/_statistical_functions.py +++ b/jax/experimental/array_api/_statistical_functions.py @@ -18,9 +18,9 @@ # jnp.var def std(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the standard deviation of the input array x.""" - return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims) + return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims) def var(x, /, *, axis=None, correction=0.0, keepdims=False): """Calculates the variance of the input array x.""" - return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims) + return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 861b9014c589..bee4e239f370 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -540,14 +540,16 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims): rtol=tol, atol=tol) @jtu.sample_product( + test_fns=[(np.var, jnp.var), (np.std, jnp.std)], shape=[(5,), (10, 5)], dtype=all_dtypes, out_dtype=inexact_dtypes, axis=[None, 0, -1], - ddof=[0, 1, 2], + correction=[0, 1, 2], keepdims=[False, True], ) - def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): + def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, correction, keepdims): + np_fn, jnp_fn = test_fns rng = jtu.rand_default(self.rng()) args_maker = self._GetArgsMaker(rng, [shape], [dtype]) @jtu.ignore_warning(category=RuntimeWarning, @@ -555,11 +557,11 @@ def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims): @jtu.ignore_warning(category=NumpyComplexWarning) def np_fun(x): # Numpy fails with bfloat16 inputs - out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), + out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype), dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype, - axis=axis, ddof=ddof, keepdims=keepdims) + axis=axis, ddof=correction, keepdims=keepdims) return out.astype(out_dtype) - jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims) + jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, correction=correction, keepdims=keepdims) tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3, np.float64: 1e-3, np.complex128: 1e-6}) if (jnp.issubdtype(dtype, jnp.complexfloating) and