diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 2126fafb2323b..0244919152cad 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -10,6 +10,8 @@ from vllm.worker.embedding_model_runner import ( ModelInputForGPUWithPoolingMetadata) from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from vllm.worker.multi_step_model_runner import ( + MutableModelInputForGPUWithMultiStepMetadata) class MockAttentionBackend(AttentionBackend): @@ -154,3 +156,82 @@ def test_embedding_model_runner_input(): None) == getattr(attn_metadata, field.name, None) # Pooling metadata is not broadcast. assert received_model_input.pooling_metadata is None + + +def test_multi_step_model_runner_input(): + sampling_metadata = SamplingMetadata( + ["seq_group"], + "selected_token_indices", + "categorized_sample_indices", + "num_prompts", + ) + attn_metadata = AttentionMetadata( + num_prefills=1, + num_prefill_tokens=2, + num_decode_tokens=3, + slot_mapping=torch.zeros(1), + ) + frozen_model_input = ModelInputForGPUWithSamplingMetadata( + input_tokens=torch.ones(10), + input_positions=torch.ones(10), + sampling_metadata=sampling_metadata, + attn_metadata=attn_metadata) + + model_input = MutableModelInputForGPUWithMultiStepMetadata( + frozen_model_input=frozen_model_input, + is_last_step=True, + is_first_multi_step=False, + current_step=4, + last_sampled_token_ids=torch.ones((10, 1)), + is_multi_step=True, + num_queries=8, + num_seqs=5, + outputs=[], + ) + + assert isinstance(model_input, + MutableModelInputForGPUWithMultiStepMetadata) + + # Test round trip serialization. + tensor_dict = model_input.as_broadcastable_tensor_dict() + attn_backend = MockAttentionBackend() + received_model_input = (MutableModelInputForGPUWithMultiStepMetadata. + from_broadcasted_tensor_dict( + tensor_dict, attn_backend=attn_backend)) + + receieved_frozen_input = received_model_input.frozen_model_input + + # Check that received copy has correct values. + assert isinstance(received_model_input, + MutableModelInputForGPUWithMultiStepMetadata) + assert receieved_frozen_input.input_tokens is not None + assert (receieved_frozen_input.input_tokens == + frozen_model_input.input_tokens).all() + assert receieved_frozen_input.input_positions is not None + assert (receieved_frozen_input.input_positions == + frozen_model_input.input_positions).all() + assert receieved_frozen_input.multi_modal_kwargs is None + assert (frozen_model_input.multi_modal_kwargs == + frozen_model_input.multi_modal_kwargs) + assert receieved_frozen_input.lora_requests is None + assert (receieved_frozen_input.lora_requests == + frozen_model_input.lora_requests) + assert receieved_frozen_input.lora_mapping is None + assert ( + receieved_frozen_input.lora_mapping == frozen_model_input.lora_mapping) + for field in dataclasses.fields(AttentionMetadata): + assert getattr(receieved_frozen_input.attn_metadata, field.name, + None) == getattr(attn_metadata, field.name, None) + # For sampling metadata, only selected_token_indices is copied. + assert (receieved_frozen_input.sampling_metadata.selected_token_indices == + sampling_metadata.selected_token_indices) + assert receieved_frozen_input.sampling_metadata.seq_groups is None + + # check non frozen fields + assert received_model_input.is_last_step == model_input.is_last_step + assert (received_model_input.is_first_multi_step == + model_input.is_first_multi_step) + assert received_model_input.current_step == model_input.current_step + assert (received_model_input.last_sampled_token_ids == + model_input.last_sampled_token_ids).all() + assert received_model_input.is_multi_step == model_input.is_multi_step diff --git a/vllm/config.py b/vllm/config.py index 4207466cfc5c0..347c55f04ab0c 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -850,7 +850,8 @@ def __init__(self, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, embedding_mode: Optional[bool] = False, - preemption_mode: Optional[str] = None) -> None: + preemption_mode: Optional[str] = None, + max_forward_calls_per_step: int = 1) -> None: if max_num_batched_tokens is not None: self.max_num_batched_tokens = max_num_batched_tokens else: @@ -879,6 +880,7 @@ def __init__(self, self.chunked_prefill_enabled = enable_chunked_prefill self.embedding_mode = embedding_mode self.preemption_mode = preemption_mode + self.max_forward_calls_per_step = max_forward_calls_per_step self._verify_args() def _verify_args(self) -> None: @@ -904,6 +906,16 @@ def _verify_args(self) -> None: f"({self.num_lookahead_slots}) must be greater than or " "equal to 0.") + if self.max_forward_calls_per_step < 1: + raise ValueError( + "max_forward_calls_per_step " + f"({self.max_forward_calls_per_step}) must be greater than or " + "equal to 1.") + + @property + def is_multi_step(self) -> bool: + return self.max_forward_calls_per_step > 1 + class DeviceConfig: device: Optional[torch.device] diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index b16850c7eb9f8..2d1bee77bade7 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -805,6 +805,9 @@ def _schedule_prefills( curr_loras.add(lora_int_id) waiting_queue.popleft() self._allocate_and_set_running(seq_group) + seq_group.init_multi_step( + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill=True)) seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) @@ -1108,6 +1111,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, + state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1184,6 +1188,7 @@ def _append_slots( slots. """ num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + seq_group.init_multi_step(num_lookahead_slots=num_lookahead_slots) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): cows = self.block_manager.append_slots(seq, num_lookahead_slots) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 73698511fdbb7..ec6014d839af4 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -93,6 +93,7 @@ class EngineArgs: lora_dtype: str = 'auto' max_cpu_loras: Optional[int] = None device: str = 'auto' + max_forward_calls_per_step: int = 1 ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -506,6 +507,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "tpu", "xpu" ], help='Device type for vLLM execution.') + parser.add_argument('--max-forward-calls-per-step', + type=int, + default=1, + help='Maximum number of forward calls per step.') parser.add_argument( '--scheduler-delay-factor', @@ -820,18 +825,29 @@ def create_engine_config(self, ) -> EngineConfig: disable_logprobs=self.disable_logprobs_during_spec_decoding, ) + if (speculative_config is not None + and self.max_forward_calls_per_step > 1): + raise ValueError("Speculative decoding is not supported with " + "multi-step (--max_forward_calls_per_step > 1)") + # make sure num_lookahead_slots is set the higher value depending on + # if we are using speculative decoding or multi-step + num_lookahead_slots = max(self.num_lookahead_slots, + self.max_forward_calls_per_step - 1) + num_lookahead_slots = num_lookahead_slots \ + if speculative_config is None \ + else speculative_config.num_lookahead_slots + scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, max_model_len=model_config.max_model_len, use_v2_block_manager=self.use_v2_block_manager, - num_lookahead_slots=(self.num_lookahead_slots - if speculative_config is None else - speculative_config.num_lookahead_slots), + num_lookahead_slots=num_lookahead_slots, delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, preemption_mode=self.preemption_mode, + max_forward_calls_per_step=self.max_forward_calls_per_step, ) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 809eb6de9f173..e21b60bac0c45 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -3,9 +3,11 @@ from functools import partial from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) +from dataclasses import dataclass from transformers import PreTrainedTokenizer from typing_extensions import assert_never +import torch import vllm.envs as envs from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig, @@ -27,7 +29,9 @@ from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceGroupMetadata) + from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -248,9 +252,24 @@ def has_new_requests(self): return not self._new_requests.empty() +@dataclass +class SchedulerOutputState: + """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" + last_output: Optional[SamplerOutput] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None + + class _AsyncLLMEngine(LLMEngine): """Extension of LLMEngine to add async methods.""" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pipeline_parallel_size = \ + self.parallel_config.pipeline_parallel_size + self.cached_scheduler_outputs = [SchedulerOutputState() + ] * pipeline_parallel_size + async def step_async( self, virtual_engine: int ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: @@ -263,13 +282,41 @@ async def step_async( and updates the scheduler with the model outputs. Finally, it decodes the sequences and returns the newly generated results. """ - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + # these are cached outputs from previous iterations. None if on first + # iteration + cached_outputs = self.cached_scheduler_outputs[virtual_engine] + seq_group_metadata_list = cached_outputs.seq_group_metadata_list + scheduler_outputs = cached_outputs.scheduler_outputs + # skip the scheduler if there are any remaining steps in the seq groups. + # This ensures that the scheduler is only called again when the current + # batch has completed. + if not self._has_remaining_steps(seq_group_metadata_list): + seq_group_metadata_list, scheduler_outputs = self.scheduler[ + virtual_engine].schedule() + + if (self.scheduler_config.is_multi_step + and scheduler_outputs.num_lookahead_slots > 0): + # cache the scheduler outputs for the next iteration if we have + # lookahead slots + self._cache_scheduler_outputs_for_multi_step( + virtual_engine, seq_group_metadata_list, scheduler_outputs) + + assert seq_group_metadata_list is not None + assert scheduler_outputs is not None if not scheduler_outputs.is_empty(): - # Execute the model. finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() + + # check if we have a cached last_output from the previous iteration + # for PP this is probably the best way to pass the sampled_token_ids + # as a broadcast across stages will cause one virtual engine's stage + # to block another VE. + # None if not multi-step or is first iteration + last_sampled_token_ids = \ + self._get_cached_sampled_token_ids_for_multi_step( + virtual_engine) + execute_model_req = ExecuteModelRequest( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, @@ -278,15 +325,33 @@ async def step_async( virtual_engine=virtual_engine, num_lookahead_slots=scheduler_outputs.num_lookahead_slots, running_queue_size=scheduler_outputs.running_queue_size, - finished_requests_ids=finished_requests_ids) + finished_requests_ids=finished_requests_ids, + last_sampled_token_ids=last_sampled_token_ids) + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) + # we need to do this here so that last step's sampled_token_ids can + # be passed to the next iteration for PP. + if self.scheduler_config.is_multi_step: + self._cache_output_for_multi_step(virtual_engine, output) else: output = [] - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + # Finish the current step for all the sequence groups. + if self.scheduler_config.is_multi_step: + for seq_group in seq_group_metadata_list: + seq_group.finish_step() + + if not self._has_remaining_steps(seq_group_metadata_list): + # clear the cache if we have finished all the steps + if self.scheduler_config.is_multi_step: + self.cached_scheduler_outputs[ + virtual_engine] = SchedulerOutputState() + request_outputs = self._process_model_outputs( + output, scheduler_outputs.scheduled_seq_groups, + scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) + else: + request_outputs = [] # Log stats. self.do_log_stats(scheduler_outputs, output) @@ -296,6 +361,66 @@ async def step_async( return request_outputs + def _has_remaining_steps( + self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] + ) -> bool: + if not self.scheduler_config.is_multi_step: + return False + + if seq_group_metadata_list is None: + return False + # TODO(will) this is a sanity check for nowto make sure that all the + # seqs are on the same steps. Eventually we will want to do some sort of + # dynamic scheduling when doing multi-step decoding. + if len(seq_group_metadata_list) == 0: + return False + steps_remaining = [ + seq_group.state.remaining_steps + for seq_group in seq_group_metadata_list + ] + if steps_remaining.count(steps_remaining[0]) != len(steps_remaining): + raise AssertionError(("All running sequence groups should " + "have the same remaining steps.")) + + if any(seq_group.state.remaining_steps > 0 + for seq_group in seq_group_metadata_list): + return True + return False + + def _cache_scheduler_outputs_for_multi_step( + self, virtual_engine: int, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + scheduler_outputs: SchedulerOutputs) -> None: + self.cached_scheduler_outputs[ + virtual_engine].seq_group_metadata_list = seq_group_metadata_list + self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ + scheduler_outputs + self.cached_scheduler_outputs[virtual_engine].last_output = None + + def _get_cached_sampled_token_ids_for_multi_step( + self, virtual_engine: int) -> Optional[torch.Tensor]: + cached_last_output = self.cached_scheduler_outputs[ + virtual_engine].last_output + if (self.scheduler_config.is_multi_step + and self.parallel_config.pipeline_parallel_size > 1 + and cached_last_output is not None + and cached_last_output.sampled_token_ids_numpy is not None): + return torch.from_numpy(cached_last_output.sampled_token_ids_numpy) + return None + + def _cache_output_for_multi_step( + self, virtual_engine: int, + output: List[Optional[SamplerOutput]]) -> None: + if (self.parallel_config.pipeline_parallel_size > 1): + if len(output) > 0 and output[0] is not None: + last_output = output[-1] + assert last_output is not None + assert last_output.sampled_token_ids_numpy is not None + assert last_output.sampled_token_ids is None + assert last_output.sampled_token_probs is None + self.cached_scheduler_outputs[ + virtual_engine].last_output = last_output + async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" await self.model_executor.stop_remote_worker_execution_loop_async() diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 57b9e2b33b982..38d8fd91b6d4f 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -70,13 +70,19 @@ def _get_create_worker_kwargs( distributed_init_method: Optional[str] = None) -> Dict: worker_kwargs = self._get_worker_kwargs(local_rank, rank, distributed_init_method) - if self.speculative_config is None: - worker_kwargs.update(worker_module_name="vllm.worker.worker", - worker_class_name="Worker") - else: + + if self.scheduler_config.is_multi_step: + worker_kwargs.update( + worker_module_name="vllm.worker.multi_step_worker", + worker_class_name="MultiStepWorker") + elif self.speculative_config: worker_kwargs.update( worker_module_name="vllm.spec_decode.spec_decode_worker", worker_class_name="create_spec_worker") + else: + worker_kwargs.update(worker_module_name="vllm.worker.worker", + worker_class_name="Worker") + return worker_kwargs def _create_worker(self, diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 4a6825c01fcf8..fae8307428119 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -79,6 +79,9 @@ def _get_worker_wrapper_args(self) -> Dict[str, Any]: if self.speculative_config is not None: worker_module_name = "vllm.spec_decode.spec_decode_worker" worker_class_name = "create_spec_worker" + elif self.scheduler_config.is_multi_step: + worker_module_name = "vllm.worker.multi_step_worker" + worker_class_name = "MultiStepWorker" else: worker_module_name = "vllm.worker.worker" worker_class_name = "Worker" diff --git a/vllm/sequence.py b/vllm/sequence.py index 7349bc6f13bd6..b9b2cfedfd267 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -9,6 +9,7 @@ Union, cast) import torch +import numpy from vllm.inputs.parse import is_valid_encoder_decoder_llm_inputs from vllm.lora.request import LoRARequest @@ -489,6 +490,17 @@ def __repr__(self) -> str: f"num_blocks={self.n_blocks}, ") +@dataclass +class SequenceGroupState: + """Mutable state tied to a specific sequence group""" + + # for multi-step decoding + num_lookahead_slots: int = 0 + num_steps: int = 1 + remaining_steps: int = 0 + current_step: int = 0 + + class SequenceGroup: """A group of sequences that are generated from the same prompt. @@ -534,6 +546,7 @@ def __init__( time_in_queue=None) self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None + self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params self.prompt_adapter_request = prompt_adapter_request @@ -588,6 +601,12 @@ def prompt_adapter_num_virtual_tokens(self) -> int: return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ if self.prompt_adapter_request else 0 + def init_multi_step(self, num_lookahead_slots: int) -> None: + self.state.num_lookahead_slots = num_lookahead_slots + self.state.num_steps = num_lookahead_slots + 1 + self.state.remaining_steps = num_lookahead_slots + 1 + self.state.current_step = 0 + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -756,6 +775,7 @@ class SequenceGroupMetadata: lora_request: LoRA request. computed_block_nums: The block numbers that are already computed, used in prefix caching. + state: Internal state tied to this sequence group. multi_modal_data: Multi modal data. encoder_seq_data: Optional sequence data for encoder prompt (SequenceGroup.encoder_seq). Should be None @@ -781,6 +801,7 @@ def __init__( token_chunk_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, computed_block_nums: Optional[List[int]] = None, + state: Optional[SequenceGroupState] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, @@ -796,6 +817,7 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data + self.state = SequenceGroupState() if state is None else state self.encoder_seq_data = encoder_seq_data self.cross_block_table = cross_block_table self._token_chunk_size = token_chunk_size @@ -834,6 +856,12 @@ def token_chunk_size(self) -> int: assert self._token_chunk_size is not None return self._token_chunk_size + def finish_step(self) -> None: + assert self.state.current_step < self.state.num_steps + self.state.current_step += 1 + self.state.remaining_steps -= 1 + assert self.state.remaining_steps >= 0 + class SequenceOutput: """The model output associated with a sequence. @@ -971,6 +999,8 @@ class SamplerOutput: # On-device tensor containing the sampled token ids. sampled_token_ids: Optional[torch.Tensor] = None + # sampled_token_ids_numpy: Optional[List[int]] = None + sampled_token_ids_numpy: Optional[numpy.ndarray] = None # Spec decode metrics populated by workers. spec_decode_worker_metrics: Optional["SpecDecodeWorkerMetrics"] = None @@ -1112,6 +1142,28 @@ class ExecuteModelRequest: num_steps: int = 1 # Finished request ids since last step. finished_requests_ids: List[str] = field(default_factory=list) + # The last sampled token ids for multi step decoding. + last_sampled_token_ids: Optional[torch.Tensor] = None + + @property + def is_first_multi_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + return first_seq_group.state.current_step == 0 + + @property + def is_last_step(self) -> bool: + # TODO(will) make this be able to handle batches with variable number of steps + assert len(self.seq_group_metadata_list) > 0 + first_seq_group = self.seq_group_metadata_list[0] + return first_seq_group.state.remaining_steps == 1 + + @property + def current_step(self) -> int: + # TODO(will) make this be able to handle batches with variable number of steps + assert len(self.seq_group_metadata_list) > 0 + return self.seq_group_metadata_list[0].state.current_step def clone( self, seq_group_metadata_list: List[SequenceGroupMetadata] @@ -1127,4 +1179,6 @@ def clone( running_queue_size=self.running_queue_size, previous_hidden_states=self.previous_hidden_states, num_steps=self.num_steps, - finished_requests_ids=self.finished_requests_ids) + finished_requests_ids=self.finished_requests_ids, + last_sampled_token_ids=self.last_sampled_token_ids.clone() + if self.last_sampled_token_ids is not None else None) diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 46ac16b504bf4..90c39407d7266 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -14,7 +14,7 @@ from vllm.attention.backends.abstract import AttentionBackend from vllm.model_executor import SamplingMetadata -T = TypeVar('T', bound="ModelRunnerInputBase") +T = TypeVar('T', bound="BroadcastableModelInput") def _add_attn_metadata_broadcastable_dict( @@ -81,18 +81,26 @@ def _add_sampling_metadata_broadcastable_dict( sampling_metadata.selected_token_indices) -@dataclasses.dataclass(frozen=True) -class ModelRunnerInputBase(ABC): - """Local inputs to each worker's model runner. May contain - device-specific data. Different worker backends may have different methods - of converting from the global ExecuteModelRequest produced by the LLM - engine to the worker-local ModelRunnerInputBase objects. - - Model runners that support multi-GPU execution should define a - ModelRunnerInputBase subclass, add their required fields, and specify how to - serialize/deserialize a ModelInput for broadcast between workers. +def _init_frozen_model_input_from_tensor_dict( + frozen_model_input_cls: Type["ModelRunnerInputBase"], + tensor_dict: Dict[str, Any]) -> Dict[str, Any]: """ + Helper method to initialize a frozen ModelInput based on broadcastable + """ + valid_tensor_kwargs = {} + for field in dataclasses.fields(frozen_model_input_cls): + val = tensor_dict.pop(field.name, None) + if val is not None: + valid_tensor_kwargs[field.name] = val + + frozen_model_input = frozen_model_input_cls(**valid_tensor_kwargs) + tensor_dict["frozen_model_input"] = frozen_model_input + return tensor_dict + +class BroadcastableModelInput(ABC): + + @abstractmethod def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: """ Extract broadcastable fields. Override for fields that require some @@ -109,11 +117,25 @@ def from_broadcasted_tensor_dict( ) -> T: """ Pop fields from the given tensor_dict and populate a new instance of - ModelRunnerInputBase. + BroadcastableModelInput. """ raise NotImplementedError +@dataclasses.dataclass(frozen=True) +class ModelRunnerInputBase(BroadcastableModelInput): + """Local inputs to each worker's model runner. May contain + device-specific data. Different worker backends may have different methods + of converting from the global ExecuteModelRequest produced by the LLM + engine to the worker-local ModelRunnerInputBase objects. + + Model runners that support multi-GPU execution should define a + ModelRunnerInputBase subclass, add their required fields, and specify how to + serialize/deserialize a ModelInput for broadcast between workers. + """ + pass + + class ModelRunnerInputBuilderBase(ABC, Generic[T]): """A builder to create ModelRunnerInputBase objects. """ diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py new file mode 100644 index 0000000000000..76a8221ea1c37 --- /dev/null +++ b/vllm/worker/multi_step_model_runner.py @@ -0,0 +1,493 @@ +import dataclasses +from dataclasses import dataclass, field +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Union) +try: + from vllm.attention.backends.flash_attn import FlashAttentionMetadata +except ModuleNotFoundError: + # vllm_flash_attn is not installed, use the identical ROCm FA metadata + from vllm.attention.backends.rocm_flash_attn import ( + ROCmFlashAttentionMetadata as FlashAttentionMetadata) + +from ..model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.worker.model_runner_base import ( + BroadcastableModelInput, _init_frozen_model_input_from_tensor_dict, + _init_attn_metadata_from_tensor_dict, + _init_sampling_metadata_from_tensor_dict) +from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, + GPUModelRunnerBase) +from vllm.logger import init_logger +from vllm.distributed import get_pp_group +from vllm.sequence import (IntermediateTensors, SamplerOutput, + SequenceGroupMetadata, SequenceOutput, + CompletionSequenceGroupOutput, Logprob) +from vllm import _custom_ops as ops + +import torch + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionBackend + +logger = init_logger(__name__) + + +@dataclass +class ModelOutput: + """The output of a single model forward pass. + + The sampler_output_ready_event is set when the tensors in + sampler_output are ready (the model+sampler forward pass has + completed). We use the event to synchronize the GPU->CPU transfer, + which we want to only run when the data has been written to the + GPU tensors. Until the event is ready, the tensors in sampler_output + will have garbage data. + + There are two scenarios: + 1. The output tensors are ready and we can pythonize them immediately. + 2. The output tensors are not ready and we need to wait for the event to be ready. + """ + sampler_output: SamplerOutput + sampler_output_ready_event: torch.cuda.Event + sampled_token_ids: Optional[torch.Tensor] = None + pythonized: bool = False + + def pythonize( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output. Blocking.""" + if not self.pythonized: + self._pythonize_sampler_output_wait_on_event( + input_metadata, copy_stream, pinned_sampled_token_buffer) + self.pythonized = True + + def maybe_pythonize( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + """Pythonize the output if ready, else return None. Non-blocking.""" + if not self.pythonized: + self.pythonized = self._pythonize_sampler_output_if_event_ready( + input_metadata, copy_stream, pinned_sampled_token_buffer) + + def _pythonize_sampler_output_wait_on_event( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> None: + self.sampler_output_ready_event.synchronize() + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids) + + def _pythonize_sampler_output_if_event_ready( + self, + input_metadata: "MutableModelInputForGPUWithMultiStepMetadata", + copy_stream: torch.cuda.Stream, + pinned_sampled_token_buffer: torch.Tensor) -> bool: + if self.sampler_output_ready_event.query(): + with torch.cuda.stream(copy_stream): + _pythonize_sampler_output(input_metadata, self.sampler_output, + pinned_sampled_token_buffer, + self.sampled_token_ids) + return True + return False + + +@dataclass(frozen=False) +class MutableModelInputForGPUWithMultiStepMetadata(BroadcastableModelInput): + # actual frozen model input dataclass passed to _base_model_runner + frozen_model_input: Optional[ModelInputForGPUWithSamplingMetadata] = None + # list of model outputs for each step, may not be all pythonized + outputs: List[ModelOutput] = field(default_factory=list) + # used to pass sampled token ids from the last step to the current step for + # TP workers. Used to append to end of outputs and used by advance_step + last_sampled_token_ids: Optional[torch.Tensor] = None + current_step: int = 0 + is_multi_step: bool = True + is_last_step: bool = False + is_first_multi_step: bool = False + step_cuda_events: List[torch.cuda.Event] = field( + default_factory=lambda: [torch.cuda.Event(blocking=True)] * 2) + num_seqs: int = -1 + num_queries: int = -1 + + def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: + assert self.frozen_model_input is not None + tensor_dict = self.frozen_model_input.as_broadcastable_tensor_dict() + new_tensor_dict = { + 'last_sampled_token_ids': self.last_sampled_token_ids, + 'current_step': self.current_step, + 'is_multi_step': self.is_multi_step, + 'is_last_step': self.is_last_step, + 'is_first_multi_step': self.is_first_multi_step, + 'num_seqs': self.num_seqs, + 'num_queries': self.num_queries, + } + tensor_dict.update(new_tensor_dict) + return tensor_dict + + @classmethod + def from_broadcasted_tensor_dict( + cls, + tensor_dict: Dict[str, Any], + attn_backend: Optional["AttentionBackend"] = None, + ) -> "MutableModelInputForGPUWithMultiStepMetadata": + tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict) + if attn_backend is not None: + tensor_dict = _init_attn_metadata_from_tensor_dict( + attn_backend, tensor_dict) + tensor_dict = _init_frozen_model_input_from_tensor_dict( + ModelInputForGPUWithSamplingMetadata, tensor_dict) + + return cls(**tensor_dict) + + def record_step_event(self, current_stream: torch.cuda.Stream): + self.step_cuda_events[self.current_step % + 2] = torch.cuda.Event(blocking=True) + self.step_cuda_events[self.current_step % 2].record(current_stream) + + def wait_previous_step(self): + self.step_cuda_events[(self.current_step + 1) % 2].wait() + + def add_sampler_output(self, + sampler_output: SamplerOutput, + sampled_token_ids: Optional[torch.Tensor] = None): + self.outputs.append( + ModelOutput(sampler_output=sampler_output, + sampler_output_ready_event=None, + sampled_token_ids=sampled_token_ids, + pythonized=False)) + + +# MutableModelInputForGPUWithMultiStepMetadata is not subclass of +# ModelInputForGPU but it wraps the actual input dataclass and adds multi-step +# metadata +# mypy: disable-error-code=type-var +class MultiStepModelRunnerBase( + GPUModelRunnerBase[MutableModelInputForGPUWithMultiStepMetadata]): + # mypy: enable-error-code=type-var + + def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): + super().__init__(*args, **kwargs) + + # uses the base model runner to execute the model and wraps it with + # multi-step logic + self._base_model_runner: GPUModelRunnerBase = base_model_runner + + self.is_multi_step = self.scheduler_config.is_multi_step + # used to copy tensors from GPU to CPU asynchronously + self._copy_stream = torch.cuda.Stream() + self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + + def load_model(self) -> None: + return self._base_model_runner.load_model() + + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + return self._base_model_runner.save_sharded_state( + path, pattern, max_size) + + def save_tensorized_model(self, + tensorizer_config: TensorizerConfig) -> None: + return self._base_model_runner.save_tensorized_model(tensorizer_config) + + def profile_run(self) -> None: + return self._base_model_runner.profile_run() + + def remove_all_loras(self): + return self._base_model_runner.remove_all_loras() + + def capture_model(self, kv_caches: List[List]) -> None: + return self._base_model_runner.capture_model(kv_caches) + + @property + def vocab_size(self) -> int: + return self._base_model_runner.vocab_size + + +class MultiStepModelRunner(MultiStepModelRunnerBase): + + def make_model_input_from_broadcasted_tensor_dict( + self, tensor_dict: Dict[str, Any] + ) -> MutableModelInputForGPUWithMultiStepMetadata: + model_input = MutableModelInputForGPUWithMultiStepMetadata.from_broadcasted_tensor_dict( + tensor_dict, + attn_backend=self.attn_backend, + ) + return model_input + + def prepare_model_input( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + virtual_engine: int = 0, + finished_requests_ids: Optional[List[str]] = None + ) -> MutableModelInputForGPUWithMultiStepMetadata: + frozen_model_input = self._base_model_runner.prepare_model_input( + seq_group_metadata_list, virtual_engine, finished_requests_ids) + + model_input = MutableModelInputForGPUWithMultiStepMetadata( + frozen_model_input=frozen_model_input, + num_seqs=len(frozen_model_input.seq_lens), + num_queries=len(frozen_model_input.query_lens), + ) + return model_input + + @torch.inference_mode() + def execute_model( + self, + model_input: MutableModelInputForGPUWithMultiStepMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + """ + Execute the model for a single step and update multi-step + metadata + """ + assert num_steps == 1, "MultiStepModelRunner only supports num_steps=1" + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + + # path for warm up runs + if not model_input.is_multi_step: + return self._base_model_runner.execute_model( + frozen_model_input, kv_caches, intermediate_tensors, num_steps) + + # make sure we skip the sampler on the lask rank and only pythonize + # if CPU is ahead. + if self.is_driver_worker and get_pp_group().is_last_rank: + + if self.pinned_sampled_token_ids is None: + self.pinned_sampled_token_ids = torch.zeros( + (self.scheduler_config.max_num_seqs, 1), + dtype=torch.long, + device="cpu", + pin_memory=True) + + self._base_model_runner.model.sampler.include_gpu_probs_tensor = True + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.skip_sampler_cpu_output = True + # TODO(will) Will need to benchmark and look at torch profiler for + # the exact location we should do this. If the CPU is very ahead, it + # does not matter if we call this before executable or after, as the + # CPU will block anyways. + for model_output in model_input.outputs: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + # some pre-execute model logic for multi-step: + # - if it's the first step, we need to reset the sampling tensors + # - if it's not the first step, we need to advance the step using the + # appended sampler output from last iteration + # - also maybe pythonize if CPU is ahead of GPU + + # explicitly block on the previous step's forward to make sure we + # don't clobber any GPU tensors still in use + current_stream = torch.cuda.current_stream() + if model_input.is_first_multi_step: + # TODO(will) Need to double check that this is not possible due to + # changing batch sizes, will remove afterwards and potentially leave + # comment for future optimization + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.reuse_sampling_tensors = False + else: + # This is not needed for flashattn backend, but for other attn + # backends such as flashinfer that performs we may need to + # synchronize any CPU operations that might clobber enqueued + # forwards. (prevents CPU from running too far ahead if needed) + model_input.wait_previous_step() + model_input = self._advance_step( + model_input, model_input.outputs[-1].sampler_output) + # TODO(will) Need to double check that this is not possible due to + # changing batch sizes, will remove afterwards and potentially leave + # comment for future optimization + if frozen_model_input.sampling_metadata: + frozen_model_input.sampling_metadata.reuse_sampling_tensors = False + + # Execute the model + output = self._base_model_runner.execute_model(frozen_model_input, + kv_caches, + intermediate_tensors, + num_steps=1) + + # record the event for the current step so that the next step can sync + model_input.record_step_event(current_stream) + + if get_pp_group().is_last_rank and self.is_driver_worker: + assert len( + output + ) == 1, "MultiStepModelRunner requires single-step base_models" + + # event for the pythonization so that we only pythonize if the + # tensors are ready. May be able to be combined with the step event + output_ready_event = torch.cuda.Event() + output_ready_event.record(current_stream) + if self.parallel_config.pipeline_parallel_size > 1: + output[0].sampled_token_ids_numpy = output[ + 0].sampled_token_ids.numpy(force=True) + model_input.outputs.append( + ModelOutput(output[0], output_ready_event, + output[0].sampled_token_ids, False)) + # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids = None + output[0].sampled_token_probs = None + output[0].logprobs = None + + model_input.current_step += 1 + + if not get_pp_group().is_last_rank: + # Should be IntermediateTensors + assert isinstance(output, IntermediateTensors) + return output + if not self.is_driver_worker: + return [] + + # Pythonize the output and block if needed since it is the last step + if model_input.is_last_step: + outputs = [] + for output in model_input.outputs: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + return outputs + + # should be [SamplerOutput] + return output + + def _update_flash_attn_metadata(self, attn_metadata, num_seqs, + num_queries): + assert isinstance(attn_metadata, FlashAttentionMetadata) + + if num_seqs != num_queries: + assert num_seqs > num_queries + assert attn_metadata.use_cuda_graph + + assert attn_metadata.num_prefills == 0 + assert attn_metadata.num_prefill_tokens == 0 + assert attn_metadata.num_decode_tokens == num_seqs + assert attn_metadata.slot_mapping.shape == (num_seqs, ) + + assert len(attn_metadata.seq_lens) == num_seqs + assert attn_metadata.seq_lens_tensor.shape == (num_seqs, ) + assert attn_metadata.max_query_len == 1 + assert attn_metadata.max_prefill_seq_len == 0 + assert attn_metadata.max_decode_seq_len == max(attn_metadata.seq_lens) + + assert attn_metadata.query_start_loc.shape == (num_queries + 1, ) + assert attn_metadata.seq_start_loc.shape == (num_seqs + 1, ) + + assert attn_metadata.context_lens_tensor.shape == (num_queries, ) + + assert attn_metadata.block_tables.shape[0] == num_seqs + + # Update query lengths. Note that we update only queries and not seqs, + # since tensors may be padded due to captured cuda graph batch size + for i in range(num_queries): + attn_metadata.seq_lens[i] += 1 + attn_metadata.max_decode_seq_len = max(attn_metadata.seq_lens) + + def _update_sampling_metadata(self, sampling_metadata, num_seqs, + num_queries): + + assert sampling_metadata.num_prompts == 0 + assert len(sampling_metadata.seq_groups) == num_queries + assert sampling_metadata.selected_token_indices.shape == ( + num_queries, ) + # assert sampling_metadata.categorized_sample_indices == TODO: Add if needed # noqa: E501 + + # Verify that all sequences are decodes + for i in range(num_queries): + seq_group = sampling_metadata.seq_groups[i] + + assert seq_group.is_prompt is False # No prompt + assert seq_group.prompt_logprob_indices == [] # No prompt + assert seq_group.sample_indices == [i] # Simple + assert seq_group.seq_len is None # Decode + assert seq_group.query_len is None # Decode + + def _advance_step( + self, model_input: MutableModelInputForGPUWithMultiStepMetadata, + out: SamplerOutput + ) -> MutableModelInputForGPUWithMultiStepMetadata: + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None + assert frozen_model_input.attn_metadata is not None + + num_seqs = model_input.num_seqs + num_queries = model_input.num_queries + assert num_seqs > 0 + assert num_queries > 0 + assert num_seqs >= num_queries + + attn_metadata = frozen_model_input.attn_metadata + assert isinstance(attn_metadata, FlashAttentionMetadata) + self._update_flash_attn_metadata(attn_metadata, num_seqs, num_queries) + + # Update GPU tensors + ops.advance_step( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=self.block_size, + input_tokens=frozen_model_input.input_tokens, + sampled_token_ids=model_input.outputs[-1].sampled_token_ids, + input_positions=frozen_model_input.input_positions, + seq_lens=attn_metadata.seq_lens_tensor, + slot_mapping=attn_metadata.slot_mapping, + block_tables=attn_metadata.block_tables) + + if frozen_model_input.seq_lens is not None: + for i in range(num_queries): + frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] + + return model_input + + +def _pythonize_sampler_output( + model_input: MutableModelInputForGPUWithMultiStepMetadata, + output: SamplerOutput, pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: Optional[torch.Tensor]) -> SamplerOutput: + """ This function is only called when the output tensors are ready. + See ModelOutput + """ + + assert sampled_token_ids is not None + assert model_input.frozen_model_input is not None + + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input.sampling_metadata is not None + # samples generation should have been skipped + assert not output.outputs + + pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries] + + # CPU GPU sync + pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False) + + # this will not block as the tensors are already on CPU + samples_list = pinned_buffer.tolist() + + sampling_metadata = frozen_model_input.sampling_metadata + + for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, + samples_list): + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + seq_outputs: List[SequenceOutput] = [] + assert len(seq_group.sampling_params.logits_processors) == 0, ( + "Logits Processors are not supported in multi-step decoding") + for parent_id, next_token_id in zip(parent_ids, next_token_ids): + # TODO(will): support logprobs + # Hard coded logprob + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + {next_token_id: Logprob(logprob=42)})) + output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py new file mode 100644 index 0000000000000..0b1342347bc5a --- /dev/null +++ b/vllm/worker/multi_step_worker.py @@ -0,0 +1,154 @@ +from vllm.worker.worker import Worker +from dataclasses import dataclass +from vllm.worker.worker import WorkerInput +from vllm.worker.model_runner_base import BroadcastableModelInput +from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.distributed import broadcast_tensor_dict, get_pp_group +from typing import Tuple, Optional, List +from dataclasses import field + +from vllm.worker.multi_step_model_runner import ( + MutableModelInputForGPUWithMultiStepMetadata) + + +@dataclass +class MultiStepState: + worker_input: WorkerInput + model_input: MutableModelInputForGPUWithMultiStepMetadata + + +class MultiStepWorker(Worker): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + pipeline_parallel_size = self.parallel_config.pipeline_parallel_size + self.multi_step_states: List[ + Optional[MultiStepState]] = [None] * pipeline_parallel_size + self.temp_output = None + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + virtual_engine = execute_model_req.virtual_engine + is_first_multi_step = execute_model_req.is_first_multi_step + if is_first_multi_step: + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: MutableModelInputForGPUWithMultiStepMetadata = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + else: + multi_step_state = self.multi_step_states[virtual_engine] + worker_input = multi_step_state.worker_input + model_input = multi_step_state.model_input + + model_input.is_first_multi_step = is_first_multi_step + model_input.is_last_step = execute_model_req.is_last_step + + # we broadcast the last sampled token ids to all TP workers so they can + # update their model input metadata inplace. + if not is_first_multi_step: + if get_pp_group().is_last_rank: + assert model_input.outputs[ + -1].sampler_output.sampled_token_ids is None + assert model_input.outputs[-1].sampled_token_ids is not None + model_input.last_sampled_token_ids = model_input.outputs[ + -1].sampled_token_ids + # free sampled token ids from the previous step if it has been + # pythonized. Cannot free the last sampled token ids because + # we need it for GPU advance_step. + for output in model_input.outputs[:-1]: + if output.pythonized: + output.sampled_token_ids = None + else: + # otherwise we need to get the cached sampled token ids from the + # execute_model_req + assert execute_model_req.last_sampled_token_ids is not None + model_input.last_sampled_token_ids = execute_model_req.last_sampled_token_ids.cuda( + ) + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + # free sampled token ids from the previous step. + # TODO(will) we could reuse the sampled token ids tensor from + # the previous step instead. + for output in model_input.outputs[:-1]: + output.sampled_token_ids = None + assert model_input.outputs[-1].sampled_token_ids is not None + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + return model_input, worker_input + + def prepare_input( + self, + execute_model_req: Optional[ExecuteModelRequest] = None, + ) -> Optional[Tuple[MutableModelInputForGPUWithMultiStepMetadata, + WorkerInput]]: + """ + Depending on the current state of the request and multi step worker, + this method may skip the normal _prepare_model_input and + _prepare_worker_input methods and instead used cached values. + """ + if self.is_driver_worker: + if execute_model_req is None: + if self.do_metadata_broadcast: + # This signals that there's no more requests to process for + # now. All workers are running infinite loop with + # broadcast_tensor_dict, and it stops the loop when the + # driver broadcasts an empty input. Send an empty input to + # notify all other workers to stop their execution loop. + broadcast_tensor_dict({}, src=0) + return None + + virtual_engine = execute_model_req.virtual_engine + model_input, worker_input = self._get_driver_input_and_broadcast( + execute_model_req) + assert isinstance(model_input, + MutableModelInputForGPUWithMultiStepMetadata) + if execute_model_req.is_first_multi_step: + # cache the worker input and model input for the next steps + self.multi_step_states[virtual_engine] = MultiStepState( + worker_input=worker_input, model_input=model_input) + # if TP workers + else: + broadcast_data = self._get_worker_input_from_broadcast() + # if the driver has sent an empty input, we should stop the worker + # loop + if broadcast_data is None: + return None + model_input, worker_input = broadcast_data + assert isinstance(model_input, + MutableModelInputForGPUWithMultiStepMetadata) + virtual_engine = worker_input.virtual_engine + if model_input.is_first_multi_step: + pass + # cache the worker input and model input for the next steps + # TODO(will) see below + else: + # TODO(will) possible to also use the cached worker input and + # model input this can be done if we want to optimize the + # broadcast to only send the last sampled token ids for + # non-first multi steps + + assert isinstance( + model_input, MutableModelInputForGPUWithMultiStepMetadata) + # we need to update the last sampled token ids in the model input + # for the workers so that they can run inplace advance_step + model_input.add_sampler_output( + SamplerOutput(outputs=[], sampled_token_ids=None), + model_input.last_sampled_token_ids) + + assert model_input is not None + assert worker_input is not None + return model_input, worker_input diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 90b844bf42139..d419b58bbb2a7 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -23,6 +23,7 @@ from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner +from vllm.worker.multi_step_model_runner import MultiStepModelRunner from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput @@ -107,6 +108,26 @@ def __init__( observability_config=observability_config, **speculative_args, ) + + # for multi-step model, wrap the model runner with MultiStepModelRunner + if self.scheduler_config.is_multi_step: + base_model_runner = self.model_runner + self.model_runner = MultiStepModelRunner( + base_model_runner, + model_config, + parallel_config, + scheduler_config, + device_config, + cache_config, + load_config=load_config, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=is_driver_worker, + prompt_adapter_config=prompt_adapter_config, + multimodal_config=multimodal_config, + **speculative_args, + ) + # Uninitialized cache engine. Will be initialized by # initialize_cache. self.cache_engine: List[CacheEngine] @@ -264,6 +285,7 @@ def kv_cache(self) -> Optional[List[List[torch.Tensor]]]: def prepare_worker_input( self, execute_model_req: ExecuteModelRequest) -> WorkerInput: virtual_engine = execute_model_req.virtual_engine + num_steps = execute_model_req.num_steps num_seq_groups = len(execute_model_req.seq_group_metadata_list) # `blocks_to_swap_in` and `blocks_to_swap_out` are cpu tensors. # they contain parameters to launch cudamemcpyasync. @@ -286,6 +308,7 @@ def prepare_worker_input( blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, virtual_engine=virtual_engine, + num_steps=num_steps, ) @torch.inference_mode() diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 20db3dad1caab..6cfea94e56ab4 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -16,7 +16,9 @@ SamplerOutput) from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) -from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase +from vllm.worker.model_runner_base import (ModelRunnerBase, + ModelRunnerInputBase, + BroadcastableModelInput) logger = init_logger(__name__) @@ -129,6 +131,7 @@ class WorkerInput: blocks_to_swap_out: Optional[torch.Tensor] = None blocks_to_copy: Optional[torch.Tensor] = None virtual_engine: int = 0 + num_steps: int = 1 @classmethod def from_broadcasted_tensor_dict( @@ -145,6 +148,7 @@ def from_broadcasted_tensor_dict( blocks_to_swap_out=tensor_dict.pop("blocks_to_swap_out"), blocks_to_copy=tensor_dict.pop("blocks_to_copy"), virtual_engine=tensor_dict["virtual_engine"], + num_steps=tensor_dict.pop("num_steps"), ) def as_broadcastable_tensor_dict( @@ -158,6 +162,7 @@ def as_broadcastable_tensor_dict( "blocks_to_swap_out": self.blocks_to_swap_out, "blocks_to_copy": self.blocks_to_copy, "virtual_engine": self.virtual_engine, + "num_steps": self.num_steps, } return tensor_dict @@ -216,51 +221,77 @@ def execute_worker(self, worker_input: WorkerInput) -> None: """ raise NotImplementedError - def execute_model( + def _get_worker_input_from_broadcast( + self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: + """ + Get the worker input from the broadcasted tensor dict. + """ + assert self.do_metadata_broadcast + assert not self.is_driver_worker + broadcast_data = broadcast_tensor_dict(src=0) + if not broadcast_data: + return None + + worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data) + model_input = ( + self.model_runner.make_model_input_from_broadcasted_tensor_dict( + broadcast_data)) + + return model_input, worker_input + + def _get_driver_input_and_broadcast( + self, execute_model_req: ExecuteModelRequest + ) -> Tuple[BroadcastableModelInput, WorkerInput]: + """ + Get the driver input and broadcast it to other workers. + """ + assert self.is_driver_worker + worker_input: WorkerInput = self.prepare_worker_input( + execute_model_req=execute_model_req) + model_input: ModelRunnerInputBase = ( + self.model_runner.prepare_model_input( + execute_model_req.seq_group_metadata_list, + execute_model_req.virtual_engine, + execute_model_req.finished_requests_ids)) + + if self.do_metadata_broadcast: + broadcast_data = worker_input.as_broadcastable_tensor_dict() + broadcast_data.update(model_input.as_broadcastable_tensor_dict()) + broadcast_tensor_dict(broadcast_data, src=0) + + return model_input, worker_input + + def prepare_input( self, execute_model_req: Optional[ExecuteModelRequest] = None - ) -> Optional[List[SamplerOutput]]: - """Executes at least one model step on the given sequences, unless no - sequences are provided.""" - start_time = time.perf_counter() + ) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]: + """ + Prepare the inputs to ModelRunner and Worker. + """ if self.is_driver_worker: if execute_model_req is None: if self.do_metadata_broadcast: - # This signals that there's no more requests to process for - # now. All workers are running infinite loop with - # broadcast_tensor_dict, and it stops the loop when the - # driver broadcasts an empty input. Send an empty input to - # notify all other workers to stop their execution loop. broadcast_tensor_dict({}, src=0) return None - worker_input: WorkerInput = self.prepare_worker_input( - execute_model_req=execute_model_req) - model_input: ModelRunnerInputBase = ( - self.model_runner.prepare_model_input( - execute_model_req.seq_group_metadata_list, - execute_model_req.virtual_engine, - execute_model_req.finished_requests_ids)) - num_steps = execute_model_req.num_steps - - if self.do_metadata_broadcast: - broadcast_data = worker_input.as_broadcastable_tensor_dict() - broadcast_data.update( - model_input.as_broadcastable_tensor_dict()) - broadcast_data["num_steps"] = num_steps - broadcast_tensor_dict(broadcast_data, src=0) + return self._get_driver_input_and_broadcast(execute_model_req) else: - assert self.do_metadata_broadcast - broadcast_data = broadcast_tensor_dict(src=0) - if not broadcast_data: - return None + return self._get_worker_input_from_broadcast() + + def execute_model( + self, + execute_model_req: Optional[ExecuteModelRequest] = None + ) -> Optional[List[SamplerOutput]]: + """Executes at least one model step on the given sequences, unless no + sequences are provided.""" + start_time = time.perf_counter() + + inputs = self.prepare_input(execute_model_req) + if inputs is None: + return None - num_steps = broadcast_data.pop("num_steps") - worker_input = WorkerInput.from_broadcasted_tensor_dict( - broadcast_data) - model_input = ( - self.model_runner. - make_model_input_from_broadcasted_tensor_dict(broadcast_data)) + model_input, worker_input = inputs + num_steps = worker_input.num_steps self.execute_worker(worker_input)