Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for copy kwarg in astype to match Array API #20195

Merged
merged 1 commit into from
Apr 23, 2024

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented Mar 12, 2024

Towards #20200

cf. data-apis/array-api#665

Note

This PR includes a placeholder utility function for managing device transfer, which will be developed in a follow-up PR

Default behavior is preserved when copy=None

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 12, 2024

Suggestion: let's implement these changes in jax.numpy.astype here: https://github.com/google/jax/blob/f8af4ef81646a29ce17942136b5f9cfaf86a225f/jax/_src/numpy/lax_numpy.py#L2212

Then we can just call jax.numpy.astype in the current location.

Also we could add some test coverage in lax_numpy_test.py

@Micky774 Micky774 marked this pull request as draft March 18, 2024 12:23
@Micky774 Micky774 force-pushed the array_api_astype branch 2 times, most recently from 2395e76 to 46e1060 Compare March 20, 2024 19:12
@Micky774 Micky774 marked this pull request as ready for review March 20, 2024 19:13
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 11, 2024

One small suggestion: it may be helpful to separate the copy change from the device change, since the way forward on the former is clear, and the way forward on the latter is less certain. Do you want to do a separate PR changing the semantics of the copy parameter in astype?

@Micky774
Copy link
Collaborator Author

One small suggestion: it may be helpful to separate the copy change from the device change, since the way forward on the former is clear, and the way forward on the latter is less certain. Do you want to do a separate PR changing the semantics of the copy parameter in astype?

Agreed. For now I'm just factoring out the device semantics as mentioned here and leaving only the changes to copy so I'll update the PR title accordingly. I'm hoping eventually we can have a common utility for general memory management for both device/copy semantics.

@Micky774 Micky774 changed the title Add support for device kwarg in astype to match Array API Add support for copy kwarg in astype to match Array API Apr 11, 2024
@Micky774
Copy link
Collaborator Author

I've created this PR in my fork to track discussion regarding device movement semantics using jnp.astype as an example. Now this PR passes through the device kwarg until it is ignored in the _place_array utility, which will eventually house the genuine device moment functionality.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Apr 15, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 15, 2024

Oh, it looks like this branch has conflicts that need to be resolved before we can merge.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm a bit worried about this change. In a lot of places we implicitly depend on x.astype(x.dtype) being a no-op, and now that the default semantics are copy=True, I think this assumption will no longer hold.

Maybe we should default to copy=None so that the default is backward-compatible?

NumPy has made a similar exception for copy in astype: https://numpy.org/neps/nep-0056-array-api-main-namespace.html#copy-keyword-semantics

@Micky774
Copy link
Collaborator Author

Micky774 commented Apr 15, 2024

If it is causing that many problems, then I agree going the gentler deprecation route as opposed to the "bug-fix" route should be preferred.

With that being said, I'm suspecting that the array API standard for astype is actually "mistaken" about their copy semantics. The original array API specification was based on NumPy's -- and other contemporary libraries' -- implementations and copy semantics. For reference, I've opened an issue discussing updating the copy semantics to be made similar to the other functions/methods: data-apis/array-api#788.

Maybe we wait to see what the community consensus on the proposal is, that way we can tune the numpy namespace to the False/None/True semantics if its accepted, and then we can stick to the "old" semantics in the array_api namespace just for technical compliance during the short term.

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 15, 2024

It seems like the array API will be sticking with copy=False/True.

I think what we should do here is follow NumPy and the Array API on the semantics of the copy argument for astype, except we should default to copy=False (which behaves like copy=None in other contexts).

In JAX, there's almost never a reason to force a copy, because arrays are immutible and copies are expensive. In NumPy, it's the opposite: you almost never want to silently return a view, because arrays are mutable.

@Micky774
Copy link
Collaborator Author

Micky774 commented Apr 19, 2024

@jakevdp Updated w/ a default of copy=False

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 19, 2024

Thanks - I'm still a bit confused about the state of things upstream. It sounds like NumPy does support copy in {True, False, None} for this function? Is that your understanding too?

@Micky774
Copy link
Collaborator Author

Thanks - I'm still a bit confused about the state of things upstream. It sounds like NumPy does support copy in {True, False, None} for this function? Is that your understanding too?

That seems not to be the case in NumPy 1.26, but is indeed true in NumPy 2.0.0rc1 with the caveat that copy=None is equivalent to copy=False. This is the same functionality that is enabled in this PR, except we do not publicly document the fact that copy=None is a valid choice. FWIW, I prefer not documenting it since I think having two different-yet-equivalent values for a keyword argument makes for a confusing API.

I'm personally hoping we see a change in NumPy and the array API to converge the astype copy semantics to be equivalent to the other copy-enabled functions/methods, but that probably won't be too soon, if it happens at all.

jax/_src/numpy/array_methods.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
tests/lax_numpy_test.py Outdated Show resolved Hide resolved
jax/numpy/__init__.pyi Outdated Show resolved Hide resolved
jax/numpy/__init__.pyi Outdated Show resolved Hide resolved
@copybara-service copybara-service bot merged commit 493698e into jax-ml:main Apr 23, 2024
14 checks passed
@Micky774 Micky774 deleted the array_api_astype branch April 23, 2024 14:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants