diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b586c98bd13a9..6a50ee59ea67a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -71,19 +71,19 @@ def forward( # Use in-place division to avoid creating a new tensor. logits.div_(t.unsqueeze(dim=1)) - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities (before applying top-p and top-k). - logprobs = torch.log(probs) - # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) - assert len(top_ps) == len(top_ks) == probs.shape[0] + assert len(top_ps) == len(top_ks) == logits.shape[0] do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_k = any(k != self.vocab_size for k in top_ks) if do_top_p or do_top_k: - probs = _apply_top_p_top_k(probs, top_ps, top_ks) + logits = _apply_top_p_top_k(logits, top_ps, top_ks) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities (before applying top-p and top-k). + logprobs = torch.log(probs) # Sample the next tokens. return _sample(probs, logprobs, input_metadata) @@ -235,31 +235,32 @@ def _get_top_p_top_k( def _apply_top_p_top_k( - probs: torch.Tensor, + logits: torch.Tensor, top_ps: List[float], top_ks: List[int], ) -> torch.Tensor: - p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) - k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) + k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) + logits_sort, logits_idx = logits.sort(dim=-1, descending=True) # Apply top-p. - probs_sum = torch.cumsum(probs_sort, dim=-1) + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - probs_sort[top_p_mask] = 0.0 + logits_sort[top_p_mask] = -float("inf") # Apply top-k. # Create a mask for the top-k elements. - top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) - top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) + top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) + top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - probs_sort[top_k_mask] = 0.0 + logits_sort[top_k_mask] = -float("inf") # Re-sort the probabilities. - probs = torch.gather(probs_sort, - dim=-1, - index=torch.argsort(probs_idx, dim=-1)) - return probs + logits = torch.gather(logits_sort, + dim=-1, + index=torch.argsort(logits_idx, dim=-1)) + return logits def _get_topk_logprobs( @@ -301,9 +302,7 @@ def _sample_from_prompt( # Random sampling. # Sample `best_of` tokens for the prompt. num_seqs = sampling_params.best_of - next_token_ids = torch.multinomial(prob, - num_samples=num_seqs, - replacement=True) + next_token_ids = torch.multinomial(prob, num_samples=num_seqs) next_token_ids = next_token_ids.tolist() return next_token_ids