Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core][distributed] fix custom allreduce in pytorch 2.5 #9815

Merged
merged 4 commits into from
Oct 30, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Oct 29, 2024

fixes #9774

pytorch changes the binary format of the ipc handle.

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao
Copy link
Member Author

youkaichao commented Oct 29, 2024

cc @hanzhi713

can you please improve the code if you have time? ideally we should use pytorch's user-facing api, rather than this private api.

example of user-facing api:

# producer:

import torch
from torch.multiprocessing.reductions import reduce_tensor

inp = torch.randn(5, 5).cuda()

out = reduce_tensor(inp)

# send `out` to consumer

# consumer

func = out[0]
tensor = func(*out[1])

this way, we won't suffer from too many pytorch internal details.

and once we share the tensor from python side, c++ side code will be much simpler, and we can also benefit from expandable segment in the future.

@cedonley
Copy link

FYI - can confirm this fixes the issue with TP=2 on NVLink A6000's that was introduced with the upgrade to pytorch 2.5. Nice catch on the handle format change. I thought I was going crazy, but hadn't noticed the 2-byte change in length of the handle itself when I compared the data doing into the ipc functions in the two versions.

@youkaichao
Copy link
Member Author

@cedonley thanks for your report and investigation!

@youkaichao youkaichao merged commit 1ab6f6b into vllm-project:main Oct 30, 2024
29 of 31 checks passed
@youkaichao youkaichao deleted the fix_ca branch October 30, 2024 00:06
@hanzhi713
Copy link
Contributor

@youkaichao There're no user facing API for getting a shareable handle. To avoid using internal Pytorch API, I think we can just call cudaIpcGetMemHandle directly on pytorch allocated tensors like here

CUDACHECK(cudaIpcGetMemHandle(

The downside would be that we will lose pytorch's safeguards against leaks, but I think that might not be a problem since allocations in custom allreduce are one-time allocations.

@hanzhi713
Copy link
Contributor

I will have some time this weekend to work on this.

@youkaichao
Copy link
Member Author

do you really need a handle?

we can get an ipc-shared tensor directly:

# producer:

import torch
from torch.multiprocessing.reductions import reduce_tensor

inp = torch.randn(5, 5).cuda()

out = reduce_tensor(inp)

# send `out` to consumer

# consumer

func = out[0]
tensor = func(*out[1])

@youkaichao
Copy link
Member Author

@hanzhi713 welcome to join our new slack https://slack.vllm.ai for chatting and collaborating!

@hanzhi713
Copy link
Contributor

hanzhi713 commented Oct 30, 2024

How do you share such a handle to other processes? I think vllm still runs one process per GPU right?

Is it just sending pickled data through torch.distributed? IMHO torch.multiprocessing is designed for sharing handles through multiprocessing.Process. It may work with generic processes, but I'm not sure if there are any caveats.

@youkaichao
Copy link
Member Author

Is it just sending pickled data through torch.distributed?

yes.

IMHO torch.multiprocessing is designed for sharing handles through multiprocessing.Process

it applies to general processes, no matter how the process is launched.

rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Nov 4, 2024
hissu-hyvarinen pushed a commit to ROCm/vllm that referenced this pull request Nov 6, 2024
siddvenk pushed a commit to siddvenk/vllm that referenced this pull request Nov 8, 2024
JC1DA pushed a commit to JC1DA/vllm that referenced this pull request Nov 11, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: [help wanted] MoE + TP + custom allreduce bug
3 participants