Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't use cupy NCCL for AMD backends #2855

Merged
merged 2 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/model_executor/parallel_utils/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE


def is_initialized() -> bool:
return _CA_HANDLE is not None


@contextmanager
def capture():
try:
Expand Down
22 changes: 16 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import time
from typing import Dict, List, Optional, Tuple, Set, Union

Expand All @@ -9,9 +10,9 @@
SchedulerConfig)
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.cupy_utils import get_nccl_backend
from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce)
from vllm.model_executor.parallel_utils import custom_all_reduce
Expand Down Expand Up @@ -650,7 +651,7 @@ def list_loras(self) -> Set[int]:
def capture_model(self, kv_caches: List[KVCache]) -> None:
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.cupy_nccl_backend = get_nccl_backend()
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()

assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to "
Expand Down Expand Up @@ -680,15 +681,15 @@ def capture_model(self, kv_caches: List[KVCache]) -> None:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]

# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported.
with custom_all_reduce.capture():
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata.
input_metadata = InputMetadata(
Expand Down Expand Up @@ -756,7 +757,7 @@ def capture(
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
with with_cupy_nccl_for_all_reduce():
with _maybe_cupy_nccl():
self.model(
input_ids,
positions,
Expand All @@ -770,7 +771,7 @@ def capture(
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with with_cupy_nccl_for_all_reduce():
with _maybe_cupy_nccl():
hidden_states = self.model(
input_ids,
positions,
Expand Down Expand Up @@ -821,6 +822,15 @@ def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)


@contextlib.contextmanager
def _maybe_cupy_nccl():
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
with with_cupy_nccl_for_all_reduce():
yield
else:
yield


def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len
return x + [pad] * (max_len - len(x))
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest
from vllm.utils import is_hip


class Worker:
Expand Down Expand Up @@ -264,7 +265,8 @@ def init_distributed_environment(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif parallel_config.world_size > 1 and cupy_port is not None:
elif (parallel_config.world_size > 1 and cupy_port is not None
and not is_hip()):
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
Expand Down
Loading