Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for device and copy kwargs in from_dlpack to match Array API #20175

Merged
merged 1 commit into from
Apr 5, 2024

Conversation

Micky774
Copy link
Collaborator

@Micky774 Micky774 commented Mar 11, 2024

Towards #20200

cf. data-apis/array-api#741

Note

"In principle, arbitrary cross-device copies could be allowed too, but the consensus in data-apis/array-api#626 was that limiting to device-to-host copies is enough for now". This PR includes the optional device-to-device transfer.

Default behavior is preserved when device=None, copy=None

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! A few minor comments

jax/_src/dlpack.py Outdated Show resolved Hide resolved
jax/_src/dlpack.py Show resolved Hide resolved
jax/_src/dlpack.py Outdated Show resolved Hide resolved
jax/_src/dlpack.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/experimental/array_api/_creation_functions.py Outdated Show resolved Hide resolved
jax/numpy/__init__.pyi Outdated Show resolved Hide resolved
@Micky774 Micky774 force-pushed the array_api branch 2 times, most recently from cd8ec29 to 40cb7d8 Compare March 11, 2024 22:15
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 11, 2024

Lint errors might indicate a problem:

jax/_src/dlpack.py:130: error: Argument 2 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client"; expected "Device"  [arg-type]
jax/_src/dlpack.py:130: error: Argument 3 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client | None"; expected "int | None"  [arg-type]

@Micky774
Copy link
Collaborator Author

Lint errors might indicate a problem:

jax/_src/dlpack.py:130: error: Argument 2 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client"; expected "Device"  [arg-type]
jax/_src/dlpack.py:130: error: Argument 3 to "dlpack_managed_tensor_to_buffer" has incompatible type "Client | None"; expected "int | None"  [arg-type]

This is a bug in the xla_extension stub annotations. I'll open a PR in XLA to resolve this. For now, I've added an inline ignore.

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Mar 11, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 12, 2024

Internal pytype tests are failing with many variations of this error:

File "/jax/_src/third_party/scipy/interpolate.py", line 4, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
  No xla_client.Device in module jax._src.lib, referenced from 'jax.numpy'

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 12, 2024

Can you change your commit message to something more informative? Thanks!

copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 12, 2024
…o_buffer` in `python/xla_extension`

Imported from GitHub PR openxla/xla#10433

Encountered bug in jax-ml/jax#20175 (see this [comment](jax-ml/jax#20175 (comment))).

This adjusts the stub file to properly overload `dlpack_managed_tensor_to_buffer` so that both signatures can be checked against.
Copybara import of the project:

--
75cabb5149b4a0bdd9e819fac0ea6a0ba756bff6 by Meekail Zain <[email protected]>:

Update

Merging this change closes #10433

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10433 from Micky774:type_update 75cabb5149b4a0bdd9e819fac0ea6a0ba756bff6
PiperOrigin-RevId: 615173747
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
…o_buffer` in `python/xla_extension`

Imported from GitHub PR openxla/xla#10433

Encountered bug in jax-ml/jax#20175 (see this [comment](jax-ml/jax#20175 (comment))).

This adjusts the stub file to properly overload `dlpack_managed_tensor_to_buffer` so that both signatures can be checked against.
Copybara import of the project:

--
75cabb5149b4a0bdd9e819fac0ea6a0ba756bff6 by Meekail Zain <[email protected]>:

Update

Merging this change closes #10433

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10433 from Micky774:type_update 75cabb5149b4a0bdd9e819fac0ea6a0ba756bff6
PiperOrigin-RevId: 615173747
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 14, 2024
…o_buffer` in `python/xla_extension`

Imported from GitHub PR openxla/xla#10433

Encountered bug in jax-ml/jax#20175 (see this [comment](jax-ml/jax#20175 (comment))).

This adjusts the stub file to properly overload `dlpack_managed_tensor_to_buffer` so that both signatures can be checked against.
Copybara import of the project:

--
75cabb5149b4a0bdd9e819fac0ea6a0ba756bff6 by Meekail Zain <[email protected]>:

Update

Merging this change closes #10433

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#10433 from Micky774:type_update 75cabb5149b4a0bdd9e819fac0ea6a0ba756bff6
PiperOrigin-RevId: 615173747
copybara-service bot pushed a commit to openxla/xla that referenced this pull request Mar 15, 2024
…o_buffer` in `python/xla_extension`

Imported from GitHub PR #10433

Encountered bug in jax-ml/jax#20175 (see this [comment](jax-ml/jax#20175 (comment))).

This adjusts the stub file to properly overload `dlpack_managed_tensor_to_buffer` so that both signatures can be checked against.
Copybara import of the project:

--
75cabb5 by Meekail Zain <[email protected]>:

Update

Merging this change closes #10433

COPYBARA_INTEGRATE_REVIEW=#10433 from Micky774:type_update 75cabb5
PiperOrigin-RevId: 615973838
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Mar 15, 2024
…o_buffer` in `python/xla_extension`

Imported from GitHub PR openxla/xla#10433

Encountered bug in jax-ml/jax#20175 (see this [comment](jax-ml/jax#20175 (comment))).

This adjusts the stub file to properly overload `dlpack_managed_tensor_to_buffer` so that both signatures can be checked against.
Copybara import of the project:

--
75cabb5149b4a0bdd9e819fac0ea6a0ba756bff6 by Meekail Zain <[email protected]>:

Update

Merging this change closes #10433

PiperOrigin-RevId: 615973838
@Micky774
Copy link
Collaborator Author

Internal pytype tests are failing with many variations of this error:

File "/jax/_src/third_party/scipy/interpolate.py", line 4, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
  No xla_client.Device in module jax._src.lib, referenced from 'jax.numpy'

@jakevdp are the internal tests still failing? If so, I will update the typing to use the _Device type to avoid letting this get stalled.

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 18, 2024

Still failing:

File "third_party/py/jax/__init__.py", line 163, in <module>: Couldn't import pyi for 'jax.numpy' [pyi-error]
  Can't find pyi for 'jaxlib.xla_client', referenced from 'jax.numpy'

I think the issue is that trying to depend on xla_client doesn't have an interface file. It's why we use Device = Any in other similar locations, e.g. here: https://github.com/google/jax/blob/aaeeaf5f0caa497d2f6e33d995cdd88a07ee523a/jax/numpy/__init__.pyi#L21-L22
and here:
https://github.com/google/jax/blob/aaeeaf5f0caa497d2f6e33d995cdd88a07ee523a/jax/_src/basearray.pyi#L22-L23

Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good - could we add some test coverage for the new arguments in array_interoperability_test.py?

@jakevdp jakevdp self-assigned this Mar 18, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Apr 4, 2024

This looks good, and is probably ready to merge more-or-less. However there's a subtlety here that I've been thinking about with @yashk2810 – the issue is that the behavior of device=XXX under jit and other transformations is kind of ambiguous: currently device_put is a no-op in this context, which means that this function will silently ignore the device. I think in the short-term we'd prefer to make that an error: from_dlpack should basically fail within any transformation, becuase its semantics are impure: it's reading an external buffer that's not tracked by JAX's normal tracing mechanisms, so e.g. if the buffer changes between the first and second call to the function, cacheing semantics may lead to incorrect outputs.

All of this is somewhat second-order though, so we should probably merge this change and iterate from there.

tests/array_interoperability_test.py Outdated Show resolved Hide resolved
@copybara-service copybara-service bot merged commit f37e503 into jax-ml:main Apr 5, 2024
12 of 14 checks passed
@Micky774 Micky774 deleted the array_api branch April 7, 2024 23:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants