Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] remove cupy dependency #3625

Merged
merged 29 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
24868ee
update CMakeLists.txt
youkaichao Mar 26, 2024
0af327b
import file from huge PR
youkaichao Mar 26, 2024
d3439b9
isort
youkaichao Mar 26, 2024
37fd9fd
support amd rccl
youkaichao Mar 26, 2024
9ec0b67
leave todo
youkaichao Mar 26, 2024
4904b72
add pynccl into test
youkaichao Mar 26, 2024
53e2ca3
linter
youkaichao Mar 26, 2024
5cdeb59
update and merge from vllm/worker/model_runner.py in main
youkaichao Mar 26, 2024
af94dc6
fix isort
youkaichao Mar 26, 2024
af8e254
restore and merge vllm/worker/worker.py
youkaichao Mar 26, 2024
db3044a
fix hip condition
youkaichao Mar 26, 2024
f078110
fix cuda condition
youkaichao Mar 26, 2024
18c9437
add error message when libnccl cannot be found
youkaichao Mar 26, 2024
3f1db02
fix test
youkaichao Mar 26, 2024
6c7082d
fix lint
youkaichao Mar 26, 2024
80eec0b
restore test of model and llava
youkaichao Mar 26, 2024
efb52bf
try to clean up cupy
youkaichao Mar 26, 2024
e6fb64d
lint
youkaichao Mar 26, 2024
5eb972e
unify init_method
youkaichao Mar 26, 2024
915213c
further cleanup cupy
youkaichao Mar 26, 2024
9b4d6fc
do not know why, but this fixes ray test
youkaichao Mar 26, 2024
983243e
merge distributed tests
youkaichao Mar 26, 2024
d2c9b4b
update comment for allreduce graph capture
youkaichao Mar 26, 2024
cefae38
explain why delete CUDA_VISIBLE_DEVICES
youkaichao Mar 26, 2024
da58b34
explain update_env
youkaichao Mar 26, 2024
850eca1
move successful load message to debug level
youkaichao Mar 26, 2024
0ed6527
fix lint
youkaichao Mar 26, 2024
0361127
restore CMakeLists.txt
youkaichao Mar 27, 2024
6d18496
leave a todo for __del__
youkaichao Mar 27, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,18 @@ steps:
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.

- label: Distributed Correctness Test
command: pytest -v -s --forked test_basic_distributed_correctness.py
- label: Distributed pynccl Test
command: pytest -v -s --forked test_pynccl.py
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.

- label: Distributed Correctness Test-facebook/opt-125m
command: TEST_DIST_MODEL=facebook/opt-125m pytest -v -s --forked test_basic_distributed_correctness.py
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.

- label: Distributed Correctness Test-meta-llama/Llama-2-7b-hf
command: TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py
working_dir: "/vllm-workspace/tests/distributed"
num_gpus: 2 # only support 1 or 2 for now.
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
#
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")

# used when building pytorch-related extensions
# TODO(youkaichao): only compute this if we are building wheels
# TODO(youkaichao): otherwise autodetect the version from current GPUs
set(TORCH_CUDA_ARCH_LIST "7.0;7.5;8.0;8.6;8.9;9.0")
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")

Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip instal

#################### RUNTIME BASE IMAGE ####################
# We used base cuda image because pytorch installs its own cuda libraries.
# However cupy depends on cuda libraries so we had to switch to the runtime image
# However pynccl depends on cuda libraries so we had to switch to the runtime image
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base

Expand Down
20 changes: 0 additions & 20 deletions Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ RUN echo "FA_BRANCH is $FA_BRANCH"
# In that case, we need to use the python reference attention implementation in vllm
ARG BUILD_FA="1"

# whether to build cupy on rocm
ARG BUILD_CUPY="1"

# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y

Expand Down Expand Up @@ -78,23 +75,6 @@ RUN if [ "$BUILD_FA" = "1" ]; then \
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi

# build cupy
RUN if [ "$BUILD_CUPY" = "1" ]; then \
mkdir -p libs \
&& cd libs \
&& git clone -b hipgraph_enablement --recursive https://github.com/ROCm/cupy.git \
&& cd cupy \
&& pip install mpi4py-mpich \
&& pip install scipy==1.9.3 \
&& pip install cython==0.29.* \
&& env CC=$MPI_HOME/bin/mpicc python -m pip install mpi4py \
&& export CUPY_INSTALL_USE_HIP=1 \
&& export ROCM_HOME=/opt/rocm \
&& export HCC_AMDGPU_TARGET="gfx90a,gfx942,gfx1100" \
&& pip install . \
&& cd ..; \
fi

COPY ./ /app/vllm

RUN python3 -m pip install --upgrade pip
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,3 @@ prometheus_client >= 0.18.0
pynvml == 11.5.0
triton >= 2.1.0
outlines == 0.0.34
cupy-cuda12x == 12.1.0 # Required for CUDA graphs. CUDA 11.8 users should install cupy-cuda11x instead.
6 changes: 0 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,6 @@ def get_requirements() -> List[str]:
if _is_cuda():
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
if get_nvcc_cuda_version() <= Version("11.8"):
# replace cupy-cuda12x with cupy-cuda11x for cuda 11.x
for i in range(len(requirements)):
if requirements[i].startswith("cupy-cuda12x"):
requirements[i] = "cupy-cuda11x"
break
elif _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
Expand Down
17 changes: 13 additions & 4 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
"""Compare the outputs of HF and distributed vLLM when using greedy sampling.

Run `pytest tests/distributed/test_basic_distributed_correctness.py --forked`.
vLLM will allocate all the available memory, so we need to run the tests one
by one. The solution is to pass arguments (model name) by environment
variables.
Run:
```sh
TEST_DIST_MODEL=facebook/opt-125m pytest \
test_basic_distributed_correctness.py
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
test_basic_distributed_correctness.py
```
"""
import os

import pytest
import torch

MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
os.environ["TEST_DIST_MODEL"],
]


Expand Down
12 changes: 12 additions & 0 deletions tests/distributed/test_comm_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
@ray.remote(num_gpus=1, max_calls=1)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
import os
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
del os.environ["CUDA_VISIBLE_DEVICES"]
youkaichao marked this conversation as resolved.
Show resolved Hide resolved
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_elements = 8
Expand All @@ -32,6 +36,10 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1)
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
import os
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_dimensions = 3
Expand All @@ -54,6 +62,10 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@ray.remote(num_gpus=1, max_calls=1)
def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
import os
del os.environ["CUDA_VISIBLE_DEVICES"]
device = torch.device(f"cuda:{rank}")
torch.cuda.set_device(device)
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
test_dict = {
Expand Down
88 changes: 88 additions & 0 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import multiprocessing
import os

import pytest
import torch

from vllm.model_executor.parallel_utils.pynccl import (NCCLCommunicator,
ncclGetUniqueId)


def distributed_run(fn, world_size):
number_of_processes = world_size
processes = []
for i in range(number_of_processes):
env = os.environ.copy()
env['RANK'] = str(i)
env['WORLD_SIZE'] = str(number_of_processes)
env['MASTER_ADDR'] = 'localhost'
env['MASTER_PORT'] = '12345'
p = multiprocessing.Process(target=fn, args=(env, ))
processes.append(p)
p.start()

for p in processes:
p.join()


def update_env(fn):
youkaichao marked this conversation as resolved.
Show resolved Hide resolved

def wrapper(env):
import os
os.environ.update(env)
fn()

return wrapper


@update_env
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == comm.world_size


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl():
distributed_run(worker_fn, 2)


@update_env
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
comm = NCCLCommunicator()
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
# TODO(youkaichao)
# seems like this operation is not performed
comm.all_reduce(a)
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**0
graph.replay()
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
def test_pynccl_with_cudagraph():
distributed_run(worker_fn_with_cudagraph, 2)


def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# as long as the function doesn't raise an exception, we're good
assert unique_id is not None
6 changes: 2 additions & 4 deletions vllm/executor/ray_gpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
is_driver_worker=True,
)

# FIXME(woosuk): We are not properly initializing cupy NCCL when
# FIXME(woosuk): We are not properly initializing pynccl when
# we have multiple nodes.
self._run_workers("init_device",
cupy_port=get_open_port()
if not model_config.enforce_eager else None)
self._run_workers("init_device")
self._run_workers(
"load_model",
max_concurrent_workers=self.parallel_config.
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/parallel_utils/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import torch
from torch.distributed import ProcessGroup

from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils import pynccl_utils
from vllm.model_executor.parallel_utils.custom_all_reduce import (
custom_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, is_cupy_nccl_enabled_for_all_reduce)
get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce)


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand All @@ -30,9 +30,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
if is_cupy_nccl_enabled_for_all_reduce():
if is_pynccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
cupy_utils.all_reduce(input_)
pynccl_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
Expand Down
Loading
Loading