From cb3fe15072856b6deb994add09d309796cc2b4cb Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Sep 2024 23:58:17 +0000 Subject: [PATCH 1/6] fix num computed tokens --- tests/core/test_num_computed_tokens_update.py | 70 ++++++++ tests/core/utils.py | 6 +- vllm/engine/llm_engine.py | 153 ++++++++---------- vllm/engine/output_processor/interfaces.py | 8 +- vllm/engine/output_processor/multi_step.py | 15 +- 5 files changed, 150 insertions(+), 102 deletions(-) create mode 100644 tests/core/test_num_computed_tokens_update.py diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py new file mode 100644 index 0000000000000..6cae95f1c9629 --- /dev/null +++ b/tests/core/test_num_computed_tokens_update.py @@ -0,0 +1,70 @@ +import pytest +from dataclasses import dataclass + +from tests.conftest import VllmRunner +from tests.core.utils import create_dummy_prompt +from vllm.engine.llm_engine import LLMEngine +from vllm.sequence import SequenceGroup + +MODEL = "JackFram/llama-160m" + +def add_seq_group_to_engine(engine: LLMEngine, + seq_group: SequenceGroup): + scheduler = engine.scheduler[0] + scheduler.add_seq_group(seq_group) + +@pytest.mark.parametrize("num_scheduler_steps", [1, 8]) +@pytest.mark.parametrize("enable_chunked_prefill", [False, True]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +def test_num_computed_tokens_update(num_scheduler_steps: int, + enable_chunked_prefill: bool, + enforce_eager: bool): + + # Make a vllm engine + runner = VllmRunner(model_name = MODEL, + gpu_memory_utilization=0.7, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + enable_chunked_prefill=enable_chunked_prefill, + enforce_eager=enforce_eager) + engine : LLMEngine = runner.model.llm_engine + + + is_multi_step = num_scheduler_steps > 1 + is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill + # In multi-step + chunked-prefill there is no separate single prompt step. + # What is scheduled will run for num_scheduler_steps always. + num_prompt_steps = num_scheduler_steps if is_multi_step_chunked_prefill else 1 + + num_output_tokens_list = [4, 8, 12, 15, 16, 17] + + # Create sequence and add to engine + prompt_len = 10 + + for req_idx, num_output_tokens in enumerate(num_output_tokens_list): + seq, seq_group = create_dummy_prompt(request_id=str(req_idx), + prompt_length=prompt_len, + min_tokens = num_output_tokens, + max_tokens = num_output_tokens) + add_seq_group_to_engine(engine, seq_group) + + assert seq.data.get_num_computed_tokens() == 0 + + for _ in range(num_prompt_steps): + # prompt steps + engine.step() + + if not seq.is_finished(): + assert seq.data.get_num_computed_tokens() == prompt_len + num_prompt_steps - 1 + + prompt_num_computed_tokens = seq.data.get_num_computed_tokens() + + decode_step_counter = 0 + while not seq.is_finished(): + assert seq.data.get_num_computed_tokens() == prompt_num_computed_tokens + decode_step_counter + for _ in range(num_scheduler_steps): + # decode step + engine.step() + decode_step_counter += 1 + + assert seq.data.get_num_computed_tokens() == prompt_len + num_output_tokens - 1 diff --git a/tests/core/utils.py b/tests/core/utils.py index 40d8f51fc186e..1e4332268c2f3 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -16,6 +16,8 @@ def create_dummy_prompt( use_beam_search: bool = False, best_of: int = 1, prompt_tokens: Optional[List[int]] = None, + min_tokens: int = 0, + max_tokens: int = 16, ) -> Tuple[Sequence, SequenceGroup]: if not block_size: block_size = prompt_length @@ -36,7 +38,9 @@ def create_dummy_prompt( arrival_time=time.time(), sampling_params=SamplingParams( use_beam_search=use_beam_search, - best_of=best_of), + best_of=best_of, + max_tokens=max_tokens, + min_tokens=min_tokens), lora_request=lora_request) return prompt, seq_group diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index d6258c6413d87..ca5d05743b9e6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -965,6 +965,45 @@ def _process_sequence_group_outputs( return + def _update_num_computed_tokens_for_multi_step_prefill( + self, + seq_group: SequenceGroup, + seq_group_meta: SequenceGroupMetadata, + is_first_step_output: Optional[bool]): + """ + This function updates num_computed_tokens for prompt sequences + when Multi-Step is enabled. + + seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group_meta: Metadata of the given SequenceGroup. + is_first_step_output: Optional[bool] - + When available, is_first_step_output indicates if the appended + output token is the output of the first-step in multi-step. + A value of None indicates that outputs from all steps in + in multi-step are submitted in a single burst. + """ + + assert self.scheduler_config.is_multi_step + + if not seq_group_meta.is_prompt: + # num_computed_token updates for multi-step decodes happen after + # the tokens are appended to the sequence. + return + + do_update: bool = False + if self.scheduler_config.chunked_prefill_enabled: + # In multi-step + chunked-prefill case, the prompt sequences + # that are scheduled are fully processed in the first step. + do_update = is_first_step_output is None or is_first_step_output == True + else: + # Normal multi-step decoding case. In this case prompt-sequences + # are actually single-stepped. Always update in this case. + assert seq_group.state.num_steps == 1 + do_update = True + + if do_update: + seq_group.update_num_computed_tokens(seq_group_meta.token_chunk_size) + def _process_model_outputs(self, ctx: SchedulerContext, request_id: Optional[str] = None) -> None: @@ -975,64 +1014,6 @@ def _process_model_outputs(self, request_id: If provided, then only this request is going to be processed """ - def update_prefill_num_computed_tokens( - seq_group: SequenceGroup, - seq_group_meta: SequenceGroupMetadata, num_outputs: int, - is_first_step_output: Optional[bool]) -> None: - """ - When multi-step and chunked-prefill are enabled together, the - prefill sequence scheduled for multi-step execution turn into - decodes in the first step itself. This function accounts - for that conversion. - - seq_group: SequenceGroup - A prefill seq_group - seq_group_meta: SequenceGroupMetadata - Metadata of the given - prefill seq_group - num_outputs: int - number of output tokens being processed for the - given seq_group - is_first_step_output: Optional[bool] - - If multi-step is enabled and num_outputs is 1, this value - indicates if this outputs belongs to the first step in the - multi-step. - If multi-step is enabled and num_outputs > 1, this value - must be None, as num_outputs > 1 indicates that outputs from - all the steps in multi-step are submitted in a single burst. - When multi-step is disabled, this value is always True. - """ - - assert seq_group_meta.is_prompt - - token_chunk_size = seq_group_meta.token_chunk_size - - if num_outputs == 1: - assert is_first_step_output is not None - - if seq_group_meta.state.num_steps == 1: - assert is_first_step_output is True - seq_group.update_num_computed_tokens(token_chunk_size) - return - - # multi-step prefill is only supported when multi-step is - # enabled with chunked prefill - assert self.scheduler_config.is_multi_step and \ - self.scheduler_config.chunked_prefill_enabled - if is_first_step_output is True: - # This sequence is a prompt during the first step only. - seq_group.update_num_computed_tokens(token_chunk_size) - return - - assert is_first_step_output is None - - # multi-step prefill is only supported when multi-step is - # enabled with chunked prefill. Outputs from all the steps are - # submitted in a single burst. - assert self.scheduler_config.is_multi_step and \ - self.scheduler_config.chunked_prefill_enabled - assert num_outputs == seq_group_meta.state.num_steps, \ - f"#outputs {len(outputs)} - num steps {seq_group_meta.state.num_steps}" #noqa - # This sequence is a prompt during the first step only. - seq_group.update_num_computed_tokens(token_chunk_size) - now = time.time() if len(ctx.output_queue) == 0: @@ -1093,7 +1074,7 @@ def update_prefill_num_computed_tokens( seq_group_meta = seq_group_metadata_list[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - seq_group = scheduled_seq_group.seq_group + seq_group: SequenceGroup = scheduled_seq_group.seq_group if seq_group.is_finished(): finished_before.append(i) @@ -1104,14 +1085,14 @@ def update_prefill_num_computed_tokens( else: output = [outputs_by_sequence_group[0][i]] - if not is_async and seq_group_meta.is_prompt: - # Updates for all decodes happen when we actually append the - # token ids to the seq in process_outputs. - update_prefill_num_computed_tokens(seq_group, seq_group_meta, - len(output), - is_first_step_output) - elif not is_async: - seq_group.update_num_computed_tokens(1) + if not is_async: + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_meta, is_first_step_output) + else: + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) if outputs: for o in outputs: @@ -1137,14 +1118,6 @@ def update_prefill_num_computed_tokens( if seq_group_meta.do_sample: output_token_num = self.output_processor.process_outputs( seq_group, output, is_async) - if self.speculative_config: - # We -1 here because we always - # (w/o speculative decoding) add the number of - # computed tokens by one in the decoding phase. - # Therefore, we remove that one token that - # is already added. - seq_group.update_num_computed_tokens(output_token_num - - 1) if seq_group.is_finished(): finished_now.append(i) @@ -1253,20 +1226,15 @@ def _advance_to_next_step( if seq_group.is_finished(): continue - if seq_group_metadata.is_prompt: - if self.scheduler_config.is_multi_step and \ - self.scheduler_config.chunked_prefill_enabled: - # Prompts are scheduled in multi-step only when - # chunking is enabled. These prompts turn into - # decodes after the very first step. Therefore, - # we skip the update to the num_computed_tokens - # here. - seq_group.update_num_computed_tokens(1) - else: - seq_group.update_num_computed_tokens( - seq_group_metadata.token_chunk_size) + if self.scheduler_config.is_multi_step: + # Updates happen only if the sequence is prefill + self._update_num_computed_tokens_for_multi_step_prefill( + seq_group, seq_group_metadata, + seq_group.state.num_steps == 1) else: - seq_group.update_num_computed_tokens(1) + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) + if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( "Async output processor expects a single sample" @@ -1276,7 +1244,14 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) + + if self.scheduler_config.is_multi_step: + is_prefill_append = seq.data.get_num_uncomputed_tokens() == 0 + seq.append_token_id(sample.output_token, sample.logprobs) + if not is_prefill_append: + seq_group.update_num_computed_tokens(1) + else: + seq.append_token_id(sample.output_token, sample.logprobs) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 554880a3cc438..50adaf4e59188 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List, Optional +from typing import Callable, List from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -58,14 +58,10 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool) -> Optional[int]: + is_async: bool) -> None: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. - - Return the number of new tokens generated in the sequence group. - The returned value is optional because it is only used for - speculative decoding mqa scorer. """ pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index f35b1ba9c2bdd..f5f42d1a9b9fe 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -69,7 +69,7 @@ def _log_prompt_logprob_unsupported_warning_once(): def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool = False) -> Optional[int]: + is_async: bool = False) -> None: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than @@ -84,10 +84,6 @@ def process_outputs(self, tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) - - Returns: - The number of tokens appended to the sequence. This is optional - because only speculative decode uses this return value. """ # Sequences can be in RUNNING or FINISHED_ABORTED state # once scheduled, as a sequence is moved to FINSIHED_ABORTED @@ -168,6 +164,7 @@ def _process_seq_outputs(self, seq: Sequence, output_token_ids = output_token_ids[:i + 1] break + is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0 # Incrementally append tokens to the sequence, as if we had only one new # token. for output_token_id, output_logprob in zip(output_token_ids, @@ -177,8 +174,14 @@ def _process_seq_outputs(self, seq: Sequence, logprobs=output_logprob, ) + if is_prefill_sampled_token: + is_prefill_sampled_token = False + else: + # Update num_computed_tokens iff the sampled token is not from + # a prefill step. + seq.data.update_num_computed_tokens(1) + self._process_decode_and_stop(seq, sampling_params) if seq.is_finished(): break - return len(output_token_ids) From 736f2fb5834dec189dfc40872ac6471d1211ce47 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Thu, 3 Oct 2024 07:35:21 +0000 Subject: [PATCH 2/6] format --- tests/core/test_num_computed_tokens_update.py | 41 ++++++++++--------- vllm/engine/llm_engine.py | 17 ++++---- vllm/engine/output_processor/multi_step.py | 9 ++-- 3 files changed, 35 insertions(+), 32 deletions(-) diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index 6cae95f1c9629..3281cd3fd45fe 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -1,18 +1,18 @@ import pytest -from dataclasses import dataclass from tests.conftest import VllmRunner -from tests.core.utils import create_dummy_prompt +from tests.core.utils import create_dummy_prompt from vllm.engine.llm_engine import LLMEngine from vllm.sequence import SequenceGroup MODEL = "JackFram/llama-160m" -def add_seq_group_to_engine(engine: LLMEngine, - seq_group: SequenceGroup): + +def add_seq_group_to_engine(engine: LLMEngine, seq_group: SequenceGroup): scheduler = engine.scheduler[0] scheduler.add_seq_group(seq_group) + @pytest.mark.parametrize("num_scheduler_steps", [1, 8]) @pytest.mark.parametrize("enable_chunked_prefill", [False, True]) @pytest.mark.parametrize("enforce_eager", [False, True]) @@ -21,20 +21,20 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, enforce_eager: bool): # Make a vllm engine - runner = VllmRunner(model_name = MODEL, - gpu_memory_utilization=0.7, - use_v2_block_manager=True, - num_scheduler_steps=num_scheduler_steps, - enable_chunked_prefill=enable_chunked_prefill, - enforce_eager=enforce_eager) - engine : LLMEngine = runner.model.llm_engine - + runner = VllmRunner(model_name=MODEL, + gpu_memory_utilization=0.7, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + enable_chunked_prefill=enable_chunked_prefill, + enforce_eager=enforce_eager) + engine: LLMEngine = runner.model.llm_engine is_multi_step = num_scheduler_steps > 1 is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill # In multi-step + chunked-prefill there is no separate single prompt step. # What is scheduled will run for num_scheduler_steps always. - num_prompt_steps = num_scheduler_steps if is_multi_step_chunked_prefill else 1 + num_prompt_steps = num_scheduler_steps \ + if is_multi_step_chunked_prefill else 1 num_output_tokens_list = [4, 8, 12, 15, 16, 17] @@ -43,9 +43,9 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, for req_idx, num_output_tokens in enumerate(num_output_tokens_list): seq, seq_group = create_dummy_prompt(request_id=str(req_idx), - prompt_length=prompt_len, - min_tokens = num_output_tokens, - max_tokens = num_output_tokens) + prompt_length=prompt_len, + min_tokens=num_output_tokens, + max_tokens=num_output_tokens) add_seq_group_to_engine(engine, seq_group) assert seq.data.get_num_computed_tokens() == 0 @@ -55,16 +55,19 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, engine.step() if not seq.is_finished(): - assert seq.data.get_num_computed_tokens() == prompt_len + num_prompt_steps - 1 + assert seq.data.get_num_computed_tokens( + ) == prompt_len + num_prompt_steps - 1 prompt_num_computed_tokens = seq.data.get_num_computed_tokens() decode_step_counter = 0 while not seq.is_finished(): - assert seq.data.get_num_computed_tokens() == prompt_num_computed_tokens + decode_step_counter + assert seq.data.get_num_computed_tokens( + ) == prompt_num_computed_tokens + decode_step_counter for _ in range(num_scheduler_steps): # decode step engine.step() decode_step_counter += 1 - assert seq.data.get_num_computed_tokens() == prompt_len + num_output_tokens - 1 + assert seq.data.get_num_computed_tokens( + ) == prompt_len + num_output_tokens - 1 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ca5d05743b9e6..62fb0aa5f859f 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -966,8 +966,7 @@ def _process_sequence_group_outputs( return def _update_num_computed_tokens_for_multi_step_prefill( - self, - seq_group: SequenceGroup, + self, seq_group: SequenceGroup, seq_group_meta: SequenceGroupMetadata, is_first_step_output: Optional[bool]): """ @@ -990,19 +989,20 @@ def _update_num_computed_tokens_for_multi_step_prefill( # the tokens are appended to the sequence. return - do_update: bool = False + do_update: bool = False if self.scheduler_config.chunked_prefill_enabled: # In multi-step + chunked-prefill case, the prompt sequences # that are scheduled are fully processed in the first step. - do_update = is_first_step_output is None or is_first_step_output == True + do_update = is_first_step_output is None or is_first_step_output else: # Normal multi-step decoding case. In this case prompt-sequences # are actually single-stepped. Always update in this case. assert seq_group.state.num_steps == 1 - do_update = True + do_update = True if do_update: - seq_group.update_num_computed_tokens(seq_group_meta.token_chunk_size) + seq_group.update_num_computed_tokens( + seq_group_meta.token_chunk_size) def _process_model_outputs(self, ctx: SchedulerContext, @@ -1116,7 +1116,7 @@ def _process_model_outputs(self, else: self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - output_token_num = self.output_processor.process_outputs( + self.output_processor.process_outputs( seq_group, output, is_async) if seq_group.is_finished(): @@ -1246,7 +1246,8 @@ def _advance_to_next_step( seq = seq_group.seqs[0] if self.scheduler_config.is_multi_step: - is_prefill_append = seq.data.get_num_uncomputed_tokens() == 0 + is_prefill_append = seq.data.get_num_uncomputed_tokens( + ) == 0 seq.append_token_id(sample.output_token, sample.logprobs) if not is_prefill_append: seq_group.update_num_computed_tokens(1) diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index f5f42d1a9b9fe..47de3656ca892 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List, Optional +from typing import Callable, List from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( @@ -106,7 +106,6 @@ def process_outputs(self, # was already appended, so we only need to do the rest of the # postprocessor: Detokenization + stopping logic self._process_decode_and_stop(seq, sequence_group.sampling_params) - return None else: # Standard multi-step case @@ -122,8 +121,8 @@ def process_outputs(self, ] assert valid_samples - return self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) def _process_decode_and_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: @@ -141,7 +140,7 @@ def _process_decode_and_stop(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], - sampling_params: SamplingParams) -> int: + sampling_params: SamplingParams) -> None: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] From 76d3e24fa598c691c98b6f1bcb6c9c454c4a4715 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 4 Oct 2024 14:57:42 +0000 Subject: [PATCH 3/6] review comments --- tests/core/test_num_computed_tokens_update.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index 3281cd3fd45fe..47a3796158c32 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -55,13 +55,14 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, engine.step() if not seq.is_finished(): - assert seq.data.get_num_computed_tokens( - ) == prompt_len + num_prompt_steps - 1 - prompt_num_computed_tokens = seq.data.get_num_computed_tokens() + # Test correctness of num_computed_tokens after the prompt steps + assert prompt_num_computed_tokens == \ + prompt_len + num_prompt_steps - 1 decode_step_counter = 0 while not seq.is_finished(): + # Test correctness of num_computed_tokens after the decode steps assert seq.data.get_num_computed_tokens( ) == prompt_num_computed_tokens + decode_step_counter for _ in range(num_scheduler_steps): @@ -69,5 +70,6 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, engine.step() decode_step_counter += 1 + # Test correctness of num_computed_tokens after the sequence finish. assert seq.data.get_num_computed_tokens( ) == prompt_len + num_output_tokens - 1 From a1cb9963402e3bfdf0c6ad39d0337d6d30d28791 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 6 Oct 2024 00:56:01 +0000 Subject: [PATCH 4/6] skip tests for any attn other than flash-attn --- tests/core/test_num_computed_tokens_update.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index 47a3796158c32..a9628a082cf0b 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -1,3 +1,5 @@ +import os + import pytest from tests.conftest import VllmRunner @@ -20,6 +22,14 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, enable_chunked_prefill: bool, enforce_eager: bool): + is_multi_step = num_scheduler_steps > 1 + is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill + + attention_backend = os.getenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") + if is_multi_step_chunked_prefill and attention_backend != "FLASH_ATTN": + pytest.skip("Multi-step with Chunked-Prefill only supports" + " FLASH_ATTN backend") + # Make a vllm engine runner = VllmRunner(model_name=MODEL, gpu_memory_utilization=0.7, @@ -29,8 +39,6 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, enforce_eager=enforce_eager) engine: LLMEngine = runner.model.llm_engine - is_multi_step = num_scheduler_steps > 1 - is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill # In multi-step + chunked-prefill there is no separate single prompt step. # What is scheduled will run for num_scheduler_steps always. num_prompt_steps = num_scheduler_steps \ From 6cad135f010aea167dbbc64d9366be3f9d4f0744 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 6 Oct 2024 02:41:52 +0000 Subject: [PATCH 5/6] skip rocm --- tests/core/test_num_computed_tokens_update.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/core/test_num_computed_tokens_update.py b/tests/core/test_num_computed_tokens_update.py index a9628a082cf0b..f3ec24e7bee3e 100644 --- a/tests/core/test_num_computed_tokens_update.py +++ b/tests/core/test_num_computed_tokens_update.py @@ -1,10 +1,9 @@ -import os - import pytest from tests.conftest import VllmRunner from tests.core.utils import create_dummy_prompt from vllm.engine.llm_engine import LLMEngine +from vllm.platforms import current_platform from vllm.sequence import SequenceGroup MODEL = "JackFram/llama-160m" @@ -25,10 +24,9 @@ def test_num_computed_tokens_update(num_scheduler_steps: int, is_multi_step = num_scheduler_steps > 1 is_multi_step_chunked_prefill = is_multi_step and enable_chunked_prefill - attention_backend = os.getenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") - if is_multi_step_chunked_prefill and attention_backend != "FLASH_ATTN": - pytest.skip("Multi-step with Chunked-Prefill only supports" - " FLASH_ATTN backend") + if is_multi_step_chunked_prefill and current_platform.is_rocm(): + pytest.skip("Multi-step with Chunked-Prefill does not support " + "rocm_flash_attn backend") # Make a vllm engine runner = VllmRunner(model_name=MODEL, From dc4caa7a4e41d13b43f2845edbc2fa6a383e8151 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 6 Oct 2024 03:33:33 +0000 Subject: [PATCH 6/6] fix multi-step + rocm_flash_attn support --- vllm/attention/backends/rocm_flash_attn.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index fb5cd11ec033a..7456aab8b8d2a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -191,12 +191,22 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: ) return self._cached_decode_metadata - def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", + def advance_step(self, + model_input: "ModelInputForGPUWithSamplingMetadata", sampled_token_ids: Optional[torch.Tensor], - block_size: int, num_seqs: int, num_queries: int): + block_size: int, + num_seqs: int, + num_queries: int, + turn_prefills_into_decodes: bool = False): """ Update metadata in-place to advance one decode step. """ + + assert not turn_prefills_into_decodes, \ + ("Chunked prefill is not supported with rocm_flash_attn yet." + "turn_prefills_into_decodes is a Multi-Step + Chunked-Prefill " + "specific parameter.") + # When using cudagraph, the num_seqs is padded to the next captured # batch sized, but num_queries tracks the actual number of requests in # the batch. For --enforce-eager mode, num_seqs == num_queries