diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b586c98bd13a9..11895b13e830e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -71,19 +71,19 @@ def forward( # Use in-place division to avoid creating a new tensor. logits.div_(t.unsqueeze(dim=1)) - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities (before applying top-p and top-k). - logprobs = torch.log(probs) - # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) - assert len(top_ps) == len(top_ks) == probs.shape[0] + assert len(top_ps) == len(top_ks) == logits.shape[0] do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_k = any(k != self.vocab_size for k in top_ks) if do_top_p or do_top_k: - probs = _apply_top_p_top_k(probs, top_ps, top_ks) + logits = _apply_top_p_top_k(logits, top_ps, top_ks) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities (before applying top-p and top-k). + logprobs = torch.log(probs) # Sample the next tokens. return _sample(probs, logprobs, input_metadata) @@ -244,16 +244,16 @@ def _apply_top_p_top_k( probs_sort, probs_idx = probs.sort(dim=-1, descending=True) # Apply top-p. - probs_sum = torch.cumsum(probs_sort, dim=-1) - top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - probs_sort[top_p_mask] = 0.0 + probs_sum = probs_sort.softmax(dim=-1).cumsum(dim=-1) + top_p_mask = (probs_sum - probs_sort.softmax(dim=-1)) > p.unsqueeze(dim=1) + probs_sort[top_p_mask] = -float("Inf") # Apply top-k. # Create a mask for the top-k elements. top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - probs_sort[top_k_mask] = 0.0 + probs_sort[top_k_mask] = -float("Inf") # Re-sort the probabilities. probs = torch.gather(probs_sort, @@ -302,8 +302,7 @@ def _sample_from_prompt( # Sample `best_of` tokens for the prompt. num_seqs = sampling_params.best_of next_token_ids = torch.multinomial(prob, - num_samples=num_seqs, - replacement=True) + num_samples=num_seqs) next_token_ids = next_token_ids.tolist() return next_token_ids