Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix min_tokens behaviour for multiple eos tokens #5849

Merged
merged 1 commit into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,12 +606,9 @@ def _create_sequence_group_with_sampling(
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# Add the eos token id into the sampling_params to support min_tokens
# processing
if seq.eos_token_id is not None:
sampling_params.all_stop_token_ids.add(seq.eos_token_id)

sampling_params.update_from_generation_config(
self.generation_config_fields)
self.generation_config_fields, seq.eos_token_id)

# Create the sequence group.
seq_group = SequenceGroup(
Expand Down
29 changes: 21 additions & 8 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,30 @@ def _verify_greedy_sampling(self) -> None:
f"Got {self.best_of}.")

def update_from_generation_config(
self, generation_config: Dict[str, Any]) -> None:
self,
generation_config: Dict[str, Any],
model_eos_token_id: Optional[int] = None) -> None:
"""Update if there are non-default values from generation_config"""

if model_eos_token_id is not None:
# Add the eos token id into the sampling_params to support
# min_tokens processing.
self.all_stop_token_ids.add(model_eos_token_id)

# Update eos_token_id for generation
if (not self.ignore_eos) and (eos_ids :=
generation_config.get("eos_token_id")):
if (eos_ids := generation_config.get("eos_token_id")) is not None:
# it can be either int or list of int
if isinstance(eos_ids, int):
eos_ids = [eos_ids]
original_stop_token_ids = set(self.stop_token_ids)
original_stop_token_ids.update(eos_ids)
self.stop_token_ids = list(original_stop_token_ids)
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
if model_eos_token_id is not None:
# We don't need to include the primary eos_token_id in
# stop_token_ids since it's handled separately for stopping
# purposes.
eos_ids.discard(model_eos_token_id)
if eos_ids:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a nit: Making this check more explicit will help readability imo

self.all_stop_token_ids.update(eos_ids)
if not self.ignore_eos:
eos_ids.update(self.stop_token_ids)
self.stop_token_ids = list(eos_ids)

@cached_property
def sampling_type(self) -> SamplingType:
Expand Down
Loading