From 4ef41b84766670c1bd8079f58d35bf32b5bcb3ab Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Sun, 8 Sep 2024 00:01:51 -0400 Subject: [PATCH] [Bugfix] Fix async postprocessor in case of preemption (#8267) --- vllm/core/scheduler.py | 87 ++++++++------- vllm/engine/async_llm_engine.py | 24 ++-- vllm/engine/llm_engine.py | 149 ++++++++++++++++--------- vllm/worker/multi_step_model_runner.py | 26 +++-- 4 files changed, 172 insertions(+), 114 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 81c78bda3b505..c3fa95f57b737 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -537,13 +537,6 @@ def _schedule_running( preempted: List[SequenceGroup] = ret.preempted swapped_out: List[SequenceGroup] = ret.swapped_out - # NOTE(woosuk): Preemption happens only when there is no available slot - # to keep all the sequence groups in the RUNNING state. - - # Store original running requests for the case of async + preemption - if self.use_async_output_proc: - orig_running = self.running.copy() - running_queue = self.running assert len(self._async_stopped) == 0 while running_queue: @@ -552,6 +545,7 @@ def _schedule_running( seq_group, SequenceStatus.RUNNING, enable_chunking, budget) if num_running_tokens == 0: + # No budget => Stop break running_queue.popleft() @@ -565,18 +559,8 @@ def _schedule_running( self._async_stopped.append(seq_group) continue - # With async postprocessor, when preemption kicks in, we need - # first to drain the async postprocessor, so that all async - # block_table freeing is applied before the preemption freeing - # is applied. - if self.use_async_output_proc and not self._can_append_slots( - seq_group): - tmp = self.running - self.running = orig_running - assert self.output_proc_callback is not None - self.output_proc_callback() - self.running = tmp - + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) @@ -588,24 +572,43 @@ def _schedule_running( and seq_group.lora_int_id in curr_loras): curr_loras.remove(seq_group.lora_int_id) + # Determine victim sequence + cont_loop = True if running_queue: - # Preempt the lowest-priority sequence groups. + # Preempt the lowest-priority sequence group. victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: preempted_mode = self._preempt(victim_seq_group, blocks_to_swap_out) if preempted_mode == PreemptionMode.RECOMPUTE: preempted.append(victim_seq_group) else: swapped_out.append(victim_seq_group) - else: - # No other sequence groups can be preempted. - # Preempt the current sequence group. - preempted_mode = self._preempt(seq_group, - blocks_to_swap_out) - if preempted_mode == PreemptionMode.RECOMPUTE: - preempted.append(seq_group) - else: - swapped_out.append(seq_group) + + if not cont_loop: break else: self._append_slots(seq_group, blocks_to_copy) @@ -1264,22 +1267,26 @@ def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: if seq.is_finished(): self.free_seq(seq) + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + def free_finished_seq_groups(self) -> None: remaining: Deque[SequenceGroup] = deque() for seq_group in self.running: - if seq_group.is_finished(): - # Free cross-attention block table, if it exists - self._free_seq_group_cross_attn_blocks(seq_group) - # Add the finished requests to the finished requests list. - # This list will be used to update the Mamba cache in the - # next step. - self._finished_requests_ids.append(seq_group.request_id) - else: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): remaining.append(seq_group) - # Free finished seqs - self._free_finished_seqs(seq_group) - self.running = remaining # Handle async stopped sequence groups diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7fe8053fffb7b..6ed1a6bba08ea 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -342,17 +342,17 @@ async def step_async( virtual_engine] # Execute the model. - output = await self.model_executor.execute_model_async( + outputs = await self.model_executor.execute_model_async( execute_model_req) # we need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, output) + self._update_cached_scheduler_output(virtual_engine, outputs) else: if len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) - output = [] + outputs = [] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -365,25 +365,25 @@ async def step_async( self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - is_async = allow_async_output_proc - is_last_step = True - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step)) + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True) - if output and allow_async_output_proc: + if outputs and allow_async_output_proc: assert len( - output + outputs ) == 1, "Async postprocessor expects only a single output set" self._advance_to_next_step( - output[0], seq_group_metadata_list, + outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: self._process_model_outputs(ctx=ctx) # Log stats. - self.do_log_stats(scheduler_outputs, output) + self.do_log_stats(scheduler_outputs, outputs) # Tracing self.do_tracing(scheduler_outputs) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 78ddcd1daaf69..94271c4a93151 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,9 +2,9 @@ import time from collections import deque from contextlib import contextmanager -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, - Mapping, Optional) + Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union @@ -90,17 +90,36 @@ class SchedulerOutputState: last_output: Optional[SamplerOutput] = None -@dataclass +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + skip: List[int] + + class SchedulerContext: - output_queue: Deque[Tuple[Optional[List[SamplerOutput]], - List[SequenceGroupMetadata], SchedulerOutputs, - bool, - bool]] = field(default_factory=lambda: deque()) - request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = field( - default_factory=lambda: []) - seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None - scheduler_outputs: Optional[SchedulerOutputs] = None + + def __init__(self): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + skip=[])) class LLMEngine: @@ -1246,23 +1265,15 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, ctx: SchedulerContext) -> None: - """Apply the model output to the sequences in the scheduled seq groups. + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. - virtual_engine: The engine id to operate on + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed - is_async: Indicates whether this postprocessor runs in - parallel with the GPU forward pass and is processing - tokens from the previous step. If this is true, then - no tokens need to be appended since it is already done - externally (before the next schedule() call) - - sampler_output: Used with multi-step execution to provide - sampler_output of each step - is_last_output: Used with multi-step execution to indicate - the last step (of each multi-step group) - - Returns RequestOutputs that can be returned to the client. """ now = time.time() @@ -1270,9 +1281,14 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: return None # Get pending async postprocessor - (outputs, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step) = ctx.output_queue.popleft() - assert outputs is not None + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, skip) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( @@ -1286,9 +1302,30 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: else: outputs_by_sequence_group = outputs + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + finished_before: List[int] = [] finished_now: List[int] = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] seq_group = scheduled_seq_group.seq_group @@ -1343,6 +1380,18 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + # Free currently finished requests if finished_now: for scheduler in self.scheduler: @@ -1354,17 +1403,16 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: if (finished_now and self.process_request_outputs_callback is not None): self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() return # Create the outputs - # Note: scheduled_seq_groups and seq_group_metadata_list - # must match with the indices - for i, scheduled_seq_group in enumerate( - scheduler_outputs.scheduled_seq_groups): - - if i in finished_before or i in finished_now: + for i in indices: + if i in skip or i in finished_before or i in finished_now: continue # Avoids double processing + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) if (seq_group.is_finished() @@ -1380,6 +1428,7 @@ def _process_model_outputs(self, ctx: SchedulerContext) -> None: if (ctx.request_outputs and self.process_request_outputs_callback is not None): self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() # For async case, we need to record the stats here. # For non-async case, the stats are done in the @@ -1548,20 +1597,20 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - output = self.model_executor.execute_model( + outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, output) + self._update_cached_scheduler_output(virtual_engine, outputs) else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. if len(ctx.output_queue) > 0: self._process_model_outputs(ctx=ctx) # No outputs in this case - output = [] + outputs = [] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -1574,18 +1623,18 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.cached_scheduler_outputs[0] = SchedulerOutputState() # Add results to the output_queue - is_async = allow_async_output_proc - is_last_step = True - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs, is_async, - is_last_step)) - - if output and allow_async_output_proc: - assert len(output) == 1, ( + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True) + + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( "Async postprocessor expects only a single output set") self._advance_to_next_step( - output[0], seq_group_metadata_list, + outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path @@ -1593,7 +1642,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self._process_model_outputs(ctx=ctx) # Log stats. - self.do_log_stats(scheduler_outputs, output) + self.do_log_stats(scheduler_outputs, outputs) # Tracing self.do_tracing(scheduler_outputs) diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index b52f2a07e344e..b13cf39bd846e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -274,12 +274,13 @@ def _async_process_outputs(self, model_input: StatefulModelInput, self.pinned_sampled_token_ids) if model_output.pythonized: ctx = output_proc_callback.keywords["ctx"] - is_async = False - is_last_step = False - ctx.output_queue.append( - ([model_output.sampler_output - ], ctx.seq_group_metadata_list, - ctx.scheduler_outputs, is_async, is_last_step)) + ctx.append_output( + outputs=[model_output.sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) + output_proc_callback() else: cont = False @@ -319,12 +320,13 @@ def _final_process_outputs(self, model_input: StatefulModelInput, if not is_last_step: ctx = output_proc_callback.keywords[ # type: ignore "ctx"] # type: ignore - is_async = False - is_last_step = False - ctx.output_queue.append( - ([output.sampler_output - ], ctx.seq_group_metadata_list, - ctx.scheduler_outputs, is_async, is_last_step)) + ctx.append_output( + outputs=[output.sampler_output], + seq_group_metadata_list=ctx. + seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) else: outputs.append(output.sampler_output) else: