diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 487f5a3d2a441..2ca86a4653cf4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -433,12 +433,9 @@ def _apply_top_k_top_p( logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. - src = torch.arange(logits_idx.shape[-1], - device=logits_idx.device).expand_as(logits_idx) - logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, - index=logits_idx, - src=src) - logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) + logits = torch.empty_like(logits_sort).scatter_(dim=-1, + index=logits_idx, + src=logits_sort) return logits