-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] separate distributed_init from worker #3904
[Core] separate distributed_init from worker #3904
Conversation
With this refactor, we have the following process group:
|
What caused confusion here, is the advanced usage of I asked the pytorch team, and got the following very helpful message:
Here is a basic example: # test.py
import torch
import torch.distributed as dist
dist.init_process_group(backend='gloo')
group0 = dist.group.WORLD
data = torch.ones((5, 5, 5))
dist.all_reduce(data)
# prints 4, gloo does allreduce for 4 GPU tensors
print(f"{data.mean().item()}")
if dist.get_rank() in [0, 1]:
group1 = dist.new_group(ranks=[0, 1], backend="nccl")
torch.cuda.set_device(dist.get_rank())
else:
group1 = dist.new_group(ranks=[2, 3], backend="gloo")
data1 = torch.ones((5, 5, 5)).cuda()
# prints 2
# for rank 0, 1, the allreduce is done by nccl, and requires setting `torch.cuda.set_device`
# for rank 2, 3, the allreduce is done by gloo, it does not require setting anything. general cpu/gpu tensors are fine
dist.all_reduce(data1, group=group1)
print(f"{data1.mean().cpu().item()}") Run this with
|
In fact, even when we call # test.py
import torch
import torch.distributed as dist
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())
a = torch.ones((5, 5, 5)).cuda()
print("before broadcast")
dist.broadcast(a, src=0)
print(f"{a.mean().cpu().item()}") Run with
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change based on our offline discussion. More specifically, make sure the default torch communication still uses device-specific communicator. Then, we can have a separate communication group just for CPU gloo communication.
The discussion is inspiring, indeed the user experience would be better if a simple call to Refactored according to the discussion, @zhuohan123 PTAL. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Left some small comments.
Currently, distributed_init is coupled with worker, e.g. it requires a parallel_config, and only works with gpu backend (nccl).
This PR is the first step to refactor distributed_init, to make it general (only depends on CPU and does not need any vllm internal types).
This way, distributed_init can be used in a standalone way, e.g. in tests, devices other than GPU, etc. A demonstration is that we can also use it in CPU backend.
Going further, I plan to move
vllm.model_executor.parallel_utils
tovllm.parallel_utils
, because it should have nothing to do withmodel_executor
.