Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] separate distributed_init from worker #3904

Merged
merged 13 commits into from
Apr 9, 2024

Conversation

youkaichao
Copy link
Member

Currently, distributed_init is coupled with worker, e.g. it requires a parallel_config, and only works with gpu backend (nccl).

This PR is the first step to refactor distributed_init, to make it general (only depends on CPU and does not need any vllm internal types).

This way, distributed_init can be used in a standalone way, e.g. in tests, devices other than GPU, etc. A demonstration is that we can also use it in CPU backend.

Going further, I plan to move vllm.model_executor.parallel_utils to vllm.parallel_utils , because it should have nothing to do with model_executor.

@youkaichao youkaichao requested a review from zhuohan123 April 7, 2024 23:20
@youkaichao
Copy link
Member Author

With this refactor, we have the following process group:

  • global process group, with gloo backend, can be used in any situation. this is initialized by vllm.model_executor.parallel_utils.parallel_state.init_distributed_environment
  • tensor parallel process group, with user-specified backend, can be used in tensor parallel, initialized by vllm.model_executor.parallel_utils.parallel_state.ensure_model_parallel_initialized with backend option (default is nccl)
  • pipeline parallel process group, currently it is initialized together with the tensor parallel process group

@youkaichao youkaichao requested a review from WoosukKwon April 8, 2024 01:39
@zhuohan123 zhuohan123 self-assigned this Apr 8, 2024
@youkaichao
Copy link
Member Author

What caused confusion here, is the advanced usage of torch.distributed.

I asked the pytorch team, and got the following very helpful message:

  • inside every process, dist.init_process_group should be called only once
  • after that, they can form new group, via dist.new_group, with different ranks or backend from that used in dist.init_process_group
  • nccl backend needs more care to be taken, e.g. it requires proper torch.cuda.set_device(dist.get_rank()), cannot be used to do collective computation inside one GPU. gloo is more general, more stable, works for both cpu and gpu, but slower than nccl for multi-gpu setting (which is the designed usecase of nccl).

Here is a basic example:

# test.py
import torch
import torch.distributed as dist
dist.init_process_group(backend='gloo')
group0 = dist.group.WORLD
data = torch.ones((5, 5, 5))
dist.all_reduce(data)
# prints 4, gloo does allreduce for 4 GPU tensors
print(f"{data.mean().item()}")
if dist.get_rank() in [0, 1]:
    group1 = dist.new_group(ranks=[0, 1], backend="nccl")
    torch.cuda.set_device(dist.get_rank())
else:
    group1 = dist.new_group(ranks=[2, 3], backend="gloo")
data1 = torch.ones((5, 5, 5)).cuda()
# prints 2
# for rank 0, 1, the allreduce is done by nccl, and requires setting `torch.cuda.set_device`
# for rank 2, 3, the allreduce is done by gloo, it does not require setting anything. general cpu/gpu tensors are fine
dist.all_reduce(data1, group=group1)
print(f"{data1.mean().cpu().item()}")

Run this with torchrun --nproc-per-node 4 test.py, we will get:

4.0
4.0
4.0
4.0
2.0
2.0
2.0
2.0

@youkaichao
Copy link
Member Author

In fact, even when we call dist.init_process_group(backend='nccl'), nccl backend is not created until the first gpu communication happens:

# test.py
import torch
import torch.distributed as dist
dist.init_process_group(backend='nccl')
torch.cuda.set_device(dist.get_rank())
a = torch.ones((5, 5, 5)).cuda()
print("before broadcast")
dist.broadcast(a, src=0)
print(f"{a.mean().cpu().item()}")

Run with export NCCL_DEBUG=TRACE; torchrun --nproc-per-node 2 test.py :

--omitted--
before broadcast
before broadcast
--omitted--
flaminio:2341839:2341901 [0] NCCL INFO comm 0x6cc3ce0 rank 0 nranks 2 cudaDev 0 busId 6000 commId 0x82b0333d5875b736 - Init COMPLETE
flaminio:2341840:2341902 [1] NCCL INFO comm 0x85fbb30 rank 1 nranks 2 cudaDev 1 busId 7000 commId 0x82b0333d5875b736 - Init COMPLETE
1.0
1.0
flaminio:2341839:2341907 [0] NCCL INFO [Service thread] Connection closed by localRank 0
flaminio:2341840:2341908 [1] NCCL INFO [Service thread] Connection closed by localRank 1
flaminio:2341839:2341839 [0] NCCL INFO comm 0x6cc3ce0 rank 0 nranks 2 cudaDev 0 busId 6000 - Abort COMPLETE
flaminio:2341840:2341840 [1] NCCL INFO comm 0x85fbb30 rank 1 nranks 2 cudaDev 1 busId 7000 - Abort COMPLETE

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change based on our offline discussion. More specifically, make sure the default torch communication still uses device-specific communicator. Then, we can have a separate communication group just for CPU gloo communication.

vllm/model_executor/parallel_utils/communication_op.py Outdated Show resolved Hide resolved
vllm/worker/worker.py Outdated Show resolved Hide resolved
@youkaichao
Copy link
Member Author

The discussion is inspiring, indeed the user experience would be better if a simple call to torch.distributed.all_reduce uses device-specific communication backend.

Refactored according to the discussion, @zhuohan123 PTAL.

@youkaichao youkaichao requested a review from zhuohan123 April 8, 2024 23:20
Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Left some small comments.

vllm/model_executor/parallel_utils/parallel_state.py Outdated Show resolved Hide resolved
vllm/worker/worker.py Outdated Show resolved Hide resolved
@youkaichao youkaichao enabled auto-merge (squash) April 9, 2024 08:06
@youkaichao youkaichao merged commit 6d592eb into vllm-project:main Apr 9, 2024
35 checks passed
@youkaichao youkaichao deleted the standalone_distributed_init branch April 9, 2024 15:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants