From dc74d2470401b7eda34720ce1c7f177ffaa1b1d7 Mon Sep 17 00:00:00 2001
From: youkaichao <youkaichao@126.com>
Date: Sun, 24 Mar 2024 17:09:29 -0700
Subject: [PATCH] fix flaky test

---
 tests/worker/test_model_runner.py | 26 ++++++++++----------------
 1 file changed, 10 insertions(+), 16 deletions(-)

diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py
index 44b22c2bd8a21..01066ef796d67 100644
--- a/tests/worker/test_model_runner.py
+++ b/tests/worker/test_model_runner.py
@@ -1,20 +1,16 @@
-import random
+import pytest
 import torch
 
 from vllm.config import ModelConfig
 from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
-from vllm.worker.model_runner import ModelRunner, _BATCH_SIZE_ALIGNMENT
+from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
 
 
-def get_aligned_size(batch_size: int, alignment: int):
-    return ((batch_size + alignment - 1) // alignment * alignment)
-
-
-def test_prepare_prompt():
+@pytest.mark.parametrize("batch_size", list(range(1, 257)))
+def test_prepare_prompt(batch_size):
     model_runner = ModelRunner(None, None, None, None, None)
     model_runner.set_block_size(16)
 
-    batch_size = random.randint(1, 256)
     prompt_lens = []
     seq_group_metadata_list = []
     block_tables = {0: [1]}
@@ -111,7 +107,8 @@ def test_prepare_prompt():
     torch.testing.assert_close(actual, expected)
 
 
-def test_prepare_decode_cuda_graph():
+@pytest.mark.parametrize("batch_size", list(range(1, 257)))
+def test_prepare_decode_cuda_graph(batch_size):
     model_config = ModelConfig(
         "facebook/opt-125m",
         "facebook/opt-125m",
@@ -127,7 +124,6 @@ def test_prepare_decode_cuda_graph():
     model_runner = ModelRunner(model_config, None, None, None, None)
     model_runner.set_block_size(16)
 
-    batch_size = random.randint(1, 256)
     prompt_lens = []
     seq_group_metadata_list = []
     for i in range(batch_size):
@@ -147,13 +143,13 @@ def test_prepare_decode_cuda_graph():
     input_tokens, input_positions, input_metadata, _, _, _ = (
         model_runner._prepare_decode(seq_group_metadata_list))
 
+    expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
     # Verify input metadata is correct for prompts.
     device = model_runner.device
     assert input_metadata.is_prompt is False
     assert input_metadata.prompt_lens is None
     assert input_metadata.num_prompt_tokens == 0
-    assert input_metadata.num_generation_tokens == (get_aligned_size(
-        len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT))
+    assert input_metadata.num_generation_tokens == expected_bs
     assert input_metadata.max_seq_len is None
     assert input_metadata.subquery_start_loc is None
     assert input_metadata.seq_start_loc is None
@@ -173,10 +169,8 @@ def test_prepare_decode_cuda_graph():
     assert input_metadata.use_cuda_graph is True
     assert input_metadata.kv_cache_dtype == "auto"
 
-    assert input_tokens.shape == (get_aligned_size(
-        len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
-    assert input_positions.shape == (get_aligned_size(
-        len(seq_group_metadata_list), _BATCH_SIZE_ALIGNMENT), )
+    assert input_tokens.shape == (expected_bs, )
+    assert input_positions.shape == (expected_bs, )
     torch.testing.assert_close(input_tokens, input_positions)
 
     # Verify Sampling