Skip to content

Commit

Permalink
Added correction arg in jnp.var and jnp.std
Browse files Browse the repository at this point in the history
Description:
- Added correction arg in jnp.var and jnp.std
- Addresses jax-ml#21088
- Updated signatures in init.pyi
  • Loading branch information
vfdev-5 committed May 17, 2024
1 parent 8f045ca commit 83b06b2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 19 deletions.
28 changes: 20 additions & 8 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,13 +433,19 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
@implements(np.var, skip_params=['out'])
def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _var(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where: ArrayLike | None = None, correction: int | float = 0) -> Array:
if ddof != 0:
if correction != 0:
raise ValueError(
"ddof and correction can't be provided simultaneously."
)
correction = ddof
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("var", a)
dtypes.check_user_dtype_supported(dtype, "var")
Expand All @@ -465,7 +471,7 @@ def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
else:
normalizer = sum(_broadcast_to(where, np.shape(a)), axis,
dtype=computation_dtype, keepdims=keepdims)
normalizer = lax.sub(normalizer, lax.convert_element_type(ddof, computation_dtype))
normalizer = lax.sub(normalizer, lax.convert_element_type(correction, computation_dtype))
result = sum(centered, axis, dtype=computation_dtype, keepdims=keepdims, where=where)
return lax.div(result, normalizer).astype(dtype)

Expand Down Expand Up @@ -494,21 +500,27 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy
@implements(np.std, skip_params=['out'])
def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
return _std(a, _ensure_optional_axes(axis), dtype, out, ddof, keepdims,
where: ArrayLike | None = None, correction: int | float = 0) -> Array:
if ddof != 0:
if correction != 0:
raise ValueError(
"ddof and correction can't be provided simultaneously."
)
correction = ddof
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
out: None = None, correction: int | float = 0, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array:
check_arraylike("std", a)
dtypes.check_user_dtype_supported(dtype, "std")
if dtype is not None and not dtypes.issubdtype(dtype, np.inexact):
raise ValueError(f"dtype argument to jnp.std must be inexact; got {dtype}")
if out is not None:
raise NotImplementedError("The 'out' argument to jnp.std is not supported.")
return lax.sqrt(var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, where=where))
return lax.sqrt(var(a, axis=axis, dtype=dtype, correction=correction, keepdims=keepdims, where=where))


@implements(np.ptp, skip_params=['out'])
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/array_api/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
# jnp.var
def std(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the standard deviation of the input array x."""
return jax.numpy.std(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.std(x, axis=axis, correction=correction, keepdims=keepdims)


def var(x, /, *, axis=None, correction=0.0, keepdims=False):
"""Calculates the variance of the input array x."""
return jax.numpy.var(x, axis=axis, ddof=correction, keepdims=keepdims)
return jax.numpy.var(x, axis=axis, correction=correction, keepdims=keepdims)
4 changes: 2 additions & 2 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ def stack(
) -> Array: ...
def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: Optional[ArrayLike] = ...) -> Array: ...
where: Optional[ArrayLike] = ..., correction: int | float = ...) -> Array: ...
def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ...
def sum(
a: ArrayLike,
Expand Down Expand Up @@ -894,7 +894,7 @@ def vander(
) -> Array: ...
def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ...,
out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *,
where: Optional[ArrayLike] = ...) -> Array: ...
where: Optional[ArrayLike] = ..., correction: int | float = ...) -> Array: ...
def vdot(
a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ...,
preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...
Expand Down
12 changes: 7 additions & 5 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,26 +540,28 @@ def testAverage(self, shape, dtype, axis, weights_shape, returned, keepdims):
rtol=tol, atol=tol)

@jtu.sample_product(
test_fns=[(np.var, jnp.var), (np.std, jnp.std)],
shape=[(5,), (10, 5)],
dtype=all_dtypes,
out_dtype=inexact_dtypes,
axis=[None, 0, -1],
ddof=[0, 1, 2],
correction=[0, 1, 2],
keepdims=[False, True],
)
def testVar(self, shape, dtype, out_dtype, axis, ddof, keepdims):
def testStdOrVar(self, test_fns, shape, dtype, out_dtype, axis, correction, keepdims):
np_fn, jnp_fn = test_fns
rng = jtu.rand_default(self.rng())
args_maker = self._GetArgsMaker(rng, [shape], [dtype])
@jtu.ignore_warning(category=RuntimeWarning,
message="Degrees of freedom <= 0 for slice.")
@jtu.ignore_warning(category=NumpyComplexWarning)
def np_fun(x):
# Numpy fails with bfloat16 inputs
out = np.var(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
out = np_fn(x.astype(np.float32 if dtype == dtypes.bfloat16 else dtype),
dtype=np.float32 if out_dtype == dtypes.bfloat16 else out_dtype,
axis=axis, ddof=ddof, keepdims=keepdims)
axis=axis, ddof=correction, keepdims=keepdims)
return out.astype(out_dtype)
jnp_fun = partial(jnp.var, dtype=out_dtype, axis=axis, ddof=ddof, keepdims=keepdims)
jnp_fun = partial(jnp_fn, dtype=out_dtype, axis=axis, correction=correction, keepdims=keepdims)
tol = jtu.tolerance(out_dtype, {np.float16: 1e-1, np.float32: 1e-3,
np.float64: 1e-3, np.complex128: 1e-6})
if (jnp.issubdtype(dtype, jnp.complexfloating) and
Expand Down
4 changes: 2 additions & 2 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5960,9 +5960,9 @@ def testWrappedSignaturesMatch(self):
'reshape': ['shape', 'copy'],
'row_stack': ['casting'],
'stack': ['casting'],
'std': ['correction', 'mean'],
'std': ['mean'],
'tri': ['like'],
'var': ['correction', 'mean'],
'var': ['mean'],
'vstack': ['casting'],
'zeros_like': ['subok', 'order']
}
Expand Down

0 comments on commit 83b06b2

Please sign in to comment.