From 8d4235030a8858e728e7df014a3417b00a49b6c9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 16 May 2024 16:06:24 +0000 Subject: [PATCH] Updated signatures in init.pyi, CHANGELOG.md and fixed copy-paste in the message --- CHANGELOG.md | 2 ++ jax/_src/numpy/reductions.py | 6 +++--- jax/numpy/__init__.pyi | 8 ++++---- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e789f28249e..a348346fae30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ Remember to align the itemized text with the first line of an item within a list deprecated and will soon be removed. Use `rtol` instead. * The ``rcond`` argument of {func}`jax.numpy.linalg.pinv` is being deprecated and will soon be removed. Use `rtol` instead. + * The ``ddof`` argument of {func}`jax.numpy.std` and {func}`jax.numpy.std` is being + deprecated and will soon be removed. Use `correction` instead. ## jaxlib 0.4.29 diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index 0eca75fbbb23..c0bfdf5bd854 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -434,7 +434,7 @@ def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None, def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array: - # TODO(vfdev-5): deprecated 2024-5-16, remove after deprecation expires. + # TODO(vfdev-5): Remove after deprecation is completed (began 2024-5-16) if not isinstance(ddof, DeprecatedArg): warnings.warn( "The ddof argument of jax.numpy.var is deprecated and setting it " @@ -503,10 +503,10 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, correction: int | float = 0, keepdims: bool = False, *, where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array: - # TODO(vfdev-5): deprecated 2024-5-16, remove after deprecation expires. + # TODO(vfdev-5): Remove after deprecation is completed (began 2024-5-16) if not isinstance(ddof, DeprecatedArg): warnings.warn( - "The ddof argument of jax.numpy.var is deprecated and setting it " + "The ddof argument of jax.numpy.std is deprecated and setting it " "will soon raise an error. To avoid an error in the future, and to " "suppress this warning, please use the correction argument instead.", DeprecationWarning, stacklevel=2) diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 36a6347d35ba..97b7f34e94e9 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -782,8 +782,8 @@ def stack( dtype: Optional[DTypeLike] = ..., ) -> Array: ... def std(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., - out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... + out: None = ..., correction: int | float = ..., keepdims: builtins.bool = ..., *, + where: Optional[ArrayLike] = ..., ddof: int | DeprecatedArg = ...) -> Array: ... def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: ... def sum( a: ArrayLike, @@ -893,8 +893,8 @@ def vander( x: ArrayLike, N: Optional[int] = ..., increasing: builtins.bool = ... ) -> Array: ... def var(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., - out: None = ..., ddof: int = ..., keepdims: builtins.bool = ..., *, - where: Optional[ArrayLike] = ...) -> Array: ... + out: None = ..., correction: int | float = ..., keepdims: builtins.bool = ..., *, + where: Optional[ArrayLike] = ..., ddof: int | DeprecatedArg = ...) -> Array: ... def vdot( a: ArrayLike, b: ArrayLike, *, precision: PrecisionLike = ..., preferred_element_type: Optional[DTypeLike] = ...) -> Array: ...