Skip to content

Commit

Permalink
multi-step scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
SolitaryThinker committed Aug 9, 2024
1 parent 5c6c54d commit 78c4ff8
Show file tree
Hide file tree
Showing 13 changed files with 1,090 additions and 65 deletions.
81 changes: 81 additions & 0 deletions tests/worker/test_model_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from vllm.worker.embedding_model_runner import (
ModelInputForGPUWithPoolingMetadata)
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
from vllm.worker.multi_step_model_runner import (
MutableModelInputForGPUWithMultiStepMetadata)


class MockAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -154,3 +156,82 @@ def test_embedding_model_runner_input():
None) == getattr(attn_metadata, field.name, None)
# Pooling metadata is not broadcast.
assert received_model_input.pooling_metadata is None


def test_multi_step_model_runner_input():
sampling_metadata = SamplingMetadata(
["seq_group"],
"selected_token_indices",
"categorized_sample_indices",
"num_prompts",
)
attn_metadata = AttentionMetadata(
num_prefills=1,
num_prefill_tokens=2,
num_decode_tokens=3,
slot_mapping=torch.zeros(1),
)
frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10),
input_positions=torch.ones(10),
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata)

model_input = MutableModelInputForGPUWithMultiStepMetadata(
frozen_model_input=frozen_model_input,
is_last_step=True,
is_first_multi_step=False,
current_step=4,
last_sampled_token_ids=torch.ones((10, 1)),
is_multi_step=True,
num_queries=8,
num_seqs=5,
outputs=[],
)

assert isinstance(model_input,
MutableModelInputForGPUWithMultiStepMetadata)

# Test round trip serialization.
tensor_dict = model_input.as_broadcastable_tensor_dict()
attn_backend = MockAttentionBackend()
received_model_input = (MutableModelInputForGPUWithMultiStepMetadata.
from_broadcasted_tensor_dict(
tensor_dict, attn_backend=attn_backend))

receieved_frozen_input = received_model_input.frozen_model_input

# Check that received copy has correct values.
assert isinstance(received_model_input,
MutableModelInputForGPUWithMultiStepMetadata)
assert receieved_frozen_input.input_tokens is not None
assert (receieved_frozen_input.input_tokens ==
frozen_model_input.input_tokens).all()
assert receieved_frozen_input.input_positions is not None
assert (receieved_frozen_input.input_positions ==
frozen_model_input.input_positions).all()
assert receieved_frozen_input.multi_modal_kwargs is None
assert (frozen_model_input.multi_modal_kwargs ==
frozen_model_input.multi_modal_kwargs)
assert receieved_frozen_input.lora_requests is None
assert (receieved_frozen_input.lora_requests ==
frozen_model_input.lora_requests)
assert receieved_frozen_input.lora_mapping is None
assert (
receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping)
for field in dataclasses.fields(AttentionMetadata):
assert getattr(receieved_frozen_input.attn_metadata, field.name,
None) == getattr(attn_metadata, field.name, None)
# For sampling metadata, only selected_token_indices is copied.
assert (receieved_frozen_input.sampling_metadata.selected_token_indices ==
sampling_metadata.selected_token_indices)
assert receieved_frozen_input.sampling_metadata.seq_groups is None

# check non frozen fields
assert received_model_input.is_last_step == model_input.is_last_step
assert (received_model_input.is_first_multi_step ==
model_input.is_first_multi_step)
assert received_model_input.current_step == model_input.current_step
assert (received_model_input.last_sampled_token_ids ==
model_input.last_sampled_token_ids).all()
assert received_model_input.is_multi_step == model_input.is_multi_step
14 changes: 13 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ def __init__(self,
delay_factor: float = 0.0,
enable_chunked_prefill: bool = False,
embedding_mode: Optional[bool] = False,
preemption_mode: Optional[str] = None) -> None:
preemption_mode: Optional[str] = None,
max_forward_calls_per_step: int = 1) -> None:
if max_num_batched_tokens is not None:
self.max_num_batched_tokens = max_num_batched_tokens
else:
Expand Down Expand Up @@ -879,6 +880,7 @@ def __init__(self,
self.chunked_prefill_enabled = enable_chunked_prefill
self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode
self.max_forward_calls_per_step = max_forward_calls_per_step
self._verify_args()

def _verify_args(self) -> None:
Expand All @@ -904,6 +906,16 @@ def _verify_args(self) -> None:
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0.")

if self.max_forward_calls_per_step < 1:
raise ValueError(
"max_forward_calls_per_step "
f"({self.max_forward_calls_per_step}) must be greater than or "
"equal to 1.")

@property
def is_multi_step(self) -> bool:
return self.max_forward_calls_per_step > 1


class DeviceConfig:
device: Optional[torch.device]
Expand Down
5 changes: 5 additions & 0 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,6 +805,9 @@ def _schedule_prefills(
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
seq_group.init_multi_step(
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True))
seq_groups.append(
ScheduledSequenceGroup(seq_group=seq_group,
token_chunk_size=num_new_tokens))
Expand Down Expand Up @@ -1108,6 +1111,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
computed_block_nums=common_computed_block_nums,
encoder_seq_data=encoder_seq_data,
cross_block_table=cross_block_table,
state=seq_group.state,
# `multi_modal_data` will only be present for the 1st comm
# between engine and worker.
# the subsequent comms can still use delta, but
Expand Down Expand Up @@ -1184,6 +1188,7 @@ def _append_slots(
slots.
"""
num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False)
seq_group.init_multi_step(num_lookahead_slots=num_lookahead_slots)

for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
cows = self.block_manager.append_slots(seq, num_lookahead_slots)
Expand Down
22 changes: 19 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class EngineArgs:
lora_dtype: str = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
max_forward_calls_per_step: int = 1
ray_workers_use_nsight: bool = False
num_gpu_blocks_override: Optional[int] = None
num_lookahead_slots: int = 0
Expand Down Expand Up @@ -506,6 +507,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"tpu", "xpu"
],
help='Device type for vLLM execution.')
parser.add_argument('--max-forward-calls-per-step',
type=int,
default=1,
help='Maximum number of forward calls per step.')

parser.add_argument(
'--scheduler-delay-factor',
Expand Down Expand Up @@ -820,18 +825,29 @@ def create_engine_config(self, ) -> EngineConfig:
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)

if (speculative_config is not None
and self.max_forward_calls_per_step > 1):
raise ValueError("Speculative decoding is not supported with "
"multi-step (--max_forward_calls_per_step > 1)")
# make sure num_lookahead_slots is set the higher value depending on
# if we are using speculative decoding or multi-step
num_lookahead_slots = max(self.num_lookahead_slots,
self.max_forward_calls_per_step - 1)
num_lookahead_slots = num_lookahead_slots \
if speculative_config is None \
else speculative_config.num_lookahead_slots

scheduler_config = SchedulerConfig(
max_num_batched_tokens=self.max_num_batched_tokens,
max_num_seqs=self.max_num_seqs,
max_model_len=model_config.max_model_len,
use_v2_block_manager=self.use_v2_block_manager,
num_lookahead_slots=(self.num_lookahead_slots
if speculative_config is None else
speculative_config.num_lookahead_slots),
num_lookahead_slots=num_lookahead_slots,
delay_factor=self.scheduler_delay_factor,
enable_chunked_prefill=self.enable_chunked_prefill,
embedding_mode=model_config.embedding_mode,
preemption_mode=self.preemption_mode,
max_forward_calls_per_step=self.max_forward_calls_per_step,
)
lora_config = LoRAConfig(
max_lora_rank=self.max_lora_rank,
Expand Down
Loading

0 comments on commit 78c4ff8

Please sign in to comment.