diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 8a37bac02823a..5a15ed67e3327 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -234,6 +234,14 @@ async def step_async( # Log stats. self.do_log_stats(scheduler_outputs, output) + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + await self.model_executor.stop_remote_worker_execution_loop_async() + return request_outputs async def encode_request_async( @@ -687,7 +695,7 @@ async def encode( multi_modal_data: Multi modal data per request. Yields: - The output `EmbeddingRequestOutput` objects from the LLMEngine + The output `EmbeddingRequestOutput` objects from the LLMEngine for the request. Details: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 60e23d4df15bb..0631c0de76822 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -692,6 +692,14 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # Log stats. self.do_log_stats(scheduler_outputs, output) + if not request_outputs: + # Stop the execute model loop in parallel workers until there are + # more requests to process. This avoids waiting indefinitely in + # torch.distributed ops which may otherwise timeout, and unblocks + # the RPC thread in the workers so that they can process any other + # queued control plane messages, such as add/remove lora adapters. + self.model_executor.stop_remote_worker_execution_loop() + return request_outputs def do_log_stats( diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index c5b1e61112afb..f7c608af1ad39 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -1,11 +1,12 @@ +import asyncio from abc import abstractmethod -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SamplerOutput logger = init_logger(__name__) @@ -13,6 +14,16 @@ class DistributedGPUExecutor(GPUExecutor): """Abstract superclass of multi-GPU executor implementations.""" + def __init__(self, *args, **kwargs): + # This is non-None when the execute model loop is running + # in the parallel workers. It's a coroutine in the AsyncLLMEngine case. + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + # Updated by implementations that require additional args to be passed + # to the _run_workers execute_model call + self.extra_execute_model_run_workers_kwargs: Dict[str, Any] = {} + + super().__init__(*args, **kwargs) + def determine_num_available_blocks(self) -> Tuple[int, int]: """Determine the number of available KV blocks. @@ -52,13 +63,28 @@ def initialize_cache(self, num_gpu_blocks: int, num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: - all_outputs = self._run_workers("execute_model", - driver_args=args, - driver_kwargs=kwargs) + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + self.parallel_worker_tasks = self._run_workers( + "start_worker_execution_loop", + async_run_remote_workers_only=True, + **self.extra_execute_model_run_workers_kwargs) # Only the driver worker returns the sampling results. - return all_outputs[0] + return self._driver_execute_model(execute_model_req) + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + + self._driver_execute_model() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) def add_lora(self, lora_request: LoRARequest) -> bool: assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." @@ -88,39 +114,84 @@ def save_sharded_state( pattern=pattern, max_size=max_size) + @abstractmethod + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError + @abstractmethod def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + raise NotImplementedError + + @abstractmethod + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" raise NotImplementedError class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + if self.parallel_worker_tasks is None: + # Start model execution loop running in the parallel workers + self.parallel_worker_tasks = asyncio.create_task( + self._start_worker_execution_loop()) + + # Only the driver worker returns the sampling results. + return await self._driver_execute_model_async(execute_model_req) + + async def stop_remote_worker_execution_loop_async(self) -> None: + if self.parallel_worker_tasks is None: + return + + await self._driver_execute_model_async() + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + await parallel_worker_tasks + @abstractmethod - async def _run_workers_async( + async def _driver_execute_model_async( self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - raise NotImplementedError + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Execute the model asynchronously in the driver worker. - async def execute_model_async(self, *args, - **kwargs) -> List[SamplerOutput]: - all_outputs = await self._run_workers_async("execute_model", - driver_args=args, - driver_kwargs=kwargs) + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + raise NotImplementedError - # Only the driver worker returns the sampling results. - return all_outputs[0] + @abstractmethod + async def _start_worker_execution_loop(self): + """Run execution loop on all workers. It guarantees all workers run + the loop or None of them is running the loop. Loop can be stopped by + `stop_remote_worker_execution_loop`. + The API is idempotent (guarantee only 1 loop run at any moment).""" + raise NotImplementedError diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 08aa58999b1ec..4d01939c2e38b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -74,6 +74,10 @@ def execute_model( """Executes at least one model step on the given sequences.""" raise NotImplementedError + def stop_remote_worker_execution_loop(self) -> None: + """Releases parallel workers from model loop.""" + return + @abstractmethod def add_lora(self, lora_request: LoRARequest) -> bool: raise NotImplementedError @@ -109,6 +113,10 @@ async def execute_model_async( """Executes one model step on the given sequences.""" raise NotImplementedError + async def stop_remote_worker_execution_loop_async(self) -> None: + """Releases parallel workers from model loop.""" + return + async def check_health_async(self) -> None: """Checks if the executor is healthy. If not, it should raise an exception.""" diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 2a7b99c9dcbe1..8fa54454907b5 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -1,13 +1,14 @@ import asyncio import os from functools import partial -from typing import Any, Dict, Optional, Tuple +from typing import Any, List, Optional from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) from vllm.logger import init_logger +from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -71,16 +72,34 @@ def shutdown(self): None)) is not None: worker_monitor.close() + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. + + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_model( + execute_model_req=execute_model_req) + def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, max_concurrent_workers: Optional[int] = None, **kwargs, ) -> Any: - """Runs the given method on all workers.""" + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ if max_concurrent_workers: raise NotImplementedError( @@ -92,15 +111,12 @@ def _run_workers( for worker in self.workers ] - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + if async_run_remote_workers_only: + # Just return futures + return worker_outputs - # Start the driver worker after all the ray workers. driver_worker_method = getattr(self.driver_worker, method) - driver_worker_output = driver_worker_method(*driver_args, - **driver_kwargs) + driver_worker_output = driver_worker_method(*args, **kwargs) # Get the results of the workers. return [driver_worker_output @@ -111,30 +127,29 @@ def check_health(self) -> None: if not self.worker_monitor.is_alive(): raise RuntimeError("Worker processes are not running") + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + class MultiprocessingGPUExecutorAsync(MultiprocessingGPUExecutor, DistributedGPUExecutorAsync): - async def _run_workers_async( - self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) - driver_executor = make_async(getattr(self.driver_worker, method)) + async def _driver_execute_model_async( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_model(execute_model_req) - # Run all the workers asynchronously. - coros = [driver_executor(*driver_args, **driver_kwargs)] + [ - worker.execute_method_async(method, *args, **kwargs) + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method_async("start_worker_execution_loop") for worker in self.workers ] - return await asyncio.gather(*coros) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index dd3ee60682d30..bed356d1b6e58 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -42,6 +42,8 @@ def _init_executor(self) -> None: self.forward_dag = None if USE_RAY_COMPILED_DAG: self.forward_dag = self._compiled_ray_dag() + self.extra_execute_model_run_workers_kwargs[ + "use_ray_compiled_dag"] = True def _configure_ray_workers_use_nsight(self, ray_remote_kwargs) -> Dict[str, Any]: @@ -171,23 +173,23 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) - def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: - all_outputs = self._run_workers( - "execute_model", - driver_kwargs={"execute_model_req": execute_model_req}, - use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + def _driver_execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + """Run execute_model in the driver worker. - # Only the driver worker returns the sampling results. - return all_outputs[0] + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + return self.driver_worker.execute_method("execute_model", + execute_model_req) def _run_workers( self, method: str, *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, + async_run_remote_workers_only: bool = False, all_args: Optional[List[Tuple[Any, ...]]] = None, all_kwargs: Optional[List[Dict[str, Any]]] = None, use_dummy_driver: bool = False, @@ -198,9 +200,11 @@ def _run_workers( """Runs the given method on all workers. Can be used in the following ways: + - async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than blocking + on the results. - args/kwargs: All workers share the same args/kwargs - - args/kwargs and driver_args/driver_kwargs: Driver worker has - different args - all_args/all_kwargs: args/kwargs for each worker are specified individually """ @@ -209,11 +213,6 @@ def _run_workers( raise NotImplementedError( "max_concurrent_workers is not supported yet.") - if driver_args is None: - driver_args = args if all_args is None else all_args[0] - if driver_kwargs is None: - driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] - count = len(self.workers) all_worker_args = repeat(args, count) if all_args is None \ else islice(all_args, 1, None) @@ -225,6 +224,7 @@ def _run_workers( # input. TODO(sang): Fix it. assert self.forward_dag is not None output_channels = self.forward_dag.execute(1) + ray_worker_outputs = [] else: # Start the ray workers first. ray_worker_outputs = [ @@ -234,6 +234,13 @@ def _run_workers( ) in zip(self.workers, all_worker_args, all_worker_kwargs) ] + if async_run_remote_workers_only: + # Just return futures + return ray_worker_outputs + + driver_args = args if all_args is None else all_args[0] + driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0] + # Start the driver worker after all the ray workers. if not use_dummy_driver: driver_worker_output = self.driver_worker.execute_method( @@ -260,6 +267,11 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + ray.get(parallel_worker_tasks) + def _compiled_ray_dag(self): import pkg_resources required_version = "2.9" @@ -303,30 +315,18 @@ class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.driver_executor = make_async(self.driver_worker.execute_method) + self.driver_exec_method = make_async(self.driver_worker.execute_method) - async def _run_workers_async( + async def _driver_execute_model_async( self, - method: str, - *args, - driver_args: Optional[Tuple[Any, ...]] = None, - driver_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - coros = [] - - if driver_args is None: - driver_args = args - if driver_kwargs is None: - driver_kwargs = kwargs - - coros.append( - self.driver_executor(method, *driver_args, **driver_kwargs)) - - # Run the ray workers asynchronously. - for worker in self.workers: - coros.append(worker.execute_method.remote(method, *args, **kwargs)) - - all_outputs = await asyncio.gather(*coros) - return all_outputs + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + return await self.driver_exec_method("execute_model", + execute_model_req) + + async def _start_worker_execution_loop(self): + coros = [ + worker.execute_method.remote("start_worker_execution_loop") + for worker in self.workers + ] + return await asyncio.gather(*coros) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 9628f7af5315a..c2b22f2acd7b4 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -47,7 +47,9 @@ def set_include_gpu_probs_tensor(self): # NGram don't need gpu sampler pass - def execute_model(self, execute_model_req: ExecuteModelRequest) -> None: + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None) -> None: """NGram doesn't depend on model execution, just pass this function""" pass diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index ef17b8c1e2cc0..3462a876c3e90 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -231,35 +231,6 @@ def initialize_cache(self, num_gpu_blocks: int, self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks) - def _broadcast_control_flow_decision( - self, - execute_model_req: Optional[ExecuteModelRequest] = None, - disable_all_speculation: bool = False) -> Tuple[int, bool]: - """Broadcast how many lookahead slots are scheduled for this step, and - whether all speculation is disabled, to all non-driver workers. - - This is required as if the number of draft model runs changes - dynamically, the non-driver workers won't know unless we perform a - communication to inform then. - - Returns the broadcasted num_lookahead_slots and disable_all_speculation. - """ - - if self.rank == self._driver_rank: - assert execute_model_req is not None - - broadcast_dict = dict( - num_lookahead_slots=execute_model_req.num_lookahead_slots, - disable_all_speculation=disable_all_speculation, - ) - broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) - else: - assert execute_model_req is None - broadcast_dict = broadcast_tensor_dict(src=self._driver_rank) - - return (broadcast_dict["num_lookahead_slots"], - broadcast_dict["disable_all_speculation"]) - @torch.inference_mode() def execute_model( self, @@ -267,39 +238,58 @@ def execute_model( ) -> List[SamplerOutput]: """Perform speculative decoding on the input batch. """ + if self.rank != self._driver_rank: + self._run_non_driver_rank() + return [] - disable_all_speculation = False - if self.rank == self._driver_rank: - disable_all_speculation = self._should_disable_all_speculation( - execute_model_req) - - (num_lookahead_slots, - disable_all_speculation) = self._broadcast_control_flow_decision( - execute_model_req, disable_all_speculation) - - if self.rank == self._driver_rank: - assert execute_model_req is not None - assert execute_model_req.seq_group_metadata_list is not None, ( - "speculative decoding requires non-None seq_group_metadata_list" - ) - - self._maybe_disable_speculative_tokens( - disable_all_speculation, - execute_model_req.seq_group_metadata_list) - - # If no spec tokens, call the proposer and scorer workers normally. - # Used for prefill. - if num_lookahead_slots == 0 or len( - execute_model_req.seq_group_metadata_list) == 0: - return self._run_no_spec(execute_model_req, - skip_proposer=disable_all_speculation) - - return self._run_speculative_decoding_step(execute_model_req, - num_lookahead_slots) - else: - self._run_non_driver_rank(num_lookahead_slots) + if execute_model_req is None: + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) return [] + disable_all_speculation = self._should_disable_all_speculation( + execute_model_req) + num_lookahead_slots = execute_model_req.num_lookahead_slots + + # Broadcast how many lookahead slots are scheduled for this step, and + # whether all speculation is disabled, to all non-driver workers. + + # This is required as if the number of draft model runs changes + # dynamically, the non-driver workers won't know unless we perform a + # communication to inform then. + broadcast_dict = dict( + num_lookahead_slots=num_lookahead_slots, + disable_all_speculation=disable_all_speculation, + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) + + assert execute_model_req.seq_group_metadata_list is not None, ( + "speculative decoding requires non-None seq_group_metadata_list") + + self._maybe_disable_speculative_tokens( + disable_all_speculation, execute_model_req.seq_group_metadata_list) + + # If no spec tokens, call the proposer and scorer workers normally. + # Used for prefill. + if num_lookahead_slots == 0 or len( + execute_model_req.seq_group_metadata_list) == 0: + return self._run_no_spec(execute_model_req, + skip_proposer=disable_all_speculation) + + return self._run_speculative_decoding_step(execute_model_req, + num_lookahead_slots) + + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop to perform speculative decoding + in parallel worker.""" + while self._run_non_driver_rank(): + pass + def _should_disable_all_speculation( self, execute_model_req: ExecuteModelRequest) -> bool: # When the batch size is too large, disable speculative decoding @@ -346,13 +336,19 @@ def _run_no_spec(self, execute_model_req: ExecuteModelRequest, sampler_output.logprobs = None return [sampler_output] - def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: + def _run_non_driver_rank(self) -> bool: """Run proposer and verifier model in non-driver workers. This is used for both speculation cases (num_lookahead_slots>0) and non-speculation cases (e.g. prefill). + + Returns True iff there are remaining sequences to process. """ - # In non-driver workers the input is None - execute_model_req = None + assert self.rank != self._driver_rank + + data = broadcast_tensor_dict(src=self._driver_rank) + if not data: + return False + num_lookahead_slots = data["num_lookahead_slots"] # Even if num_lookahead_slots is zero, we want to run the proposer model # as it may have KV. @@ -360,9 +356,10 @@ def _run_non_driver_rank(self, num_lookahead_slots: int) -> None: # We run the proposer once per lookahead slot. In the future we should # delegate how many times it runs to the proposer. for _ in range(max(num_lookahead_slots, 1)): - self.proposer_worker.execute_model(execute_model_req) + self.proposer_worker.execute_model() - self.scorer_worker.execute_model(execute_model_req) + self.scorer_worker.execute_model() + return True @nvtx_range("spec_decode_worker._run_speculative_decoding_step") def _run_speculative_decoding_step( diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 91f30978ead87..ef02de95fc54e 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -47,7 +47,7 @@ def __init__( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[PoolerOutput]: (input_tokens, input_positions, attn_metadata, pooling_metadata, @@ -84,10 +84,11 @@ def execute_model( def prepare_input_tensors( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, PoolingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: + assert seq_group_metadata_list is not None # Prepare input tensors. ( input_tokens, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 9720363ac300e..87d5f5c1b9d67 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -609,10 +609,11 @@ def _prepare_model_input( def prepare_input_tensors( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, SamplingMetadata, Set[LoRARequest], LoRAMapping, torch.Tensor]: if self.is_driver_worker: + assert seq_group_metadata_list is not None # Prepare input tensors. ( input_tokens, @@ -676,7 +677,7 @@ def prepare_input_tensors( @torch.inference_mode() def execute_model( self, - seq_group_metadata_list: List[SequenceGroupMetadata], + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[torch.Tensor], ) -> Optional[SamplerOutput]: (input_tokens, input_positions, attn_metadata, sampling_metadata, diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 97b3873b2a9f6..10411a2bf7a10 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -226,48 +226,42 @@ def execute_model( self, execute_model_req: Optional[ExecuteModelRequest] = None ) -> List[Union[SamplerOutput, PoolerOutput]]: + if not self.is_driver_worker: + self._execute_model_non_driver() + return [] if execute_model_req is None: - seq_group_metadata_list = None - else: - seq_group_metadata_list = execute_model_req.seq_group_metadata_list + # This signals that there's no more requests to process for now. + # All workers are running infinite loop with broadcast_tensor_dict, + # and it stops the loop when the driver broadcasts an empty input. + # Send an empty input to notify all other workers to stop their + # execution loop. + broadcast_tensor_dict({}, src=0) + return [] - blocks_to_swap_in: torch.Tensor - blocks_to_swap_out: torch.Tensor - blocks_to_copy: torch.Tensor - if self.is_driver_worker: - assert seq_group_metadata_list is not None - assert execute_model_req is not None - num_seq_groups = len(seq_group_metadata_list) - # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. - # they contain parameters to launch cudamemcpyasync. - blocks_to_swap_in = torch.tensor( - execute_model_req.blocks_to_swap_in, - device="cpu", - dtype=torch.int64).view(-1, 2) - blocks_to_swap_out = torch.tensor( - execute_model_req.blocks_to_swap_out, - device="cpu", - dtype=torch.int64).view(-1, 2) - # `blocks_to_copy` is a gpu tensor. The src and tgt of - # blocks to copy are in the same device, and `blocks_to_copy` - # can be used directly within cuda kernels. - blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, - device=self.device, + seq_group_metadata_list = execute_model_req.seq_group_metadata_list + num_seq_groups = len(seq_group_metadata_list) + # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. + # they contain parameters to launch cudamemcpyasync. + blocks_to_swap_in = torch.tensor(execute_model_req.blocks_to_swap_in, + device="cpu", + dtype=torch.int64).view(-1, 2) + blocks_to_swap_out = torch.tensor(execute_model_req.blocks_to_swap_out, + device="cpu", dtype=torch.int64).view(-1, 2) - data: Dict[str, Any] = { - "num_seq_groups": num_seq_groups, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - } - broadcast_tensor_dict(data, src=0) - else: - data = broadcast_tensor_dict(src=0) - num_seq_groups = data["num_seq_groups"] - blocks_to_swap_in = data["blocks_to_swap_in"] - blocks_to_swap_out = data["blocks_to_swap_out"] - blocks_to_copy = data["blocks_to_copy"] + # `blocks_to_copy` is a gpu tensor. The src and tgt of + # blocks to copy are in the same device, and `blocks_to_copy` + # can be used directly within cuda kernels. + blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy, + device=self.device, + dtype=torch.int64).view(-1, 2) + data: Dict[str, Any] = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) @@ -282,6 +276,39 @@ def execute_model( # to conform to interface. return [output] + @torch.inference_mode() + def start_worker_execution_loop(self) -> None: + """Execute model loop in parallel worker. + + You can stop the loop by executing a driver worker with an empty output. + See `stop_remote_worker_execution_loop` for more details. + """ + while self._execute_model_non_driver(): + pass + + def _execute_model_non_driver(self) -> bool: + """Execute model in parallel worker. + + Returns True iff there are remaining sequences to process. + """ + assert not self.is_driver_worker + data = broadcast_tensor_dict(src=0) + if not data: + return False + + num_seq_groups = data.get("num_seq_groups", 0) + blocks_to_swap_in = data.get("blocks_to_swap_in") + blocks_to_swap_out = data.get("blocks_to_swap_out") + blocks_to_copy = data.get("blocks_to_copy") + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return False + + self.model_runner.execute_model(None, self.gpu_cache) + return True + def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_runner.add_lora(lora_request) diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 1f04f821eb0f0..dbac1b5ba339b 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,7 +1,7 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict, List, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -48,8 +48,9 @@ def initialize_cache(self, num_gpu_blocks: int, @abstractmethod def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" raise NotImplementedError