diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 3310f190e87a7..65b14fa340ee6 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -10,6 +10,7 @@ from vllm.model_executor.utils import set_random_seed from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.worker.model_runner import ModelRunner +from vllm.utils import Counter class MockLogitsSampler(Sampler): @@ -25,9 +26,8 @@ def forward(self, *args, **kwargs): def _prepare_test( batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: - vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) - fake_logits = torch.full((batch_size, vocab_size), + fake_logits = torch.full((batch_size, VOCAB_SIZE), 1e-2, dtype=input_tensor.dtype) sampler = MockLogitsSampler(fake_logits) @@ -35,6 +35,7 @@ def _prepare_test( return input_tensor, fake_logits, sampler, model_runner +VOCAB_SIZE = 32000 RANDOM_SEEDS = list(range(128)) CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) @@ -184,6 +185,225 @@ def test_sampler_all_beam(seed: int, device: str): del model_runner +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_min_tokens_penalty(seed: int, device: str): + seq_id_counter = Counter(start=random.randint(0, 100)) + set_random_seed(seed) + torch.set_default_device(device) + + def create_sampling_params(min_tokens, + eos_token_id=0, + stop_token_ids=None): + sampling_params = SamplingParams( + min_tokens=min_tokens, + max_tokens=9999, # keep higher than max of min_tokens + stop_token_ids=stop_token_ids, + ) + sampling_params.eos_token_id = eos_token_id + return sampling_params + + def create_sequence_data(num_input=3, num_generated=0): + seq_data = SequenceData( + random.choices(range(0, VOCAB_SIZE), k=num_input)) + if num_generated > 0: + seq_data.output_token_ids = random.choices(range(0, VOCAB_SIZE), + k=num_generated) + return seq_data + + def generate_test_case(): + # generate multiple seq groups but limit total batch size + batch_size = random.randint(1, 128) + + expected_penalization = [] + sequence_metadata_list = [] + while batch_size > 0: + # 20% chance to generate prompt seq group with single sequence + is_prompt = random.random() < 0.2 + num_seqs = 1 if is_prompt else random.randint(1, batch_size) + + eos_token_id = random.randint(0, VOCAB_SIZE - 1) + min_tokens = random.randint(0, 50) + num_stop_tokens = random.randint(0, 8) + if num_stop_tokens > 0: + stop_token_ids = random.choices(range(0, VOCAB_SIZE - 1), + k=num_stop_tokens) + else: + stop_token_ids = None + + sampling_params = create_sampling_params( + min_tokens=min_tokens, + eos_token_id=eos_token_id, + stop_token_ids=stop_token_ids) + + seq_data = {} + seq_group_penalization = [] + for _ in range(num_seqs): + num_input = random.randint(1, 100) + num_generated = random.randint(1, 100) if not is_prompt else 0 + seq_data[next(seq_id_counter)] = create_sequence_data( + num_input=num_input, num_generated=num_generated) + seq_group_penalization.append(num_generated < min_tokens) + + expected_penalization.extend(seq_group_penalization) + sequence_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{batch_size}", + is_prompt=is_prompt, + seq_data=seq_data, + sampling_params=sampling_params, + block_tables={}, + )) + batch_size -= num_seqs + + return { + "expected_penalization": expected_penalization, + "seq_group_metadata_list": sequence_metadata_list, + } + + # define some explicit test cases for edge case behavior + prompt_without_penalization = { + "expected_penalization": [False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params(0), + block_tables={}, + ), + ] + } + + prompt_with_penalization = { + "expected_penalization": [True], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params(1), + block_tables={}, + ), + ] + } + + stop_penalizing_after_min_tokens = { + "expected_penalization": [False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=False, + seq_data={ + next(seq_id_counter): + create_sequence_data(num_generated=1), + }, + sampling_params=create_sampling_params(1), + block_tables={}, + ) + ] + } + + stop_token_ids = [42, 99, 42, 0] # intentional duplication + simple_combination = { + "expected_penalization": [True, False, False], + "seq_group_metadata_list": [ + SequenceGroupMetadata( + request_id="test_1", + is_prompt=False, + seq_data={ + next(seq_id_counter): + create_sequence_data(num_generated=1), + next(seq_id_counter): + create_sequence_data(num_generated=100), + }, + sampling_params=create_sampling_params( + 2, stop_token_ids=stop_token_ids), + block_tables={}, + ), + SequenceGroupMetadata( + request_id="test_2", + is_prompt=True, + seq_data={ + next(seq_id_counter): create_sequence_data(), + }, + sampling_params=create_sampling_params( + 0, stop_token_ids=stop_token_ids), + block_tables={}, + ) + ] + } + + if seed == 0: + test_cases = [ + prompt_without_penalization, + prompt_with_penalization, + stop_penalizing_after_min_tokens, + simple_combination, + ] + else: + test_cases = [generate_test_case()] + + def run_test_case(*, + expected_penalization=None, + seq_group_metadata_list=None): + assert expected_penalization, "Invalid test case" + assert seq_group_metadata_list, "Invalid test case" + + batch_size = 0 + prompt_lens = [] + sampling_params_per_seq = [] + for sgm in seq_group_metadata_list: + num_seqs = len(sgm.seq_data) + batch_size += num_seqs + sampling_params = sgm.sampling_params + for seq_id in sgm.seq_data: + prompt_lens.append(sgm.seq_data[seq_id].get_prompt_len()) + sampling_params_per_seq.append(sampling_params) + + _, fake_logits, sampler, model_runner = _prepare_test(batch_size) + sampling_metadata = model_runner._prepare_sample( + seq_group_metadata_list, + prompt_lens=prompt_lens, + subquery_lens=prompt_lens) + # the logits tensor is modified in-place by the sampler + _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) + + for logits_idx, (should_penalize, sampling_params) in enumerate( + zip(expected_penalization, sampling_params_per_seq)): + + tokens_to_check = [sampling_params.eos_token_id] + if sampling_params.stop_token_ids: + tokens_to_check.extend(sampling_params.stop_token_ids) + tokens_to_check = set(tokens_to_check) + + if should_penalize: + for token_id in tokens_to_check: + assert fake_logits[logits_idx, token_id] == -float( + 'inf' + ), f"Expected token {token_id} for logits row {logits_idx}" + " to be penalized" + # no other tokens should be set to -inf + assert torch.count_nonzero( + fake_logits[logits_idx, :] == -float('inf')) == len( + tokens_to_check + ), f"Expected only {len(tokens_to_check)} to be penalized" + else: + # no tokens should be set to -inf + assert torch.count_nonzero( + fake_logits[logits_idx, :] == + -float('inf')) == 0, "No tokens should have been penalized" + + del model_runner + + for test_case in test_cases: + run_test_case(**test_case) + + @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_mixed(seed: int, device: str): diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f9638d1101906..1984b94024a16 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -282,6 +282,9 @@ def add_request( # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() + # inject the eos token id into the sampling_params to support min_tokens + # processing + sampling_params.eos_token_id = seq.eos_token_id # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, @@ -713,6 +716,21 @@ def _get_stats(self, def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the minimum number of tokens has been generated yet; + # skip the stop string/token checks if not + if seq.get_output_len() < sampling_params.min_tokens: + return + for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): self._finalize_sequence(seq, sampling_params, stop_str) @@ -725,16 +743,6 @@ def _check_stop(self, seq: Sequence, seq.status = SequenceStatus.FINISHED_STOPPED return - # Check if the sequence has reached max_model_len. - if seq.get_len() > self.scheduler_config.max_model_len: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - - # Check if the sequence has reached max_tokens. - if seq.get_output_len() == sampling_params.max_tokens: - seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED - return - # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) and seq.get_last_token_id() == seq.eos_token_id): diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f1fae1f825f97..965313e29f8d4 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -88,6 +88,7 @@ class ChatCompletionRequest(BaseModel): length_penalty: Optional[float] = 1.0 early_stopping: Optional[bool] = False ignore_eos: Optional[bool] = False + min_tokens: Optional[int] = 0 stop_token_ids: Optional[List[int]] = Field(default_factory=list) skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True @@ -165,6 +166,7 @@ def logit_bias_logits_processor( stop=self.stop, stop_token_ids=self.stop_token_ids, max_tokens=self.max_tokens, + min_tokens=self.min_tokens, logprobs=self.top_logprobs if self.logprobs else None, prompt_logprobs=self.top_logprobs if self.echo else None, best_of=self.best_of, @@ -224,6 +226,7 @@ class CompletionRequest(BaseModel): early_stopping: Optional[bool] = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) ignore_eos: Optional[bool] = False + min_tokens: Optional[int] = 0 skip_special_tokens: Optional[bool] = True spaces_between_special_tokens: Optional[bool] = True # doc: end-completion-sampling-params @@ -296,6 +299,7 @@ def logit_bias_logits_processor( stop_token_ids=self.stop_token_ids, ignore_eos=self.ignore_eos, max_tokens=self.max_tokens if not echo_without_generation else 1, + min_tokens=self.min_tokens, logprobs=self.logprobs, use_beam_search=self.use_beam_search, early_stopping=self.early_stopping, diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 162d2abb292aa..d07527304962d 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,4 +1,5 @@ """A layer that samples the next tokens from the model's outputs.""" +import itertools from typing import Dict, List, Optional, Tuple import torch @@ -36,6 +37,10 @@ def forward( assert logits is not None _, vocab_size = logits.shape + # Apply min_tokens penalty which sets stop tokens to -inf if min_tokens + # have not been generated yet + logits = _apply_min_tokens_penalty(logits, sampling_metadata) + # Prepare sampling tensors with pinned memory to avoid blocking. (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) = SamplingTensors.from_sampling_metadata( @@ -94,6 +99,42 @@ def _get_bin_counts_and_mask( return bin_counts, mask +def _apply_min_tokens_penalty( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + # list of indices in logits that will be set to -inf + logits_to_penalize = [] + start_idx = 0 + for seq_ids, sampling_params in sampling_metadata.seq_groups: + min_tokens = sampling_params.min_tokens + if min_tokens > 0: + seqs_to_penalize = [] + for i, seq_id in enumerate(seq_ids): + seq_data = sampling_metadata.seq_data[seq_id] + if len(seq_data.output_token_ids) < min_tokens: + seqs_to_penalize.append(i) + + if seqs_to_penalize: + # convert to the index into logits + seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] + # use set() to remove any duplicates + token_ids_to_penalize = set(sampling_params.stop_token_ids + + [sampling_params.eos_token_id]) + # itertools.product pairs each seq index with every token id + logits_to_penalize.extend( + itertools.product(seqs_to_penalize, token_ids_to_penalize)) + + start_idx += len(seq_ids) + + if logits_to_penalize: + # use zip and * to group indices along each dimension + # eg. [ (1,2), (1,3), (5,6) ] -> ( (1,1,5), (2,3,6) ) + logits[tuple(zip(*logits_to_penalize))] = -float("inf") + + return logits + + def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, output_tokens_tensor: torch.Tensor, presence_penalties: torch.Tensor, diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 4aa158878fb96..6f81ee31f84dd 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -79,6 +79,8 @@ class SamplingParams: ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. + min_tokens: Minimum number of tokens to generate per output sequence + before EOS or stop_token_ids can be generated logprobs: Number of log probabilities to return per output token. Note that the implementation follows the OpenAI API: The return result includes the log probabilities on the `logprobs` most likely @@ -113,6 +115,7 @@ def __init__( include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: Optional[int] = 16, + min_tokens: int = 0, logprobs: Optional[int] = None, prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, @@ -144,6 +147,7 @@ def __init__( self.stop_token_ids = list(stop_token_ids) self.ignore_eos = ignore_eos self.max_tokens = max_tokens + self.min_tokens = min_tokens self.logprobs = logprobs self.prompt_logprobs = prompt_logprobs self.skip_special_tokens = skip_special_tokens @@ -161,6 +165,8 @@ def __init__( self.top_k = -1 self.min_p = 0.0 self._verify_greedy_sampling() + # injected by the engine + self.eos_token_id = None def _verify_args(self) -> None: if self.n < 1: @@ -191,6 +197,13 @@ def _verify_args(self) -> None: if self.max_tokens is not None and self.max_tokens < 1: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.") + if self.min_tokens < 0: + raise ValueError(f"min_tokens must be greater than or equal to 0, " + f"got {self.min_tokens}.") + if self.max_tokens is not None and self.min_tokens > self.max_tokens: + raise ValueError( + f"min_tokens must be less than or equal to " + f"max_tokens={self.max_tokens}, got {self.min_tokens}.") if self.logprobs is not None and self.logprobs < 0: raise ValueError( f"logprobs must be non-negative, got {self.logprobs}.") @@ -272,6 +285,7 @@ def __repr__(self) -> str: f"include_stop_str_in_output={self.include_stop_str_in_output}, " f"ignore_eos={self.ignore_eos}, " f"max_tokens={self.max_tokens}, " + f"min_tokens={self.min_tokens}, " f"logprobs={self.logprobs}, " f"prompt_logprobs={self.prompt_logprobs}, " f"skip_special_tokens={self.skip_special_tokens}, "