diff --git a/tests/samplers/test_logits_processor.py b/tests/samplers/test_logits_processor.py new file mode 100644 index 0000000000000..3788e9e9752ff --- /dev/null +++ b/tests/samplers/test_logits_processor.py @@ -0,0 +1,62 @@ +import pytest +import torch + +from vllm import SamplingParams + +MODELS = ["facebook/opt-125m"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +def test_logits_processor_force_generate( + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + vllm_model = vllm_runner(model, dtype=dtype) + tokenizer = vllm_model.model.get_tokenizer() + repeat_times = 2 + enforced_answers = " vLLM" + vllm_token_ids = tokenizer.encode(enforced_answers, + add_special_tokens=False) + max_tokens = len(vllm_token_ids) * repeat_times + + def pick_vllm(token_ids, logits): + token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)] + logits[token_id] = torch.finfo(logits.dtype).max + return logits + + params_with_logprobs = SamplingParams( + logits_processors=[pick_vllm], + prompt_logprobs=3, + max_tokens=max_tokens, + ) + + # test logits_processors when prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[0], + sampling_params=params_with_logprobs, + prompt_token_ids=None, + ) + + # test prompt_logprobs is not None + vllm_model.model._add_request( + prompt=example_prompts[1], + sampling_params=SamplingParams( + prompt_logprobs=3, + max_tokens=max_tokens, + ), + prompt_token_ids=None, + ) + + # test grouped requests + vllm_model.model._add_request( + prompt=example_prompts[2], + sampling_params=SamplingParams(max_tokens=max_tokens), + prompt_token_ids=None, + ) + + outputs = vllm_model.model._run_engine(False) + + assert outputs[0].outputs[0].text == enforced_answers * repeat_times diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 28e8f6bb7e638..ec531f79ced52 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -86,8 +86,16 @@ def _apply_logits_processors( ) -> torch.Tensor: logits_row_idx = 0 found_logits_processors = False - for seq_ids, sampling_params in sampling_metadata.seq_groups: + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group logits_processors = sampling_params.logits_processors + # handle prompt_logprobs by skipping rows in logits added for + # the prompt tokens (prompt logprobs are not processed) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + assert len(seq_ids) == 1 + logits_row_idx += sampling_metadata.prompt_lens[i] - 1 + if logits_processors: found_logits_processors = True for seq_id in seq_ids: @@ -100,5 +108,6 @@ def _apply_logits_processors( else: logits_row_idx += len(seq_ids) if found_logits_processors: + # verifies that no rows in logits were missed unexpectedly assert logits_row_idx == logits.shape[0] return logits