From c049b7afb6bd5f224064eb71cfb9c0587ebe185c Mon Sep 17 00:00:00 2001 From: Abraham-Xu Date: Sun, 13 Aug 2023 09:34:14 +0000 Subject: [PATCH 1/3] Align with huggingface top_k sampling --- vllm/model_executor/layers/sampler.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index b586c98bd13a9..11895b13e830e 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -71,19 +71,19 @@ def forward( # Use in-place division to avoid creating a new tensor. logits.div_(t.unsqueeze(dim=1)) - # We use float32 for probabilities and log probabilities. - # Compute the probabilities. - probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # Compute the log probabilities (before applying top-p and top-k). - logprobs = torch.log(probs) - # Apply top-p and top-k truncation. top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size) - assert len(top_ps) == len(top_ks) == probs.shape[0] + assert len(top_ps) == len(top_ks) == logits.shape[0] do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) do_top_k = any(k != self.vocab_size for k in top_ks) if do_top_p or do_top_k: - probs = _apply_top_p_top_k(probs, top_ps, top_ks) + logits = _apply_top_p_top_k(logits, top_ps, top_ks) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities (before applying top-p and top-k). + logprobs = torch.log(probs) # Sample the next tokens. return _sample(probs, logprobs, input_metadata) @@ -244,16 +244,16 @@ def _apply_top_p_top_k( probs_sort, probs_idx = probs.sort(dim=-1, descending=True) # Apply top-p. - probs_sum = torch.cumsum(probs_sort, dim=-1) - top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - probs_sort[top_p_mask] = 0.0 + probs_sum = probs_sort.softmax(dim=-1).cumsum(dim=-1) + top_p_mask = (probs_sum - probs_sort.softmax(dim=-1)) > p.unsqueeze(dim=1) + probs_sort[top_p_mask] = -float("Inf") # Apply top-k. # Create a mask for the top-k elements. top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - probs_sort[top_k_mask] = 0.0 + probs_sort[top_k_mask] = -float("Inf") # Re-sort the probabilities. probs = torch.gather(probs_sort, @@ -302,8 +302,7 @@ def _sample_from_prompt( # Sample `best_of` tokens for the prompt. num_seqs = sampling_params.best_of next_token_ids = torch.multinomial(prob, - num_samples=num_seqs, - replacement=True) + num_samples=num_seqs) next_token_ids = next_token_ids.tolist() return next_token_ids From 034963a56a84b09784be95c473919c59f40a8598 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 15 Aug 2023 22:38:27 +0000 Subject: [PATCH 2/3] fix format --- vllm/model_executor/layers/sampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 11895b13e830e..ef1f1221fe44c 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -78,7 +78,7 @@ def forward( do_top_k = any(k != self.vocab_size for k in top_ks) if do_top_p or do_top_k: logits = _apply_top_p_top_k(logits, top_ps, top_ks) - + # We use float32 for probabilities and log probabilities. # Compute the probabilities. probs = torch.softmax(logits, dim=-1, dtype=torch.float) @@ -301,8 +301,7 @@ def _sample_from_prompt( # Random sampling. # Sample `best_of` tokens for the prompt. num_seqs = sampling_params.best_of - next_token_ids = torch.multinomial(prob, - num_samples=num_seqs) + next_token_ids = torch.multinomial(prob, num_samples=num_seqs) next_token_ids = next_token_ids.tolist() return next_token_ids From 4ced78b66224a4f5ce25982032b0874b9280d360 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 15 Aug 2023 23:42:30 +0000 Subject: [PATCH 3/3] rename probs -> logits & remove one extra softmax --- vllm/model_executor/layers/sampler.py | 29 ++++++++++++++------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index ef1f1221fe44c..6a50ee59ea67a 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -235,31 +235,32 @@ def _get_top_p_top_k( def _apply_top_p_top_k( - probs: torch.Tensor, + logits: torch.Tensor, top_ps: List[float], top_ks: List[int], ) -> torch.Tensor: - p = torch.tensor(top_ps, dtype=probs.dtype, device=probs.device) - k = torch.tensor(top_ks, dtype=torch.int, device=probs.device) - probs_sort, probs_idx = probs.sort(dim=-1, descending=True) + p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) + k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) + logits_sort, logits_idx = logits.sort(dim=-1, descending=True) # Apply top-p. - probs_sum = probs_sort.softmax(dim=-1).cumsum(dim=-1) - top_p_mask = (probs_sum - probs_sort.softmax(dim=-1)) > p.unsqueeze(dim=1) - probs_sort[top_p_mask] = -float("Inf") + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) + logits_sort[top_p_mask] = -float("inf") # Apply top-k. # Create a mask for the top-k elements. - top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device) - top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1) + top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) + top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - probs_sort[top_k_mask] = -float("Inf") + logits_sort[top_k_mask] = -float("inf") # Re-sort the probabilities. - probs = torch.gather(probs_sort, - dim=-1, - index=torch.argsort(probs_idx, dim=-1)) - return probs + logits = torch.gather(logits_sort, + dim=-1, + index=torch.argsort(logits_idx, dim=-1)) + return logits def _get_topk_logprobs(