Skip to content

Commit

Permalink
Change seeded generate test to use mixed batch
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill committed Feb 21, 2024
1 parent a0eebef commit 1a774fd
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions tests/samplers/test_seeded_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from vllm import SamplingParams

MODEL = "facebook/opt-125m"
RANDOM_SEEDS = list(range(3))
RANDOM_SEEDS = list(range(5))


@pytest.fixture
Expand All @@ -37,7 +37,7 @@ def test_random_sample_with_seed(
top_k=random.randint(5, 20),
n=random.randint(1, 10),
presence_penalty=random.randint(0, 1),
max_tokens=4,
max_tokens=8,
ignore_eos=True,
)

Expand All @@ -46,23 +46,37 @@ def test_random_sample_with_seed(
sampling_params_seed_2 = copy.deepcopy(sampling_params)
sampling_params_seed_2.seed = 200

vllm_outputs_no_seed_1 = vllm_model.generate(example_prompts,
sampling_params)
vllm_outputs_seed_1_1 = vllm_model.generate(example_prompts,
sampling_params_seed_1)
vllm_outputs_seed_2_1 = vllm_model.generate(example_prompts,
sampling_params_seed_2)
vllm_outputs_no_seed_2 = vllm_model.generate(example_prompts,
sampling_params)
vllm_outputs_seed_1_2 = vllm_model.generate(example_prompts,
sampling_params_seed_1)
vllm_outputs_seed_2_2 = vllm_model.generate(example_prompts,
sampling_params_seed_2)

for output_a, output_b in combinations(
(vllm_outputs_no_seed_1, vllm_outputs_no_seed_2, vllm_outputs_seed_1_1,
vllm_outputs_seed_2_1), 2):
assert output_a != output_b

assert vllm_outputs_seed_1_1 == vllm_outputs_seed_1_2
assert vllm_outputs_seed_2_1 == vllm_outputs_seed_2_2
llm = vllm_model.model

for prompt in example_prompts:
for params in (
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
sampling_params,
sampling_params_seed_1,
sampling_params_seed_2,
):
llm._add_request(
prompt=prompt,
prompt_token_ids=None,
sampling_params=params,
)

results = llm._run_engine(use_tqdm=False)
all_outputs = [[out.token_ids for out in output.outputs]
for output in results]

for i in range(0, len(example_prompts), 6):
outputs = all_outputs[i:i + 6]

# verify all non-seeded requests differ
for output_a, output_b in combinations(
(outputs[0], outputs[1], outputs[2], outputs[3]),
2,
):
assert output_a != output_b

# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]

0 comments on commit 1a774fd

Please sign in to comment.