Skip to content

Commit

Permalink
Added logits processor API to sampling params (vllm-project#1469)
Browse files Browse the repository at this point in the history
  • Loading branch information
noamgat authored Nov 3, 2023
1 parent 54ca1ba commit 555bdcc
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 2 deletions.
34 changes: 34 additions & 0 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,37 @@ def test_sampler_mixed(seed: int):
continue
for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, worker = _prepare_test(batch_size)

# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
return logits

seq_group_metadata_list = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(temperature=0,
logits_processors=[pick_ith]),
block_tables={0: [1]},
))

_, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list)
sampler_output = sampler(embedding=None,
hidden_states=input_tensor,
input_metadata=input_metadata)
for i, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx
24 changes: 24 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def forward(
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)

# Apply logits processors (if any).
logits = _apply_logits_processors(logits, input_metadata)
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
Expand Down Expand Up @@ -155,6 +157,28 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
return output_tokens


def _apply_logits_processors(logits: torch.Tensor,
input_metadata: InputMetadata) -> torch.Tensor:
logits_row_idx = 0
found_logits_processors = False
for seq_ids, sampling_params in input_metadata.seq_groups:
logits_processors = sampling_params.logits_processors
if logits_processors:
found_logits_processors = True
for seq_id in seq_ids:
logits_row = logits[logits_row_idx]
token_ids = input_metadata.seq_data[seq_id].output_token_ids
for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
logits[logits_row_idx] = logits_row
logits_row_idx += 1
else:
logits_row_idx += len(seq_ids)
if found_logits_processors:
assert logits_row_idx == logits.shape[0]
return logits


def _apply_penalties(
logits: torch.Tensor,
output_tokens: List[List[int]],
Expand Down
14 changes: 12 additions & 2 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Sampling parameters for text generation."""
from enum import IntEnum
from functools import cached_property
from typing import List, Optional, Union
from typing import Callable, List, Optional, Union
import torch

_SAMPLING_EPS = 1e-5

Expand All @@ -12,6 +13,12 @@ class SamplingType(IntEnum):
BEAM = 2


LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""


class SamplingParams:
"""Sampling parameters for text generation.
Expand Down Expand Up @@ -73,6 +80,8 @@ class SamplingParams:
skip_special_tokens: Whether to skip special tokens in the output.
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
"""

def __init__(
Expand All @@ -96,6 +105,7 @@ def __init__(
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[LogitsProcessor]] = None,
) -> None:
self.n = n
self.best_of = best_of if best_of is not None else n
Expand Down Expand Up @@ -124,7 +134,7 @@ def __init__(
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
self.spaces_between_special_tokens = spaces_between_special_tokens

self.logits_processors = logits_processors
self._verify_args()
if self.use_beam_search:
self._verify_beam_search()
Expand Down

0 comments on commit 555bdcc

Please sign in to comment.