-
-
Notifications
You must be signed in to change notification settings - Fork 4.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: Dead lock in distributed inference when ray worker raises an exception #3455
Comments
What's worse, there are many cases inside Any control divergence during this period (e.g. ray worker raised Exception while the main process is waiting for creating process group), causes a deadlock. Lines 252 to 305 in abfc4f3
vllm/vllm/model_executor/parallel_utils/cupy_utils.py Lines 70 to 96 in abfc4f3
The core code of We need to come up with a better way for initializing distributed inference. |
You should probably just have Ray pick up the first raised exception (via |
The problem is we don't know whether the worker will raise exception. Normally we expect all workers (plus main process) to run smoothly to initialize a process group, but here the main process has a difficult decision to make. It cannot wait and test worker exception while waiting for initializing a process group at the same time. |
For future reference: Some nightly build pytorch contains a bug that will initialize cuda context during The code to detect whether we have a buggy torch version is: # code borrowed from https://github.com/pytorch/pytorch/pull/117010
import torch
import ctypes
x = ctypes.c_int(-1)
# `ans` holds the error code, and `x` holds the device count
ans = ctypes.CDLL('libcuda.so.1').cuDeviceGetCount(ctypes.byref(x))
# normally, `import torch` does not initialize cuda, so we get CUDA_ERROR_NOT_INITIALIZED , which is 3
# check https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html for detailed error code
if ans == 3 and x.value == -1 :
print("your torch version is good!")
if ans == 0:
print("your torch version contains a bug!") It seems some nightly build of pytorch (from |
Can't you just use multithreading, one to do |
This idea does not work; always need concurrent polling if len(workers) > 0 |
Will be fixed by: #6556 |
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you! |
This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you! |
Your current environment
Any distributed inference tasks with ray currently suffer from this issue.
🐛 Describe the bug
Basic background of
ray
ray
provides an easy-to-use asynchronous execution framework:The way it deals with
Exception
is noteworthy, see comments in the below:The deadlock in distributed inference
The deadlock happens during initialization of distributed inference, i.e. creating process group to collaborate.
A minimal reproducible example looks like this:
Normally it works with the following output:
However, if the
f
function throws an exception before callingdist.init_process_group
, it will be kept in an error state, waiting for the main process to callray.get
to error out; meanwhile, the main process is stuck atdist.init_process_group
, waiting for the worker process to join to initialize the process group for distributed environment. Together they caused a deadlock.How is this related with
vLLM
vLLM
usesray
for distributed inference, and the core code is attached below:vllm/vllm/executor/ray_gpu_executor.py
Lines 299 to 351 in 6b78837
When calling
init_model
, both ray worker and the main process will reach the following function:vllm/vllm/worker/worker.py
Lines 71 to 96 in abfc4f3
And essentially we are back to the minimal reproducible example mentioned before. All of the exception before
init_distributed_environment
can cause deadlock.In my case, my GPU driver has some problem, and
torch.cuda.set_device
raises an exception, causing the deadlock.Solution to be discussed
Any suggestion to fix this is welcome.
Might be related: #2466 .
The text was updated successfully, but these errors were encountered: