Skip to content

Commit

Permalink
Add repetition_penalty aligned with huggingface
Browse files Browse the repository at this point in the history
  • Loading branch information
Abraham-Xu committed Aug 27, 2023
1 parent 791d79d commit e428cdd
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 30 deletions.
99 changes: 69 additions & 30 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Sampler(nn.Module):
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence and frequency penalties.
3. Apply presence, frequency and repetition penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
Expand Down Expand Up @@ -54,12 +54,14 @@ def forward(
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
presence_penalties, frequency_penalties, repetition_penalties = \
_get_penalties(input_metadata)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size)
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(input_metadata, logits, output_tokens,
presence_penalties, frequency_penalties,
repetition_penalties, self.vocab_size)

# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
Expand Down Expand Up @@ -108,19 +110,23 @@ def _get_penalties(
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if i < input_metadata.num_prompts:
# A prompt input.
presence_penalties.append(p)
frequency_penalties.append(f)
repetition_penalties.append(r)
else:
# A generation token.
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties
repetition_penalties += [r] * len(seq_ids)
return presence_penalties, frequency_penalties, repetition_penalties


def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
Expand All @@ -143,10 +149,12 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:


def _apply_penalties(
input_metadata: InputMetadata,
logits: torch.Tensor,
output_tokens: List[List[int]],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
vocab_size: int,
) -> torch.Tensor:
num_seqs = logits.shape[0]
Expand All @@ -162,30 +170,61 @@ def _apply_penalties(
indices.append(i)

# Return early if all sequences have zero penalties.
if not indices:
return logits

bin_counts = []
for i in indices:
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
bin_counts = np.stack(bin_counts, axis=0)
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
device=logits.device)

frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)

# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
if indices:
bin_counts = []
for i in indices:
bin_counts.append(
np.bincount(output_tokens[i], minlength=vocab_size))
bin_counts = np.stack(bin_counts, axis=0)
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
device=logits.device)

frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API.
# Refer to
# https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
else:
# repetition penalty aligned with huggingface transformers
for i, seq_group in enumerate(input_metadata.seq_groups):
r = repetition_penalties[i]
if r == 1.0:
continue
seq_ids, _ = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id = seq_ids[0]
seq_data = input_metadata.seq_data[seq_id]
token_ids = seq_data.get_token_ids()
token_ids = torch.tensor(token_ids,
dtype=torch.int64,
device=logits.device)
score = torch.gather(logits[i], 0, token_ids)
score = torch.where(score < 0, score * r, score / r)
logits[i].scatter_(0, token_ids, score)
else:
# A generation token.
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
token_ids = seq_data.get_token_ids()
token_ids = torch.tensor(token_ids,
dtype=torch.int64,
device=logits.device)
score = torch.gather(logits[i], 0, token_ids)
score = torch.where(score < 0, score * r, score / r)
logits[i].scatter_(0, token_ids, score)

return logits


Expand Down
9 changes: 9 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class SamplingParams:
frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
repetition_penalty: The parameter for repetition penalty. 1.0 means no
penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for
more details.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
Expand All @@ -48,6 +51,7 @@ def __init__(
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
Expand All @@ -61,6 +65,7 @@ def __init__(
self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down Expand Up @@ -94,6 +99,9 @@ def _verify_args(self) -> None:
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.repetition_penalty <= 0.0:
raise ValueError("repetition_penalty must be a strictly positive "
f"float, got {self.repetition_penalty}.")
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
Expand Down Expand Up @@ -134,6 +142,7 @@ def __repr__(self) -> str:
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
Expand Down

0 comments on commit e428cdd

Please sign in to comment.