Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

API for sharded linear. #3125

Open
wujingyue opened this issue Oct 7, 2024 · 0 comments
Open

API for sharded linear. #3125

wujingyue opened this issue Oct 7, 2024 · 0 comments

Comments

@wujingyue
Copy link
Collaborator

This feature request is to create a drop-in replacement for https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html that's sharded.

A linear layer can be sharded in several ways. For example,

  1. [b, s, DIDx{4h}] = linear([b, s, h], [DIDx{4h}, h], [DIDx{4h}])
  2. [b, s, h] = linear([b, s, DIDx{4h}], [h, DIDx{4h}], [h])
  3. With data parallelism, expect b to be DIDy parallel in addition to the hidden dimension.

Due to #2563, we have to manually split the device dimension in the logical domain instead of having it as logical-to-loop transforms. This has prevented us from having a drop-in replacement. For example, the weight has to be 3D and torch.linear takes 2D weight.

#3073 is an attempt to support case 1 with the limitation of #2563. I did this first because having case 1 fused turns out to be important for performance.

Other cases are yet to be done, and probably should be done after #2563 to avoid accumulating too many tech debts. Note: for case 2 in particular, we'll also need to decompose a sharded linear into matmul + collective + biasadd.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant