Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added correction arg to jnp.var and jnp.std for array-api compliance #21262

Merged

Conversation

vfdev-5
Copy link
Collaborator

@vfdev-5 vfdev-5 commented May 16, 2024

Description:

@vfdev-5 vfdev-5 force-pushed the depr-change-ddof-to-correction-21088 branch from 1d25490 to 4dfc76e Compare May 16, 2024 15:45
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! One minor comment below; also we need to mention this in the CHANGELOG, and also update jax/numpy/__init__.pyi to reflect the new type signatures.

jax/_src/numpy/reductions.py Outdated Show resolved Hide resolved
tests/lax_numpy_reducers_test.py Show resolved Hide resolved
@jakevdp jakevdp self-assigned this May 16, 2024
@vfdev-5 vfdev-5 marked this pull request as ready for review May 16, 2024 16:09
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - last thing we need is to squash the changes into a single commit (see https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests).

Thanks!

@vfdev-5 vfdev-5 force-pushed the depr-change-ddof-to-correction-21088 branch from 8d42350 to 8f079b3 Compare May 16, 2024 16:16
@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 16, 2024

Thanks for the review, @jakevdp !

Copy link
Collaborator

@Micky774 Micky774 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just needs a minor correction in the changelog

CHANGELOG.md Outdated Show resolved Hide resolved
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels May 16, 2024
@vfdev-5 vfdev-5 marked this pull request as draft May 17, 2024 06:46
@Micky774
Copy link
Collaborator

Tests are currently failing because of a signature mismatch between the new std, var funcs and their numpy counterparts. This can be fixed by updating extra_params in lax_numpy_test::testWrappedSignaturesMatch to include the new correction arg for both functions. See: https://github.com/google/jax/blob/5e2710c2c28a6f5bc2d6c89cf7148ea254685c30/tests/lax_numpy_test.py#L5970-L5973

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 17, 2024

@Micky774 yes, right now removing "correction" from unmatched list I have this:
Missing entries:

'std': {'np_params': ['a', 'axis', 'dtype', 'out', 'ddof', 'keepdims', 'where', 'correction'], 'jnp_params': ['a', 'axis', 'dtype', 'out', 'correction', 'keepdims', 'where', 'ddof']}
'var': {'np_params': ['a', 'axis', 'dtype', 'out', 'ddof', 'keepdims', 'where', 'correction'], 'jnp_params': ['a', 'axis', 'dtype', 'out', 'correction', 'keepdims', 'where', 'ddof']}

it means that I should reorder the incorrect order I've done in jnp:

- def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
-        out: None = None, correction: int | float = 0, keepdims: bool = False, *,
-        where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array:

+ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
+        out: None = None, ddof: int | DeprecatedArg = DeprecatedArg(), keepdims: bool = False, *,
+        where: ArrayLike | None = None, correction: int | float = 0) -> Array:

However, given that numpy does not deprecate ddof arg: https://numpy.org/devdocs/reference/generated/numpy.std.html
Does it make sense to deprecate it in jnp, especially when it will be removed there will be once again a failure of signature mismatch ? We may finally wanted just to add new arg ("correction" to jnp.std / jnp.var) similarly to numpy ?

@Micky774
Copy link
Collaborator

Ah good point, it seems that they do intend to maintain both as valid API (discussion) so we ought to do the same. It will still be fully array API compliant, so indeed let's keep ddof.

@vfdev-5 vfdev-5 force-pushed the depr-change-ddof-to-correction-21088 branch from 8f079b3 to 83b06b2 Compare May 17, 2024 13:03
jax/_src/numpy/reductions.py Outdated Show resolved Hide resolved
@vfdev-5 vfdev-5 changed the title Deprecated ddof in jnp.var and jnp.std Added correction arg to jnp.var and jnp.std for array-api compliance May 22, 2024
@vfdev-5 vfdev-5 force-pushed the depr-change-ddof-to-correction-21088 branch 2 times, most recently from d878da3 to 20b87e1 Compare May 22, 2024 14:12
@vfdev-5 vfdev-5 marked this pull request as ready for review May 22, 2024 14:13
@vfdev-5 vfdev-5 requested a review from Micky774 May 22, 2024 14:14
@vfdev-5 vfdev-5 force-pushed the depr-change-ddof-to-correction-21088 branch from 20b87e1 to 4e4afbe Compare May 22, 2024 20:01
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor suggestion below.

Also, regarding the default value quesion: a possibly brilliant, possibly terrible idea. What if we do this:

class _int(int): pass
_zero = _int(0)
def var(..., ddof: int = _zero,...):
  if correction is None:
    correction = int(ddof) if isinstance(ddof, _int) else ddof
  elif ddof is zero:
    raise ValueError(...)

I can't think of any reasons why this wouldn't work, can you?

@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 22, 2024

Yes, this approach works, but looks like a hack. By the way, numpy in this case does not raise the error:

>>> import numpy as np                                  
>>> np.var(np.array([1.0, 2.0]), ddof=0, correction=0)                                                               
np.float64(0.25)  

I wonder where exactly the boundary for type hints is passing between jax.numpy and numpy. For example, instead of introducing _zero = _int(0), we could change ddof: int = 0 into ddof: int | None = None and set it internally to zero by default. Docs will be still showing what numpy has:

var(a: 'ArrayLike', axis: 'Axis' = None, dtype: 'DTypeLike | None' = None, out: 'None' = None, ddof: 'int | None' = None, keepdims: 'bool' = False, *, where: 'ArrayLike | None' = None, correction: 'int | float | None' = None) -> 'Ar
ray'

    ddof : {int, float}, optional
        Means Delta Degrees of Freedom.  The divisor used in calculations
        is ``N - ddof``, where ``N`` represents the number of elements.
        By default `ddof` is zero. See Notes for details about use of `ddof`.

    correction : {int, float}, optional
        Array API compatible name for the ``ddof`` parameter. Only one of them
        can be provided at the same time.

Jake, I'm happy to implement your solution with _zero = _int(0) if there is no way of ddof type hint using None. I'm just trying to figure out the project conventions and limitations on the proposed changes.

@jakevdp
Copy link
Collaborator

jakevdp commented May 22, 2024

Let's stick with ddof=0 as a default. Simplest is best here I think

@vfdev-5 vfdev-5 force-pushed the depr-change-ddof-to-correction-21088 branch from 4e4afbe to 7666284 Compare May 24, 2024 13:32
@vfdev-5
Copy link
Collaborator Author

vfdev-5 commented May 24, 2024

Let's stick with ddof=0 as a default. Simplest is best here I think

Sounds good, if I understand correctly your comment that we keep the following check:

if correction is None:
  correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
  raise ValueError("ddof and correction can't be provided simultaneously.")

and do not change ddof.

I updated the test according to your review

Description:
- Added correction arg in jnp.var and jnp.std
- Addresses jax-ml#21088
- Updated signatures in init.pyi
- Updated tests
@vfdev-5 vfdev-5 force-pushed the depr-change-ddof-to-correction-21088 branch from 7666284 to 55f8284 Compare May 24, 2024 16:16
@copybara-service copybara-service bot merged commit bab7f40 into jax-ml:main May 24, 2024
12 of 14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants