Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] Simplify the handling of shardings in Exported. #18761

Merged
merged 1 commit into from
Dec 2, 2023

Conversation

gnecula
Copy link
Collaborator

@gnecula gnecula commented Dec 1, 2023

Previously, Exported contained tuples of XlaCompatibleSharding for the input and output shardings. These shardings contain references to JAX devices, which is too much for exporting purposes and in fact it gets in the way when we want to serialize the Exported.

We change Exported to carry xla_client.HloSharding instead, which conveniently can be serialized to proto. We use the value None to denote an unspecified sharding. We also add nr_devices and then for exporting purposes we can construct actual XlaCompatibleSharding when we need to.

@gnecula gnecula self-assigned this Dec 1, 2023
@gnecula gnecula requested a review from yashk2810 December 1, 2023 10:32
@gnecula gnecula added the pull ready Ready for copybara import and testing label Dec 1, 2023
@gnecula gnecula force-pushed the export_sharding branch 4 times, most recently from 6528dca to ea1d3b4 Compare December 1, 2023 15:27
if apply_jit:
# Prepare a device assignment. For exporting purposes, all it matters
# is the number of devices.
device_assignment = jax.devices(jax.default_backend())[:nr_devices]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is correct. device_assignment is only always jax.devices(). The devices can be laid out in any order which is usually decided by mesh_utils.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which part is not correct?

I do not understand

. device_assignment is only always jax.devices()

For lowering purposes it does not matter what is the device assigment, e.g., what devices are in what order, because the HLO contains only logical device ids, from 0 to len(device_assignment) - 1. When you compile and run, you need the actual mapping of these logical devices, but not when you lower. This is why I think it is sufficient to work with nr_devices and the HloSharding protos for export purposes.

The code here is used in some rare circumstances, after you lowered a function this gives you a hook to lower the vjp, which will call export(pjit(vjp(f))). For the call to pjit we need actual XlaCompatibleSharding, for which we need a device assignment. But we are only going to lower this code now, we won't compile it, or run it, so it does not matter which devices are in the list.

In export.call_exported we actually use an Exported and there we will use an actual device assignment for the embedding JAX computation, from ModuleContext.axis_context.device_assignment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, sorry I miswrote: I wanted to say: device_assignment is not always jax.devices().

Okay, if this is lowering time, then it's fine. I didn't know this is happening during lowering.

But one thing I would do is test this!

Try with something like: (note that you need to try this on TPU to actually get mesh_utils to return a different order than jax.devices())

devices = mesh_utils.create_device_mesh((4, 2))
mesh = Mesh(devices, ('x', 'y')
NamedSharding(mesh, P(None, 'y'))  # use this sharding and check if the HloSharding is the correct one as you expect.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I shared a colab showing that a function results in the same exact lowering with two different device assignments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Maybe add that as a test too?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really think that a test is warranted here, the fact that the StableHLO depends only on the HloSharding and not the actual device assignment is a property to HLO, not something that the export code ensures.

Previously, Exported contained tuples of `XlaCompatibleSharding`
for the input and output shardings. These shardings contain references
to JAX devices, which is too much for exporting purposes and in fact
it gets in the way when we want to serialize the Exported.

We change Exported to carry `xla_client.HloSharding` instead, which
conveniently can be serialized to proto. We use the value `None` to
denote an unspecified sharding. We also add `nr_devices`
and then for exporting purposes we can construct actual
`XlaCompatibleSharding` when we need to.
@copybara-service copybara-service bot merged commit b51b80e into jax-ml:main Dec 2, 2023
6 checks passed
@gnecula gnecula deleted the export_sharding branch December 2, 2023 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants