From 947b794146aeae41ea17cbbe8bf9e53abc9c3f53 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 22 Sep 2023 17:48:04 -0700 Subject: [PATCH] [Sampler] Vectorized sampling (simplified) (#1048) Co-authored-by: Antoni Baum --- tests/samplers/test_sampler.py | 184 ++++++++++ vllm/model_executor/layers/sampler.py | 461 ++++++++++++++++---------- vllm/sampling_params.py | 16 + 3 files changed, 481 insertions(+), 180 deletions(-) create mode 100644 tests/samplers/test_sampler.py diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py new file mode 100644 index 0000000000000..a5f55d50fbb76 --- /dev/null +++ b/tests/samplers/test_sampler.py @@ -0,0 +1,184 @@ +import pytest +import random +from typing import Tuple +from unittest.mock import patch + +import torch + +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.utils import set_random_seed +from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.worker.worker import Worker + + +class MockLogitsSampler(Sampler): + + def __init__(self, vocab_size: int, fake_logits: torch.Tensor): + super().__init__(vocab_size=vocab_size) + self.fake_logits = fake_logits + + def forward(self, *args, **kwargs): + with patch("vllm.model_executor.layers.sampler._prune_hidden_states", + lambda x, y: x): + with patch("vllm.model_executor.layers.sampler._get_logits", + lambda *args, **kwargs: self.fake_logits): + return super().forward(*args, **kwargs) + + +def _prepare_test( + batch_size: int +) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, Worker]: + vocab_size = 32000 + input_tensor = torch.rand((batch_size, 1024), + device="cuda", + dtype=torch.float16) + fake_logits = torch.full((batch_size, vocab_size), + 1e-2, + device=input_tensor.device, + dtype=input_tensor.dtype) + sampler = MockLogitsSampler(32000, fake_logits) + worker = Worker(None, None, None) + worker.block_size = 16 + return input_tensor, fake_logits, sampler, worker + + +RANDOM_SEEDS = list(range(128)) + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_all_greedy(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams(temperature=0, ), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + expected = torch.argmax(fake_logits, dim=-1) + for i, sequence_output in enumerate(sampler_output): + for nth_output in sequence_output: + assert nth_output.output_token == expected[i].item() + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_all_random(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + + for i in range(batch_size): + fake_logits[i, i] = 1e2 + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams( + temperature=1.0, + n=random.randint(1, 10), + ), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + for i, sequence_output in enumerate(sampler_output): + for nth_output in sequence_output: + assert nth_output.output_token == i + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_all_beam(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + + seq_group_metadata_list = [] + for i in range(batch_size): + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=SamplingParams( + temperature=0, + best_of=2, + use_beam_search=True, + ), + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + # no assertion here as I am not sure how to determine whether + # the outputs are expected - in other words, this just tests + # whether there are no exceptions in the sampler + # when handling an all-beam search case. + + +@pytest.mark.parametrize("seed", RANDOM_SEEDS) +def test_sampler_mixed(seed: int): + set_random_seed(seed) + batch_size = random.randint(1, 256) + input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size) + + seq_group_metadata_list = [] + expected_tokens = [] + for i in range(batch_size): + n = 1 + sampling_type = random.randint(0, 2) + if sampling_type == 0: + sampling_params = SamplingParams(temperature=0) + elif sampling_type == 1: + n = random.randint(1, 10) + sampling_params = SamplingParams( + temperature=random.random() + 0.1, + top_p=min(random.random() + 0.1, 1), + top_k=random.randint(0, 10) or -1, + n=n, + presence_penalty=random.randint(0, 1), + ) + else: + sampling_params = SamplingParams(temperature=0, + use_beam_search=True, + best_of=2) + for idx in range(n): + fake_logits[i, i + idx] = 1e2 + expected_tokens.append(i + idx) + seq_group_metadata_list.append( + SequenceGroupMetadata( + request_id=f"test_{i}", + is_prompt=True, + seq_data={0: SequenceData([1, 2, 3])}, + sampling_params=sampling_params, + block_tables={0: [1]}, + )) + + _, _, input_metadata = worker._prepare_inputs(seq_group_metadata_list) + sampler_output = sampler(embedding=None, + hidden_states=input_tensor, + input_metadata=input_metadata) + for i, sequence_output in enumerate(sampler_output): + if seq_group_metadata_list[i].sampling_params.use_beam_search: + continue + for nth_output in sequence_output: + assert nth_output.output_token in expected_tokens diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 013b44060d1a3..5179b8d94ae1f 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,15 +1,14 @@ """A layer that samples the next tokens from the model's outputs.""" -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Optional, Tuple -import numpy as np import torch import torch.nn as nn from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.parallel_utils.tensor_parallel import ( gather_from_tensor_model_parallel_region) -from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput, SequenceOutputs +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs _SAMPLING_EPS = 1e-5 @@ -44,12 +43,8 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = gather_from_tensor_model_parallel_region(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :self.vocab_size] + logits = _get_logits(hidden_states, embedding, embedding_bias, + self.vocab_size) # Apply presence and frequency penalties. output_tokens = _get_output_tokens(input_metadata) @@ -59,7 +54,7 @@ def forward( assert len(presence_penalties) == logits.shape[0] assert len(frequency_penalties) == logits.shape[0] logits = _apply_penalties(logits, output_tokens, presence_penalties, - frequency_penalties, self.vocab_size) + frequency_penalties) # Apply temperature scaling. temperatures = _get_temperatures(input_metadata) @@ -90,19 +85,47 @@ def forward( return _sample(probs, logprobs, input_metadata) +def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor], + vocab_size: int) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = gather_from_tensor_model_parallel_region(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :vocab_size] + return logits + + def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, ) -> torch.Tensor: + last_token_indices = {t: [] for t in SamplingType} start_idx = 0 - last_token_indicies: List[int] = [] - for prompt_len in input_metadata.prompt_lens: - last_token_indicies.append(start_idx + prompt_len - 1) - start_idx += prompt_len - last_token_indicies.extend( - range(start_idx, start_idx + input_metadata.num_generation_tokens)) - return hidden_states.index_select( - 0, torch.tensor(last_token_indicies, device=hidden_states.device)) + for i, seq_group in enumerate(input_metadata.seq_groups): + seq_ids, sampling_params = seq_group + sampling_type = sampling_params.sampling_type + if i < input_metadata.num_prompts: + assert len(seq_ids) == 1, "Prompt input should have only one seq." + prompt_len = input_metadata.prompt_lens[i] + last_token_indices[sampling_type].append(start_idx + prompt_len - + 1) + start_idx += prompt_len + else: + num_seqs = len(seq_ids) + last_token_indices[sampling_type].extend( + range(start_idx, start_idx + num_seqs)) + start_idx += num_seqs + + all_last_token_indices = [] + for sampling_type in SamplingType: + all_last_token_indices.extend(last_token_indices[sampling_type]) + all_last_token_indices = torch.tensor(all_last_token_indices, + dtype=torch.long, + device=hidden_states.device) + return hidden_states.index_select(0, all_last_token_indices) def _get_penalties( @@ -149,11 +172,8 @@ def _apply_penalties( output_tokens: List[List[int]], presence_penalties: List[float], frequency_penalties: List[float], - vocab_size: int, ) -> torch.Tensor: - num_seqs = logits.shape[0] - # Collect the indices of sequences that have non-zero penalties. - indices = [] + num_seqs, vocab_size = logits.shape for i in range(num_seqs): if not output_tokens[i]: continue @@ -161,33 +181,40 @@ def _apply_penalties( f = frequency_penalties[i] if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS: continue - indices.append(i) - - # Return early if all sequences have zero penalties. - if not indices: + break + else: + # Return early if all sequences have zero penalties. return logits - bin_counts = [] - for i in indices: - bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size)) - bin_counts = np.stack(bin_counts, axis=0) - bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype, - device=logits.device) + max_output_len = max(len(tokens) for tokens in output_tokens) + padded_output_tokens = [ + tokens + [vocab_size] * (max_output_len - len(tokens)) + for tokens in output_tokens + ] + output_tokens_tensor = torch.tensor(padded_output_tokens, + dtype=torch.long, + device=logits.device) + + # Compute the bin counts for the output tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=logits.device) + bin_counts.scatter_add_(1, output_tokens_tensor, + torch.ones_like(output_tokens_tensor)) + bin_counts = bin_counts[:, :vocab_size] # Remove the padding bin. - frequency_penalties = [frequency_penalties[i] for i in indices] frequency_penalties = torch.tensor(frequency_penalties, dtype=logits.dtype, device=logits.device) - presence_penalties = [presence_penalties[i] for i in indices] presence_penalties = torch.tensor(presence_penalties, dtype=logits.dtype, device=logits.device) # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts - presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype) - logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask + logits -= frequency_penalties.unsqueeze(dim=1) * bin_counts + logits -= presence_penalties.unsqueeze(dim=1) * (bin_counts > 0) return logits @@ -268,95 +295,154 @@ def _apply_top_p_top_k( def _get_topk_logprobs( logprobs: torch.Tensor, num_logprobs: Optional[int], -) -> Dict[int, float]: +) -> List[Dict[int, float]]: + num_seqs = logprobs.size(0) if num_logprobs is None or num_logprobs == 0: - return {} - - topk_logprobs, topk_ids = torch.topk(logprobs, num_logprobs) - if num_logprobs == 1: - topk_logprobs = [topk_logprobs.item()] - topk_ids = [topk_ids.item()] - else: - topk_logprobs = topk_logprobs.tolist() - topk_ids = topk_ids.tolist() - - token_to_logprob: Dict[int, float] = {} - for token_id, logprob in zip(topk_ids, topk_logprobs): - token_to_logprob[token_id] = logprob - return token_to_logprob + return [{} for _ in range(num_seqs)] + + all_topk_logprobs, all_topk_ids = torch.topk(logprobs, + num_logprobs, + dim=-1) + all_topk_logprobs = all_topk_logprobs.cpu() + all_topk_ids = all_topk_ids.cpu() + all_token_to_logprob = [] + for topk_logprobs, topk_ids in zip(all_topk_logprobs, all_topk_ids): + token_to_logprob: Dict[int, float] = {} + for token_id, logprob in zip(topk_ids, topk_logprobs): + token_to_logprob[token_id.item()] = logprob.item() + all_token_to_logprob.append(token_to_logprob) + return all_token_to_logprob + + +def _build_sequence_outputs( + parent_ids: List[int], + next_token_ids: List[int], + selected_token_logprobs: torch.Tensor, + parent_seq_ids: List[int], + parent_logprobs: torch.Tensor, + num_output_logprobs: Optional[int], +) -> List[SequenceOutputs]: + # Get top-k log probabilities for the next tokens. + next_logprobs = _get_topk_logprobs(parent_logprobs, num_output_logprobs) + seq_outputs: List[SequenceOutputs] = [] + for parent_id, next_token_id, token_logprob in zip( + parent_ids, next_token_ids, selected_token_logprobs): + output_logprobs = next_logprobs[parent_id].copy() + output_logprobs[next_token_id] = token_logprob + seq_outputs.append( + SequenceOutputs(parent_seq_ids[parent_id], next_token_id, + output_logprobs)) + return seq_outputs -def _sample_from_prompt( - prob: torch.Tensor, - sampling_params: SamplingParams, -) -> List[int]: - if sampling_params.use_beam_search: - # Beam search. - beam_width = sampling_params.best_of - # Sample 2 * beam_width candidates to make sure that with high - # probability we can get `beam_width` candidates in addition to - # the finished sequences for the next iteration. See - # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 - # for details. See also HF reference: - # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 - _, next_token_ids = torch.topk(prob, 2 * beam_width) - next_token_ids = next_token_ids.tolist() - elif sampling_params.temperature < _SAMPLING_EPS: - # Greedy sampling. - assert sampling_params.best_of == 1 - next_token_id = torch.argmax(prob) - next_token_ids = [next_token_id.item()] - else: - # Random sampling. - # Sample `best_of` tokens for the prompt. - num_seqs = sampling_params.best_of - next_token_ids = torch.multinomial(prob, - num_samples=num_seqs, - replacement=True) - next_token_ids = next_token_ids.tolist() - return next_token_ids - - -def _sample_from_generation_tokens( - seq_ids: List[int], +def _greedy_sample( + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + logprobs: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + samples = torch.argmax(logprobs, dim=-1).cpu() + sample_idx = 0 + results = [] + for seq_group in selected_seq_groups: + seq_ids, _ = seq_group + num_parent_seqs = len(seq_ids) + assert num_parent_seqs == 1, ( + "Greedy sampling should have only one seq.") + parent_ids = list(range(num_parent_seqs)) + next_token_ids = [samples[sample_idx].item()] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + assert sample_idx == logprobs.size(0) + return results + + +def _random_sample( + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], probs: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + # Find the maximum best_of value of the prompt phase requests. + max_best_of = 1 + for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): + if is_prompt: + seq_ids, sampling_params = seq_group + max_best_of = max(max_best_of, sampling_params.best_of) + random_samples = torch.multinomial(probs, + num_samples=max_best_of, + replacement=True).cpu() + sample_idx = 0 + results = [] + for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): + seq_ids, sampling_params = seq_group + num_parent_seqs = len(seq_ids) + if is_prompt: + # Prompt phase. + assert num_parent_seqs == 1, ( + "Prompt input should have only one seq.") + parent_ids = [0] * sampling_params.best_of + next_token_ids = random_samples[ + sample_idx, :sampling_params.best_of].tolist() + else: + # Generation phase. + parent_ids = list(range(num_parent_seqs)) + next_token_ids = random_samples[sample_idx:sample_idx + + num_parent_seqs, 0].tolist() + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + assert sample_idx == probs.size(0) + return results + + +def _beam_search_sample( + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], + seq_data: Dict[int, SequenceData], logprobs: torch.Tensor, - seq_logprobs: List[float], - sampling_params: SamplingParams, -) -> Tuple[List[int], List[int]]: - # NOTE(woosuk): sampling_params.best_of can be greater than - # len(seq_ids) because some sequences in the group might have - # been already terminated. - if sampling_params.use_beam_search: - # Beam search. - # Add cumulative logprobs for the sequences in the group. - seq_logprobs = torch.tensor(seq_logprobs, - dtype=torch.float, - device=logprobs.device) - logprobs = logprobs + seq_logprobs.unsqueeze(dim=1) - - vocab_size = logprobs.size(-1) - beam_width = len(seq_ids) - _, topk_ids = torch.topk(logprobs.flatten(), 2 * beam_width) - topk_ids = topk_ids.tolist() - seq_idx = [i // vocab_size for i in topk_ids] - parent_seq_ids = [seq_ids[i] for i in seq_idx] - next_token_ids = [i % vocab_size for i in topk_ids] - elif sampling_params.temperature < _SAMPLING_EPS: - # Greedy sampling. - assert len(seq_ids) == 1 - next_token_id = torch.argmax(probs, dim=-1) - next_token_ids = [int(next_token_id.item())] - parent_seq_ids = seq_ids - else: - # Random sampling. - # Sample 1 token for each sequence in the group. - next_token_ids = torch.multinomial(probs, - num_samples=1, - replacement=True) - next_token_ids = next_token_ids.squeeze(dim=-1).tolist() - parent_seq_ids = seq_ids - return parent_seq_ids, next_token_ids +) -> List[Tuple[List[int], List[int]]]: + # We sample 2 * beam_width candidates to make sure that with high + # probability we can get `beam_width` candidates in addition to + # the finished sequences for the next iteration. See + # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 + # for details. See also HF reference: + # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 + # + # Note: Beam search is not vectorized, so its speed can be slower than + # other sampling methods. + sample_idx = 0 + results = [] + for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): + seq_ids, sampling_params = seq_group + num_parent_seqs = len(seq_ids) + beam_width = sampling_params.best_of + seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] + if is_prompt: + # Prompt phase. + assert num_parent_seqs == 1, ( + "Prompt input should have only one seq.") + parent_ids = [0] * (2 * beam_width) + _, next_token_ids = torch.topk(seq_group_logprobs[0], + 2 * beam_width) + next_token_ids = next_token_ids.tolist() + else: + # Generation phase. + cumulative_logprobs = [ + seq_data[seq_id].cumulative_logprob for seq_id in seq_ids + ] + cumulative_logprobs = torch.tensor( + cumulative_logprobs, + dtype=torch.float, + device=seq_group_logprobs.device) + seq_group_logprobs = (seq_group_logprobs + + cumulative_logprobs.unsqueeze(dim=1)) + _, topk_ids = torch.topk(seq_group_logprobs.flatten(), + 2 * beam_width) + topk_ids = topk_ids.tolist() + vocab_size = seq_group_logprobs.size(-1) + parent_ids = [i // vocab_size for i in topk_ids] + next_token_ids = [i % vocab_size for i in topk_ids] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + assert sample_idx == logprobs.size(0) + return results def _sample( @@ -364,65 +450,80 @@ def _sample( logprobs: torch.Tensor, input_metadata: InputMetadata, ) -> SamplerOutput: - seq_outputs: SamplerOutput = [] - - # TODO(woosuk): Optimize. - idx = 0 + categorized_seq_group_ids = {t: [] for t in SamplingType} + category_num_tokens = {t: 0 for t in SamplingType} for i, seq_group in enumerate(input_metadata.seq_groups): - seq_group_outputs: List[SequenceOutputs] = [] seq_ids, sampling_params = seq_group - if i < input_metadata.num_prompts: - # Generate the next tokens for a prompt input. - assert len(seq_ids) == 1, "Prompt input should have only one seq." - parent_seq_id = seq_ids[0] - prob = probs[idx] - logprob = logprobs[idx] - idx += 1 - - # Sample the next tokens. - next_token_ids = _sample_from_prompt(prob, sampling_params) - # Get top-k log probabilities for the next tokens. - next_logprobs = _get_topk_logprobs(logprob, - sampling_params.logprobs) - - # Build the output. - for next_token_id in next_token_ids: - output_logprobs = next_logprobs.copy() - output_logprobs[next_token_id] = logprob[next_token_id].item() - seq_group_outputs.append( - SequenceOutputs(parent_seq_id, next_token_id, - output_logprobs)) + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + num_seqs = len(seq_ids) + category_num_tokens[sampling_type] += num_seqs + + seq_outputs_dict: Dict[int, List[SequenceOutputs]] = {} + category_start_idx = 0 + for sampling_type in SamplingType: + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [input_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < input_metadata.num_prompts for i in seq_group_ids] + num_tokens = category_num_tokens[sampling_type] + if num_tokens == 0: + continue + category_logprobs = logprobs[category_start_idx:category_start_idx + + num_tokens] + category_probs = probs[category_start_idx:category_start_idx + + num_tokens] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, category_logprobs) + elif sampling_type == SamplingType.RANDOM: + sample_results = _random_sample(seq_groups, is_prompts, + category_probs) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, is_prompts, + input_metadata.seq_data, + category_logprobs) else: - # Generate the next tokens for generation tokens. + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # Batched query for logprobs of selected token + batched_logprobs_query_seq_indices: List[int] = [] + batched_logprobs_query_token_indices: List[int] = [] + sample_idx = 0 + for seq_group_id, seq_group, sample_result in zip( + seq_group_ids, seq_groups, sample_results): + seq_ids, sampling_params = seq_group + next_token_ids, parent_ids = sample_result num_parent_seqs = len(seq_ids) - prob = probs[idx:idx + num_parent_seqs] - logprob = logprobs[idx:idx + num_parent_seqs] - idx += num_parent_seqs - - # Sample the next tokens. - seq_logprobs = [ - input_metadata.seq_data[seq_id].cumulative_logprob - for seq_id in seq_ids - ] - parent_seq_ids, next_token_ids = _sample_from_generation_tokens( - seq_ids, prob, logprob, seq_logprobs, sampling_params) - - # Get top-k log probabilities for the next tokens. - next_logprobs: Dict[int, Dict[int, float]] = {} - for j, seq_id in enumerate(seq_ids): - next_logprobs[seq_id] = _get_topk_logprobs( - logprob[j], sampling_params.logprobs) - - # Build the output. - for parent_seq_id, next_token_id in zip(parent_seq_ids, - next_token_ids): - j = seq_ids.index(parent_seq_id) - output_logprobs = next_logprobs[parent_seq_id].copy() - output_logprobs[next_token_id] = logprob[j, - next_token_id].item() - seq_group_outputs.append( - SequenceOutputs(parent_seq_id, next_token_id, - output_logprobs)) - seq_outputs.append(seq_group_outputs) - - return seq_outputs + batched_logprobs_query_seq_indices.extend( + [sample_idx + parent_id for parent_id in parent_ids]) + batched_logprobs_query_token_indices.extend(next_token_ids) + sample_idx += num_parent_seqs + assert sample_idx == num_tokens + batched_logprobs_query_result = category_logprobs[[ + batched_logprobs_query_seq_indices, + batched_logprobs_query_token_indices + ]].tolist() + + # Build the sequence outputs. + sample_idx = 0 + result_idx = 0 + for seq_group_id, seq_group, sample_result in zip( + seq_group_ids, seq_groups, sample_results): + seq_ids, sampling_params = seq_group + next_token_ids, parent_ids = sample_result + num_results = len(next_token_ids) + num_parent_seqs = len(seq_ids) + parent_logprobs = category_logprobs[sample_idx:sample_idx + + num_parent_seqs] + selected_token_logprobs = batched_logprobs_query_result[ + result_idx:result_idx + num_results] + seq_output = _build_sequence_outputs(parent_ids, next_token_ids, + selected_token_logprobs, + seq_ids, parent_logprobs, + sampling_params.logprobs) + seq_outputs_dict[seq_group_id] = seq_output + sample_idx += num_parent_seqs + result_idx += num_results + assert sample_idx == num_tokens + category_start_idx += num_tokens + + return [seq_outputs_dict[i] for i in range(len(input_metadata.seq_groups))] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 85fb503aa3d87..53bd743fce9da 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,9 +1,17 @@ """Sampling parameters for text generation.""" +from enum import IntEnum +from functools import cached_property from typing import List, Optional, Union _SAMPLING_EPS = 1e-5 +class SamplingType(IntEnum): + GREEDY = 0 + RANDOM = 1 + BEAM = 2 + + class SamplingParams: """Sampling parameters for text generation. @@ -166,6 +174,14 @@ def _verify_greedy_sampling(self) -> None: if self.top_k != -1: raise ValueError("top_k must be -1 when using greedy sampling.") + @cached_property + def sampling_type(self) -> SamplingType: + if self.use_beam_search: + return SamplingType.BEAM + if self.temperature < _SAMPLING_EPS: + return SamplingType.GREEDY + return SamplingType.RANDOM + def __repr__(self) -> str: return (f"SamplingParams(n={self.n}, " f"best_of={self.best_of}, "