Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement the min_tokens sampling parameter #3124

Merged
merged 14 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 222 additions & 2 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner
from vllm.utils import Counter


class MockLogitsSampler(Sampler):
Expand All @@ -26,16 +27,16 @@ def forward(self, *args, **kwargs):
def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner


VOCAB_SIZE = 32000
RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
Expand Down Expand Up @@ -185,6 +186,225 @@ def test_sampler_all_beam(seed: int, device: str):
del model_runner


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_min_tokens_penalty(seed: int, device: str):
seq_id_counter = Counter(start=random.randint(0, 100))
set_random_seed(seed)
torch.set_default_device(device)

def create_sampling_params(min_tokens,
eos_token_id=0,
stop_token_ids=None):
sampling_params = SamplingParams(
min_tokens=min_tokens,
max_tokens=9999, # keep higher than max of min_tokens
stop_token_ids=stop_token_ids,
)
sampling_params.eos_token_id = eos_token_id
return sampling_params

def create_sequence_data(num_input=3, num_generated=0):
seq_data = SequenceData(
random.choices(range(0, VOCAB_SIZE), k=num_input))
if num_generated > 0:
seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE),
k=num_generated)
return seq_data

def generate_test_case():
# generate multiple seq groups but limit total batch size
batch_size = random.randint(1, 128)

expected_penalization = []
sequence_metadata_list = []
while batch_size > 0:
# 20% chance to generate prompt seq group with single sequence
is_prompt = random.random() < 0.2
num_seqs = 1 if is_prompt else random.randint(1, batch_size)

eos_token_id = random.randint(0, VOCAB_SIZE - 1)
min_tokens = random.randint(0, 50)
num_stop_tokens = random.randint(0, 8)
if num_stop_tokens > 0:
stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1),
k=num_stop_tokens)
else:
stop_token_ids = None

sampling_params = create_sampling_params(
min_tokens=min_tokens,
eos_token_id=eos_token_id,
stop_token_ids=stop_token_ids)

seq_data = {}
seq_group_penalization = []
for _ in range(num_seqs):
num_input = random.randint(1, 100)
num_generated = random.randint(1, 100) if not is_prompt else 0
seq_data[next(seq_id_counter)] = create_sequence_data(
num_input=num_input, num_generated=num_generated)
seq_group_penalization.append(num_generated < min_tokens)

expected_penalization.extend(seq_group_penalization)
sequence_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{batch_size}",
is_prompt=is_prompt,
seq_data=seq_data,
sampling_params=sampling_params,
block_tables={},
))
batch_size -= num_seqs

return {
"expected_penalization": expected_penalization,
"seq_group_metadata_list": sequence_metadata_list,
}

# define some explicit test cases for edge case behavior
prompt_without_penalization = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(0),
block_tables={},
),
]
}

prompt_with_penalization = {
"expected_penalization": [True],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(1),
block_tables={},
),
]
}

stop_penalizing_after_min_tokens = {
"expected_penalization": [False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
},
sampling_params=create_sampling_params(1),
block_tables={},
)
]
}

stop_token_ids = [42, 99, 42, 0] # intentional duplication
simple_combination = {
"expected_penalization": [True, False, False],
"seq_group_metadata_list": [
SequenceGroupMetadata(
request_id="test_1",
is_prompt=False,
seq_data={
next(seq_id_counter):
create_sequence_data(num_generated=1),
next(seq_id_counter):
create_sequence_data(num_generated=100),
},
sampling_params=create_sampling_params(
2, stop_token_ids=stop_token_ids),
block_tables={},
),
SequenceGroupMetadata(
request_id="test_2",
is_prompt=True,
seq_data={
next(seq_id_counter): create_sequence_data(),
},
sampling_params=create_sampling_params(
0, stop_token_ids=stop_token_ids),
block_tables={},
)
]
}

if seed == 0:
test_cases = [
prompt_without_penalization,
prompt_with_penalization,
stop_penalizing_after_min_tokens,
simple_combination,
]
else:
test_cases = [generate_test_case()]

def run_test_case(*,
expected_penalization=None,
seq_group_metadata_list=None):
assert expected_penalization, "Invalid test case"
assert seq_group_metadata_list, "Invalid test case"

batch_size = 0
prompt_lens = []
sampling_params_per_seq = []
for sgm in seq_group_metadata_list:
num_seqs = len(sgm.seq_data)
batch_size += num_seqs
sampling_params = sgm.sampling_params
for seq_id in sgm.seq_data:
prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len())
sampling_params_per_seq.append(sampling_params)

_, fake_logits, sampler, model_runner = _prepare_test(batch_size)
sampling_metadata = model_runner._prepare_sample(
seq_group_metadata_list,
prompt_lens=prompt_lens,
subquery_lens=prompt_lens)
# the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)

for logits_idx, (should_penalize, sampling_params) in enumerate(
zip(expected_penalization, sampling_params_per_seq)):

tokens_to_check = [sampling_params.eos_token_id]
if sampling_params.stop_token_ids:
tokens_to_check.extend(sampling_params.stop_token_ids)
tokens_to_check = set(tokens_to_check)

if should_penalize:
for token_id in tokens_to_check:
assert fake_logits[logits_idx, token_id] == -float(
'inf'
), f"Expected token {token_id} for logits row {logits_idx}"
" to be penalized"
# no other tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] == -float('inf')) == len(
tokens_to_check
), f"Expected only {len(tokens_to_check)} to be penalized"
else:
# no tokens should be set to -inf
assert torch.count_nonzero(
fake_logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized"

del model_runner

for test_case in test_cases:
run_test_case(**test_case)


@pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
Expand Down
28 changes: 18 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def add_request(
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = seq.eos_token_id

# Create the sequence group.
seq_group = SequenceGroup(request_id, [seq], sampling_params,
Expand Down Expand Up @@ -757,6 +760,21 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None:
def _check_stop(self, seq: Sequence,
sampling_params: SamplingParams) -> None:
"""Stop the finished sequences."""
# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the minimum number of tokens has been generated yet;
# skip the stop string/token checks if not
if seq.get_output_len() < sampling_params.min_tokens:
return

for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
self._finalize_sequence(seq, sampling_params, stop_str)
Expand All @@ -769,16 +787,6 @@ def _check_stop(self, seq: Sequence,
seq.status = SequenceStatus.FINISHED_STOPPED
return

# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the sequence has generated the EOS token.
if ((not sampling_params.ignore_eos)
and seq.get_last_token_id() == seq.eos_token_id):
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ChatCompletionRequest(BaseModel):
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
Expand Down Expand Up @@ -166,6 +167,7 @@ def logit_bias_logits_processor(
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
Expand Down Expand Up @@ -225,6 +227,7 @@ class CompletionRequest(BaseModel):
early_stopping: Optional[bool] = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
# doc: end-completion-sampling-params
Expand Down Expand Up @@ -297,6 +300,7 @@ def logit_bias_logits_processor(
stop_token_ids=self.stop_token_ids,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
Expand Down
41 changes: 41 additions & 0 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A layer that samples the next tokens from the model's outputs."""
import itertools
from typing import Dict, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -36,6 +37,10 @@ def forward(
assert logits is not None
_, vocab_size = logits.shape

# Apply min_tokens penalty which sets stop tokens to -inf if min_tokens
# have not been generated yet
logits = _apply_min_tokens_penalty(logits, sampling_metadata)

# Prepare sampling tensors with pinned memory to avoid blocking.
(sampling_tensors, do_penalties, do_top_p_top_k,
do_min_p) = SamplingTensors.from_sampling_metadata(
Expand Down Expand Up @@ -94,6 +99,42 @@ def _get_bin_counts_and_mask(
return bin_counts, mask


def _apply_min_tokens_penalty(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great @tjohnson31415. But I think technically the token_ids_to_penalize should be determined per seq_group (i.e. also within the loop) since they may be different per seq group. The indexing gets a bit tricker but I think it might be possible with scatter_ with src=-torch.inf. Or else could group the sequences that share the same list of tokens to pernalize.

Copy link
Contributor Author

@tjohnson31415 tjohnson31415 Mar 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Heh, yup. Thanks for pointing that out. I still need to write some tests for this 😅. I pushed a fix to build a list of coordinates to penalize within the loop so the stop ids are per seq_group.

I was trying to use scatter initially, but couldn't figure out how to get it to work. In particular, scatter uses a rectangular tensor and doesn't seem to have a way to "skip" rows where we don't want to scatter into. So I think a gather-modify-scatter (where we gather across all sequences and stop token ids) would work, but we'd still need to index into the gather'd tensor to set the -inf values.

logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
# list of indices in logits that will be set to -inf
logits_to_penalize = []
start_idx = 0
for seq_ids, sampling_params in sampling_metadata.seq_groups:
min_tokens = sampling_params.min_tokens
if min_tokens > 0:
seqs_to_penalize = []
for i, seq_id in enumerate(seq_ids):
seq_data = sampling_metadata.seq_data[seq_id]
if len(seq_data.output_token_ids) < min_tokens:
seqs_to_penalize.append(i)

if seqs_to_penalize:
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
[sampling_params.eos_token_id])
# itertools.product pairs each seq index with every token id
logits_to_penalize.extend(
itertools.product(seqs_to_penalize, token_ids_to_penalize))

start_idx += len(seq_ids)

if logits_to_penalize:
# use zip and * to group indices along each dimension
# eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) )
logits[tuple(zip(*logits_to_penalize))] = -float("inf")

return logits


def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor: torch.Tensor,
presence_penalties: torch.Tensor,
Expand Down
Loading
Loading