Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Usage]: OpenRLHF: How can I create a second NCCL Group in a vLLM v0.4.3+ Ray worker? #5477

Open
hijkzzz opened this issue Jun 13, 2024 · 8 comments
Labels
usage How to use vllm

Comments

@hijkzzz
Copy link

hijkzzz commented Jun 13, 2024

Your current environment

We are working on accelerating RLHF algorithms and need to broadcast the weights of the DeepSpeed engine to the vLLM Ray worker. In v0.4.2, we were able to create an additional NCCL group to achieve this. However, after updating to v0.4.3 and incorporating the changes from this MR, we found that doing so causes NCCL errors during broadcast.

Our weight synchronization code is located at: https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_engine.py.
and
https://github.com/OpenLLMAI/OpenRLHF/blob/main/openrlhf/trainer/ray/vllm_worker_wrap.py

see init_process_group (build NCCL group between vLLM and DeepSpeed named self._model_update_group)

and update_weight (Broadcast weights from DeepSpeed to vLLM, torch.distributed.broadcast(weight, 0, group=self._model_update_group))

We temporarily replaced the NCCL backend with GLOO to make it work, but the performance was poor。

The error message is:

�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] Error executing method start_worker_execution_loop. This might cause deadlock in distributed execution.
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] Traceback (most recent call last):
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 140, in execute_method
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     return executor(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     return func(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/worker/worker.py", line 286, in start_worker_execution_loop
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     while self._execute_model_non_driver():
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/worker/worker.py", line 295, in _execute_model_non_driver
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     data = broadcast_tensor_dict(src=0)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/vllm/distributed/communication_op.py", line 284, in broadcast_tensor_dict
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     torch.distributed.broadcast_object_list(recv_metadata_list,
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     return func(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2649, in broadcast_object_list
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     broadcast(object_sizes_tensor, src=src, group=group)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     return func(*args, **kwargs)
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]   File "/opt/conda/envs/ptca/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2144, in broadcast
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148]     work.wait()
�[36m(RayWorkerWrapper pid=4183, ip=10.3.32.122)�[0m ERROR 06-03 23:13:39 worker_base.py:148] RuntimeError: [../third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:81] Timed out waiting 1800000ms for recv operation to complete

Even call self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop() before broadcast, there will still be one other NCCL error.

(LLMRayActor pid=116814) /12 :    0   1
(LLMRayActor pid=116814) a5
(LLMRayActor pid=116812) Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] -1/-1/-1->0->1 [4] -1/-1/-1->0->1 [5] -1/-1/-1->0->1 [6] 1/-1/-1->0->-1 [7] 1/-1/-1->0->-1 [8] 1/-1/-1->0->-1 [9] -1/-1/-1->0->1 [10] -1/-1/-1->0->1 [11] -1/-1/-1->0->1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120170 [0] proxy.cc:1336 NCCL WARN Cuda failure 1 'invalid argument'
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] transport/p2p.cc:272 NCCL WARN Cuda failure 'invalid argument'
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL INFO transport/p2p.cc:327 -> 1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL INFO transport/p2p.cc:507 -> 1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL INFO transport.cc:183 -> 1
(LLMRayActor pid=116812) a5fa65866c9c:116812:120158 [0] NCCL IERROR 06-13 13:24:49 worker_base.py:148] Error executing method update_weight. This might cause deadlock in distributed execution.
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] Traceback (most recent call last):
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]   File "/home/jianh/.local/lib/python3.10/site-packages/vllm/worker/worker_base.py", line 140, in execute_method
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]     return executor(*args, **kwargs)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]   File "/tmp/ray/session_2024-06-13_13-16-35_468561_107280/runtime_resources/working_dir_files/_ray_pkg_d1835c417c453aec/openrlhf/trainer/ray/vllm_worker_wrap.py", line 39, in update_weight
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]     torch.distributed.broadcast(weight, 0, group=self._model_update_group)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/c10d_logger.py", line 75, in wrapper
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]     return func(*args, **kwargs)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]   File "/usr/local/lib/python3.10/dist-packages/torch/distributed/distributed_c10d.py", line 2140, in broadcast
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148]     work = group.broadcast([tensor], opts)
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] torch.distributed.DistBackendError: NCCL error in: ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1970, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.20.5
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] ncclUnhandledCudaError: Call to CUDA function failed.
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] Last error:
(LLMRayActor pid=116812) ERROR 06-13 13:24:49 worker_base.py:148] Cuda failure 'invalid argument'
(RayWorkerWrapper pid=117813) a5fa65866c9c:117813:120165 [1] NCCL INF
(RayWorkerWrapper pid=117839) ERROR 06-13 13:24:49 worker_base.py:148] Error executing method update_weight. This might cause deadlock in distributed execution.

I think our code torch.distributed.broadcast(weight, 0, group=self._model_update_group) may be conflicts with this this MR. btw, I'm not sure how to fix it.

@hijkzzz hijkzzz added the usage How to use vllm label Jun 13, 2024
@hijkzzz
Copy link
Author

hijkzzz commented Jun 13, 2024

@njhill Do you have any insights? Thanks.

@njhill
Copy link
Member

njhill commented Jun 17, 2024

@hijkzzz I don't have any immediate insight. I can take a closer look but can't promise how soon.

We could also consider adding a flag to disable the behaviour introduced in #4894, in particular to have the remote worker "loop" always exit after a single iteration. There would be a performance downside to that but it may help with cases like yours.

@youkaichao
Copy link
Member

actually, I'm quite surprised that it worked previously. vLLM should take control over all distributed initialization and destruction. How can you add another process into the group?

@hijkzzz
Copy link
Author

hijkzzz commented Jun 18, 2024

actually, I'm quite surprised that it worked previously. vLLM should take control over all distributed initialization and destruction. How can you add another process into the group?

We hacked the init_process_group API and created a new group for vLLM engines and rank0 of DeepSpeed.
See here: https://github.com/OpenLLMAI/OpenRLHF/blob/188139f809d9d14a8b1d8210f9e6746e2422e4e0/openrlhf/utils/distributed_util.py#L20
and
https://github.com/OpenLLMAI/OpenRLHF/blob/188139f809d9d14a8b1d8210f9e6746e2422e4e0/openrlhf/trainer/ray/ppo_actor.py#L89
Thanks

@youkaichao
Copy link
Member

This is quite hacky. If possible, I suggest sharing cuda tensors across process, e.g. if vLLM has TP processes, and your DeepSpeed process group also has TP processes, they can share cuda tensor without copying around. It requires the two groups own the same set of tensors though.

@hijkzzz
Copy link
Author

hijkzzz commented Jun 19, 2024

This is quite hacky. If possible, I suggest sharing cuda tensors across process, e.g. if vLLM has TP processes, and your DeepSpeed process group also has TP processes, they can share cuda tensor without copying around. It requires the two groups own the same set of tensors though.

This cannot meet the requirements for multi-machine distributed training in RLHF.

@hijkzzz
Copy link
Author

hijkzzz commented Sep 1, 2024

This bug can be solved by export NCCL_P2P_DISABLE=1
btw, this will result in a loss of performance.

@vwxyzjn
Copy link

vwxyzjn commented Oct 28, 2024

Hi there, has there been an update or workaround for this issue?

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
usage How to use vllm
Projects
None yet
Development

No branches or pull requests

4 participants