Skip to content

Commit

Permalink
Add distributed model executor abstraction (#3191)
Browse files Browse the repository at this point in the history
zhuohan123 authored Mar 11, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 657061f commit 4c92270
Showing 13 changed files with 818 additions and 509 deletions.
2 changes: 1 addition & 1 deletion docs/source/dev/engine/llm_engine.rst
Original file line number Diff line number Diff line change
@@ -2,5 +2,5 @@ LLMEngine
=================================

.. autoclass:: vllm.engine.llm_engine.LLMEngine
:members: add_request, abort_request, step, _init_cache
:members: add_request, abort_request, step
:show-inheritance:
8 changes: 6 additions & 2 deletions format.sh
Original file line number Diff line number Diff line change
@@ -95,13 +95,17 @@ echo 'vLLM yapf: Done'
# echo 'vLLM mypy:'
# mypy

CODESPELL_EXCLUDES=(
'--skip' '*docs/source/_build/**'
)

# check spelling of specified files
spell_check() {
codespell "$@"
}

spell_check_all(){
codespell --toml pyproject.toml
codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}"
}

# Spelling check of files that differ from main branch.
@@ -116,7 +120,7 @@ spell_check_changed() {

if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then
git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \
codespell
codespell "${CODESPELL_EXCLUDES[@]}"
fi
}

3 changes: 2 additions & 1 deletion tests/lora/conftest.py
Original file line number Diff line number Diff line change
@@ -152,4 +152,5 @@ def get_model_patched(model_config, device_config, **kwargs):
@pytest.fixture
def llama_2_7b_model_extra_embeddings(
llama_2_7b_engine_extra_embeddings) -> nn.Module:
yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model
yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker.
model_runner.model)
4 changes: 2 additions & 2 deletions vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster
from vllm.engine.ray_utils import initialize_ray_cluster
from vllm.entrypoints.llm import LLM
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams
@@ -19,5 +19,5 @@
"EngineArgs",
"AsyncLLMEngine",
"AsyncEngineArgs",
"initialize_cluster",
"initialize_ray_cluster",
]
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, ClassVar
from typing import TYPE_CHECKING, Optional, Union, ClassVar
from dataclasses import dataclass
import os
from packaging.version import Version
@@ -10,6 +10,9 @@
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup

logger = init_logger(__name__)

_GB = 1 << 30
@@ -397,6 +400,7 @@ def __init__(
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
ray_workers_use_nsight: bool = False,
placement_group: Optional["PlacementGroup"] = None,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron():
@@ -412,6 +416,7 @@ def __init__(
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.ray_workers_use_nsight = ray_workers_use_nsight
self.placement_group = placement_group

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
106 changes: 38 additions & 68 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -2,16 +2,16 @@
import os
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator, Callable)
from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator)

from transformers import PreTrainedTokenizer

from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.ray_utils import initialize_cluster, ray
from vllm.engine.ray_utils import initialize_ray_cluster, ray
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
@@ -208,17 +208,10 @@ async def step_async(self) -> List[RequestOutput]:

if not scheduler_outputs.is_empty():
# Execute the model.
all_outputs = await self._run_workers_async(
"execute_model",
driver_kwargs={
"seq_group_metadata_list": seq_group_metadata_list,
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
})

# Only the driver worker returns the sampling results.
output = all_outputs[0]
output = await self.model_executor.execute_model_async(
seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in,
scheduler_outputs.blocks_to_swap_out,
scheduler_outputs.blocks_to_copy)
else:
output = []

@@ -268,37 +261,8 @@ async def add_request_async(
lora_request=lora_request,
)

async def _run_workers_async(
self,
method: str,
*args,
driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Any:
"""Runs the given method on all workers."""
coros = []

if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs

# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(driver_executor, *driver_args, **driver_kwargs)))

# Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))

all_outputs = await asyncio.gather(*coros)
return all_outputs

async def check_health_async(self):
"""Raises an error if engine is unhealthy."""
self._check_if_any_actor_is_dead()
async def check_health_async(self) -> None:
self.model_executor.check_health()


class AsyncLLMEngine:
@@ -353,6 +317,34 @@ def __init__(self,
self._request_tracker: Optional[RequestTracker] = None
self._errored_with: Optional[BaseException] = None

@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
if parallel_config.worker_use_ray or engine_args.engine_use_ray:
initialize_ray_cluster(parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
else:
assert parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
from vllm.executor.gpu_executor import GPUExecutorAsync
executor_class = GPUExecutorAsync
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine

@property
def is_running(self) -> bool:
return (self.background_loop is not None
@@ -670,35 +662,13 @@ async def get_model_config(self) -> ModelConfig:
else:
return self.engine.get_model_config()

@classmethod
def from_engine_args(cls,
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
placement_group = initialize_cluster(parallel_config,
engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray,
*engine_configs,
placement_group,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop)
return engine

async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote()
else:
self.engine.do_log_stats()

async def check_health(self):
async def check_health(self) -> None:
"""Raises an error if engine is unhealthy."""
t = time.perf_counter()
logger.debug("Starting health check...")
446 changes: 44 additions & 402 deletions vllm/engine/llm_engine.py

Large diffs are not rendered by default.

58 changes: 26 additions & 32 deletions vllm/engine/ray_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pickle

from typing import Optional, List, Tuple, TYPE_CHECKING
from typing import Optional, List, Tuple

from vllm.config import ParallelConfig
from vllm.logger import init_logger
@@ -65,45 +65,38 @@ def execute_model_compiled_dag_remote(self, ignored):
ray = None
RayWorkerVllm = None

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup


def initialize_cluster(
def initialize_ray_cluster(
parallel_config: ParallelConfig,
engine_use_ray: bool = False,
ray_address: Optional[str] = None,
) -> Optional["PlacementGroup"]:
"""Initialize the distributed cluster probably with Ray.
):
"""Initialize the distributed cluster with Ray.
it will connect to the Ray cluster and create a placement group
for the workers, which includes the specification of the resources
for each distributed worker.
Args:
parallel_config: The configurations for parallel execution.
engine_use_ray: Whether to use Ray for async engine.
ray_address: The address of the Ray cluster. If None, uses
the default Ray cluster address.
Returns:
An optional `PlacementGroup`. It includes the specification
of the resources for each distributed worker. None if Ray is
not used.
"""
if parallel_config.worker_use_ray or engine_use_ray:
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")
# Connect to a ray cluster.
if is_hip():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)

if not parallel_config.worker_use_ray:
assert parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
return None
if ray is None:
raise ImportError(
"Ray is not installed. Please install Ray to use distributed "
"serving.")

# Connect to a ray cluster.
if is_hip():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True)

if parallel_config.placement_group:
# Placement group is already set.
return

# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
@@ -138,4 +131,5 @@ def initialize_cluster(
# if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800)

return current_placement_group
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group
Empty file added vllm/executor/__init__.py
Empty file.
75 changes: 75 additions & 0 deletions vllm/executor/executor_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional

from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata


class ExecutorBase(ABC):
"""Base class for all executors.
An executor is responsible for executing the model on a specific device
type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor
that can execute the model on multiple devices.
"""

@abstractmethod
def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
) -> None:
raise NotImplementedError

@abstractmethod
def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
"""Executes one model step on the given sequences."""
raise NotImplementedError

@abstractmethod
def add_lora(self, lora_request: LoRARequest) -> bool:
raise NotImplementedError

@abstractmethod
def remove_lora(self, lora_id: int) -> bool:
raise NotImplementedError

@abstractmethod
def list_loras(self) -> List[int]:
raise NotImplementedError

@abstractmethod
def check_health(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError


class ExecutorAsyncBase(ExecutorBase):

@abstractmethod
async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
"""Executes one model step on the given sequences."""
raise NotImplementedError

@abstractmethod
async def check_health_async(self) -> None:
"""Checks if the executor is healthy. If not, it should raise an
exception."""
raise NotImplementedError
163 changes: 163 additions & 0 deletions vllm/executor/gpu_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import importlib
from typing import Dict, List, Optional

from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
from vllm.executor.utils import check_block_size_valid
from vllm.logger import init_logger
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import (get_ip, get_open_port, get_distributed_init_method,
make_async)

logger = init_logger(__name__)

# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}


class GPUExecutor(ExecutorBase):

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
) -> None:
self.model_config = model_config
self.cache_config = cache_config
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config

# Instantiate the worker and load the model to GPU.
self._init_worker()

# Profile the memory usage and initialize the cache.
self._init_cache()

def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker

def _init_worker(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
Worker = self._dispatch_worker()

assert self.parallel_config.world_size == 1, (
"GPUExecutor only supports single GPU.")

distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
self.driver_worker = Worker(
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
kv_cache_dtype=self.cache_config.cache_dtype,
is_driver_worker=True,
)
self.driver_worker.init_model()
self.driver_worker.load_model()

def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.
The engine first profiles the existing memory usage.
Then, it allocates the remaining memory for KV blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameter.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU.
num_gpu_blocks, num_cpu_blocks = (
self.driver_worker.profile_num_available_blocks(
block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.
gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes,
cache_dtype=self.cache_config.cache_dtype,
))

logger.info(f"# GPU blocks: {num_gpu_blocks}, "
f"# CPU blocks: {num_cpu_blocks}")

check_block_size_valid(num_gpu_blocks, self.cache_config.block_size,
self.model_config.max_model_len)

self.cache_config.num_gpu_blocks = num_gpu_blocks
self.cache_config.num_cpu_blocks = num_cpu_blocks

# Initialize the cache.
self.driver_worker.init_cache_engine(cache_config=self.cache_config)
# Warm up the model. This includes capturing the model into CUDA graph
# if enforce_eager is False.
self.driver_worker.warm_up_model()

def execute_model(self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput:
output = self.driver_worker.execute_model(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy,
)
return output

def add_lora(self, lora_request: LoRARequest) -> bool:
assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
return self.driver_worker.add_lora(lora_request)

def remove_lora(self, lora_id: int) -> bool:
assert lora_id > 0, "lora_id must be greater than 0."
return self.driver_worker.remove_lora(lora_id)

def list_loras(self) -> List[int]:
return self.driver_worker.list_loras()

def check_health(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return


class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):

async def execute_model_async(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> SamplerOutput:
output = await make_async(self.driver_worker.execute_model)(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=blocks_to_swap_in,
blocks_to_swap_out=blocks_to_swap_out,
blocks_to_copy=blocks_to_copy)
return output

async def check_health_async(self) -> None:
# GPUExecutor will always be healthy as long as
# it's running.
return
442 changes: 442 additions & 0 deletions vllm/executor/ray_gpu_executor.py

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions vllm/executor/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
f"The model's max seq len ({max_model_len}) "
"is larger than the maximum number of tokens that can be "
f"stored in KV cache ({max_seq_len}). Try increasing "
"`gpu_memory_utilization` or decreasing `max_model_len` when "
"initializing the engine.")

0 comments on commit 4c92270

Please sign in to comment.