Skip to content

Commit

Permalink
Add distributed model executor abstraction (vllm-project#3191)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuohan123 authored and starmpcc committed Mar 14, 2024
1 parent 509e1a6 commit ec75970
Show file tree
Hide file tree
Showing 13 changed files with 818 additions and 509 deletions.
2 changes: 1 addition & 1 deletion docs/source/dev/engine/llm_engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ LLMEngine
=================================

.. autoclass:: vllm.engine.llm_engine.LLMEngine
:members: add_request, abort_request, step, _init_cache
:members: add_request, abort_request, step
:show-inheritance:
8 changes: 6 additions & 2 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,17 @@ echo 'vLLM yapf: Done'
# echo 'vLLM mypy:'
# mypy

CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
)

# check spelling of specified files
spell_check() {
codespell "$@"
}

spell_check_all(){
codespell --toml pyproject.toml
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
}

# Spelling check of files that differ from main branch.
Expand All @@ -116,7 +120,7 @@ spell_check_changed() {

if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell
codespell "${CODESPELL_EXCLUDES[@]}"
fi
}

Expand Down
3 changes: 2 additions & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,4 +152,5 @@ def get_model_patched(model_config, device_config, **kwargs):
@pytest.fixture
def llama_2_7b_model_extra_embeddings(
llama_2_7b_engine_extra_embeddings) -> nn.Module:
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
Expand All @@ -19,5 +19,5 @@
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_cluster",
"initialize_ray_cluster",
]
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, ClassVar
from typing import TYPE_CHECKING, Optional, Union, ClassVar
from dataclasses import dataclass
import os
from packaging.version import Version
Expand All @@ -10,6 +10,9 @@
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

_GB = 1 << 30
Expand Down Expand Up @@ -397,6 +400,7 @@ def __init__(
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron():
Expand All @@ -412,6 +416,7 @@ def __init__(
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
Expand Down
106 changes: 38 additions & 68 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import os
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator, Callable)
from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)

from transformers import PreTrainedTokenizer

from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.engine.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
Expand Down Expand Up @@ -208,17 +208,10 @@ async def step_async(self) -> List[RequestOutput]:

if not scheduler_outputs.is_empty():
# Execute the model.
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})

# Only the driver worker returns the sampling results.
output = all_outputs[0]
output = await self.model_executor.execute_model_async(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy)
else:
output = []

Expand Down Expand Up @@ -268,37 +261,8 @@ async def add_request_async(
lora_request=lora_request,
)

async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(driver_executor, *driver_args, **driver_kwargs)))

# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))

all_outputs = await asyncio.gather(*coros)
return all_outputs

async def check_health_async(self):
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
async def check_health_async(self) -> None:
self.model_executor.check_health()


class AsyncLLMEngine:
Expand Down Expand Up @@ -353,6 +317,34 @@ def __init__(self,
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None

@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
assert parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine

@property
def is_running(self) -> bool:
return (self.background_loop is not None
Expand Down Expand Up @@ -670,35 +662,13 @@ async def get_model_config(self) -> ModelConfig:
else:
return self.engine.get_model_config()

@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
placement_group = initialize_cluster(parallel_config,
engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine

async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote()
else:
self.engine.do_log_stats()

async def check_health(self):
async def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
Expand Down
Loading

0 comments on commit ec75970

Please sign in to comment.