-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[export] Simplify the handling of shardings in Exported. #18761
Conversation
6528dca
to
ea1d3b4
Compare
if apply_jit: | ||
# Prepare a device assignment. For exporting purposes, all it matters | ||
# is the number of devices. | ||
device_assignment = jax.devices(jax.default_backend())[:nr_devices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is correct. device_assignment
is only always jax.devices(). The devices can be laid out in any order which is usually decided by mesh_utils.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which part is not correct?
I do not understand
. device_assignment is only always jax.devices()
For lowering purposes it does not matter what is the device assigment, e.g., what devices are in what order, because the HLO contains only logical device ids, from 0 to len(device_assignment) - 1
. When you compile and run, you need the actual mapping of these logical devices, but not when you lower. This is why I think it is sufficient to work with nr_devices
and the HloSharding
protos for export purposes.
The code here is used in some rare circumstances, after you lowered a function this gives you a hook to lower the vjp
, which will call export(pjit(vjp(f)))
. For the call to pjit
we need actual XlaCompatibleSharding, for which we need a device assignment. But we are only going to lower this code now, we won't compile it, or run it, so it does not matter which devices are in the list.
In export.call_exported
we actually use an Exported
and there we will use an actual device assignment for the embedding JAX computation, from ModuleContext.axis_context.device_assignment
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh, sorry I miswrote: I wanted to say: device_assignment is not always jax.devices()
.
Okay, if this is lowering time, then it's fine. I didn't know this is happening during lowering.
But one thing I would do is test this!
Try with something like: (note that you need to try this on TPU to actually get mesh_utils to return a different order than jax.devices())
devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('x', 'y')
NamedSharding(mesh, P(None, 'y')) # use this sharding and check if the HloSharding is the correct one as you expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I shared a colab showing that a function results in the same exact lowering with two different device assignments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Maybe add that as a test too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really think that a test is warranted here, the fact that the StableHLO depends only on the HloSharding and not the actual device assignment is a property to HLO, not something that the export code ensures.
Previously, Exported contained tuples of `XlaCompatibleSharding` for the input and output shardings. These shardings contain references to JAX devices, which is too much for exporting purposes and in fact it gets in the way when we want to serialize the Exported. We change Exported to carry `xla_client.HloSharding` instead, which conveniently can be serialized to proto. We use the value `None` to denote an unspecified sharding. We also add `nr_devices` and then for exporting purposes we can construct actual `XlaCompatibleSharding` when we need to.
ea1d3b4
to
3eb3e2d
Compare
Previously, Exported contained tuples of
XlaCompatibleSharding
for the input and output shardings. These shardings contain references to JAX devices, which is too much for exporting purposes and in fact it gets in the way when we want to serialize the Exported.We change Exported to carry
xla_client.HloSharding
instead, which conveniently can be serialized to proto. We use the valueNone
to denote an unspecified sharding. We also addnr_devices
and then for exporting purposes we can construct actualXlaCompatibleSharding
when we need to.