From 1a774fd23f4bfcb145f19b9352bfc25acc7febbf Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 19 Feb 2024 22:02:00 -0800 Subject: [PATCH] Change seeded generate test to use mixed batch --- tests/samplers/test_seeded_generate.py | 58 ++++++++++++++++---------- 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/tests/samplers/test_seeded_generate.py b/tests/samplers/test_seeded_generate.py index 8b3fc4dda085a..fcb0e09d46143 100644 --- a/tests/samplers/test_seeded_generate.py +++ b/tests/samplers/test_seeded_generate.py @@ -12,7 +12,7 @@ from vllm import SamplingParams MODEL = "facebook/opt-125m" -RANDOM_SEEDS = list(range(3)) +RANDOM_SEEDS = list(range(5)) @pytest.fixture @@ -37,7 +37,7 @@ def test_random_sample_with_seed( top_k=random.randint(5, 20), n=random.randint(1, 10), presence_penalty=random.randint(0, 1), - max_tokens=4, + max_tokens=8, ignore_eos=True, ) @@ -46,23 +46,37 @@ def test_random_sample_with_seed( sampling_params_seed_2 = copy.deepcopy(sampling_params) sampling_params_seed_2.seed = 200 - vllm_outputs_no_seed_1 = vllm_model.generate(example_prompts, - sampling_params) - vllm_outputs_seed_1_1 = vllm_model.generate(example_prompts, - sampling_params_seed_1) - vllm_outputs_seed_2_1 = vllm_model.generate(example_prompts, - sampling_params_seed_2) - vllm_outputs_no_seed_2 = vllm_model.generate(example_prompts, - sampling_params) - vllm_outputs_seed_1_2 = vllm_model.generate(example_prompts, - sampling_params_seed_1) - vllm_outputs_seed_2_2 = vllm_model.generate(example_prompts, - sampling_params_seed_2) - - for output_a, output_b in combinations( - (vllm_outputs_no_seed_1, vllm_outputs_no_seed_2, vllm_outputs_seed_1_1, - vllm_outputs_seed_2_1), 2): - assert output_a != output_b - - assert vllm_outputs_seed_1_1 == vllm_outputs_seed_1_2 - assert vllm_outputs_seed_2_1 == vllm_outputs_seed_2_2 + llm = vllm_model.model + + for prompt in example_prompts: + for params in ( + sampling_params, + sampling_params_seed_1, + sampling_params_seed_2, + sampling_params, + sampling_params_seed_1, + sampling_params_seed_2, + ): + llm._add_request( + prompt=prompt, + prompt_token_ids=None, + sampling_params=params, + ) + + results = llm._run_engine(use_tqdm=False) + all_outputs = [[out.token_ids for out in output.outputs] + for output in results] + + for i in range(0, len(example_prompts), 6): + outputs = all_outputs[i:i + 6] + + # verify all non-seeded requests differ + for output_a, output_b in combinations( + (outputs[0], outputs[1], outputs[2], outputs[3]), + 2, + ): + assert output_a != output_b + + # verify requests with the same seed match + assert outputs[1] == outputs[4] + assert outputs[2] == outputs[5]