From 6b30dd59f9b48b769a9197f280ab7b8c82076aec Mon Sep 17 00:00:00 2001 From: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Date: Thu, 12 Sep 2024 23:48:59 -0400 Subject: [PATCH] [Bugfix] Fix async log stats (#8417) Signed-off-by: Alvant --- tests/basic_correctness/test_preemption.py | 1 + vllm/engine/llm_engine.py | 20 ++++++++++++++++---- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7e77037da07d3..50d399bef1878 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -64,6 +64,7 @@ def test_chunked_prefill_recompute( enable_chunked_prefill=enable_chunked_prefill, max_num_seqs=max_num_seqs, worker_use_ray=worker_use_ray, + disable_log_stats=False, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1745dc5c09803..cbff8d867ef83 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1103,7 +1103,8 @@ def _process_model_outputs(self, # LLMEngine/AsyncLLMEngine directly if is_async: # Log stats. - self.do_log_stats(scheduler_outputs, outputs, finished_before) + self.do_log_stats(scheduler_outputs, outputs, finished_before, + skip) # Tracing self.do_tracing(scheduler_outputs) @@ -1410,18 +1411,20 @@ def remove_logger(self, logger_name: str) -> None: def do_log_stats(self, scheduler_outputs: Optional[SchedulerOutputs] = None, model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> None: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: stats = self._get_stats(scheduler_outputs, model_output, - finished_before) + finished_before, skip) for logger in self.stat_loggers.values(): logger.log(stats) def _get_stats(self, scheduler_outputs: Optional[SchedulerOutputs], model_output: Optional[List[SamplerOutput]] = None, - finished_before: Optional[List[int]] = None) -> Stats: + finished_before: Optional[List[int]] = None, + skip: Optional[List[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: @@ -1429,6 +1432,10 @@ def _get_stats(self, the scheduled batch, model_output: Optional, used to emit speculative decoding metrics which are created by the workers. + finished_before: Optional, indices of sequences that were finished + before. These sequences will be ignored. + skip: Optional, indices of sequences that were preempted. These + sequences will be ignored. """ now = time.time() @@ -1503,6 +1510,11 @@ def _get_stats(self, actual_num_batched_tokens -= 1 continue + # Currently, skip == preempted sequences, so we need to skip + # their log stats + if skip and idx in skip: + continue + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group