Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align with huggingface beam search #646

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 194 additions & 27 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import time
from typing import Dict, List, Optional, Tuple
import copy
from typing import Union, Dict, List, Optional, Tuple

from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import BlockSpaceManager
Expand Down Expand Up @@ -273,34 +274,200 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
def update(
self,
seq_outputs: Dict[int, SequenceOutputs],
decode_func,
) -> List[SequenceGroup]:
scheduled: List[SequenceGroup] = []
for seq_group in self.running:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
if seq.seq_id in seq_outputs:
scheduled.append(seq_group)
break
# Update the running sequences and free blocks.
seq_outputs_data = self._decode_seq_outputs(seq_outputs, decode_func)
for i, seq_group in enumerate(self.running):
try:
group_outputs_data = seq_outputs_data[i]
except KeyError:
continue
sampling_params = seq_group.sampling_params
finished = []
pending = []
for _, data in group_outputs_data.items():
stopped, reason = self._stopping_criteria(
data[1], sampling_params)
if stopped:
finished.append([data, reason])
else:
pending.append(data)

if sampling_params.use_beam_search:
# update finished sequences
force_stop = False
length_penalty = sampling_params.length_penalty
if finished:
highest_attainable_score = max([
data[0][1].get_score(length_penalty)
for data in finished
])
for j, ((finished_seq_id, finished_seq_data),
reason) in enumerate(finished):
finished_seqs = sorted(
seq_group.get_seqs(
status=SequenceStatus.FINISHED_STOPPED) +
seq_group.get_seqs(
status=SequenceStatus.FINISHED_LENGTH_CAPPED),
key=lambda x: x.data.get_score(length_penalty))
assert len(finished_seqs) <= sampling_params.n

if len(finished_seqs) < sampling_params.n:
finished_seq = copy.deepcopy(
seq_group.find(finished_seq_id,
status=SequenceStatus.RUNNING))
finished_seq.data = copy.deepcopy(finished_seq_data)
finished_seq.status = reason
seq_group.append_seq(finished_seq)
else:
worst_seq = finished_seqs[0]
worst_score = worst_seq.data.get_score(length_penalty)
curr_score = finished_seq_data.get_score(
length_penalty)

if j == 0 and worst_score >= highest_attainable_score:
force_stop = True
break

if curr_score > worst_score:
worst_seq.data = copy.deepcopy(finished_seq_data)
worst_seq.status = reason

if force_stop:
for seq in seq_group.get_seqs(
status=SequenceStatus.RUNNING):
self.block_manager.free(seq)
seq_group.seqs.pop(seq_group.seqs.index(seq))
continue

# schedule next-beam tasks
pending = pending[:sampling_params.n]
running_ids = [
seq.seq_id for seq in seq_group.get_seqs(
status=SequenceStatus.RUNNING)
]
all_ids = [j + min(running_ids) for j in range(len(pending))]
pending_ids = [p[0] for p in pending]
new_ids = list(set(all_ids) - set(pending_ids))

used_seq_ids = []
for (parent_id, seq_data) in pending:
parent_seq = seq_group.find(parent_id,
status=SequenceStatus.RUNNING)
if parent_id not in used_seq_ids:
parent_seq.append_token_id(
seq_data.get_last_token_id())
parent_seq.data = copy.deepcopy(seq_data)
used_seq_ids.append(parent_id)
else:
new_seq_id = new_ids.pop()
used_seq_ids.append(new_seq_id)
if new_seq_id in running_ids:
new_seq = seq_group.find(
new_seq_id, status=SequenceStatus.RUNNING)
self.block_manager.free(new_seq)
parent_seq.fork(new_seq)
else:
new_seq = copy.deepcopy(parent_seq)
new_seq.seq_id = new_seq_id
seq_group.append_seq(new_seq)
new_seq = seq_group.find(
new_seq_id, status=SequenceStatus.RUNNING)

self.block_manager.fork(parent_seq, new_seq)
new_seq.append_token_id(seq_data.get_last_token_id())
new_seq.data = copy.deepcopy(seq_data)

for unused_id in list(set(running_ids) - set(used_seq_ids)):
try:
unused_seq = seq_group.find(
unused_id, status=SequenceStatus.RUNNING)
self.block_manager.free(unused_seq)
seq_group.seqs.pop(seq_group.seqs.index(unused_seq))
except ValueError:
continue
else:
for parent_id, seq_data in pending:
parent_seq = seq_group.find(parent_id)
parent_seq.append_token_id(seq_data.get_last_token_id())
parent_seq.data = copy.deepcopy(seq_data)

# Update the scheduled sequences and free blocks.
for seq_group in scheduled:
# Process beam search results before processing the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
output = seq_outputs[seq.seq_id]
if seq.seq_id != output.parent_seq_id:
# The sequence is a fork of the parent sequence (beam
# search). Free the current sequence.
self.block_manager.free(seq)
# Fork the parent sequence.
parent_seq = seq_group.find(output.parent_seq_id)
parent_seq.fork(seq)
self.block_manager.fork(parent_seq, seq)

# Process the new tokens.
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
return scheduled
for (parent_id, seq_data), reason in finished:
parent_seq = seq_group.find(parent_id)
parent_seq.data = copy.deepcopy(seq_data)
self.free_seq(parent_seq, reason)

# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return self.running.copy()

def _decode_seq_outputs(
self,
seq_outputs: List[Dict[int, SequenceOutputs]],
decode_func,
) -> Dict[int, Dict[int, List[Union[int, SequenceData]]]]:
seq_outputs_data = {}
seq_outputs_ = {}
for seq_output in seq_outputs:
seq_output_ids = [v.parent_seq_id for _, v in seq_output.items()]
for i, seq_group in enumerate(self.running):
seq_group_ids = [
seq.seq_id for seq in seq_group.get_seqs(
status=SequenceStatus.RUNNING)
]
if set(seq_output_ids).issubset(set(seq_group_ids)):
seq_outputs_[i] = seq_output
break
for i, seq_group in enumerate(self.running):
try:
group_outputs = seq_outputs_[i]
except KeyError:
continue
group_outputs_data = {}
for pseudo_seq_id, output in group_outputs.items():
try:
parent_seq = seq_group.find(output.parent_seq_id,
status=SequenceStatus.RUNNING)
except ValueError:
continue
seq_data = copy.deepcopy(parent_seq.data)
seq_data.append_token_id(output.output_token, output.logprobs)
new_token, new_output_text = decode_func(
prev_output_tokens=seq_data.output_tokens,
new_token_id=seq_data.get_last_token_id(),
)
if new_token is not None:
seq_data.output_tokens.append(new_token)
seq_data.output_text = new_output_text
group_outputs_data[pseudo_seq_id] = [
output.parent_seq_id, seq_data
]
seq_outputs_data[i] = group_outputs_data

return seq_outputs_data

def _stopping_criteria(
self, seq_data: SequenceData,
sampling_params) -> List[Union[bool, SequenceStatus | None]]:
"""Check if the given sequence stopped."""
for stop_str in sampling_params.stop:
if seq_data.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq_data.output_text = seq_data.output_text[:-len(stop_str)]
return [True, SequenceStatus.FINISHED_STOPPED]

# Check if the sequence has reached max_model_len or max_tokens.
if (seq_data.get_len() >= self.scheduler_config.max_model_len) or \
(seq_data.get_output_len() >= sampling_params.max_tokens):
return [True, SequenceStatus.FINISHED_LENGTH_CAPPED]
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq_data.get_last_token_id() == sampling_params.eos_token_id:
return [True, SequenceStatus.FINISHED_STOPPED]

return [False, None]

def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status
Expand Down
66 changes: 9 additions & 57 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.sequence import Sequence, SequenceGroup
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer)
from vllm.utils import Counter
Expand Down Expand Up @@ -262,6 +262,7 @@ def add_request(
seqs.append(seq)

# Create the sequence group.
sampling_params.eos_token_id = self.tokenizer.eos_token_id
seq_group = SequenceGroup(request_id, seqs, sampling_params,
arrival_time)

Expand Down Expand Up @@ -317,13 +318,14 @@ def step(self) -> List[RequestOutput]:
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
)
# Update the scheduler with the model outputs.
seq_groups = self.scheduler.update(output)
# Update the scheduler with the
# decoded and stopped model outputs.
seq_groups = self.scheduler.update(
output,
partial(detokenize_incrementally,
tokenizer=self.tokenizer,
skip_special_tokens=True))

# Decode the sequences.
self._decode_sequences(seq_groups)
# Stop the sequences that meet the stopping criteria.
self._stop_sequences(seq_groups)
# Free the finished sequence groups.
self.scheduler.free_finished_seq_groups()

Expand Down Expand Up @@ -402,56 +404,6 @@ def _log_system_stats(
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
self.last_logging_time = now

def _decode_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Decodes the sequence outputs."""
for seq_group in seq_groups:
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
new_token, new_output_text = detokenize_incrementally(
self.tokenizer,
seq.output_tokens,
seq.get_last_token_id(),
skip_special_tokens=True,
)
if new_token is not None:
seq.output_tokens.append(new_token)
seq.output_text = new_output_text

def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
"""Stop the finished sequences."""
for seq_group in seq_groups:
sampling_params = seq_group.sampling_params
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
# Check if the sequence has generated a stop string.
stopped = False
for stop_str in sampling_params.stop:
if seq.output_text.endswith(stop_str):
# Truncate the output text so that the stop string is
# not included in the output.
seq.output_text = seq.output_text[:-len(stop_str)]
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
stopped = True
break
if stopped:
continue

# Check if the sequence has reached max_model_len.
if seq.get_len() > self.scheduler_config.max_model_len:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has reached max_tokens.
if seq.get_output_len() == sampling_params.max_tokens:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
continue
# Check if the sequence has generated the EOS token.
if not sampling_params.ignore_eos:
if seq.get_last_token_id() == self.tokenizer.eos_token_id:
self.scheduler.free_seq(
seq, SequenceStatus.FINISHED_STOPPED)
continue

def _run_workers(
self,
method: str,
Expand Down
Loading
Loading