Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Multiprocessing executor for single-node multi-GPU deployment #3466

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ steps:
num_gpus: 2 # only support 1 or 2 for now.
commands:
- pytest -v -s test_pynccl.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s test_chunked_prefill_distributed.py
# Use spawn to avoid CUDA re-init issues
- TEST_DIST_MODEL=facebook/opt-125m MULTIPROC_METHOD=spawn pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf MULTIPROC_METHOD=spawn pytest -v -s test_basic_distributed_correctness.py
- TEST_DIST_MODEL=facebook/opt-125m MULTIPROC_METHOD=spawn pytest -v -s test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf MULTIPROC_METHOD=spawn pytest -v -s test_chunked_prefill_distributed.py

- label: Engine Test
command: pytest -v -s engine tokenization test_sequence.py test_config.py
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ RUN ldconfig /usr/local/cuda-12.1/compat/
# install vllm wheel first, so that torch etc will be installed
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
--mount=type=cache,target=/root/.cache/pip \
pip install dist/*.whl --verbose
pip install "$(echo dist/*.whl)[ray]" --verbose

RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
--mount=type=cache,target=/root/.cache/pip \
Expand Down
1 change: 0 additions & 1 deletion requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
-r requirements-common.txt

# Dependencies for NVIDIA GPUs
ray >= 2.9
pynvml == 11.5.0
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.2.1
Expand Down
4 changes: 2 additions & 2 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Common dependencies
-r requirements-common.txt

# Dependencies for AMD GPUs
ray == 2.9.3
# No specific dependencies currently for AMD GPUs

20 changes: 16 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import sys
from shutil import which
from typing import Dict, List
from typing import Dict, List, Optional

import torch
from packaging.version import Version, parse
Expand Down Expand Up @@ -361,6 +361,20 @@ def _read_requirements(filename: str) -> List[str]:
return requirements


def get_extra_requirements() -> Optional[Dict[str, List[str]]]:
extras = {"tensorizer": ["tensorizer==2.9.0a1"]}
if _is_cuda():
extras["ray"] = ["ray>=2.9"]
elif _is_hip():
extras["ray"] = ["ray==2.9.3"]
elif _is_neuron() or _is_cpu():
pass
else:
raise ValueError(
"Unsupported platform, please use CUDA, ROCM or Neuron.")
return extras


ext_modules = []

if _is_cuda():
Expand Down Expand Up @@ -405,9 +419,7 @@ def _read_requirements(filename: str) -> List[str]:
python_requires=">=3.8",
install_requires=get_requirements(),
ext_modules=ext_modules,
extras_require={
"tensorizer": ["tensorizer==2.9.0a1"],
},
extras_require=get_extra_requirements(),
cmdclass={"build_ext": cmake_build_ext} if not _is_neuron() else {},
package_data=package_data,
)
11 changes: 6 additions & 5 deletions tests/distributed/test_basic_distributed_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,25 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("worker_use_ray", [False, True])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
worker_use_ray: bool,
) -> None:

hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(
model,
dtype=dtype,
tensor_parallel_size=2,
)
vllm_model = vllm_runner(model,
dtype=dtype,
tensor_parallel_size=2,
worker_use_ray=worker_use_ray)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model

Expand Down
3 changes: 3 additions & 0 deletions tests/distributed/test_chunked_prefill_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("chunked_prefill_token_size", [16])
@pytest.mark.parametrize("worker_use_ray", [False, True])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -35,6 +36,7 @@ def test_models(
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
worker_use_ray: bool,
) -> None:
# Add a chunked prefill config.
max_num_seqs = min(chunked_prefill_token_size, 256)
Expand All @@ -53,6 +55,7 @@ def test_models(
max_num_seqs=max_num_seqs,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
worker_use_ray=worker_use_ray,
)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model
Expand Down
179 changes: 179 additions & 0 deletions tests/engine/test_local_workers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from time import sleep
from typing import Any, List, Tuple

import pytest

from vllm.engine.local_worker_utils import (LocalWorkerVllm, ResultHandler,
WorkerMonitor)


class DummyWorker:
"""Dummy version of vllm.worker.worker"""

def __init__(self, rank: int):
self.rank = rank

def worker_method(self, worker_input: Any) -> Tuple[int, Any]:
sleep(0.05)

if isinstance(worker_input, Exception):
# simulate error case
raise worker_input

return self.rank, input


def _start_workers() -> Tuple[List[LocalWorkerVllm], WorkerMonitor]:
result_handler = ResultHandler()
workers = [
LocalWorkerVllm(result_handler, partial(DummyWorker, rank=rank))
for rank in range(8)
]

for worker in workers:
worker.start()

worker_monitor = WorkerMonitor(workers, result_handler)
assert not worker_monitor.is_alive()

result_handler.start()
worker_monitor.start()
assert worker_monitor.is_alive()

return workers, worker_monitor


def test_local_workers() -> None:
"""Test workers with sync task submission"""

workers, worker_monitor = _start_workers()

def execute_workers(worker_input: str) -> None:
worker_outputs = [
worker.execute_method("worker_method", worker_input)
for worker in workers
]

for rank, output in enumerate(worker_outputs):
assert output.get() == (rank, input)

executor = ThreadPoolExecutor(max_workers=4)

# Test concurrent submission from different threads
futures = [
executor.submit(partial(execute_workers, f"thread {thread_num}"))
for thread_num in range(4)
]

for future in futures:
future.result()

# Test error case
exception = ValueError("fake error")
result = workers[0].execute_method("worker_method", exception)
try:
result.get()
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"

# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].kill()

# Other workers should get shut down here
worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)


def test_local_workers_clean_shutdown() -> None:
"""Test clean shutdown"""

workers, worker_monitor = _start_workers()

assert worker_monitor.is_alive()
assert all(worker.is_alive() for worker in workers)

# Clean shutdown
worker_monitor.close()

worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)


@pytest.mark.asyncio
async def test_local_workers_async() -> None:
"""Test local workers with async task submission"""

workers, worker_monitor = _start_workers()

async def execute_workers(worker_input: str) -> None:
worker_coros = [
worker.execute_method_async("worker_method", worker_input)
for worker in workers
]

results = await asyncio.gather(*worker_coros)
for rank, result in enumerate(results):
assert result == (rank, input)

tasks = [
asyncio.create_task(execute_workers(f"task {task_num}"))
for task_num in range(4)
]

for task in tasks:
await task

# Test error case
exception = ValueError("fake error")
try:
_result = await workers[0].execute_method_async(
"worker_method", exception)
pytest.fail("task should have failed")
except Exception as e:
assert isinstance(e, ValueError)
assert str(e) == "fake error"

# Test cleanup when a worker fails
assert worker_monitor.is_alive()
workers[3].kill()

# Other workers should get shut down here
worker_monitor.join(2)

# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.is_alive() for worker in workers)

# Further attempts to submit tasks should fail
try:
_result = await workers[0].execute_method_async(
"worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)
19 changes: 11 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import enum
import importlib.util
import io
import json
import os
Expand Down Expand Up @@ -422,7 +423,7 @@ def verify_with_parallel_config(
@dataclass
class TokenizerPoolConfig:
"""Configuration for the tokenizer pool.

Args:
pool_size: Number of tokenizer workers in the pool.
pool_type: Type of the pool.
Expand All @@ -446,9 +447,9 @@ def create_config(
tokenizer_pool_extra_config: Optional[Union[str, dict]]
) -> Optional["TokenizerPoolConfig"]:
"""Create a TokenizerPoolConfig from the given parameters.

If tokenizer_pool_size is 0, return None.

Args:
tokenizer_pool_size: Number of tokenizer workers in the pool.
tokenizer_pool_type: Type of the pool.
Expand Down Expand Up @@ -477,9 +478,9 @@ class ParallelConfig:
Args:
pipeline_parallel_size: Number of pipeline parallel groups.
tensor_parallel_size: Number of tensor parallel groups.
worker_use_ray: Whether to use Ray for model workers. Will be set to
worker_use_ray: Whether to use Ray for model workers. Will default to
True if either pipeline_parallel_size or tensor_parallel_size is
greater than 1.
greater than 1 and Ray is installed.
max_parallel_loading_workers: Maximum number of multiple batches
when load model sequentially. To avoid RAM OOM when using tensor
parallel and large models.
Expand All @@ -495,7 +496,7 @@ def __init__(
self,
pipeline_parallel_size: int,
tensor_parallel_size: int,
worker_use_ray: bool,
worker_use_ray: Optional[bool] = None,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
Expand All @@ -512,8 +513,10 @@ def __init__(
self.placement_group = placement_group

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
if self.world_size > 1:
self.worker_use_ray = True
if self.worker_use_ray is None:
ray_found = importlib.util.find_spec("ray") is not None
self.worker_use_ray = ray_found and self.world_size > 1

Comment on lines +516 to +519
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we check whether ray is successfully imported in vllm/engine/ray_utils.py?

self._verify_args()

def _verify_args(self) -> None:
Expand Down
Loading
Loading