Skip to content

Commit

Permalink
Align with huggingface top_k sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
Abraham-Xu committed Aug 13, 2023
1 parent e06f504 commit c049b7a
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,19 +71,19 @@ def forward(
# Use in-place division to avoid creating a new tensor.
logits.div_(t.unsqueeze(dim=1))

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs)

# Apply top-p and top-k truncation.
top_ps, top_ks = _get_top_p_top_k(input_metadata, self.vocab_size)
assert len(top_ps) == len(top_ks) == probs.shape[0]
assert len(top_ps) == len(top_ks) == logits.shape[0]
do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps)
do_top_k = any(k != self.vocab_size for k in top_ks)
if do_top_p or do_top_k:
probs = _apply_top_p_top_k(probs, top_ps, top_ks)
logits = _apply_top_p_top_k(logits, top_ps, top_ks)

# We use float32 for probabilities and log probabilities.
# Compute the probabilities.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# Compute the log probabilities (before applying top-p and top-k).
logprobs = torch.log(probs)

# Sample the next tokens.
return _sample(probs, logprobs, input_metadata)
Expand Down Expand Up @@ -244,16 +244,16 @@ def _apply_top_p_top_k(
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)

# Apply top-p.
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1)
probs_sort[top_p_mask] = 0.0
probs_sum = probs_sort.softmax(dim=-1).cumsum(dim=-1)
top_p_mask = (probs_sum - probs_sort.softmax(dim=-1)) > p.unsqueeze(dim=1)
probs_sort[top_p_mask] = -float("Inf")

# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask = torch.arange(probs_idx.shape[-1], device=probs_idx.device)
top_k_mask = top_k_mask.expand(probs_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze(dim=1)
probs_sort[top_k_mask] = 0.0
probs_sort[top_k_mask] = -float("Inf")

# Re-sort the probabilities.
probs = torch.gather(probs_sort,
Expand Down Expand Up @@ -302,8 +302,7 @@ def _sample_from_prompt(
# Sample `best_of` tokens for the prompt.
num_seqs = sampling_params.best_of
next_token_ids = torch.multinomial(prob,
num_samples=num_seqs,
replacement=True)
num_samples=num_seqs)
next_token_ids = next_token_ids.tolist()
return next_token_ids

Expand Down

0 comments on commit c049b7a

Please sign in to comment.