-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added correction
arg to jnp.var and jnp.std for array-api compliance
#21262
Added correction
arg to jnp.var and jnp.std for array-api compliance
#21262
Conversation
1d25490
to
4dfc76e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! One minor comment below; also we need to mention this in the CHANGELOG
, and also update jax/numpy/__init__.pyi
to reflect the new type signatures.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great - last thing we need is to squash the changes into a single commit (see https://jax.readthedocs.io/en/latest/contributing.html#single-change-commits-and-pull-requests).
Thanks!
8d42350
to
8f079b3
Compare
Thanks for the review, @jakevdp ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just needs a minor correction in the changelog
Tests are currently failing because of a signature mismatch between the new |
@Micky774 yes, right now removing "correction" from unmatched list I have this:
it means that I should reorder the incorrect order I've done in jnp: - def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
- out: None = None, correction: int | float = 0, keepdims: bool = False, *,
- where: ArrayLike | None = None, ddof: int | DeprecatedArg = DeprecatedArg()) -> Array:
+ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
+ out: None = None, ddof: int | DeprecatedArg = DeprecatedArg(), keepdims: bool = False, *,
+ where: ArrayLike | None = None, correction: int | float = 0) -> Array: However, given that numpy does not deprecate ddof arg: https://numpy.org/devdocs/reference/generated/numpy.std.html |
Ah good point, it seems that they do intend to maintain both as valid API (discussion) so we ought to do the same. It will still be fully array API compliant, so indeed let's keep |
8f079b3
to
83b06b2
Compare
correction
arg to jnp.var and jnp.std for array-api compliance
d878da3
to
20b87e1
Compare
20b87e1
to
4e4afbe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor suggestion below.
Also, regarding the default value quesion: a possibly brilliant, possibly terrible idea. What if we do this:
class _int(int): pass
_zero = _int(0)
def var(..., ddof: int = _zero,...):
if correction is None:
correction = int(ddof) if isinstance(ddof, _int) else ddof
elif ddof is zero:
raise ValueError(...)
I can't think of any reasons why this wouldn't work, can you?
Yes, this approach works, but looks like a hack. By the way, numpy in this case does not raise the error:
I wonder where exactly the boundary for type hints is passing between jax.numpy and numpy. For example, instead of introducing
Jake, I'm happy to implement your solution with |
Let's stick with |
4e4afbe
to
7666284
Compare
Sounds good, if I understand correctly your comment that we keep the following check: if correction is None:
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.") and do not change I updated the test according to your review |
Description: - Added correction arg in jnp.var and jnp.std - Addresses jax-ml#21088 - Updated signatures in init.pyi - Updated tests
7666284
to
55f8284
Compare
Description:
correction
arg to jnp.var and jnp.std for array-api compliance