-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for device kwarg in astype, and add matching utility func #21086
base: main
Are you sure you want to change the base?
Conversation
593f01c
to
5445230
Compare
@yashk2810 Sorry for the delay in getting this message out. I've updated the PR with simplifications to the Is there already a mechanism to predict when two shards are equivalent in the sense that calling cc: @jakevdp if you have any insights regarding how we want to handle |
The easiest thing would just be to error if |
f0f52a5
to
60eaabc
Compare
sorry, I don't understand what this means. Also, why do you need such a mechanism? |
I was looking for a way to tell if |
device_put should do that for you :) |
Yash - the issue is that the semantics of |
Ok, then why not error instead of doing the no-op logic here? If you want to transfer or have it be a no-op should be device_put's job. If device is not None and copy=True, then we should error since it makes no sense right? |
That's what I suggested above: the most conservative thing would be to error if |
We can abstract away a function which can determine that which we can share here but device_put is complex so let's error for now and file a bug against me to give you such a function which you can call here. |
e3fdebb
to
33d2df7
Compare
@@ -37,6 +37,9 @@ Remember to align the itemized text with the first line of an item within a list | |||
|
|||
## jax 0.4.28 (May 9, 2024) | |||
|
|||
* New Functionality | |||
* {func}`jax.numpy.astype` supports a new `device` keyword argument. | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move up to 0.4.29
Note that JAX may not always strictly adhere to array API device semantics when | ||
using ``jax.jit``. In particular, specifying the ``device`` argument is | ||
equivalent to calling ``jax.device_put(x, device)``. For up-to-date details on | ||
device placement, see the documentation of ``jax.device_put`` for more details. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually references go at the bottom of the docstring. You can put this paragraph above the .. _Python...
line
|
||
This utility uses `jax.device_put` for device placement. | ||
""" | ||
out = x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why out=x?
) | ||
out = api.device_put(out, device) | ||
|
||
# TODO(micky774): Avoid copy if data has already been copied via device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This todo doesn't make sense? In this branch the device is None. Remove the todo?
device = jax.devices("cpu")[0] | ||
expected_sharding = SingleDeviceSharding(device) | ||
else: | ||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use (2, 2) mesh to coverage across more hardware?
expected_sharding = SingleDeviceSharding(device) | ||
else: | ||
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) | ||
device = jax.sharding.NamedSharding(global_mesh, P('x', 'y')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice to try other PartitionSpecs too (can you parameterize the test on that too)
Towards #20200
This PR adds a private device-placement utility
jax._src.numpy.util._place_array
to manage array API compliant array placement behavior for use in thejax.numpy
namespace. Copies are mediated by thelax._array_copy
utility, while device transfer is performed viaapi.device_put
.cc: @jakevdp