From af6d97483e978bca830607327f288e88e30d81f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=9B=8F=E4=B8=80?= Date: Fri, 20 Sep 2024 02:28:25 +0800 Subject: [PATCH] [Core] simplify logits resort in _apply_top_k_top_p (#8619) --- vllm/model_executor/layers/sampler.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 487f5a3d2a441..2ca86a4653cf4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -433,12 +433,9 @@ def _apply_top_k_top_p( logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. - src = torch.arange(logits_idx.shape[-1], - device=logits_idx.device).expand_as(logits_idx) - logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, - index=logits_idx, - src=src) - logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) + logits = torch.empty_like(logits_sort).scatter_(dim=-1, + index=logits_idx, + src=logits_sort) return logits