Skip to content

Commit

Permalink
Deprecated ddof in jnp.var and jnp.std
Browse files Browse the repository at this point in the history
Description:
- Deprecated ddof in jnp.var and jnp.std
- Addresses jax-ml#21088
  • Loading branch information
vfdev-5 committed May 16, 2024
1 parent 8f045ca commit 4dfc76e
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 18 deletions.
38 changes: 27 additions & 11 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
_broadcast_to, check_arraylike, _complex_elem_type,
promote_dtypes_inexact, promote_dtypes_numeric, _where, implements)
from jax._src.lax import lax as lax_internal
from jax._src.typing import Array, ArrayLike, DType, DTypeLike
from jax._src.typing import Array, ArrayLike, DType, DTypeLike, DeprecatedArg
from jax._src.util import (
canonicalize_axis as _canonicalize_axis, maybe_named_axis,
NumpyComplexWarning)
Expand Down Expand Up @@ -432,14 +432,22 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,

@implements(np.var, skip_params=['out'])
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array:
# TODO(vfdev-5): deprecated 2024-5-16, remove after deprecation expires.
if not isinstance(ddof, DeprecatedArg):
warnings.warn(
"The ddof argument of jax.numpy.var is deprecated and setting it "
"will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the correction argument instead.",
DeprecationWarning, stacklevel=2)
correction = ddof
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("var", a)
dtypes.check_user_dtype_supported(dtype, "var")
Expand All @@ -465,7 +473,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
dtype=computation_dtype, keepdims=keepdims)
normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype))
normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype))
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
return lax.div(result, normalizer).astype(dtype)

Expand Down Expand Up @@ -493,22 +501,30 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy

@implements(np.std, skip_params=['out'])
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array:
# TODO(vfdev-5): deprecated 2024-5-16, remove after deprecation expires.
if not isinstance(ddof, DeprecatedArg):
warnings.warn(
"The ddof argument of jax.numpy.var is deprecated and setting it "
"will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the correction argument instead.",
DeprecationWarning, stacklevel=2)
correction = ddof
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("std", a)
dtypes.check_user_dtype_supported(dtype, "std")
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))


@implements(np.ptp, skip_params=['out'])
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
# jnp.var
def std(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the standard deviation of the input array x."""
return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims)


def var(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the variance of the input array x."""
return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims)
12 changes: 7 additions & 5 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,26 +540,28 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
rtol=tol, atol=tol)

@jtu.sample_product(
test_fns=[(np.var, jnp.var), (np.std, jnp.std)],
shape=[(5,), (10, 5)],
dtype=all_dtypes,
out_dtype=inexact_dtypes,
axis=[None, 0, -1],
ddof=[0, 1, 2],
correction=[0, 1, 2],
keepdims=[False, True],
)
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, correction, keepdims):
np_fn, jnp_fn = test_fns
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x):
# Numpy fails with bfloat16 inputs
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
axis=axis, ddof=ddof, keepdims=keepdims)
axis=axis, ddof=correction, keepdims=keepdims)
return out.astype(out_dtype)
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, correction=correction, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
Expand Down

0 comments on commit 4dfc76e

Please sign in to comment.