Skip to content

Commit

Permalink
[Core] remove cupy dependency (#3625)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Mar 27, 2024
1 parent e66b629 commit 8f44fac
Show file tree
Hide file tree
Showing 17 changed files with 506 additions and 223 deletions.
7 changes: 5 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ steps:
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.

- label: Distributed Correctness Test
command: pytest -v -s --forked test_basic_distributed_correctness.py
- label: Distributed Tests
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
commands:
- pytest -v -s --forked test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py

- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal

#################### RUNTIME BASE IMAGE ####################
# We used base cuda image because pytorch installs its own cuda libraries.
# However cupy depends on cuda libraries so we had to switch to the runtime image
# However pynccl depends on cuda libraries so we had to switch to the runtime image
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base

Expand Down
20 changes: 0 additions & 20 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"

# whether to build cupy on rocm
ARG BUILD_CUPY="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -78,23 +75,6 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi

# build cupy
RUN if [ "$BUILD_CUPY" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \
&& cd cupy \
&& pip install mpi4py-mpich \
&& pip install scipy==1.9.3 \
&& pip install cython==0.29.* \
&& env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
&& export CUPY_INSTALL_USE_HIP=1 \
&& export ROCM_HOME=/opt/rocm \
&& export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \
&& pip install . \
&& cd ..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
outlines == 0.0.34
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
6 changes: 0 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,6 @@ def get_requirements() -> List[str]:
if _is_cuda():
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
if get_nvcc_cuda_version() <= Version("11.8"):
# replace cupy-cuda12x with cupy-cuda11x for cuda 11.x
for i in range(len(requirements)):
if requirements[i].startswith("cupy-cuda12x"):
requirements[i] = "cupy-cuda11x"
break
elif _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
Expand Down
17 changes: 13 additions & 4 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.
Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_basic_distributed_correctness.py
```
"""
import os

import pytest
import torch

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
os.environ["TEST_DIST_MODEL"],
]


Expand Down
20 changes: 20 additions & 0 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
import os

import pytest
import ray
import torch
Expand All @@ -16,6 +18,12 @@
@ray.remote(num_gpus=1, max_calls=1)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_elements = 8
Expand All @@ -32,6 +40,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_dimensions = 3
Expand All @@ -54,6 +68,12 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# they will be able to set the device to the correct GPU
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
test_dict = {
Expand Down
90 changes: 90 additions & 0 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import multiprocessing
import os

import pytest
import torch

from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetUniqueId)


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env['RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()

for p in processes:
p.join()


def update_env(fn):
# `multiprocessing.Process` cannot accept environment variables directly
# so we need to pass the environment variables as arguments
# and update the environment variables in the function
def wrapper(env):
import os
os.environ.update(env)
fn()

return wrapper


@update_env
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == comm.world_size


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
distributed_run(worker_fn, 2)


@update_env
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
comm = NCCLCommunicator()
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
comm.all_reduce(a)
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**0
graph.replay()
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)


def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# as long as the function doesn't raise an exception, we're good
assert unique_id is not None
6 changes: 2 additions & 4 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
is_driver_worker=True,
)

# FIXME(woosuk): We are not properly initializing cupy NCCL when
# FIXME(woosuk): We are not properly initializing pynccl when
# we have multiple nodes.
self._run_workers("init_device",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers("init_device")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/parallel_utils/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import torch
from torch.distributed import ProcessGroup

from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.custom_all_reduce import (
custom_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_cupy_nccl_enabled_for_all_reduce)
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand All @@ -30,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
if is_cupy_nccl_enabled_for_all_reduce():
if is_pynccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
pynccl_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
Expand Down
Loading

0 comments on commit 8f44fac

Please sign in to comment.