Skip to content

Commit

Permalink
[Bugfix] Fix logits processor when prompt_logprobs is not None (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen authored Apr 10, 2024
1 parent d5272ac commit 9515d11
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 1 deletion.
62 changes: 62 additions & 0 deletions tests/samplers/test_logits_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pytest
import torch

from vllm import SamplingParams

MODELS = ["facebook/opt-125m"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_logits_processor_force_generate(
vllm_runner,
example_prompts,
model: str,
dtype: str,
) -> None:
vllm_model = vllm_runner(model, dtype=dtype)
tokenizer = vllm_model.model.get_tokenizer()
repeat_times = 2
enforced_answers = " vLLM"
vllm_token_ids = tokenizer.encode(enforced_answers,
add_special_tokens=False)
max_tokens = len(vllm_token_ids) * repeat_times

def pick_vllm(token_ids, logits):
token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
logits[token_id] = torch.finfo(logits.dtype).max
return logits

params_with_logprobs = SamplingParams(
logits_processors=[pick_vllm],
prompt_logprobs=3,
max_tokens=max_tokens,
)

# test logits_processors when prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[0],
sampling_params=params_with_logprobs,
prompt_token_ids=None,
)

# test prompt_logprobs is not None
vllm_model.model._add_request(
prompt=example_prompts[1],
sampling_params=SamplingParams(
prompt_logprobs=3,
max_tokens=max_tokens,
),
prompt_token_ids=None,
)

# test grouped requests
vllm_model.model._add_request(
prompt=example_prompts[2],
sampling_params=SamplingParams(max_tokens=max_tokens),
prompt_token_ids=None,
)

outputs = vllm_model.model._run_engine(False)

assert outputs[0].outputs[0].text == enforced_answers * repeat_times
11 changes: 10 additions & 1 deletion vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,16 @@ def _apply_logits_processors(
) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in sampling_metadata.seq_groups:
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
logits_processors = sampling_params.logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
assert len(seq_ids) == 1
logits_row_idx += sampling_metadata.prompt_lens[i] - 1

if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
Expand All @@ -100,5 +108,6 @@ def _apply_logits_processors(
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_row_idx == logits.shape[0]
return logits

0 comments on commit 9515d11

Please sign in to comment.