From 2e7796f2cf4537dd3b08d5de3aa8349f5db1a168 Mon Sep 17 00:00:00 2001 From: heeju-kim2 <157340754+heeju-kim2@users.noreply.github.com> Date: Sat, 11 May 2024 02:36:25 +0900 Subject: [PATCH] [Speculative decoding] CUDA graph support (#4295) Co-authored-by: Cade Daniel --- .../e2e/test_multistep_correctness.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/tests/spec_decode/e2e/test_multistep_correctness.py b/tests/spec_decode/e2e/test_multistep_correctness.py index 94d71fb012727..d2da039e84c07 100644 --- a/tests/spec_decode/e2e/test_multistep_correctness.py +++ b/tests/spec_decode/e2e/test_multistep_correctness.py @@ -611,3 +611,40 @@ def test_many_k(baseline_llm_generator, test_llm_generator, batch_size: int, batch_size, max_output_len=output_len, force_output_len=True) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Required for spec decode. + "use_v2_block_manager": True, + + # Verify equality when cuda graphs allowed. + "enforce_eager": False, + "model": "JackFram/llama-68m", + }]) +@pytest.mark.parametrize( + "per_test_common_llm_kwargs", + [ + { + # Identical models. + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 5, + }, + ]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", [{}]) +@pytest.mark.parametrize("batch_size", [8]) +@pytest.mark.parametrize("output_len", [32]) +@pytest.mark.parametrize("seed", [1]) +def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator, + batch_size, output_len): + """Verify spec decode equality when cuda graphs are enabled. + """ + run_greedy_equality_correctness_test( + baseline_llm_generator, + test_llm_generator, + batch_size, + max_output_len=output_len, + force_output_len=True, + )