Skip to content

Commit

Permalink
More Woosuk review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic committed Aug 26, 2024
1 parent 4614f4c commit a5cad38
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 35 deletions.
4 changes: 1 addition & 3 deletions examples/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
llm = LLM(model="facebook/opt-125m",
num_scheduler_steps=8,
use_v2_block_manager=True)
llm = LLM(model="facebook/opt-125m")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
Expand Down
2 changes: 1 addition & 1 deletion tests/core/test_chunked_prefill_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int):


def schedule_and_update_computed_tokens(scheduler):
metas, out, _, _ = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
Expand Down
2 changes: 1 addition & 1 deletion tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def append_new_token(out, token_id: int):


def schedule_and_update_computed_tokens(scheduler):
metas, out, _, _ = scheduler.schedule()
metas, out, _ = scheduler.schedule()
for s, meta in zip(out.scheduled_seq_groups, metas):
s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
return metas, out
Expand Down
15 changes: 7 additions & 8 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ def __init__(
self._scheduled_seq_group_cache.append(
PyObjectCache(scheduled_seq_group_builder))

# For async postprocessor, the extra decode run cannot be done
# when the request reaches max_model_len. In this case, the request
# will be stopped during schedule() call and added to this stop list
# for processing and deallocation by the free_finished_seq_groups()
self._async_stopped: List[SequenceGroup] = []

@property
Expand Down Expand Up @@ -1088,9 +1092,8 @@ def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool:
return no_beam_search

def schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[Tuple[
ScheduledSequenceGroup, SequenceGroupMetadata]], bool]:
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
Expand All @@ -1107,9 +1110,6 @@ def schedule(
self.use_async_output_proc
and not self.scheduler_config.is_multi_step)

# Create list of scheduled request ids
scheduled_ids: List[Tuple[ScheduledSequenceGroup,
SequenceGroupMetadata]] = []
# Create input data structures.
seq_group_metadata_list: List[SequenceGroupMetadata] = []
for i, scheduled_seq_group in enumerate(
Expand Down Expand Up @@ -1216,7 +1216,6 @@ def schedule(
allow_async_output_proc = self._allow_async_output_proc(
seq_group)

scheduled_ids.append((scheduled_seq_group, seq_group_metadata))
# Now that the batch has been created, we can assume all blocks in the
# batch will have been computed before the next scheduling invocation.
# This is because the engine assumes that a failure in model execution
Expand All @@ -1242,7 +1241,7 @@ def schedule(
self.cache_id = self.next_cache_id

# Return results
return (seq_group_metadata_list, scheduler_outputs, scheduled_ids,
return (seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)

def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
Expand Down
8 changes: 3 additions & 5 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,13 @@ async def step_async(
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
scheduled_ids = cached_outputs.scheduled_ids
allow_async_output_proc = cached_outputs.allow_async_output_proc

# skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
(seq_group_metadata_list, scheduler_outputs, scheduled_ids,
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()

Expand All @@ -300,11 +299,10 @@ async def step_async(
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
scheduled_ids, allow_async_output_proc)
allow_async_output_proc)

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert scheduled_ids is not None

if self.scheduler_config.is_multi_step:
assert not allow_async_output_proc
Expand Down Expand Up @@ -363,7 +361,7 @@ async def step_async(

# Cache results in engine
self.output_queue.append(
(output, scheduled_ids, scheduler_outputs))
(output, seq_group_metadata_list, scheduler_outputs))

if (len(output) > 0) and allow_async_output_proc:
assert len(
Expand Down
45 changes: 28 additions & 17 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ class SchedulerOutputState:
"""Caches the scheduler outputs for a virtual engine. Used for Multi-Step"""
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
scheduler_outputs: Optional[SchedulerOutputs] = None
scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup,
SequenceGroupMetadata]]] = None
allow_async_output_proc: bool = False
last_output: Optional[SamplerOutput] = None

Expand Down Expand Up @@ -193,6 +191,8 @@ def __init__(
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
input_registry: InputRegistry = INPUT_REGISTRY,
# To improve performance, only final requests outputs may be required.
# If this set to true, then no intermediate outputs will be returned.
step_return_finished_only=False,
) -> None:
logger.info(
Expand Down Expand Up @@ -408,8 +408,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:

# Async output processing pointers
self.output_queue: Deque[Tuple[List[SamplerOutput],
List[Tuple[ScheduledSequenceGroup,
SequenceGroupMetadata]],
List[SequenceGroupMetadata],
SchedulerOutputs]] = deque()
self.request_outputs: List[Union[RequestOutput,
EmbeddingRequestOutput]] = []
Expand Down Expand Up @@ -1218,6 +1217,15 @@ def _process_sequence_group_outputs(
def _process_model_outputs(self, is_async, clear_outputs=True) -> None:
"""Apply the model output to the sequences in the scheduled seq groups.
is_async: Indicates whether this postprocessor runs in
parallel with the GPU forward pass and is processing
tokens from the previous step. If this is true, then
no tokens need to be appended since it is already done
externally (before the next schedule() call)
clear_outputs: Sometimes existing outputs need to be combined
with outputs of this call. This happens for postprocessor
draining at the final stage (like when sequences are finished)
Returns RequestOutputs that can be returned to the client.
"""
now = time.time()
Expand All @@ -1228,20 +1236,26 @@ def _process_model_outputs(self, is_async, clear_outputs=True) -> None:
if len(self.output_queue) == 0:
return None

(outputs, scheduled_ids,
(outputs, seq_group_metadata_list,
scheduler_outputs) = self.output_queue.popleft()

# Sanity check
assert len(seq_group_metadata_list) == len(
scheduler_outputs.scheduled_seq_groups)

# Organize outputs by [step][sequence group] instead of
# [sequence group][step].
if len(outputs) > 1:
outputs_by_sequence_group = create_output_by_sequence_group(
outputs, num_seq_groups=len(scheduled_ids))
outputs, num_seq_groups=len(seq_group_metadata_list))
else:
outputs_by_sequence_group = outputs

output = [None]
finished_before: List[int] = []
for i, (scheduled_seq_group,
seq_group_meta) in enumerate(scheduled_ids):
for i, seq_group_meta in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

seq_group = scheduled_seq_group.seq_group

if seq_group.is_finished():
Expand Down Expand Up @@ -1288,7 +1302,9 @@ def _process_model_outputs(self, is_async, clear_outputs=True) -> None:
scheduler.free_finished_seq_groups()

# Create the outputs.
for i, (scheduled_seq_group, _) in enumerate(scheduled_ids):
for i, _ in enumerate(seq_group_metadata_list):
scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i]

if i in finished_before:
continue # Avoids double processing

Expand Down Expand Up @@ -1402,14 +1418,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
cached_outputs = self.cached_scheduler_outputs[0]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
scheduled_ids = cached_outputs.scheduled_ids
allow_async_output_proc = cached_outputs.allow_async_output_proc

# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
if not self._has_remaining_steps(seq_group_metadata_list):
(seq_group_metadata_list, scheduler_outputs, scheduled_ids,
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc) = self.scheduler[0].schedule()

if not allow_async_output_proc and len(self.output_queue) > 0:
Expand All @@ -1421,11 +1436,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(
0, seq_group_metadata_list, scheduler_outputs,
scheduled_ids, allow_async_output_proc)
allow_async_output_proc)

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
assert scheduled_ids is not None

if self.scheduler_config.is_multi_step:
assert not allow_async_output_proc
Expand Down Expand Up @@ -1483,7 +1497,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
# Add results to the output_queue
# (for async or non-async postprocessing)
self.output_queue.append(
(output, scheduled_ids, scheduler_outputs))
(output, seq_group_metadata_list, scheduler_outputs))

if (len(output) > 0) and allow_async_output_proc:
assert len(output) == 1, ("Multi step decoding does not work "
Expand Down Expand Up @@ -1544,14 +1558,11 @@ def _cache_scheduler_outputs_for_multi_step(
self, virtual_engine: int,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
scheduler_outputs: SchedulerOutputs,
scheduled_ids: Optional[List[Tuple[ScheduledSequenceGroup,
SequenceGroupMetadata]]],
allow_async_output_proc: bool) -> None:
co = self.cached_scheduler_outputs[virtual_engine]

co.seq_group_metadata_list = seq_group_metadata_list
co.scheduler_outputs = scheduler_outputs
co.scheduled_ids = scheduled_ids
co.allow_async_output_proc = allow_async_output_proc
co.last_output = None

Expand Down

0 comments on commit a5cad38

Please sign in to comment.