-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for max_version, dl_device, copy
kwargs in __dlpack__
to match Array API
#20198
Conversation
35543cf
to
b4b1164
Compare
b4b1164
to
bb27b92
Compare
Minor comment, but in general it's useful to have meaningful commit messages ("Update" doesn't communicate much about what the change includes). |
bb27b92
to
8235852
Compare
@yashk2810 @jakevdp This PR should be ready for review again. I've added tests and updated the documentation of |
4a85f4a
to
65be6b3
Compare
65be6b3
to
1e4cba7
Compare
Tests are failing because (jax2tf tests are not run in github CI because they are too expensive). |
1e4cba7
to
a2feff2
Compare
@jakevdp Done -- I've also simplified the tests a bit and refactored |
Towards #20200
cf. data-apis/array-api#741, data-apis/array-api#602, https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
Note
"In principle, arbitrary cross-device copies could be allowed too, but the consensus in data-apis/array-api#626 was that limiting to device-to-host copies is enough for now". This PR includes the optional device-to-device transfer.
Default behavior is preserved when
max_version=None, dl_device=None, copy=None,
This PR also adds new versioning information to
jax/_src/dlpack.py
: