Skip to content

Commit

Permalink
feat: implement the min_tokens sampling parameter (#3124)
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
tjohnson31415 and njhill authored Mar 25, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 819924e commit c13ad1b
Showing 5 changed files with 299 additions and 12 deletions.
224 changes: 222 additions & 2 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.utils import Counter


class MockLogitsSampler(Sampler):
@@ -25,16 +26,16 @@ def forward(self, *args, **kwargs):
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner


VOCAB_SIZE = 32000
RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
@@ -184,6 +185,225 @@ def test_sampler_all_beam(seed: int, device: str):
del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_id_counter = Counter(start=random.randint(0, 100))
set_random_seed(seed)
torch.set_default_device(device)

def create_sampling_params(min_tokens,
eos_token_id=0,
stop_token_ids=None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids,
)
sampling_params.eos_token_id = eos_token_id
return sampling_params

def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
return seq_data

def generate_test_case():
# generate multiple seq groups but limit total batch size
batch_size = random.randint(1, 128)

expected_penalization = []
sequence_metadata_list = []
while batch_size > 0:
# 20% chance to generate prompt seq group with single sequence
is_prompt = random.random() < 0.2
num_seqs = 1 if is_prompt else random.randint(1, batch_size)

eos_token_id = random.randint(0, VOCAB_SIZE - 1)
min_tokens = random.randint(0, 50)
num_stop_tokens = random.randint(0, 8)
if num_stop_tokens > 0:
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
k=num_stop_tokens)
else:
stop_token_ids = None

sampling_params = create_sampling_params(
min_tokens=min_tokens,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids)

seq_data = {}
seq_group_penalization = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
num_generated = random.randint(1, 100) if not is_prompt else 0
seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens)

expected_penalization.extend(seq_group_penalization)
sequence_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{batch_size}",
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=sampling_params,
block_tables={},
))
batch_size -= num_seqs

return {
"expected_penalization": expected_penalization,
"seq_group_metadata_list": sequence_metadata_list,
}

# define some explicit test cases for edge case behavior
prompt_without_penalization = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(0),
block_tables={},
),
]
}

prompt_with_penalization = {
"expected_penalization": [True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(1),
block_tables={},
),
]
}

stop_penalizing_after_min_tokens = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
},
sampling_params=create_sampling_params(1),
block_tables={},
)
]
}

stop_token_ids = [42, 99, 42, 0] # intentional duplication
simple_combination = {
"expected_penalization": [True, False, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=100),
},
sampling_params=create_sampling_params(
2, stop_token_ids=stop_token_ids),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
block_tables={},
)
]
}

if seed == 0:
test_cases = [
prompt_without_penalization,
prompt_with_penalization,
stop_penalizing_after_min_tokens,
simple_combination,
]
else:
test_cases = [generate_test_case()]

def run_test_case(*,
expected_penalization=None,
seq_group_metadata_list=None):
assert expected_penalization, "Invalid test case"
assert seq_group_metadata_list, "Invalid test case"

batch_size = 0
prompt_lens = []
sampling_params_per_seq = []
for sgm in seq_group_metadata_list:
num_seqs = len(sgm.seq_data)
batch_size += num_seqs
sampling_params = sgm.sampling_params
for seq_id in sgm.seq_data:
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
sampling_params_per_seq.append(sampling_params)

_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list,
prompt_lens=prompt_lens,
subquery_lens=prompt_lens)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_seq)):

tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:
tokens_to_check.extend(sampling_params.stop_token_ids)
tokens_to_check = set(tokens_to_check)

if should_penalize:
for token_id in tokens_to_check:
assert fake_logits[logits_idx, token_id] == -float(
'inf'
), f"Expected token {token_id} for logits row {logits_idx}"
" to be penalized"
# no other tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] == -float('inf')) == len(
tokens_to_check
), f"Expected only {len(tokens_to_check)} to be penalized"
else:
# no tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized"

del model_runner

for test_case in test_cases:
run_test_case(**test_case)


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
28 changes: 18 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -282,6 +282,9 @@ def add_request(
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id

# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
@@ -713,6 +716,21 @@ def _get_stats(self,
def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return

for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
@@ -725,16 +743,6 @@ def _check_stop(self, seq: Sequence,
seq.status = SequenceStatus.FINISHED_STOPPED
return

# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
@@ -88,6 +88,7 @@ class ChatCompletionRequest(BaseModel):
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
@@ -165,6 +166,7 @@ def logit_bias_logits_processor(
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
@@ -224,6 +226,7 @@ class CompletionRequest(BaseModel):
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
# doc: end-completion-sampling-params
@@ -296,6 +299,7 @@ def logit_bias_logits_processor(
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
41 changes: 41 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
from typing import Dict, List, Optional, Tuple

import torch
@@ -36,6 +37,10 @@ def forward(
assert logits is not None
_, vocab_size = logits.shape

# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
@@ -94,6 +99,42 @@ def _get_bin_counts_and_mask(
return bin_counts, mask


def _apply_min_tokens_penalty(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
for seq_ids, sampling_params in sampling_metadata.seq_groups:
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
seq_data = sampling_metadata.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)

if seqs_to_penalize:
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))

start_idx += len(seq_ids)

if logits_to_penalize:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")

return logits


def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
14 changes: 14 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
@@ -79,6 +79,8 @@ class SamplingParams:
ignore_eos: Whether to ignore the EOS token and continue generating
tokens after the EOS token is generated.
max_tokens: Maximum number of tokens to generate per output sequence.
min_tokens: Minimum number of tokens to generate per output sequence
before EOS or stop_token_ids can be generated
logprobs: Number of log probabilities to return per output token.
Note that the implementation follows the OpenAI API: The return
result includes the log probabilities on the `logprobs` most likely
@@ -113,6 +115,7 @@ def __init__(
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
min_tokens: int = 0,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
@@ -144,6 +147,7 @@ def __init__(
self.stop_token_ids = list(stop_token_ids)
self.ignore_eos = ignore_eos
self.max_tokens = max_tokens
self.min_tokens = min_tokens
self.logprobs = logprobs
self.prompt_logprobs = prompt_logprobs
self.skip_special_tokens = skip_special_tokens
@@ -161,6 +165,8 @@ def __init__(
self.top_k = -1
self.min_p = 0.0
self._verify_greedy_sampling()
# injected by the engine
self.eos_token_id = None

def _verify_args(self) -> None:
if self.n < 1:
@@ -191,6 +197,13 @@ def _verify_args(self) -> None:
if self.max_tokens is not None and self.max_tokens < 1:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.min_tokens < 0:
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
@@ -272,6 +285,7 @@ def __repr__(self) -> str:
f"include_stop_str_in_output={self.include_stop_str_in_output}, "
f"ignore_eos={self.ignore_eos}, "
f"max_tokens={self.max_tokens}, "
f"min_tokens={self.min_tokens}, "
f"logprobs={self.logprobs}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"skip_special_tokens={self.skip_special_tokens}, "

0 comments on commit c13ad1b

Please sign in to comment.