Skip to content

Commit

Permalink
Added correction arg in jnp.var and jnp.std
Browse files Browse the repository at this point in the history
Description:
- Added correction arg in jnp.var and jnp.std
- Addresses jax-ml#21088
- Updated signatures in init.pyi
- Updated tests
  • Loading branch information
vfdev-5 committed May 24, 2024
1 parent 8f045ca commit 55f8284
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 22 deletions.
24 changes: 16 additions & 8 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,17 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
@implements(np.var, skip_params=['out'])
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("var", a)
dtypes.check_user_dtype_supported(dtype, "var")
Expand All @@ -465,7 +469,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
dtype=computation_dtype, keepdims=keepdims)
normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype))
normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype))
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
return lax.div(result, normalizer).astype(dtype)

Expand Down Expand Up @@ -494,21 +498,25 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
@implements(np.std, skip_params=['out'])
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array:
if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("std", a)
dtypes.check_user_dtype_supported(dtype, "std")
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))


@implements(np.ptp, skip_params=['out'])
Expand Down
7 changes: 3 additions & 4 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,12 @@

import jax

# TODO(micky774): Remove after deprecating ddof-->correction in jnp.std and
# jnp.var

def std(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the standard deviation of the input array x."""
return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims)


def var(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the variance of the input array x."""
return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims)
4 changes: 2 additions & 2 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def stack(
) -> Array: ...
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: Optional[ArrayLike] = ...) -> Array: ...
where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ...
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def sum(
a: ArrayLike,
Expand Down Expand Up @@ -894,7 +894,7 @@ def vander(
) -> Array: ...
def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: Optional[ArrayLike] = ...) -> Array: ...
where: Optional[ArrayLike] = ..., correction: int | float | None = ...) -> Array: ...
def vdot(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...
Expand Down
23 changes: 17 additions & 6 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,31 +540,42 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
rtol=tol, atol=tol)

@jtu.sample_product(
test_fns=[(np.var, jnp.var), (np.std, jnp.std)],
shape=[(5,), (10, 5)],
dtype=all_dtypes,
out_dtype=inexact_dtypes,
axis=[None, 0, -1],
ddof=[0, 1, 2],
ddof_correction=[(0, None), (1, None), (1, 0), (0, 0), (0, 1), (0, 2)],
keepdims=[False, True],
)
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, ddof_correction, keepdims):
np_fn, jnp_fn = test_fns
ddof, correction = ddof_correction
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x):
# setup ddof and correction kwargs excluding case when correction is not specified
ddof_correction_kwargs = {"ddof": ddof}
if correction is not None:
key = "correction" if numpy_version >= (2, 0) else "ddof"
ddof_correction_kwargs[key] = correction
# Numpy fails with bfloat16 inputs
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
axis=axis, ddof=ddof, keepdims=keepdims)
axis=axis, keepdims=keepdims, **ddof_correction_kwargs)
return out.astype(out_dtype)
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, ddof=ddof, correction=correction,
keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
not jnp.issubdtype(out_dtype, jnp.complexfloating)):
self.assertRaises(ValueError, lambda: jnp_fun(*args_maker()))
self.assertRaises(ValueError, jnp_fun, *args_maker())
elif (correction is not None and ddof != 0):
self.assertRaises(ValueError, jnp_fun, *args_maker())
else:
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker,
tol=tol)
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5960,9 +5960,9 @@ def testWrappedSignaturesMatch(self):
'reshape': ['shape', 'copy'],
'row_stack': ['casting'],
'stack': ['casting'],
'std': ['correction', 'mean'],
'std': ['mean'],
'tri': ['like'],
'var': ['correction', 'mean'],
'var': ['mean'],
'vstack': ['casting'],
'zeros_like': ['subok', 'order']
}
Expand Down

0 comments on commit 55f8284

Please sign in to comment.