diff --git a/docs/source/dev/engine/llm_engine.rst b/docs/source/dev/engine/llm_engine.rst index b550a9b5faa62..1de6d7adc87c6 100644 --- a/docs/source/dev/engine/llm_engine.rst +++ b/docs/source/dev/engine/llm_engine.rst @@ -2,5 +2,5 @@ LLMEngine ================================= .. autoclass:: vllm.engine.llm_engine.LLMEngine - :members: add_request, abort_request, step, _init_cache + :members: add_request, abort_request, step :show-inheritance: \ No newline at end of file diff --git a/format.sh b/format.sh index eb2c5ab031626..ff30111123bee 100755 --- a/format.sh +++ b/format.sh @@ -95,13 +95,17 @@ echo 'vLLM yapf: Done' # echo 'vLLM mypy:' # mypy +CODESPELL_EXCLUDES=( + '--skip' '*docs/source/_build/**' +) + # check spelling of specified files spell_check() { codespell "$@" } spell_check_all(){ - codespell --toml pyproject.toml + codespell --toml pyproject.toml "${CODESPELL_EXCLUDES[@]}" } # Spelling check of files that differ from main branch. @@ -116,7 +120,7 @@ spell_check_changed() { if ! git diff --diff-filter=ACM --quiet --exit-code "$MERGEBASE" -- '*.py' '*.pyi' &>/dev/null; then git diff --name-only --diff-filter=ACM "$MERGEBASE" -- '*.py' '*.pyi' | xargs \ - codespell + codespell "${CODESPELL_EXCLUDES[@]}" fi } diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 67273144ecd02..30a8ad03c8ada 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -152,4 +152,5 @@ def get_model_patched(model_config, device_config, **kwargs): @pytest.fixture def llama_2_7b_model_extra_embeddings( llama_2_7b_engine_extra_embeddings) -> nn.Module: - yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model + yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. + model_runner.model) diff --git a/vllm/__init__.py b/vllm/__init__.py index f1e30f5eb6e6e..5e40c3c20fcd2 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -3,7 +3,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_cluster +from vllm.engine.ray_utils import initialize_ray_cluster from vllm.entrypoints.llm import LLM from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams @@ -19,5 +19,5 @@ "EngineArgs", "AsyncLLMEngine", "AsyncEngineArgs", - "initialize_cluster", + "initialize_ray_cluster", ] diff --git a/vllm/config.py b/vllm/config.py index e893fe702c975..d2b68b6fa1fe2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, ClassVar +from typing import TYPE_CHECKING, Optional, Union, ClassVar from dataclasses import dataclass import os from packaging.version import Version @@ -10,6 +10,9 @@ from vllm.transformers_utils.config import get_config from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + logger = init_logger(__name__) _GB = 1 << 30 @@ -397,6 +400,7 @@ def __init__( max_parallel_loading_workers: Optional[int] = None, disable_custom_all_reduce: bool = False, ray_workers_use_nsight: bool = False, + placement_group: Optional["PlacementGroup"] = None, ) -> None: self.pipeline_parallel_size = pipeline_parallel_size if is_neuron(): @@ -412,6 +416,7 @@ def __init__( self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce self.ray_workers_use_nsight = ray_workers_use_nsight + self.placement_group = placement_group self.world_size = pipeline_parallel_size * self.tensor_parallel_size # Ray worker is not supported for Neuron backend. diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5629d1a863d04..0cee604c14d45 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,8 +2,8 @@ import os import time from functools import partial -from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, - Union, AsyncIterator, Callable) +from typing import (Callable, Dict, Iterable, List, Optional, Set, Tuple, Type, + Union, AsyncIterator) from transformers import PreTrainedTokenizer @@ -11,7 +11,7 @@ from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine -from vllm.engine.ray_utils import initialize_cluster, ray +from vllm.engine.ray_utils import initialize_ray_cluster, ray from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -208,17 +208,10 @@ async def step_async(self) -> List[RequestOutput]: if not scheduler_outputs.is_empty(): # Execute the model. - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, - "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, - "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] + output = await self.model_executor.execute_model_async( + seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, + scheduler_outputs.blocks_to_swap_out, + scheduler_outputs.blocks_to_copy) else: output = [] @@ -268,37 +261,8 @@ async def add_request_async( lora_request=lora_request, ) - async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[List[Any]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - coros = [] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # Run the driver worker asynchronously. - driver_executor = getattr(self.driver_worker, method) - coros.append(asyncio.get_event_loop().run_in_executor( - None, partial(driver_executor, *driver_args, **driver_kwargs))) - - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) - - all_outputs = await asyncio.gather(*coros) - return all_outputs - - async def check_health_async(self): - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() + async def check_health_async(self) -> None: + self.model_executor.check_health() class AsyncLLMEngine: @@ -353,6 +317,34 @@ def __init__(self, self._request_tracker: Optional[RequestTracker] = None self._errored_with: Optional[BaseException] = None + @classmethod + def from_engine_args(cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] + if parallel_config.worker_use_ray or engine_args.engine_use_ray: + initialize_ray_cluster(parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync + executor_class = RayGPUExecutorAsync + else: + assert parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") + from vllm.executor.gpu_executor import GPUExecutorAsync + executor_class = GPUExecutorAsync + # Create the async LLM engine. + engine = cls(parallel_config.worker_use_ray, + engine_args.engine_use_ray, + *engine_configs, + executor_class, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + max_log_len=engine_args.max_log_len, + start_engine_loop=start_engine_loop) + return engine + @property def is_running(self) -> bool: return (self.background_loop is not None @@ -670,35 +662,13 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() - @classmethod - def from_engine_args(cls, - engine_args: AsyncEngineArgs, - start_engine_loop: bool = True) -> "AsyncLLMEngine": - """Creates an async LLM engine from the engine arguments.""" - # Create the engine configs. - engine_configs = engine_args.create_engine_configs() - parallel_config = engine_configs[2] - # Initialize the cluster. - placement_group = initialize_cluster(parallel_config, - engine_args.engine_use_ray) - # Create the async LLM engine. - engine = cls(parallel_config.worker_use_ray, - engine_args.engine_use_ray, - *engine_configs, - placement_group, - log_requests=not engine_args.disable_log_requests, - log_stats=not engine_args.disable_log_stats, - max_log_len=engine_args.max_log_len, - start_engine_loop=start_engine_loop) - return engine - async def do_log_stats(self) -> None: if self.engine_use_ray: await self.engine.do_log_stats.remote() else: self.engine.do_log_stats() - async def check_health(self): + async def check_health(self) -> None: """Raises an error if engine is unhealthy.""" t = time.perf_counter() logger.debug("Starting health check...") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6e045cd6d73c6..4cdad4180aa14 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,11 +1,5 @@ -import copy -from collections import defaultdict -import os import time -import pickle -import importlib -from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, - Union) +from typing import Dict, Iterable, List, Optional, Tuple, Type, Union from transformers import PreTrainedTokenizer @@ -15,8 +9,9 @@ ParallelConfig, SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs +from vllm.executor.executor_base import ExecutorBase from vllm.engine.metrics import StatLogger, Stats -from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray +from vllm.engine.ray_utils import initialize_ray_cluster from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams @@ -24,29 +19,11 @@ SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, TokenizerGroup) -from vllm.utils import (Counter, set_cuda_visible_devices, get_ip, - get_open_port, get_distributed_init_method) - -if ray: - from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy - -if TYPE_CHECKING: - from ray.util.placement_group import PlacementGroup +from vllm.utils import Counter logger = init_logger(__name__) _LOCAL_LOGGING_INTERVAL_SEC = 5 -# A map between the device type (in device config) to its worker module. -DEVICE_TO_WORKER_MODULE_MAP = { - "cuda": "vllm.worker.worker", - "neuron": "vllm.worker.neuron_worker", -} - -# If the env var is set, it uses the Ray's compiled DAG API -# which optimizes the control plane overhead. -# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. -USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) - class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -71,8 +48,8 @@ class LLMEngine: parallel_config: The configuration related to distributed execution. scheduler_config: The configuration related to the request scheduler. device_config: The configuration related to the device. - placement_group: Ray placement group for distributed execution. - Required for distributed execution. + executor_class: The model executor class for managing distributed + execution. log_stats: Whether to log statistics. """ @@ -84,7 +61,7 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], - placement_group: Optional["PlacementGroup"], + executor_class: Type[ExecutorBase], log_stats: bool, ) -> None: logger.info( @@ -121,33 +98,13 @@ def __init__( self._init_tokenizer() self.seq_counter = Counter() - # Create the parallel GPU workers. - if self.parallel_config.worker_use_ray: - # Disable Ray usage stats collection. - ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") - if ray_usage != "1": - os.environ["RAY_USAGE_STATS_ENABLED"] = "0" - # Pass additional arguments to initialize the worker - additional_ray_args = {} - if self.parallel_config.ray_workers_use_nsight: - logger.info("Configuring Ray workers to use nsight.") - additional_ray_args = { - "runtime_env": { - "nsight": { - "t": "cuda,cudnn,cublas", - "o": "'worker_process_%p'", - "cuda-graph-trace": "node", - } - } - } - self._init_workers_ray(placement_group, **additional_ray_args) - else: - self._init_workers() - - # Profile the memory usage and initialize the cache. - self._init_cache() + self.model_executor = executor_class(model_config, cache_config, + parallel_config, scheduler_config, + device_config, lora_config) # Create the scheduler. + # NOTE: the cache_config here have been updated with the numbers of + # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) # Metric Logging. @@ -157,9 +114,29 @@ def __init__( labels=dict(model_name=model_config.model)) self.stat_logger.info("cache_config", self.cache_config) - self.forward_dag = None - if USE_RAY_COMPILED_DAG: - self.forward_dag = self._compiled_ray_dag() + @classmethod + def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] + + # Initialize the cluster and specify the executor class. + if parallel_config.worker_use_ray: + initialize_ray_cluster(parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + executor_class = RayGPUExecutor + else: + assert parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") + from vllm.executor.gpu_executor import GPUExecutor + executor_class = GPUExecutor + + # Create the LLM engine. + engine = cls(*engine_configs, + executor_class=executor_class, + log_stats=not engine_args.disable_log_stats) + return engine def __reduce__(self): # This is to ensure that the LLMEngine is not referenced in @@ -173,39 +150,6 @@ def get_tokenizer_for_seq(self, sequence: Sequence) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(sequence.lora_request) - def _dispatch_worker(self): - worker_module = DEVICE_TO_WORKER_MODULE_MAP[ - self.device_config.device_type] - imported_worker = importlib.import_module(worker_module) - Worker = imported_worker.Worker - return Worker - - def _init_workers(self): - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - Worker = self._dispatch_worker() - - assert self.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") - - self.workers: List[Worker] = [] - distributed_init_method = get_distributed_init_method( - get_ip(), get_open_port()) - self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - local_rank=0, - rank=0, - distributed_init_method=distributed_init_method, - lora_config=self.lora_config, - kv_cache_dtype=self.cache_config.cache_dtype, - is_driver_worker=True, - ) - self._run_workers("init_model") - self._run_workers("load_model") - def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( enable_lora=bool(self.lora_config), @@ -218,126 +162,6 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): self.tokenizer: TokenizerGroup = TokenizerGroup( self.model_config.tokenizer, **init_kwargs) - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): - if self.parallel_config.tensor_parallel_size == 1: - num_gpus = self.cache_config.gpu_memory_utilization - else: - num_gpus = 1 - - self.driver_dummy_worker: RayWorkerVllm = None - self.workers: List[RayWorkerVllm] = [] - - driver_ip = get_ip() - for bundle_id, bundle in enumerate(placement_group.bundle_specs): - if not bundle.get("GPU", 0): - continue - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=placement_group, - placement_group_capture_child_tasks=True, - placement_group_bundle_index=bundle_id, - ) - worker = ray.remote( - num_cpus=0, - num_gpus=num_gpus, - scheduling_strategy=scheduling_strategy, - **ray_remote_kwargs, - )(RayWorkerVllm).remote(self.model_config.trust_remote_code) - - worker_ip = ray.get(worker.get_node_ip.remote()) - if worker_ip == driver_ip and self.driver_dummy_worker is None: - # If the worker is on the same node as the driver, we use it - # as the resource holder for the driver process. - self.driver_dummy_worker = worker - else: - self.workers.append(worker) - - if self.driver_dummy_worker is None: - raise ValueError( - "Ray does not allocate any GPUs on the driver node. Consider " - "adjusting the Ray placement group or running the driver on a " - "GPU node.") - - driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) - worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) - - node_workers = defaultdict(list) - node_gpus = defaultdict(list) - - node_workers[driver_node_id].append(0) - node_gpus[driver_node_id].extend(driver_gpu_ids) - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, - start=1): - node_workers[node_id].append(i) - node_gpus[node_id].extend(gpu_ids) - for node_id, gpu_ids in node_gpus.items(): - node_gpus[node_id] = sorted(gpu_ids) - - # Set CUDA_VISIBLE_DEVICES for the driver. - set_cuda_visible_devices(node_gpus[driver_node_id]) - for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): - worker.set_cuda_visible_devices.remote(node_gpus[node_id]) - - distributed_init_method = get_distributed_init_method( - driver_ip, get_open_port()) - - # Lazy import the Worker to avoid importing torch.cuda/xformers - # before CUDA_VISIBLE_DEVICES is set in the Worker - Worker = self._dispatch_worker() - - # Initialize torch distributed process group for the workers. - model_config = copy.deepcopy(self.model_config) - parallel_config = copy.deepcopy(self.parallel_config) - scheduler_config = copy.deepcopy(self.scheduler_config) - device_config = copy.deepcopy(self.device_config) - lora_config = copy.deepcopy(self.lora_config) - kv_cache_dtype = self.cache_config.cache_dtype - - for rank, (worker, (node_id, - _)) in enumerate(zip(self.workers, - worker_node_and_gpu_ids), - start=1): - local_rank = node_workers[node_id].index(rank) - worker.init_worker.remote( - lambda rank=rank, local_rank=local_rank: Worker( - model_config, - parallel_config, - scheduler_config, - device_config, - local_rank, - rank, - distributed_init_method, - lora_config=lora_config, - kv_cache_dtype=kv_cache_dtype, - )) - - driver_rank = 0 - driver_local_rank = node_workers[driver_node_id].index(driver_rank) - self.driver_worker = Worker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - driver_local_rank, - driver_rank, - distributed_init_method, - lora_config=self.lora_config, - kv_cache_dtype=kv_cache_dtype, - is_driver_worker=True, - ) - - # don't use cupy for eager mode - self._run_workers("init_model", - cupy_port=get_open_port() - if not model_config.enforce_eager else None) - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) @@ -346,81 +170,6 @@ def _verify_args(self) -> None: self.lora_config.verify_with_scheduler_config( self.scheduler_config) - def _init_cache(self) -> None: - """Profiles the memory usage and initializes the KV cache. - - The engine will first conduct a profiling of the existing memory usage. - Then, it calculate the maximum possible number of GPU and CPU blocks - that can be allocated with the remaining free memory. - More details can be found in the - :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method - from class :class:`~vllm.worker.Worker`. - - Afterwards, as there may be multiple workers, - we take the minimum number of blocks across all workers - to ensure this can be applied to all of them. - - Finally, the engine will initialize the KV cache - with the calculated number of blocks. - - .. tip:: - You may limit the usage of GPU memory - by adjusting the `gpu_memory_utilization` parameters. - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers( - "profile_num_available_blocks", - block_size=self.cache_config.block_size, - gpu_memory_utilization=self.cache_config.gpu_memory_utilization, - cpu_swap_space=self.cache_config.swap_space_bytes, - cache_dtype=self.cache_config.cache_dtype, - ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - # FIXME(woosuk): Change to debug log. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") - - if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") - max_seq_len = self.cache_config.block_size * num_gpu_blocks - if self.model_config.max_model_len > max_seq_len: - raise ValueError( - f"The model's max seq len ({self.model_config.max_model_len}) " - "is larger than the maximum number of tokens that can be " - f"stored in KV cache ({max_seq_len}). Try increasing " - "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - # Initialize the cache. - self._run_workers("init_cache_engine", cache_config=self.cache_config) - # Warm up the model. This includes capturing the model into CUDA graph - # if enforce_eager is False. - self._run_workers("warm_up_model") - - @classmethod - def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": - """Creates an LLM engine from the engine arguments.""" - # Create the engine configs. - engine_configs = engine_args.create_engine_configs() - parallel_config = engine_configs[2] - # Initialize the cluster. - placement_group = initialize_cluster(parallel_config) - # Create the LLM engine. - engine = cls(*engine_configs, - placement_group, - log_stats=not engine_args.disable_log_stats) - return engine - def encode_request( self, request_id: str, # pylint: disable=unused-argument @@ -826,7 +575,7 @@ def step(self) -> List[RequestOutput]: - A Sequence Group (SG) refer to a group of sequences that are generated from the same prompt. - - Step 2: Calls the workers to execute the model. + - Step 2: Calls the distributed executor to execute the model. - Step 3: Processes the model output. This mainly includes: - Decodes the relevant outputs. @@ -862,19 +611,10 @@ def step(self) -> List[RequestOutput]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if not scheduler_outputs.is_empty(): - # Execute the model. - all_outputs = self._run_workers( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, - "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, - "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] + output = self.model_executor.execute_model( + seq_group_metadata_list, scheduler_outputs.blocks_to_swap_in, + scheduler_outputs.blocks_to_swap_out, + scheduler_outputs.blocks_to_copy) else: output = [] @@ -1043,111 +783,13 @@ def _finalize_sequence(self, seq: Sequence, seq.output_text = seq.output_text[:-len(stop_string)] def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) + return self.model_executor.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) + return self.model_executor.remove_lora(lora_id) def list_loras(self) -> List[int]: - return self._run_workers("list_loras") - - def _run_workers( - self, - method: str, - *args, - driver_args: Optional[List[Any]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - max_concurrent_workers: Optional[int] = None, - use_ray_compiled_dag: bool = False, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - - if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") - - if use_ray_compiled_dag: - # Right now, compiled DAG can only accept a single - # input. TODO(sang): Fix it. - output_channels = self.forward_dag.execute(1) - else: - # Start the ray workers first. - ray_worker_outputs = [ - worker.execute_method.remote(method, *args, **kwargs) - for worker in self.workers - ] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) - - # Get the results of the ray workers. - if self.workers: - if use_ray_compiled_dag: - try: - ray_worker_outputs = [ - pickle.loads(chan.begin_read()) - for chan in output_channels - ] - finally: - # Has to call end_read in order to reuse the DAG. - for chan in output_channels: - chan.end_read() - else: - ray_worker_outputs = ray.get(ray_worker_outputs) - - return [driver_worker_output] + ray_worker_outputs - - def _compiled_ray_dag(self): - import pkg_resources - required_version = "2.9" - current_version = pkg_resources.get_distribution("ray").version - if current_version < required_version: - raise ValueError(f"Ray version {required_version} or greater is " - f"required, but found {current_version}") - - from ray.dag import MultiOutputNode, InputNode - assert self.parallel_config.worker_use_ray - - # Right now, compiled DAG requires at least 1 arg. We send - # a dummy value for now. It will be fixed soon. - with InputNode() as input_data: - forward_dag = MultiOutputNode([ - worker.execute_model_compiled_dag_remote.bind(input_data) - for worker in self.workers - ]) - return forward_dag.experimental_compile() + return self.model_executor.list_loras() def check_health(self) -> None: - """Raises an error if engine is unhealthy.""" - self._check_if_any_actor_is_dead() - - def _check_if_any_actor_is_dead(self): - if not self.parallel_config.worker_use_ray: - return - - if not self.workers: - return - - dead_actors = [] - for actor in self.workers: - actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access - if actor_state["State"] == "DEAD": - dead_actors.append(actor) - if dead_actors: - raise RuntimeError("At least one Worker is dead. " - f"Dead Workers: {dead_actors}. ") + self.model_executor.check_health() diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index bbcbbdfea2f00..742f3dc575190 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,6 +1,6 @@ import pickle -from typing import Optional, List, Tuple, TYPE_CHECKING +from typing import Optional, List, Tuple from vllm.config import ParallelConfig from vllm.logger import init_logger @@ -65,45 +65,38 @@ def execute_model_compiled_dag_remote(self, ignored): ray = None RayWorkerVllm = None -if TYPE_CHECKING: - from ray.util.placement_group import PlacementGroup - -def initialize_cluster( +def initialize_ray_cluster( parallel_config: ParallelConfig, - engine_use_ray: bool = False, ray_address: Optional[str] = None, -) -> Optional["PlacementGroup"]: - """Initialize the distributed cluster probably with Ray. +): + """Initialize the distributed cluster with Ray. + + it will connect to the Ray cluster and create a placement group + for the workers, which includes the specification of the resources + for each distributed worker. Args: parallel_config: The configurations for parallel execution. - engine_use_ray: Whether to use Ray for async engine. ray_address: The address of the Ray cluster. If None, uses the default Ray cluster address. - - Returns: - An optional `PlacementGroup`. It includes the specification - of the resources for each distributed worker. None if Ray is - not used. """ - if parallel_config.worker_use_ray or engine_use_ray: - if ray is None: - raise ImportError( - "Ray is not installed. Please install Ray to use distributed " - "serving.") - # Connect to a ray cluster. - if is_hip(): - ray.init(address=ray_address, - ignore_reinit_error=True, - num_gpus=parallel_config.world_size) - else: - ray.init(address=ray_address, ignore_reinit_error=True) - - if not parallel_config.worker_use_ray: - assert parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") - return None + if ray is None: + raise ImportError( + "Ray is not installed. Please install Ray to use distributed " + "serving.") + + # Connect to a ray cluster. + if is_hip(): + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + ray.init(address=ray_address, ignore_reinit_error=True) + + if parallel_config.placement_group: + # Placement group is already set. + return # Create placement group for worker processes current_placement_group = ray.util.get_current_placement_group() @@ -138,4 +131,5 @@ def initialize_cluster( # if they cannot be provisioned. ray.get(current_placement_group.ready(), timeout=1800) - return current_placement_group + # Set the placement group in the parallel config + parallel_config.placement_group = current_placement_group diff --git a/vllm/executor/__init__.py b/vllm/executor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py new file mode 100644 index 0000000000000..30717e8a87358 --- /dev/null +++ b/vllm/executor/executor_base.py @@ -0,0 +1,75 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata + + +class ExecutorBase(ABC): + """Base class for all executors. + + An executor is responsible for executing the model on a specific device + type (e.g., CPU, GPU, Neuron, etc.). Or it can be a distributed executor + that can execute the model on multiple devices. + """ + + @abstractmethod + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + raise NotImplementedError + + @abstractmethod + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def list_loras(self) -> List[int]: + raise NotImplementedError + + @abstractmethod + def check_health(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + raise NotImplementedError + + +class ExecutorAsyncBase(ExecutorBase): + + @abstractmethod + async def execute_model_async( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> SamplerOutput: + """Executes one model step on the given sequences.""" + raise NotImplementedError + + @abstractmethod + async def check_health_async(self) -> None: + """Checks if the executor is healthy. If not, it should raise an + exception.""" + raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py new file mode 100644 index 0000000000000..9019ee7763c77 --- /dev/null +++ b/vllm/executor/gpu_executor.py @@ -0,0 +1,163 @@ +import importlib +from typing import Dict, List, Optional + +from vllm.lora.request import LoRARequest +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.utils import check_block_size_valid +from vllm.logger import init_logger +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import (get_ip, get_open_port, get_distributed_init_method, + make_async) + +logger = init_logger(__name__) + +# A map between the device type (in device config) to its worker module. +DEVICE_TO_WORKER_MODULE_MAP = { + "cuda": "vllm.worker.worker", + "neuron": "vllm.worker.neuron_worker", +} + + +class GPUExecutor(ExecutorBase): + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + # Instantiate the worker and load the model to GPU. + self._init_worker() + + # Profile the memory usage and initialize the cache. + self._init_cache() + + def _dispatch_worker(self): + worker_module = DEVICE_TO_WORKER_MODULE_MAP[ + self.device_config.device_type] + imported_worker = importlib.import_module(worker_module) + Worker = imported_worker.Worker + return Worker + + def _init_worker(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + Worker = self._dispatch_worker() + + assert self.parallel_config.world_size == 1, ( + "GPUExecutor only supports single GPU.") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.driver_worker.init_model() + self.driver_worker.load_model() + + def _init_cache(self) -> None: + """Profiles the memory usage and initializes the KV cache. + + The engine first profiles the existing memory usage. + Then, it allocates the remaining memory for KV blocks. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_gpu_blocks, num_cpu_blocks = ( + self.driver_worker.profile_num_available_blocks( + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config. + gpu_memory_utilization, + cpu_swap_space=self.cache_config.swap_space_bytes, + cache_dtype=self.cache_config.cache_dtype, + )) + + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # Initialize the cache. + self.driver_worker.init_cache_engine(cache_config=self.cache_config) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self.driver_worker.warm_up_model() + + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + output = self.driver_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self.driver_worker.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self.driver_worker.remove_lora(lora_id) + + def list_loras(self) -> List[int]: + return self.driver_worker.list_loras() + + def check_health(self) -> None: + # GPUExecutor will always be healthy as long as + # it's running. + return + + +class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> SamplerOutput: + output = await make_async(self.driver_worker.execute_model)( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) + return output + + async def check_health_async(self) -> None: + # GPUExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py new file mode 100644 index 0000000000000..261fcfb7dad9b --- /dev/null +++ b/vllm/executor/ray_gpu_executor.py @@ -0,0 +1,442 @@ +import asyncio +import copy +from collections import defaultdict +import os +import pickle +import importlib +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.engine.ray_utils import RayWorkerVllm, ray +from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.utils import check_block_size_valid +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import (set_cuda_visible_devices, get_ip, get_open_port, + get_distributed_init_method, make_async) + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + +# A map between the device type (in device config) to its worker module. +DEVICE_TO_WORKER_MODULE_MAP = { + "cuda": "vllm.worker.worker", + "neuron": "vllm.worker.neuron_worker", +} + +# If the env var is set, it uses the Ray's compiled DAG API +# which optimizes the control plane overhead. +# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. +USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) + + +class RayGPUExecutor(ExecutorBase): + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + + assert self.parallel_config.worker_use_ray + placement_group = self.parallel_config.placement_group + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + # Profile the memory usage and initialize the cache. + self._init_cache() + + self.forward_dag = None + if USE_RAY_COMPILED_DAG: + self.forward_dag = self._compiled_ray_dag() + + def _dispatch_worker(self): + worker_module = DEVICE_TO_WORKER_MODULE_MAP[ + self.device_config.device_type] + imported_worker = importlib.import_module(worker_module) + Worker = imported_worker.Worker + return Worker + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + if self.parallel_config.tensor_parallel_size == 1: + # For single GPU case, we use a ray worker with constrained memory. + num_gpus = self.cache_config.gpu_memory_utilization + else: + # Otherwise, the ray workers are allocated with a full GPU. + num_gpus = 1 + + # The driver dummy worker does not actually use any resources. + # It holds the resource for the driver worker. + self.driver_dummy_worker: RayWorkerVllm = None + # The remaining workers are the actual ray actors. + self.workers: List[RayWorkerVllm] = [] + + # Create the workers. + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + else: + # Else, added to the list of workers. + self.workers.append(worker) + + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + # Get the set of GPU IDs used on each node. + driver_node_id, driver_gpu_ids = ray.get( + self.driver_dummy_worker.get_node_and_gpu_ids.remote()) + worker_node_and_gpu_ids = ray.get( + [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + + node_workers = defaultdict(list) + node_gpus = defaultdict(list) + + node_workers[driver_node_id].append(0) + node_gpus[driver_node_id].extend(driver_gpu_ids) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, + start=1): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + # Set CUDA_VISIBLE_DEVICES for the driver and workers. + set_cuda_visible_devices(node_gpus[driver_node_id]) + for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): + worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + Worker = self._dispatch_worker() + + model_config = copy.deepcopy(self.model_config) + parallel_config = copy.deepcopy(self.parallel_config) + scheduler_config = copy.deepcopy(self.scheduler_config) + device_config = copy.deepcopy(self.device_config) + lora_config = copy.deepcopy(self.lora_config) + kv_cache_dtype = self.cache_config.cache_dtype + + # Initialize the actual workers with the Worker class. + for rank, (worker, (node_id, _)) in enumerate( + zip(self.workers, worker_node_and_gpu_ids), + start=1, + ): + local_rank = node_workers[node_id].index(rank) + worker.init_worker.remote( + lambda rank=rank, local_rank=local_rank: Worker( + model_config, + parallel_config, + scheduler_config, + device_config, + local_rank, + rank, + distributed_init_method, + lora_config=lora_config, + kv_cache_dtype=kv_cache_dtype, + )) + + # Initialize the driver worker with the Worker class. + driver_rank = 0 + driver_local_rank = node_workers[driver_node_id].index(driver_rank) + self.driver_worker = Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + driver_local_rank, + driver_rank, + distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=True, + ) + + # FIXME(woosuk): We are not properly initializing cupy NCCL when + # we have multiple nodes. + self._run_workers("init_model", + cupy_port=get_open_port() + if not model_config.enforce_eager else None) + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) + + def _init_cache(self) -> None: + """Profiles the memory usage and initializes the KV cache. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + More details can be found in the + :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method + from class :class:`~vllm.worker.Worker`. + + Afterwards, as there may be multiple workers, + we take the minimum number of blocks across all workers + to ensure this can be applied to all of them. + + Finally, the engine will initialize the KV cache + with the calculated number of blocks. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameter. + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers( + "profile_num_available_blocks", + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config.gpu_memory_utilization, + cpu_swap_space=self.cache_config.swap_space_bytes, + cache_dtype=self.cache_config.cache_dtype, + ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + check_block_size_valid(num_gpu_blocks, self.cache_config.block_size, + self.model_config.max_model_len) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # Initialize the cache. + self._run_workers("init_cache_engine", cache_config=self.cache_config) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self._run_workers("warm_up_model") + + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + all_outputs = self._run_workers( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + }, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + return output + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + if use_ray_compiled_dag: + # Right now, compiled DAG can only accept a single + # input. TODO(sang): Fix it. + output_channels = self.forward_dag.execute(1) + else: + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_output = getattr(self.driver_worker, + method)(*driver_args, **driver_kwargs) + + # Get the results of the ray workers. + if self.workers: + if use_ray_compiled_dag: + try: + ray_worker_outputs = [ + pickle.loads(chan.begin_read()) + for chan in output_channels + ] + finally: + # Has to call end_read in order to reuse the DAG. + for chan in output_channels: + chan.end_read() + else: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs + + def _compiled_ray_dag(self): + import pkg_resources + required_version = "2.9" + current_version = pkg_resources.get_distribution("ray").version + if current_version < required_version: + raise ValueError(f"Ray version {required_version} or greater is " + f"required, but found {current_version}") + + from ray.dag import MultiOutputNode, InputNode + assert self.parallel_config.worker_use_ray + + # Right now, compiled DAG requires at least 1 arg. We send + # a dummy value for now. It will be fixed soon. + with InputNode() as input_data: + forward_dag = MultiOutputNode([ + worker.execute_model_compiled_dag_remote.bind(input_data) + for worker in self.workers + ]) + return forward_dag.experimental_compile() + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + self._check_if_any_actor_is_dead() + + def _check_if_any_actor_is_dead(self): + if not self.workers: + return + + dead_actors = [] + for actor in self.workers: + actor_state = ray.state.actors(actor._ray_actor_id.hex()) # pylint: disable=protected-access + if actor_state["State"] == "DEAD": + dead_actors.append(actor) + if dead_actors: + raise RuntimeError("At least one Worker is dead. " + f"Dead Workers: {dead_actors}. ") + + +class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + coros = [] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Run the driver worker asynchronously. + driver_executor = make_async(getattr(self.driver_worker, method)) + coros.append(driver_executor(*driver_args, **driver_kwargs)) + + # Run the ray workers asynchronously. + for worker in self.workers: + coros.append(worker.execute_method.remote(method, *args, **kwargs)) + + all_outputs = await asyncio.gather(*coros) + return all_outputs + + async def execute_model_async( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> SamplerOutput: + all_outputs = await self._run_workers_async( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + }, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + return output + + async def check_health_async(self) -> None: + """Raises an error if engine is unhealthy.""" + self._check_if_any_actor_is_dead() diff --git a/vllm/executor/utils.py b/vllm/executor/utils.py new file mode 100644 index 0000000000000..44976696a77c6 --- /dev/null +++ b/vllm/executor/utils.py @@ -0,0 +1,13 @@ +def check_block_size_valid(num_gpu_blocks, block_size, max_model_len) -> None: + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = block_size * num_gpu_blocks + if max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.")