From 6410635ddba5e1cd380a78b6c3c6a3e94e1c01fd Mon Sep 17 00:00:00 2001 From: Swapnil Parekh Date: Sun, 12 May 2024 20:47:47 -0400 Subject: [PATCH] [CORE] Improvement in ranks code (#4718) --- vllm/model_executor/layers/sampler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c8bab46c83eca..a84f562909d50 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -681,7 +681,9 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: """ vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), indices] - return (x > vals[:, None]).long().sum(1).add_(1) + result = (x > vals[:, None]) + del vals + return result.sum(1).add_(1) def _get_logprobs(