Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add repetition_penalty aligned with huggingface #866

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 69 additions & 30 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class Sampler(nn.Module):
1. Discard the hidden states that are not used for sampling (i.e., all
tokens except the final one in each prompt).
2. Compute the logits for the next tokens.
3. Apply presence and frequency penalties.
3. Apply presence, frequency and repetition penalties.
4. Apply temperature scaling.
5. Apply top-p and top-k truncation.
6. Sample the next tokens.
Expand Down Expand Up @@ -54,12 +54,14 @@ def forward(
# Apply presence and frequency penalties.
output_tokens = _get_output_tokens(input_metadata)
assert len(output_tokens) == logits.shape[0]
presence_penalties, frequency_penalties = _get_penalties(
input_metadata)
presence_penalties, frequency_penalties, repetition_penalties = \
_get_penalties(input_metadata)
assert len(presence_penalties) == logits.shape[0]
assert len(frequency_penalties) == logits.shape[0]
logits = _apply_penalties(logits, output_tokens, presence_penalties,
frequency_penalties, self.vocab_size)
assert len(repetition_penalties) == logits.shape[0]
logits = _apply_penalties(input_metadata, logits, output_tokens,
presence_penalties, frequency_penalties,
repetition_penalties, self.vocab_size)

# Apply temperature scaling.
temperatures = _get_temperatures(input_metadata)
Expand Down Expand Up @@ -108,19 +110,23 @@ def _get_penalties(
# Collect the presence and frequency penalties.
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
for i, seq_group in enumerate(input_metadata.seq_groups):
seq_ids, sampling_params = seq_group
p = sampling_params.presence_penalty
f = sampling_params.frequency_penalty
r = sampling_params.repetition_penalty
if i < input_metadata.num_prompts:
# A prompt input.
presence_penalties.append(p)
frequency_penalties.append(f)
repetition_penalties.append(r)
else:
# A generation token.
presence_penalties += [p] * len(seq_ids)
frequency_penalties += [f] * len(seq_ids)
return presence_penalties, frequency_penalties
repetition_penalties += [r] * len(seq_ids)
return presence_penalties, frequency_penalties, repetition_penalties


def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:
Expand All @@ -143,10 +149,12 @@ def _get_output_tokens(input_metadata: InputMetadata) -> List[List[int]]:


def _apply_penalties(
input_metadata: InputMetadata,
logits: torch.Tensor,
output_tokens: List[List[int]],
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
vocab_size: int,
) -> torch.Tensor:
num_seqs = logits.shape[0]
Expand All @@ -162,30 +170,61 @@ def _apply_penalties(
indices.append(i)

# Return early if all sequences have zero penalties.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is misleading now?

if not indices:
return logits

bin_counts = []
for i in indices:
bin_counts.append(np.bincount(output_tokens[i], minlength=vocab_size))
bin_counts = np.stack(bin_counts, axis=0)
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
device=logits.device)

frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)

# We follow the definition in OpenAI API.
# Refer to https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
if indices:
bin_counts = []
for i in indices:
bin_counts.append(
np.bincount(output_tokens[i], minlength=vocab_size))
bin_counts = np.stack(bin_counts, axis=0)
bin_counts = torch.from_numpy(bin_counts).to(dtype=logits.dtype,
device=logits.device)

frequency_penalties = [frequency_penalties[i] for i in indices]
frequency_penalties = torch.tensor(frequency_penalties,
dtype=logits.dtype,
device=logits.device)
presence_penalties = [presence_penalties[i] for i in indices]
presence_penalties = torch.tensor(presence_penalties,
dtype=logits.dtype,
device=logits.device)
# We follow the definition in OpenAI API.
# Refer to
# https://platform.openai.com/docs/api-reference/parameter-details
logits[indices] -= frequency_penalties.unsqueeze(dim=1) * bin_counts
presence_mask = (bin_counts > 0.0).to(dtype=logits.dtype)
logits[indices] -= presence_penalties.unsqueeze(dim=1) * presence_mask
else:
Copy link

@leshanbog leshanbog Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why else? will presence_penalty and repetition_penalty work together?

# repetition penalty aligned with huggingface transformers
for i, seq_group in enumerate(input_metadata.seq_groups):
r = repetition_penalties[i]
if r == 1.0:
continue
seq_ids, _ = seq_group
if i < input_metadata.num_prompts:
# A prompt input.
# NOTE: While the prompt input usually has no output tokens,
# it may have output tokens in the case of recomputation.
seq_id = seq_ids[0]
seq_data = input_metadata.seq_data[seq_id]
token_ids = seq_data.get_token_ids()
token_ids = torch.tensor(token_ids,
dtype=torch.int64,
device=logits.device)
score = torch.gather(logits[i], 0, token_ids)
score = torch.where(score < 0, score * r, score / r)
logits[i].scatter_(0, token_ids, score)
else:
# A generation token.
for seq_id in seq_ids:
seq_data = input_metadata.seq_data[seq_id]
token_ids = seq_data.get_token_ids()
token_ids = torch.tensor(token_ids,
dtype=torch.int64,
device=logits.device)
score = torch.gather(logits[i], 0, token_ids)
score = torch.where(score < 0, score * r, score / r)
logits[i].scatter_(0, token_ids, score)

return logits


Expand Down
18 changes: 18 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class SamplingParams:
frequency in the generated text so far. Values > 0 encourage the
model to use new tokens, while values < 0 encourage the model to
repeat tokens.
repetition_penalty: The parameter for repetition penalty. 1.0 means no
penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for
more details.
temperature: Float that controls the randomness of the sampling. Lower
values make the model more deterministic, while higher values make
the model more random. Zero means greedy sampling.
Expand All @@ -48,6 +51,7 @@ def __init__(
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
Expand All @@ -61,6 +65,7 @@ def __init__(
self.best_of = best_of if best_of is not None else n
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.repetition_penalty = repetition_penalty
self.temperature = temperature
self.top_p = top_p
self.top_k = top_k
Expand Down Expand Up @@ -94,6 +99,18 @@ def _verify_args(self) -> None:
if not -2.0 <= self.frequency_penalty <= 2.0:
raise ValueError("frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}.")
if self.repetition_penalty <= 0.0:
raise ValueError("repetition_penalty must be a strictly positive "
f"float, got {self.repetition_penalty}.")
Abraham-Xu marked this conversation as resolved.
Show resolved Hide resolved
if self.repetition_penalty != 1.0 and (
abs(self.frequency_penalty) > _SAMPLING_EPS
or abs(self.presence_penalty) > _SAMPLING_EPS):
raise ValueError(
f"repetition_penalty cannot be used with "
f"frequency_penalty and presence_penalty."
f"got repetition_penalty={self.repetition_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"presence_penalty={self.presence_penalty}")
if self.temperature < 0.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}.")
Expand Down Expand Up @@ -134,6 +151,7 @@ def __repr__(self) -> str:
f"best_of={self.best_of}, "
f"presence_penalty={self.presence_penalty}, "
f"frequency_penalty={self.frequency_penalty}, "
f"repetition_penalty={self.repetition_penalty}, "
f"temperature={self.temperature}, "
f"top_p={self.top_p}, "
f"top_k={self.top_k}, "
Expand Down
Loading