Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Dead lock in distributed inference when ray worker raises an exception #3455

Closed
youkaichao opened this issue Mar 17, 2024 · 9 comments
Labels
bug Something isn't working stale

Comments

@youkaichao
Copy link
Member

Your current environment

Any distributed inference tasks with ray currently suffer from this issue.

🐛 Describe the bug

Basic background of ray

ray provides an easy-to-use asynchronous execution framework:

def f():
    print(1)

import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle
result = ray.get(handle) # synchronously wait for the worker to finish and return the result

The way it deals with Exception is noteworthy, see comments in the below:

def f():
    print(1)
    raise RuntimeError("test")
    # the following line will not be executed
    print(2)

import ray
ray.init()
marked_function = ray.remote(f) # mark `f` as a remote function that can be asynchronously executed
handle = marked_function.remote() # schedule a worker to asynchronously execute the function, immediately return a handle

# ... do other work in the meantime ...
# the main process will not be notified if the worker fails

# only when we call `ray.get` will we be notified of the error
result = ray.get(handle) # raise the error that was thrown in the worker, wrapping it in a RayTaskError

The deadlock in distributed inference

The deadlock happens during initialization of distributed inference, i.e. creating process group to collaborate.

A minimal reproducible example looks like this:

import torch
import torch.distributed as dist

def f(rank, world_size, distributed_init_method):
    # raise RuntimeError # uncoment this line to see a deadlock
    dist.init_process_group(
        backend="gloo",
        init_method=distributed_init_method,
        world_size=world_size,
        rank=rank,
    )
    tensor = torch.zeros(1)
    dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
    print(f"Rank {rank} has data {tensor.item()}")

import ray
ray.init()
marked_function = ray.remote(f)

distributed_init_method = "tcp://127.0.0.1:29500"
world_size = 2

# start the first process
handle = marked_function.remote(rank=0, world_size=world_size, distributed_init_method=distributed_init_method)

# the main process is the second process
# wait for the first process to join here to initialize the process group for distributed environment
dist.init_process_group(backend="gloo", init_method=distributed_init_method, world_size=world_size, rank=1)

# two processes are ready to communicate
tensor = torch.ones(1)
dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
print(f"Rank 1 has data {tensor.item()}")

result = ray.get(handle)

Normally it works with the following output:

2024-03-17 10:24:23,293 INFO worker.py:1724 -- Started a local Ray instance.
Rank 1 has data 1.0
(f pid=14616) Rank 0 has data 1.0

However, if the f function throws an exception before calling dist.init_process_group, it will be kept in an error state, waiting for the main process to call ray.get to error out; meanwhile, the main process is stuck at dist.init_process_group, waiting for the worker process to join to initialize the process group for distributed environment. Together they caused a deadlock.

How is this related with vLLM

vLLM uses ray for distributed inference, and the core code is attached below:

def _run_workers(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
output_channels = self.forward_dag.execute(1)
else:
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
try:
ray_worker_outputs = [
pickle.loads(chan.begin_read())
for chan in output_channels
]
finally:
# Has to call end_read in order to reuse the DAG.
for chan in output_channels:
chan.end_read()
else:
ray_worker_outputs = ray.get(ray_worker_outputs)
return [driver_worker_output] + ray_worker_outputs

When calling init_model, both ray worker and the main process will reach the following function:

def init_model(self, cupy_port: Optional[int] = None) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
cupy_port, self.distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)

And essentially we are back to the minimal reproducible example mentioned before. All of the exception before init_distributed_environment can cause deadlock.

In my case, my GPU driver has some problem, and torch.cuda.set_device raises an exception, causing the deadlock.

Solution to be discussed

Any suggestion to fix this is welcome.

Might be related: #2466 .

@youkaichao youkaichao added the bug Something isn't working label Mar 17, 2024
@youkaichao
Copy link
Member Author

What's worse, there are many cases inside init_distributed_environment that can cause Exception, and many synchronization point that can cause both main process and ray worker to wait for each other.

Any control divergence during this period (e.g. ray worker raised Exception while the main process is waiting for creating process group), causes a deadlock.

vllm/vllm/worker/worker.py

Lines 252 to 305 in abfc4f3

def init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
if cupy_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size()
if cupy_world_size != parallel_config.world_size:
raise RuntimeError(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif (parallel_config.world_size > 1 and cupy_port is not None):
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
host="localhost",
port=cupy_port,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
if cupy_utils.is_initialized():
cupy_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
init_custom_ar()

def init_process_group(world_size: int, rank: int, host: str,
port: int) -> None:
"""Initializes the CuPy NCCL backend.
# TODO: handle NCCL timeouts.
"""
assert not is_initialized()
if isinstance(cupy, Exception):
raise ImportError(
"NCCLBackend is not available. Please install cupy.") from cupy
# TODO(woosuk): Create TP and PP process groups for CuPy.
global _NCCL_BACKEND
global _WORLD_SIZE
assert world_size > 0, f"{world_size=} should be a positive integer"
assert 0 <= rank < world_size, (
f"{rank=} should be a integer between [0, {world_size})")
cupy.cuda.runtime.setDevice(torch.cuda.current_device())
_NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port)
_WORLD_SIZE = world_size
# Stop the TCP store to prevent the deadlock issues at termination time.
# FIXME(woosuk): This is hacky. Find a more robust solution.
if rank == 0 and hasattr(_NCCL_BACKEND, "_store"):
_NCCL_BACKEND._store.stop()

The core code of init_distributed_environment involves the above two functions. And there are many, many possible Exception and synchronization points.

We need to come up with a better way for initializing distributed inference.

@richardliaw
Copy link
Collaborator

You should probably just have Ray pick up the first raised exception (via ray.wait) and then kill the rest of the workers when that happens

@youkaichao
Copy link
Member Author

youkaichao commented Mar 18, 2024

You should probably just have Ray pick up the first raised exception (via ray.wait)

The problem is we don't know whether the worker will raise exception. Normally we expect all workers (plus main process) to run smoothly to initialize a process group, but here the main process has a difficult decision to make. It cannot wait and test worker exception while waiting for initializing a process group at the same time.

@youkaichao
Copy link
Member Author

For future reference:

Some nightly build pytorch contains a bug that will initialize cuda context during import torch. This makes the module not pickle-able, and will cause error. Combined with the deadlock mechanism discussed in this issue, these buggy torch versions will cause deadlock when used with vllm, as demonstrated in #3457 .

The code to detect whether we have a buggy torch version is:

# code borrowed from https://github.com/pytorch/pytorch/pull/117010

import torch
import ctypes
x = ctypes.c_int(-1)
# `ans` holds the error code, and `x` holds the device count
ans = ctypes.CDLL('libcuda.so.1').cuDeviceGetCount(ctypes.byref(x))

# normally, `import torch` does not initialize cuda, so we get CUDA_ERROR_NOT_INITIALIZED , which is 3
# check https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html for detailed error code
if ans == 3 and x.value == -1 :
    print("your torch version is good!")

if ans == 0:
    print("your torch version contains a bug!")

It seems some nightly build of pytorch (from torch-2.2.0.dev20231116 to torch-2.3.0.dev20231224, or to be specific, any torch version contains code from this PR pytorch/pytorch#112623 ) are affected.

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 5, 2024

It cannot wait and test worker exception while waiting for initializing a process group at the same time.

Can't you just use multithreading, one to do ray.wait and the other to do dist.init_process_group?

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 5, 2024

The easy and cleaner alternative is simply to put the result of driver into ray object store and then always call ray.get() on all result objectrefs

This idea does not work; always need concurrent polling if len(workers) > 0

@jon-chuang
Copy link
Contributor

jon-chuang commented Aug 5, 2024

Will be fixed by: #6556

Copy link

github-actions bot commented Nov 5, 2024

This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!

@github-actions github-actions bot added the stale label Nov 5, 2024
Copy link

github-actions bot commented Dec 5, 2024

This issue has been automatically closed due to inactivity. Please feel free to reopen if you feel it is still relevant. Thank you!

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Dec 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working stale
Projects
None yet
3 participants