From f3a6ed580634dad0e7769e13db7dc42ab6ac3a28 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Mon, 11 Nov 2024 20:24:36 +0000 Subject: [PATCH 01/19] Fix for Spec model TP + Chunked Prefill Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b57742c2ebfdd..aa1e4dddf4213 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -641,6 +641,12 @@ def _run_non_driver_rank(self) -> bool: # that the hidden states can be propagated to proposer when needed. if data["no_spec"]: self.scorer_worker.execute_model() + # If no spec case we still want to run the proposer model + # but ONLY once to match `not skip_proposer` in + # driver `_run_no_spec` + if not data["disable_all_speculation"]: + self.proposer_worker.execute_model() + return True if not data["disable_all_speculation"]: # Even if num_lookahead_slots is zero, we want to run the From 8aacb663d4e56068b6ea01749da1b9738d9956ec Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 12 Nov 2024 19:12:08 +0000 Subject: [PATCH 02/19] Revert "Fix for Spec model TP + Chunked Prefill" This reverts commit 6863d1f364eec79deb5dd7bce143e47d81670d87. Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index aa1e4dddf4213..b57742c2ebfdd 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -641,12 +641,6 @@ def _run_non_driver_rank(self) -> bool: # that the hidden states can be propagated to proposer when needed. if data["no_spec"]: self.scorer_worker.execute_model() - # If no spec case we still want to run the proposer model - # but ONLY once to match `not skip_proposer` in - # driver `_run_no_spec` - if not data["disable_all_speculation"]: - self.proposer_worker.execute_model() - return True if not data["disable_all_speculation"]: # Even if num_lookahead_slots is zero, we want to run the From 086811e35ecd816aa55dc951408f833d25f5138e Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 12 Nov 2024 19:12:20 +0000 Subject: [PATCH 03/19] Move fix to scheduler Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/core/scheduler.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 530cbdc3a9190..feb3040568abc 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1201,15 +1201,23 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) # Put prefills first due to Attention backend ordering assumption. + scheduled_seq_groups = (prefills.seq_groups + + running_scheduled.prefill_seq_groups + + swapped_in.prefill_seq_groups + + running_scheduled.decode_seq_groups + + swapped_in.decode_seq_groups) + num_prefill_groups = (len(prefills.seq_groups) + + len(swapped_in.prefill_seq_groups) + + len(running_scheduled.prefill_seq_groups)) + # If all prompts, then we set num_lookahead_slots to 0 + # this alloows us to go through the `no_spec` path in + # `spec_decode_worker.py` + all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) + num_lookahead_slots = (0 if all_prefills else + running_scheduled.num_lookahead_slots) return SchedulerOutputs( - scheduled_seq_groups=(prefills.seq_groups + - running_scheduled.prefill_seq_groups + - swapped_in.prefill_seq_groups + - running_scheduled.decode_seq_groups + - swapped_in.decode_seq_groups), - num_prefill_groups=(len(prefills.seq_groups) + - len(swapped_in.prefill_seq_groups) + - len(running_scheduled.prefill_seq_groups)), + scheduled_seq_groups=scheduled_seq_groups, + num_prefill_groups=num_prefill_groups, num_batched_tokens=budget.num_batched_tokens + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, @@ -1218,7 +1226,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: swapped_in.blocks_to_copy, ignored_seq_groups=prefills.ignored_seq_groups + swapped_in.infeasible_seq_groups, - num_lookahead_slots=running_scheduled.num_lookahead_slots, + num_lookahead_slots=num_lookahead_slots, running_queue_size=len(self.running), preempted=(len(running_scheduled.preempted) + len(running_scheduled.swapped_out)), From aaa7884a33011b8dacc29fe38a532e02b90135f3 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 12 Nov 2024 19:31:01 +0000 Subject: [PATCH 04/19] Add assert Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b57742c2ebfdd..b63a01fa5c591 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -409,6 +409,12 @@ def execute_model( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots + if (all(sgm.is_prompt + for sgm in execute_model_req.seq_group_metadata_list)): + assert num_lookahead_slots == 0, ( + "Prompt only runs should have num_lookahead_slots equal to 0. " + "This should never happen, please file a bug at " + "https://github.com/vllm-project/vllm/issues") # Speculative decoding is disabled in the following cases: # 1. Prefill phase: Speculative decoding is not # used during the prefill phase. @@ -419,9 +425,7 @@ def execute_model( # In any of these cases, the proposer and scorer workers # are called normally. # We expect `num_speculative_tokens` to be None for prefills. - no_spec = all( - sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list - ) or num_lookahead_slots == 0 or disable_all_speculation or all( + no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( sgm.num_speculative_tokens == 0 for sgm in execute_model_req.seq_group_metadata_list) From a43b19b698995e454060e7d788ffb4b11b89e702 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 12 Nov 2024 19:35:32 +0000 Subject: [PATCH 05/19] Small cleanup Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index b63a01fa5c591..7278674a00623 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -409,8 +409,10 @@ def execute_model( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots - if (all(sgm.is_prompt - for sgm in execute_model_req.seq_group_metadata_list)): + all_prompt = (all( + sgm.is_prompt + for sgm in execute_model_req.seq_group_metadata_list)) + if (all_prompt): assert num_lookahead_slots == 0, ( "Prompt only runs should have num_lookahead_slots equal to 0. " "This should never happen, please file a bug at " From e7576f76f3e5fc05710a957dd80b722672709a9c Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 12 Nov 2024 19:36:27 +0000 Subject: [PATCH 06/19] Small cleanup Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 7278674a00623..67efe22be4a29 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -412,7 +412,7 @@ def execute_model( all_prompt = (all( sgm.is_prompt for sgm in execute_model_req.seq_group_metadata_list)) - if (all_prompt): + if all_prompt: assert num_lookahead_slots == 0, ( "Prompt only runs should have num_lookahead_slots equal to 0. " "This should never happen, please file a bug at " From 99d6e1bd0f13305ea4bee0dc26058eebf5580c49 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Wed, 13 Nov 2024 18:38:31 +0000 Subject: [PATCH 07/19] Typo fix Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/core/scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index feb3040568abc..29e09a1d8a69a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1210,7 +1210,7 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)) # If all prompts, then we set num_lookahead_slots to 0 - # this alloows us to go through the `no_spec` path in + # this allows us to go through the `no_spec` path in # `spec_decode_worker.py` all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) num_lookahead_slots = (0 if all_prefills else From c2a0b8191859330334efd467b03b1508ad3820e0 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Wed, 13 Nov 2024 18:44:21 +0000 Subject: [PATCH 08/19] Docs change Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- docs/source/serving/compatibility_matrix.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index fa03d2cde1486..a93632ff36fb8 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -118,7 +118,7 @@ Feature x Feature - - * - :ref:`SD ` - - ✗ + - ✅ - ✅ - ✗ - ✅ From 7b5cafd41a1bf234d93e221e313280f19bea4f96 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Wed, 13 Nov 2024 19:06:57 +0000 Subject: [PATCH 09/19] Removed unnecessary checks Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- tests/spec_decode/e2e/test_compatibility.py | 46 --------------------- vllm/config.py | 10 ----- 2 files changed, 56 deletions(-) diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index a3f0464e79675..af8397c235f48 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -50,49 +50,3 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): with pytest.raises(ValueError, match="cannot be larger than"): get_output_from_llm_generator(test_llm_generator, prompts, sampling_params) - - -@pytest.mark.parametrize("common_llm_kwargs", - [{ - "model": "meta-llama/Llama-2-7b-chat-hf", - "speculative_model": "JackFram/llama-68m", - "num_speculative_tokens": 5, - "enable_chunked_prefill": "True", - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [ - { - "tensor_parallel_size": 2, - "speculative_draft_tensor_parallel_size": 2, - }, - { - "tensor_parallel_size": 4, - "speculative_draft_tensor_parallel_size": 4, - }, - { - "tensor_parallel_size": 8, - "speculative_draft_tensor_parallel_size": 8, - }, -]) -@pytest.mark.parametrize("test_llm_kwargs", [{}]) -@pytest.mark.parametrize("seed", [1]) -def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one( - test_llm_generator): - """Verify that speculative decoding fails if chunked prefill is enabled for - draft model with tensor parallelism of more than 1. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - ] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - ) - - with pytest.raises(ValueError, match="with tensor parallel size 1"): - get_output_from_llm_generator(test_llm_generator, prompts, - sampling_params) diff --git a/vllm/config.py b/vllm/config.py index c87feaec3e5f6..eae6f909e3933 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1409,16 +1409,6 @@ def maybe_create_spec_config( draft_hf_config ) - if (enable_chunked_prefill and \ - speculative_draft_tensor_parallel_size != 1): - # TODO - Investigate why the error reported in - # https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258 - # is happening and re-enable it. - raise ValueError( - "Chunked prefill and speculative decoding can be enabled " - "simultaneously only for draft models with tensor " - "parallel size 1.") - draft_model_config.max_model_len = ( SpeculativeConfig._maybe_override_draft_max_model_len( speculative_max_model_len, From 5ca1bb85af69fca17806118b86d541d357de4906 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Mon, 18 Nov 2024 18:56:09 +0000 Subject: [PATCH 10/19] E2E Test Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- .../e2e/test_integration_dist_tp2.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 25562ca85adf4..244c6b12a56d3 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -115,3 +115,59 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, max_output_len=32, seed=seed, temperature=0.0) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +@pytest.mark.parametrize( + "common_llm_kwargs", + [[ + # Skip cuda graph recording for fast test. + "--enforce-eager", + "--tensor_parallel_size", + "2", + + # precision + "--dtype", + "bfloat16", + ]]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [["--enable-chunked-prefill", "False"], + [ + "--enable-chunked-prefill", "True", "--max-num-batched-tokens", "4", + "--max-num-seqs", "4" + ]]) +@pytest.mark.parametrize("baseline_llm_kwargs", [[]]) +@pytest.mark.parametrize("model, test_llm_kwargs", + [("JackFram/llama-68m", [ + "--speculative-model", + "JackFram/llama-68m", + "--num_speculative-tokens", + "3", + ]), + ("JackFram/llama-68m", [ + "--speculative-model", + "JackFram/llama-68m", + "--num_speculative-tokens", + "3", + "--speculative-draft-tensor-parallel-size", + "1", + ])]) +@pytest.mark.parametrize("batch_size", [2]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_chunked_prefill__tp2(model, common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, seed: int): + """Verify spec decode works well with smaller tp for draft models. + """ + run_equality_correctness_test_tp(model, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=32, + seed=seed, + temperature=0.0) From 18187d483557968572d7881fd0c13f09da9da245 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:21:34 +0000 Subject: [PATCH 11/19] Change condition to exclude multi step Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/core/scheduler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 29e09a1d8a69a..d23009dae01ee 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1213,8 +1213,10 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # this allows us to go through the `no_spec` path in # `spec_decode_worker.py` all_prefills = (len(scheduled_seq_groups) == num_prefill_groups) - num_lookahead_slots = (0 if all_prefills else - running_scheduled.num_lookahead_slots) + num_lookahead_slots = (0 if + (all_prefills + and not self.scheduler_config.is_multi_step) + else running_scheduled.num_lookahead_slots) return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, From 2d3d16f87c29b4c949f68a0f2fdf4059dbacad28 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Mon, 18 Nov 2024 20:50:35 +0000 Subject: [PATCH 12/19] Add chunked prefill scheduler unit test Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- tests/core/test_chunked_prefill_scheduler.py | 33 ++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index acd82065ae457..2c7c47412f23d 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -413,6 +413,39 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens +def test_chunked_prefill_spec_prefill(): + """Verify preempt works with chunked prefill requests""" + block_size = 4 + max_seqs = 30 + max_model_len = 200 + max_num_batched_tokens = 30 + scheduler_config = SchedulerConfig( + "generate", + max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True, + num_lookahead_slots=5, + ) + cache_config = CacheConfig(block_size, 1.0, 1, "auto") + cache_config.num_cpu_blocks = 16 + cache_config.num_gpu_blocks = 16 + scheduler = Scheduler(scheduler_config, cache_config, None) + + _, seq_group = create_dummy_prompt("1", + prompt_length=60, + block_size=block_size) + scheduler.add_seq_group(seq_group) + _, out = schedule_and_update_computed_tokens(scheduler) + # The request is chunked. + # prefill scheduled now. + assert len(out.scheduled_seq_groups) == 1 + assert out.num_prefill_groups == 1 + assert seq_group.is_prefill() + assert out.num_batched_tokens == max_num_batched_tokens + assert out.num_lookahead_slots == 0 + + def test_chunked_prefill_max_seqs(): block_size = 4 max_seqs = 2 From 964e9f69a4779c5a3021df73ed572fba041497ec Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 19 Nov 2024 18:56:56 +0000 Subject: [PATCH 13/19] Fix multiple batch chunked prefill + TP + spec Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- tests/spec_decode/e2e/test_integration_dist_tp2.py | 2 +- vllm/spec_decode/spec_decode_worker.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 244c6b12a56d3..ab83cd09fb522 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -156,7 +156,7 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, ])]) @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) -def test_spec_decode_chunked_prefill__tp2(model, common_llm_kwargs, +def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, seed: int): diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 67efe22be4a29..522c41332f560 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -659,6 +659,9 @@ def _run_non_driver_rank(self) -> bool: if not data["no_spec"]: self.scorer_worker.execute_model() + data = broadcast_tensor_dict(src=self._driver_rank) + if data["run_spec_proposer"]: + self.proposer_worker.execute_model() return True @@ -712,6 +715,10 @@ def _run_speculative_decoding_step( idx for idx in non_spec_indices if execute_model_req.seq_group_metadata_list[idx].is_prompt ] + broadcast_dict = dict( + run_spec_proposer=bool(non_spec_indices) + ) + broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) if len(non_spec_indices): all_hidden_states = proposal_scores.hidden_states # TODO fix `return_hidden_states`, same as in `_run_no_spec` From b517bbebfc0274cc8b226dec1f7ea83e2b08e691 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 19 Nov 2024 19:27:35 +0000 Subject: [PATCH 14/19] Format Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- tests/spec_decode/e2e/test_integration_dist_tp2.py | 6 +++--- vllm/spec_decode/spec_decode_worker.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index ab83cd09fb522..7fea161a4a7be 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -157,9 +157,9 @@ def test_draft_model_tp_lt_target_model_tp2(model, common_llm_kwargs, @pytest.mark.parametrize("batch_size", [2]) @pytest.mark.parametrize("seed", [1]) def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, - per_test_common_llm_kwargs, - baseline_llm_kwargs, test_llm_kwargs, - batch_size: int, seed: int): + per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, + batch_size: int, seed: int): """Verify spec decode works well with smaller tp for draft models. """ run_equality_correctness_test_tp(model, diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 522c41332f560..35e2eb625705a 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -715,9 +715,7 @@ def _run_speculative_decoding_step( idx for idx in non_spec_indices if execute_model_req.seq_group_metadata_list[idx].is_prompt ] - broadcast_dict = dict( - run_spec_proposer=bool(non_spec_indices) - ) + broadcast_dict = dict(run_spec_proposer=bool(non_spec_indices)) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) if len(non_spec_indices): all_hidden_states = proposal_scores.hidden_states From c6127ebea7d9c7f2ff6305db8b8dadbca5add55f Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Thu, 21 Nov 2024 02:09:35 +0000 Subject: [PATCH 15/19] Nits and add multi step test Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- tests/core/test_chunked_prefill_scheduler.py | 18 ++++++++++++------ .../e2e/test_integration_dist_tp2.py | 3 ++- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 2c7c47412f23d..eaaf004df38b2 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -413,19 +413,24 @@ def cannot_append_second_group2(seq_group, num_lookahead_slots): assert out.num_batched_tokens == max_num_batched_tokens -def test_chunked_prefill_spec_prefill(): - """Verify preempt works with chunked prefill requests""" +@pytest.mark.parametrize("num_scheduler_steps", [1, 5]) +def test_chunked_prefill_spec_prefill(num_scheduler_steps): + """Verify that the num_lookahead_slots is set appropriately for an all""" + """prefill batch depending on whether multi-step scheduling is enabled""" + """or not""" block_size = 4 max_seqs = 30 max_model_len = 200 max_num_batched_tokens = 30 + num_lookahead_slots = 4 scheduler_config = SchedulerConfig( "generate", max_num_batched_tokens, max_seqs, max_model_len, enable_chunked_prefill=True, - num_lookahead_slots=5, + num_lookahead_slots=num_lookahead_slots, + num_scheduler_steps=num_scheduler_steps, ) cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = 16 @@ -433,7 +438,7 @@ def test_chunked_prefill_spec_prefill(): scheduler = Scheduler(scheduler_config, cache_config, None) _, seq_group = create_dummy_prompt("1", - prompt_length=60, + prompt_length=30, block_size=block_size) scheduler.add_seq_group(seq_group) _, out = schedule_and_update_computed_tokens(scheduler) @@ -441,9 +446,10 @@ def test_chunked_prefill_spec_prefill(): # prefill scheduled now. assert len(out.scheduled_seq_groups) == 1 assert out.num_prefill_groups == 1 - assert seq_group.is_prefill() assert out.num_batched_tokens == max_num_batched_tokens - assert out.num_lookahead_slots == 0 + print(out.num_lookahead_slots) + assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else + num_lookahead_slots) def test_chunked_prefill_max_seqs(): diff --git a/tests/spec_decode/e2e/test_integration_dist_tp2.py b/tests/spec_decode/e2e/test_integration_dist_tp2.py index 7fea161a4a7be..02cba92795142 100644 --- a/tests/spec_decode/e2e/test_integration_dist_tp2.py +++ b/tests/spec_decode/e2e/test_integration_dist_tp2.py @@ -160,7 +160,8 @@ def test_spec_decode_chunked_prefill_tp2(model, common_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, seed: int): - """Verify spec decode works well with smaller tp for draft models. + """Verify spec decode works well with same and different TP size for + the draft model with chunked prefill. """ run_equality_correctness_test_tp(model, common_llm_kwargs, From 31b0ddf3c016ca6949af2cc6ac948706e9957372 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Fri, 22 Nov 2024 21:32:57 +0000 Subject: [PATCH 16/19] Remove additional broadcast needed for proposer prefill pass --- vllm/spec_decode/spec_decode_worker.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 35e2eb625705a..70f15c72fbee4 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -408,10 +408,13 @@ def execute_model( disable_all_speculation = self._should_disable_all_speculation( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots + all_prompt = None + atleast_one_prompt = False + for sgm in execute_model_req.seq_group_metadata_list: + all_prompt = (sgm.is_prompt if all_prompt is None else all_prompt + and sgm.is_prompt) + atleast_one_prompt = atleast_one_prompt or sgm.is_prompt - all_prompt = (all( - sgm.is_prompt - for sgm in execute_model_req.seq_group_metadata_list)) if all_prompt: assert num_lookahead_slots == 0, ( "Prompt only runs should have num_lookahead_slots equal to 0. " @@ -448,6 +451,15 @@ def execute_model( num_lookahead_slots=num_lookahead_slots, no_spec=no_spec, disable_all_speculation=disable_all_speculation, + # When both chunked prefill and speculative decoding are enabled + # it is possible that the same batch contains both prefill + # and decodes. If that happens in the scorer we run the batch + # as one single forward pass. However, in the proposer we + # run them as 2 different batches - one for prefill and + # the other for decodes. The variable indicates to the non-driver + # worker that there are prefills as part of the speculative batch + # and hence it needs to run an extra prefill forward pass. + run_spec_proposer_for_prefill=atleast_one_prompt, ) broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) @@ -659,8 +671,7 @@ def _run_non_driver_rank(self) -> bool: if not data["no_spec"]: self.scorer_worker.execute_model() - data = broadcast_tensor_dict(src=self._driver_rank) - if data["run_spec_proposer"]: + if data["run_spec_proposer_for_prefill"]: self.proposer_worker.execute_model() return True @@ -715,8 +726,6 @@ def _run_speculative_decoding_step( idx for idx in non_spec_indices if execute_model_req.seq_group_metadata_list[idx].is_prompt ] - broadcast_dict = dict(run_spec_proposer=bool(non_spec_indices)) - broadcast_tensor_dict(broadcast_dict, src=self._driver_rank) if len(non_spec_indices): all_hidden_states = proposal_scores.hidden_states # TODO fix `return_hidden_states`, same as in `_run_no_spec` From c232b847e463b353540c58b8aa1b2299f6f7e766 Mon Sep 17 00:00:00 2001 From: Sourashis Roy Date: Sat, 23 Nov 2024 02:17:26 +0000 Subject: [PATCH 17/19] Fix failing test Signed-off-by: Sourashis Roy --- tests/spec_decode/test_spec_decode_worker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 8df143104c279..d7caf57147278 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -867,7 +867,8 @@ def test_chunked_prefill_flow(k: int, batch_size: int, batch_composition: str): target_group_metadata_list = prefill + decodes execute_model_req = ExecuteModelRequest( seq_group_metadata_list=target_group_metadata_list, - num_lookahead_slots=k) + # For prefill only batches we expect num_lookahead_slots = 0. + num_lookahead_slots=k if n_decodes > 0 else 0) target_token_ids = torch.randint(low=0, high=vocab_size, From 2d99f393766978fc52d982514bb5238e56229401 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Mon, 25 Nov 2024 18:49:17 +0000 Subject: [PATCH 18/19] Address Nick comments Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 70f15c72fbee4..8f3c824c4d2c3 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -408,12 +408,14 @@ def execute_model( disable_all_speculation = self._should_disable_all_speculation( execute_model_req) num_lookahead_slots = execute_model_req.num_lookahead_slots - all_prompt = None + all_prompt = True atleast_one_prompt = False + all_zero_spec_tokens = True for sgm in execute_model_req.seq_group_metadata_list: - all_prompt = (sgm.is_prompt if all_prompt is None else all_prompt - and sgm.is_prompt) + all_prompt = all_prompt and sgm.is_prompt atleast_one_prompt = atleast_one_prompt or sgm.is_prompt + all_zero_spec_tokens = all_zero_spec_tokens and ( + sgm.num_speculative_tokens == 0) if all_prompt: assert num_lookahead_slots == 0, ( @@ -430,9 +432,8 @@ def execute_model( # In any of these cases, the proposer and scorer workers # are called normally. # We expect `num_speculative_tokens` to be None for prefills. - no_spec = num_lookahead_slots == 0 or disable_all_speculation or all( - sgm.num_speculative_tokens == 0 - for sgm in execute_model_req.seq_group_metadata_list) + no_spec = (num_lookahead_slots == 0 or disable_all_speculation + or all_zero_spec_tokens) # Broadcast how many lookahead slots are scheduled for this step, and # whether all speculation is disabled, to all non-driver workers. From 01b43aa8353e178b57a790bba39e2935b36e8431 Mon Sep 17 00:00:00 2001 From: andoorve <37849411+andoorve@users.noreply.github.com> Date: Tue, 26 Nov 2024 01:34:41 +0000 Subject: [PATCH 19/19] Fix test failure Signed-off-by: andoorve <37849411+andoorve@users.noreply.github.com> --- vllm/spec_decode/spec_decode_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 8f3c824c4d2c3..b279931ca4b02 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -417,7 +417,7 @@ def execute_model( all_zero_spec_tokens = all_zero_spec_tokens and ( sgm.num_speculative_tokens == 0) - if all_prompt: + if all_prompt and execute_model_req.seq_group_metadata_list: assert num_lookahead_slots == 0, ( "Prompt only runs should have num_lookahead_slots equal to 0. " "This should never happen, please file a bug at "