From a2d5cc553951a01f845de71c0bc09ef74c06c341 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 25 Aug 2023 15:01:05 -0700 Subject: [PATCH 01/10] Refactor AsyncLLMEngine Signed-off-by: Antoni Baum --- vllm/core/scheduler.py | 13 +- vllm/engine/async_llm_engine.py | 238 ++++++++++++++++++-------------- vllm/engine/llm_engine.py | 117 ++++++++++++---- 3 files changed, 234 insertions(+), 134 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f59102cb13087..06c0f83a6721a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,6 @@ import enum import time -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union, Iterable from vllm.config import CacheConfig, SchedulerConfig from vllm.core.block_manager import BlockSpaceManager @@ -84,17 +84,22 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) - def abort_seq_group(self, request_id: str) -> None: + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) for state_queue in [self.waiting, self.running, self.swapped]: for seq_group in state_queue: - if seq_group.request_id == request_id: + if seq_group.request_id in request_ids: # Remove the sequence group from the state queue. state_queue.remove(seq_group) for seq in seq_group.seqs: if seq.is_finished(): continue self.free_seq(seq, SequenceStatus.FINISHED_ABORTED) - return + request_ids.remove(seq_group.request_id) + if not request_ids: + return def has_unfinished_seqs(self) -> bool: return self.waiting or self.running or self.swapped diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9049505d1f8c1..fa42cf47ad337 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Iterable from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -14,6 +14,44 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds +class Stream: + """A stream of outputs for a request that can be + iterated over asynchronously.""" + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._queue = asyncio.Queue() + self._finished = False + + def put(self, item: RequestOutput) -> None: + if self._finished: + return + self._queue.put_nowait(item) + + def finish(self) -> None: + self._queue.put_nowait(StopIteration) + self._finished = True + + @property + def finished(self) -> bool: + return self._finished + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._queue.get() + if result is StopIteration: + raise StopAsyncIteration + return result + +def _raise_exception_on_finish( + task: asyncio.Task, +) -> None: + try: + task.result() + except Exception as e: + raise RuntimeError("Task finished unexpectedly.") from e + raise RuntimeError("Task finished unexpectedly.") class AsyncLLMEngine: """An asynchronous wrapper for LLMEngine. @@ -42,6 +80,7 @@ def __init__(self, engine_use_ray: bool, *args, log_requests: bool = True, + inline: bool = False, **kwargs) -> None: self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray @@ -53,33 +92,84 @@ def __init__(self, else: engine_class = ray.remote(num_gpus=1)(LLMEngine).remote self.engine = engine_class(*args, **kwargs) - # Request id -> request output. - self.request_outputs: Dict[str, RequestOutput] = {} - # Request id -> event to notify that there is new output. - self.request_events: Dict[str, asyncio.Event] = {} - self.is_engine_running = False - self.kicking_request_id: Optional[str] = None - - async def engine_step(self, kicking_request_id: Optional[str] = None): + # Request id -> stream. + self.request_streams: Dict[str, Stream] = {} + self.background_loop = None + if not inline: + # Start the background loop. + self.background_loop = asyncio.get_event_loop().create_task(self.run_engine_loop()) + self.background_loop.add_done_callback(_raise_exception_on_finish) + + async def engine_step(self): """Kick the engine to process the waiting requests.""" - self.is_engine_running = True - self.kicking_request_id = kicking_request_id if self.engine_use_ray: request_outputs = await self.engine.step.remote() else: - # Yield to the event loop to allow other coroutines to run - # while is_engine_running is True. This let the engine to add new - # requests into the queue. - await asyncio.sleep(0) - request_outputs = self.engine.step() - self.is_engine_running = False - self.kicking_request_id = None + request_outputs = await self.engine.step_async() - # Notify the waiting coroutines that there are new outputs ready. + # Put the outputs into the corresponding streams. for request_output in request_outputs: request_id = request_output.request_id - self.request_outputs[request_id] = request_output - self.request_events[request_id].set() + self.request_streams[request_id].put(request_output) + if request_output.finished: + if self.log_requests: + logger.info(f"Finished request {request_id}.") + self.request_streams[request_id].finish() + + # Clean up aborted and finished requests. + finished_requests = set() + for stream in self.request_streams.values(): + if stream.finished: + finished_requests.add(stream.request_id) + + await self.engine_abort(finished_requests) + for request_id in finished_requests: + del self.request_streams[request_id] + + async def engine_abort(self, request_ids: Iterable[str]): + if self.engine_use_ray: + await self.engine.abort_request.remote(request_ids) + else: + self.engine.abort_request(request_ids) + + async def run_engine_loop(self): + while True: + await self.engine_step() + await asyncio.sleep(0) + + async def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + ) -> Stream: + if self.log_requests: + logger.info(f"Received request {request_id}: " + f"prompt: {prompt!r}, " + f"sampling params: {sampling_params}, " + f"prompt token ids: {prompt_token_ids}.") + + stream = Stream(request_id) + self.request_streams[request_id] = stream + + # Add the request into the vLLM engine's waiting queue. + if self.engine_use_ray: + await self.engine.add_request.remote( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) + else: + self.engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) + + return stream async def generate( self, @@ -108,78 +198,32 @@ async def generate( # Preprocess the request. arrival_time = time.time() - # Create an event to notify us that there is new output from the - # vLLM engine. - request_event = asyncio.Event() - self.request_events[request_id] = request_event + stream = await self.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) - if self.log_requests: - logger.info(f"Received request {request_id}: " - f"prompt: {prompt!r}, " - f"sampling params: {sampling_params}, " - f"prompt token ids: {prompt_token_ids}.") + try: + async for request_output in stream: + yield request_output + except Exception as e: + # If there is an exception, abort the request. + self._abort(request_id) + raise e - # Add the request into the vLLM engine's waiting queue. - if self.engine_use_ray: - await self.engine.add_request.remote( - request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) - else: - self.engine.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + async def abort(self, request_id: str) -> None: + """Abort a request. - # The vLLM engine does not have a background loop that keeps - # processing incoming requests. Therefore, we need to keep kicking - # the engine to process the requests. - while True: - if request_id not in self.request_events: - # The request has been aborted. - return - - # Kick the engine if the engine is not running. - if not self.is_engine_running: - try: - await self.engine_step(request_id) - except RuntimeError as e: - await self.abort(request_id) - raise e - - # Wait for new output. The group_event will be set in engine_step - # when there is new output available for the sequence group. - # Added a timeout to prevent deadlock. - try: - await asyncio.wait_for(request_event.wait(), - timeout=TIMEOUT_TO_PREVENT_DEADLOCK) - except asyncio.TimeoutError: - continue - # Reset the event to wait for the next output. - request_event.clear() - - # Decode and return new outputs. - request_output = self.request_outputs[request_id] - yield request_output - - # Once finished, release the resources of the sequence group. - if request_output.finished: - if self.log_requests: - logger.info(f"Finished request {request_id}.") + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. - del self.request_outputs[request_id] - del self.request_events[request_id] - # Kick the engine if the engine is not running. This is to - # prevent that there are still requests in engine's waiting - # queue to be executed. - if not self.is_engine_running: - await self.engine_step() - break + Args: + request_id: The unique id of the request. + """ + return self._abort(request_id) - async def abort(self, request_id: str) -> None: + def _abort(self, request_id: str) -> None: """Abort a request. Abort a submitted request. If the request is finished or not found, @@ -188,28 +232,14 @@ async def abort(self, request_id: str) -> None: Args: request_id: The unique id of the request. """ - if request_id not in self.request_events: + if request_id not in self.request_streams or self.request_streams[request_id].finished: # The request has already finished or been aborted. return if self.log_requests: logger.info(f"Aborted request {request_id}.") - if self.engine_use_ray: - await self.engine.abort_request.remote(request_id) - else: - self.engine.abort_request(request_id) - - if request_id in self.request_events: - del self.request_events[request_id] - if request_id in self.request_outputs: - del self.request_outputs[request_id] - - # To prevent deadlock when a request is aborted while the engine is - # running. - if self.kicking_request_id == request_id: - self.is_engine_running = False - self.kicking_request_id = None + self.request_streams[request_id].finish() async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 908d01d959fd8..5401f507e6fe2 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,17 +1,17 @@ import time import copy from functools import partial -from typing import Any, List, Optional, Tuple, TYPE_CHECKING - +from typing import Any, List, Optional, Tuple, Union, Iterable, TYPE_CHECKING +import asyncio from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.core.scheduler import Scheduler +from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, SequenceGroupMetadata from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter @@ -268,11 +268,11 @@ def add_request( # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) - def abort_request(self, request_id: str) -> None: - """Aborts a request with the given ID. + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. Args: - request_id: The ID of the request to abort. + request_id: The ID(s) of the request to abort. """ self.scheduler.abort_seq_group(request_id) @@ -288,35 +288,21 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def step(self) -> List[RequestOutput]: - """Performs one decoding iteration and returns newly generated results. - - This function performs one decoding iteration of the engine. It first - schedules the sequences to be executed in the next iteration and the - token blocks to be swapped in/out/copy. Then, it executes the model - and updates the scheduler with the model outputs. Finally, it decodes - the sequences and returns the newly generated results. - """ + def _schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Optional[List[RequestOutput]]]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if scheduler_outputs.is_empty(): if not scheduler_outputs.ignored_seq_groups: # Nothing to do. - return [] + return seq_group_metadata_list, scheduler_outputs, [] # If there are ignored seq groups, we need to return them as the # request outputs. - return [ + return seq_group_metadata_list, scheduler_outputs, [ RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups ] + return seq_group_metadata_list, scheduler_outputs, None - # Execute the model. - output = self._run_workers( - "execute_model", - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) + def _process_worker_outputs(self, output, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: # Update the scheduler with the model outputs. seq_groups = self.scheduler.update(output) @@ -339,6 +325,55 @@ def step(self) -> List[RequestOutput]: scheduler_outputs.num_batched_tokens) return request_outputs + async def step_async(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + seq_group_metadata_list, scheduler_outputs, early_return = self._schedule() + if early_return is not None: + return early_return + + # Execute the model. + output = await self._run_workers_async( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + ) + + return self._process_worker_outputs(output, scheduler_outputs) + + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + seq_group_metadata_list, scheduler_outputs, early_return = self._schedule() + if early_return is not None: + return early_return + + # Execute the model. + output = self._run_workers( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + ) + + return self._process_worker_outputs(output, scheduler_outputs) + def _log_system_stats( self, prompt_run: bool, @@ -481,3 +516,33 @@ def _run_workers( for other_output in all_outputs[1:]: assert output == other_output return output + + async def _run_workers_async( + self, + method: str, + *args, + get_all_outputs: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + all_outputs = [] + for worker in self.workers: + if self.parallel_config.worker_use_ray: + executor = partial(worker.execute_method.remote, method) + else: + executor = getattr(worker, method) + + output = executor(*args, **kwargs) + all_outputs.append(output) + + if self.parallel_config.worker_use_ray: + all_outputs = await asyncio.gather(*all_outputs) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output + return output From 9107bc93d086b5d94834abbeb664ebf6fd2634b9 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 25 Aug 2023 15:17:22 -0700 Subject: [PATCH 02/10] Lint Signed-off-by: Antoni Baum --- vllm/engine/async_llm_engine.py | 28 +++++++++++++++++----------- vllm/engine/llm_engine.py | 15 +++++++++++---- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index fa42cf47ad337..f8ff1ebea9b31 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -14,9 +14,11 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds + class Stream: """A stream of outputs for a request that can be iterated over asynchronously.""" + def __init__(self, request_id: str) -> None: self.request_id = request_id self._queue = asyncio.Queue() @@ -44,15 +46,15 @@ async def __anext__(self) -> RequestOutput: raise StopAsyncIteration return result -def _raise_exception_on_finish( - task: asyncio.Task, -) -> None: + +def _raise_exception_on_finish(task: asyncio.Task, ) -> None: try: task.result() except Exception as e: raise RuntimeError("Task finished unexpectedly.") from e raise RuntimeError("Task finished unexpectedly.") + class AsyncLLMEngine: """An asynchronous wrapper for LLMEngine. @@ -97,7 +99,8 @@ def __init__(self, self.background_loop = None if not inline: # Start the background loop. - self.background_loop = asyncio.get_event_loop().create_task(self.run_engine_loop()) + self.background_loop = asyncio.get_event_loop().create_task( + self.run_engine_loop()) self.background_loop.add_done_callback(_raise_exception_on_finish) async def engine_step(self): @@ -121,12 +124,14 @@ async def engine_step(self): for stream in self.request_streams.values(): if stream.finished: finished_requests.add(stream.request_id) + if finished_requests: + print(finished_requests) - await self.engine_abort(finished_requests) + await self._engine_abort(finished_requests) for request_id in finished_requests: del self.request_streams[request_id] - async def engine_abort(self, request_ids: Iterable[str]): + async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: await self.engine.abort_request.remote(request_ids) else: @@ -199,10 +204,10 @@ async def generate( arrival_time = time.time() stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) try: async for request_output in stream: @@ -232,7 +237,8 @@ def _abort(self, request_id: str) -> None: Args: request_id: The unique id of the request. """ - if request_id not in self.request_streams or self.request_streams[request_id].finished: + if request_id not in self.request_streams or self.request_streams[ + request_id].finished: # The request has already finished or been aborted. return diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 5401f507e6fe2..8cc43524cb8d1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -288,7 +288,10 @@ def has_unfinished_requests(self) -> bool: """Returns True if there are unfinished requests.""" return self.scheduler.has_unfinished_seqs() - def _schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Optional[List[RequestOutput]]]: + def _schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, + Optional[List[RequestOutput]]]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if scheduler_outputs.is_empty(): if not scheduler_outputs.ignored_seq_groups: @@ -302,7 +305,9 @@ def _schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Opti ] return seq_group_metadata_list, scheduler_outputs, None - def _process_worker_outputs(self, output, scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + def _process_worker_outputs( + self, output, + scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: # Update the scheduler with the model outputs. seq_groups = self.scheduler.update(output) @@ -335,7 +340,8 @@ async def step_async(self) -> List[RequestOutput]: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs, early_return = self._schedule() + (seq_group_metadata_list, scheduler_outputs, + early_return) = self._schedule() if early_return is not None: return early_return @@ -359,7 +365,8 @@ def step(self) -> List[RequestOutput]: and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs, early_return = self._schedule() + (seq_group_metadata_list, scheduler_outputs, + early_return) = self._schedule() if early_return is not None: return early_return From 90ffbccb0c3403ce972b7977457247abec403720 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 25 Aug 2023 15:20:53 -0700 Subject: [PATCH 03/10] Rename Signed-off-by: Antoni Baum --- vllm/engine/async_llm_engine.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index f8ff1ebea9b31..5a832a502f8b4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -15,8 +15,8 @@ TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds -class Stream: - """A stream of outputs for a request that can be +class AsyncStream: + """A stream of RequestOutputs for a request that can be iterated over asynchronously.""" def __init__(self, request_id: str) -> None: @@ -95,7 +95,7 @@ def __init__(self, engine_class = ray.remote(num_gpus=1)(LLMEngine).remote self.engine = engine_class(*args, **kwargs) # Request id -> stream. - self.request_streams: Dict[str, Stream] = {} + self.request_streams: Dict[str, AsyncStream] = {} self.background_loop = None if not inline: # Start the background loop. @@ -149,14 +149,14 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, - ) -> Stream: + ) -> AsyncStream: if self.log_requests: logger.info(f"Received request {request_id}: " f"prompt: {prompt!r}, " f"sampling params: {sampling_params}, " f"prompt token ids: {prompt_token_ids}.") - stream = Stream(request_id) + stream = AsyncStream(request_id) self.request_streams[request_id] = stream # Add the request into the vLLM engine's waiting queue. From 302d6fa28b2a43939e1efc0d9804355ebe9f1518 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 25 Aug 2023 15:39:21 -0700 Subject: [PATCH 04/10] Add ray_remote_kwargs Signed-off-by: Antoni Baum --- vllm/engine/llm_engine.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8cc43524cb8d1..a5a38d2f1e3f3 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -135,7 +135,8 @@ def _init_workers(self, distributed_init_method: str): get_all_outputs=True, ) - def _init_workers_ray(self, placement_group: "PlacementGroup"): + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel @@ -150,6 +151,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup"): scheduling_strategy=PlacementGroupSchedulingStrategy( placement_group=placement_group, placement_group_capture_child_tasks=True), + **ray_remote_kwargs, )(RayWorker).remote() self.workers.append(worker) From de307249eba8aa75b65d2260a2a2d409811d7bd0 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 25 Aug 2023 15:41:02 -0700 Subject: [PATCH 05/10] Nit Signed-off-by: Antoni Baum --- vllm/engine/async_llm_engine.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 5a832a502f8b4..fea4840140ff6 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Dict, List, Optional, Iterable +from typing import Dict, List, Optional, Iterable, Type from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -77,6 +77,8 @@ class AsyncLLMEngine: *args, *kwargs: Arguments for LLMEngine. """ + _engine_class: Type[LLMEngine] = LLMEngine + def __init__(self, worker_use_ray: bool, engine_use_ray: bool, @@ -88,11 +90,11 @@ def __init__(self, self.engine_use_ray = engine_use_ray self.log_requests = log_requests if not self.engine_use_ray: - engine_class = LLMEngine + engine_class = self._engine_class elif self.worker_use_ray: - engine_class = ray.remote(num_cpus=0)(LLMEngine).remote + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote else: - engine_class = ray.remote(num_gpus=1)(LLMEngine).remote + engine_class = ray.remote(num_gpus=1)(self._engine_class).remote self.engine = engine_class(*args, **kwargs) # Request id -> stream. self.request_streams: Dict[str, AsyncStream] = {} From e44fca4ae8d9afa50b90a9568b6249bef6abb76b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 25 Aug 2023 16:25:49 -0700 Subject: [PATCH 06/10] Remove debug Signed-off-by: Antoni Baum --- vllm/engine/async_llm_engine.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index fea4840140ff6..9c9f22c411a74 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -126,8 +126,6 @@ async def engine_step(self): for stream in self.request_streams.values(): if stream.finished: finished_requests.add(stream.request_id) - if finished_requests: - print(finished_requests) await self._engine_abort(finished_requests) for request_id in finished_requests: From 9b1b34a590c35b8a5ee9354e5194c82c5aecf7d7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 25 Aug 2023 16:57:18 -0700 Subject: [PATCH 07/10] Nit Signed-off-by: Antoni Baum --- vllm/engine/async_llm_engine.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 9c9f22c411a74..4e33d7251b588 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -89,13 +89,8 @@ def __init__(self, self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray self.log_requests = log_requests - if not self.engine_use_ray: - engine_class = self._engine_class - elif self.worker_use_ray: - engine_class = ray.remote(num_cpus=0)(self._engine_class).remote - else: - engine_class = ray.remote(num_gpus=1)(self._engine_class).remote - self.engine = engine_class(*args, **kwargs) + self.engine = self._init_engine(*args, **kwargs) + # Request id -> stream. self.request_streams: Dict[str, AsyncStream] = {} self.background_loop = None @@ -105,6 +100,15 @@ def __init__(self, self.run_engine_loop()) self.background_loop.add_done_callback(_raise_exception_on_finish) + def _init_engine(self, *args, **kwargs) -> LLMEngine: + if not self.engine_use_ray: + engine_class = self._engine_class + elif self.worker_use_ray: + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + else: + engine_class = ray.remote(num_gpus=1)(self._engine_class).remote + return engine_class(*args, **kwargs) + async def engine_step(self): """Kick the engine to process the waiting requests.""" if self.engine_use_ray: From 29a0ad4ccbcf6c780bb911d2e4d3ba25d19ac797 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 31 Aug 2023 15:32:17 -0700 Subject: [PATCH 08/10] Apply feedback from code review Signed-off-by: Antoni Baum --- vllm/core/scheduler.py | 2 +- vllm/engine/async_llm_engine.py | 83 +++++++++++++++++++++++++++------ vllm/engine/llm_engine.py | 78 ++++--------------------------- 3 files changed, 81 insertions(+), 82 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 06c0f83a6721a..cb4b7988bb89a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,6 +1,6 @@ import enum import time -from typing import Dict, List, Optional, Tuple, Union, Iterable +from typing import Dict, Iterable, List, Optional, Tuple, Union from vllm.config import CacheConfig, SchedulerConfig from vllm.core.block_manager import BlockSpaceManager diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4e33d7251b588..82c9e10ffdc0f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,6 +1,7 @@ import asyncio import time -from typing import Dict, List, Optional, Iterable, Type +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Set, Type from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -12,8 +13,6 @@ logger = init_logger(__name__) -TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds - class AsyncStream: """A stream of RequestOutputs for a request that can be @@ -47,7 +46,7 @@ async def __anext__(self) -> RequestOutput: return result -def _raise_exception_on_finish(task: asyncio.Task, ) -> None: +def _raise_exception_on_finish(task: asyncio.Task) -> None: try: task.result() except Exception as e: @@ -55,6 +54,66 @@ def _raise_exception_on_finish(task: asyncio.Task, ) -> None: raise RuntimeError("Task finished unexpectedly.") +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + async def step_async(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + (seq_group_metadata_list, scheduler_outputs, + early_return) = self._schedule() + if early_return is not None: + return early_return + + # Execute the model. + output = await self._run_workers_async( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + ) + + return self._process_worker_outputs(output, scheduler_outputs) + + async def _run_workers_async( + self, + method: str, + *args, + get_all_outputs: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + all_outputs = [] + for worker in self.workers: + if self.parallel_config.worker_use_ray: + executor = partial(worker.execute_method.remote, method) + else: + executor = getattr(worker, method) + + output = executor(*args, **kwargs) + all_outputs.append(output) + + if self.parallel_config.worker_use_ray: + all_outputs = await asyncio.gather(*all_outputs) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output + return output + + class AsyncLLMEngine: """An asynchronous wrapper for LLMEngine. @@ -77,7 +136,7 @@ class AsyncLLMEngine: *args, *kwargs: Arguments for LLMEngine. """ - _engine_class: Type[LLMEngine] = LLMEngine + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine def __init__(self, worker_use_ray: bool, @@ -93,6 +152,7 @@ def __init__(self, # Request id -> stream. self.request_streams: Dict[str, AsyncStream] = {} + self.finished_requests: Set[str] = set() self.background_loop = None if not inline: # Start the background loop. @@ -124,16 +184,12 @@ async def engine_step(self): if self.log_requests: logger.info(f"Finished request {request_id}.") self.request_streams[request_id].finish() + self.finished_requests.add(request_id) - # Clean up aborted and finished requests. - finished_requests = set() - for stream in self.request_streams.values(): - if stream.finished: - finished_requests.add(stream.request_id) - - await self._engine_abort(finished_requests) - for request_id in finished_requests: + await self._engine_abort(self.finished_requests) + for request_id in self.finished_requests: del self.request_streams[request_id] + self.finished_requests.clear() async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: @@ -250,6 +306,7 @@ def _abort(self, request_id: str) -> None: logger.info(f"Aborted request {request_id}.") self.request_streams[request_id].finish() + self.finished_requests.add(request_id) async def get_model_config(self) -> ModelConfig: """Get the model configuration of the vLLM engine.""" diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a5a38d2f1e3f3..e098c372dd31a 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,17 +1,18 @@ -import time import copy +import time from functools import partial -from typing import Any, List, Optional, Tuple, Union, Iterable, TYPE_CHECKING -import asyncio +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union + from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs -from vllm.engine.ray_utils import initialize_cluster, ray, RayWorker +from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import Sequence, SequenceGroup, SequenceStatus, SequenceGroupMetadata +from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, get_tokenizer) from vllm.utils import Counter @@ -116,7 +117,8 @@ def __init__( def _init_workers(self, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel + from vllm.worker.worker import \ + Worker # pylint: disable=import-outside-toplevel assert self.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") @@ -139,7 +141,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel + from vllm.worker.worker import \ + Worker # pylint: disable=import-outside-toplevel self.workers: List[Worker] = [] for bundle in placement_group.bundle_specs: @@ -296,11 +299,6 @@ def _schedule( Optional[List[RequestOutput]]]: seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() if scheduler_outputs.is_empty(): - if not scheduler_outputs.ignored_seq_groups: - # Nothing to do. - return seq_group_metadata_list, scheduler_outputs, [] - # If there are ignored seq groups, we need to return them as the - # request outputs. return seq_group_metadata_list, scheduler_outputs, [ RequestOutput.from_seq_group(seq_group) for seq_group in scheduler_outputs.ignored_seq_groups @@ -332,32 +330,6 @@ def _process_worker_outputs( scheduler_outputs.num_batched_tokens) return request_outputs - async def step_async(self) -> List[RequestOutput]: - """Performs one decoding iteration and returns newly generated results. - The workers are ran asynchronously if possible. - - This function performs one decoding iteration of the engine. It first - schedules the sequences to be executed in the next iteration and the - token blocks to be swapped in/out/copy. Then, it executes the model - and updates the scheduler with the model outputs. Finally, it decodes - the sequences and returns the newly generated results. - """ - (seq_group_metadata_list, scheduler_outputs, - early_return) = self._schedule() - if early_return is not None: - return early_return - - # Execute the model. - output = await self._run_workers_async( - "execute_model", - seq_group_metadata_list=seq_group_metadata_list, - blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, - blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, - blocks_to_copy=scheduler_outputs.blocks_to_copy, - ) - - return self._process_worker_outputs(output, scheduler_outputs) - def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. @@ -525,33 +497,3 @@ def _run_workers( for other_output in all_outputs[1:]: assert output == other_output return output - - async def _run_workers_async( - self, - method: str, - *args, - get_all_outputs: bool = False, - **kwargs, - ) -> Any: - """Runs the given method on all workers.""" - all_outputs = [] - for worker in self.workers: - if self.parallel_config.worker_use_ray: - executor = partial(worker.execute_method.remote, method) - else: - executor = getattr(worker, method) - - output = executor(*args, **kwargs) - all_outputs.append(output) - - if self.parallel_config.worker_use_ray: - all_outputs = await asyncio.gather(*all_outputs) - - if get_all_outputs: - return all_outputs - - # Make sure all workers have the same results. - output = all_outputs[0] - for other_output in all_outputs[1:]: - assert output == other_output - return output From 217d2b48290d7921d5c58900d8476a64a25c5faa Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sat, 2 Sep 2023 23:02:08 -0700 Subject: [PATCH 09/10] Handle cancellations better Signed-off-by: Antoni Baum --- vllm/engine/async_llm_engine.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 82c9e10ffdc0f..2002d16494b1f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -263,13 +263,13 @@ async def generate( # Preprocess the request. arrival_time = time.time() - stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) - try: + stream = await self.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time) + async for request_output in stream: yield request_output except Exception as e: From d8fb81129dcf290537511dfa1acc2341d5c30cf5 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sun, 3 Sep 2023 18:51:00 -0700 Subject: [PATCH 10/10] Apply feedback from code review Signed-off-by: Antoni Baum --- vllm/engine/async_llm_engine.py | 24 +++++++++++++++--------- vllm/engine/llm_engine.py | 6 ++---- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 2002d16494b1f..54f3867694e5e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,7 +1,7 @@ import asyncio import time from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Set, Type +from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -143,7 +143,7 @@ def __init__(self, engine_use_ray: bool, *args, log_requests: bool = True, - inline: bool = False, + start_engine_loop: bool = False, **kwargs) -> None: self.worker_use_ray = worker_use_ray self.engine_use_ray = engine_use_ray @@ -154,13 +154,19 @@ def __init__(self, self.request_streams: Dict[str, AsyncStream] = {} self.finished_requests: Set[str] = set() self.background_loop = None - if not inline: - # Start the background loop. - self.background_loop = asyncio.get_event_loop().create_task( - self.run_engine_loop()) - self.background_loop.add_done_callback(_raise_exception_on_finish) - - def _init_engine(self, *args, **kwargs) -> LLMEngine: + if start_engine_loop: + self._start_background_loop() + + def _start_background_loop(self) -> None: + """Start the background loop.""" + if self.background_loop is not None: + raise RuntimeError("Background loop is already running.") + self.background_loop = asyncio.get_event_loop().create_task( + self.run_engine_loop()) + self.background_loop.add_done_callback(_raise_exception_on_finish) + + def _init_engine(self, *args, + **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: if not self.engine_use_ray: engine_class = self._engine_class elif self.worker_use_ray: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e098c372dd31a..54141bbe551e7 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -117,8 +117,7 @@ def __init__( def _init_workers(self, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import \ - Worker # pylint: disable=import-outside-toplevel + from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel assert self.parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") @@ -141,8 +140,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker - from vllm.worker.worker import \ - Worker # pylint: disable=import-outside-toplevel + from vllm.worker.worker import Worker # pylint: disable=import-outside-toplevel self.workers: List[Worker] = [] for bundle in placement_group.bundle_specs: