From 3d1b227482b361d862bb6221cf87c653babd55eb Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Tue, 12 Nov 2024 20:53:13 -0800 Subject: [PATCH] [V1] Support VLMs with fine-grained scheduling (#9871) Signed-off-by: Woosuk Kwon Co-authored-by: Roger Wang --- vllm/model_executor/models/gpt2.py | 11 +- vllm/model_executor/models/llama.py | 7 +- vllm/model_executor/models/llava.py | 46 +++--- vllm/model_executor/models/opt.py | 7 +- vllm/model_executor/models/phi3v.py | 63 +++++--- vllm/model_executor/models/qwen2.py | 7 +- vllm/v1/core/encoder_cache_manager.py | 48 ++++++ vllm/v1/core/scheduler.py | 205 +++++++++++++++++++++++--- vllm/v1/engine/core.py | 10 ++ vllm/v1/engine/mm_input_mapper.py | 39 +++++ vllm/v1/request.py | 41 +++++- vllm/v1/worker/gpu_model_runner.py | 154 ++++++++++++++++--- 12 files changed, 542 insertions(+), 96 deletions(-) create mode 100644 vllm/v1/core/encoder_cache_manager.py create mode 100644 vllm/v1/engine/mm_input_mapper.py diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fcff7ec2e01eb..adf2a7a51f737 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -216,9 +216,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor], ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - inputs_embeds = self.wte(input_ids) + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -263,6 +265,9 @@ def __init__( self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.transformer.wte(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -270,9 +275,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2472128976d88..8aed0fead18f9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -538,6 +538,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): normalize=False, softmax=False) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -545,9 +548,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return model_output def compute_logits( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index ca963fa1c52ea..af712bf8f9506 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -17,6 +17,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import NestedTensors from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of @@ -448,6 +449,25 @@ def _process_image_input(self, image_features = self._process_image_pixels(image_input) return self.multi_modal_projector(image_features) + def process_mm_inputs(self, **kwargs): + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + vision_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if vision_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + return inputs_embeds + def forward( self, input_ids: torch.Tensor, @@ -455,6 +475,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object, ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LLaVA-1.5. @@ -494,24 +515,13 @@ def forward( """ if intermediate_tensors is not None: inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - else: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None + elif inputs_embeds is None: + vision_embeddings = self.process_mm_inputs(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 58b6107eba347..997fe642439e6 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -360,6 +360,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -367,9 +370,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4b5dc944bce4b..de03d28638cda 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -39,6 +39,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import NestedTensors, PlaceholderRange from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.utils import is_list_of @@ -500,15 +501,20 @@ def input_processor_for_phi3v(ctx: InputContext, # TODO: Move this to utils or integrate with clip. new_token_ids: List[int] = [] + placeholder_ranges: List[PlaceholderRange] = [] placeholder_idx = 0 while merged_token_ids: token_id = merged_token_ids.pop(0) if token_id == _IMAGE_TOKEN_ID: - new_token_ids.extend( - repeat_and_pad_token( - _IMAGE_TOKEN_ID, - repeat_count=image_feature_size[placeholder_idx], - )) + replacement_ids = repeat_and_pad_token( + _IMAGE_TOKEN_ID, + repeat_count=image_feature_size[placeholder_idx], + ) + placeholder_ranges.append({ + "offset": len(new_token_ids), + "length": len(replacement_ids) + }) + new_token_ids.extend(replacement_ids) placeholder_idx += 1 else: new_token_ids.append(token_id) @@ -516,7 +522,8 @@ def input_processor_for_phi3v(ctx: InputContext, # NOTE: Create a defensive copy of the original inputs return token_inputs(prompt_token_ids=new_token_ids, prompt=new_prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + multi_modal_placeholders={"image": placeholder_ranges}) @MULTIMODAL_REGISTRY.register_image_input_mapper() @@ -669,32 +676,42 @@ def _process_image_input( return image_embeds + def process_mm_inputs(self, **kwargs): + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + vision_embeddings = self._process_image_input(image_input) + return vision_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + vision_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + if vision_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) + return inputs_embeds + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, **kwargs: object): if intermediate_tensors is not None: inputs_embeds = None - else: - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) - else: - inputs_embeds = self.language_model.model.embed_tokens( - input_ids) - - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent - # for `torch.compile` integration - input_ids = None + elif inputs_embeds is None: + vision_embeddings = self.process_mm_inputs(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + vision_embeddings) + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 2195ce49aa9a7..b623c576bb673 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -441,6 +441,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.model.get_input_embeddings(input_ids) + def forward( self, input_ids: torch.Tensor, @@ -448,9 +451,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + inputs_embeds) return hidden_states def compute_logits( diff --git a/vllm/v1/core/encoder_cache_manager.py b/vllm/v1/core/encoder_cache_manager.py new file mode 100644 index 0000000000000..845bd5ea05e3c --- /dev/null +++ b/vllm/v1/core/encoder_cache_manager.py @@ -0,0 +1,48 @@ +from typing import Dict, List, Set, Tuple + +from vllm.v1.request import Request + + +class EncoderCacheManager: + + def __init__(self, cache_size: int): + self.cache_size = cache_size + self.num_free_slots = cache_size + # req_id -> cached input ids + self.cached: Dict[str, Set[int]] = {} + # List of [req_id, input_id] + self.freed: List[Tuple[str, int]] = [] + + def has_cache(self, request: Request, input_id: int) -> bool: + req_id = request.request_id + return req_id in self.cached and input_id in self.cached[req_id] + + def can_allocate(self, request: Request, input_id: int) -> bool: + num_tokens = request.get_num_encoder_tokens(input_id) + return num_tokens <= self.num_free_slots + + def allocate(self, request: Request, input_id: int) -> None: + req_id = request.request_id + if req_id not in self.cached: + self.cached[req_id] = set() + self.cached[req_id].add(input_id) + self.num_free_slots -= request.get_num_encoder_tokens(input_id) + + def get_cached_input_ids(self, request: Request) -> Set[int]: + return self.cached.get(request.request_id, set()) + + def free(self, request: Request, input_id: int) -> None: + req_id = request.request_id + if req_id not in self.cached: + return + + self.cached[req_id].discard(input_id) + if len(self.cached[req_id]) == 0: + del self.cached[req_id] + self.num_free_slots += request.get_num_encoder_tokens(input_id) + self.freed.append((req_id, input_id)) + + def get_freed_ids(self) -> List[Tuple[str, int]]: + freed = self.freed + self.freed = [] + return freed diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index ee860e792281d..ba50a9786d805 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -1,16 +1,21 @@ from collections import deque from dataclasses import dataclass -from typing import Deque, Dict, Iterable, List, Optional, Set, Union +from typing import (TYPE_CHECKING, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.logger import init_logger -from vllm.multimodal import MultiModalDataDict from vllm.sampling_params import SamplingParams +from vllm.v1.core.encoder_cache_manager import EncoderCacheManager from vllm.v1.core.kv_cache_manager import KVCacheManager from vllm.v1.engine import EngineCoreOutput from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.request import Request, RequestStatus +if TYPE_CHECKING: + from vllm.multimodal import MultiModalKwargs + from vllm.multimodal.base import PlaceholderRange + logger = init_logger(__name__) @@ -61,12 +66,20 @@ def __init__( # Request id -> RunningRequestData self.running_reqs_data: Dict[str, RunningRequestData] = {} - def schedule(self) -> "SchedulerOutput": - scheduled_new_reqs: List[Request] = [] - scheduled_resumed_reqs: List[Request] = [] - scheduled_running_reqs: List[Request] = [] - preempted_reqs: List[Request] = [] + # Encoder-related. + # NOTE(woosuk): Here, "encoder" includes the vision encoder (and + # projector if needed). Currently, we assume that the encoder also + # has the Transformer architecture (e.g., ViT). + # FIXME(woosuk): Below are placeholder values. We need to calculate the + # actual values from the configurations. + self.max_num_encoder_input_tokens = 2048 + # NOTE(woosuk): For the models without encoder (e.g., text-only models), + # the encoder cache will not be initialized and used, regardless of + # the cache size. This is because the memory space for the encoder cache + # is preallocated in the profiling run. + self.encoder_cache_manager = EncoderCacheManager(cache_size=2048) + def schedule(self) -> "SchedulerOutput": # NOTE(woosuk) on the scheduling algorithm: # There's no "decoding phase" nor "prefill phase" in the scheduler. # Each request just has the num_computed_tokens and num_tokens, @@ -74,23 +87,45 @@ def schedule(self) -> "SchedulerOutput": # At each step, the scheduler tries to assign tokens to the requests # so that each request's num_computed_tokens can catch up its # num_tokens. This is general enough to cover chunked prefills, - # prefix caching, and the "jump forward" optimization in the future. + # prefix caching, and the "jump decoding" optimization in the future. + + scheduled_new_reqs: List[Request] = [] + scheduled_resumed_reqs: List[Request] = [] + scheduled_running_reqs: List[Request] = [] + preempted_reqs: List[Request] = [] req_to_new_block_ids: Dict[str, List[int]] = {} num_scheduled_tokens: Dict[str, int] = {} token_budget = self.max_num_scheduled_tokens + # Encoder-related. + scheduled_encoder_inputs: Dict[str, List[int]] = {} + encoder_budget = self.max_num_encoder_input_tokens # First, schedule the RUNNING requests. + # NOTE(woosuk): At most 1 request in the RUNNING queue is allowed to be + # in the "partial" state, where the request has some tokens computed + # but not all. The constraint is due to the persistent batch in the + # V1 model runner. + # TODO(woosuk): Remove this constraint after refactoring model runner. + has_partial_request = False req_index = 0 while req_index < len(self.running): - if token_budget == 0: - break - + # Only the last request in the RUNNING queue can be "partial". + assert not has_partial_request + assert token_budget > 0 request = self.running[req_index] num_new_tokens = request.num_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 + # Schedule encoder inputs. + encoder_inputs_to_schedule, num_new_tokens, new_encoder_budget = ( + self._try_schedule_encoder_inputs(request, + request.num_computed_tokens, + num_new_tokens, + encoder_budget)) + assert num_new_tokens > 0 + while True: new_blocks = self.kv_cache_manager.append_slots( request, num_new_tokens) @@ -106,22 +141,40 @@ def schedule(self) -> "SchedulerOutput": preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. + can_schedule = False break else: # The request can be scheduled. - scheduled_running_reqs.append(request) - - req_to_new_block_ids[request.request_id] = [ - b.block_id for b in new_blocks - ] - num_scheduled_tokens[request.request_id] = num_new_tokens - token_budget -= num_new_tokens - req_index += 1 + can_schedule = True break + if not can_schedule: + break + + # Schedule the request. + scheduled_running_reqs.append(request) + req_to_new_block_ids[request.request_id] = [ + b.block_id for b in new_blocks + ] + num_scheduled_tokens[request.request_id] = num_new_tokens + token_budget -= num_new_tokens + req_index += 1 + has_partial_request = (request.num_computed_tokens + num_new_tokens + < request.num_tokens) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: + if has_partial_request: + break if len(self.running) == self.max_num_running_reqs: break if token_budget == 0: @@ -149,12 +202,21 @@ def schedule(self) -> "SchedulerOutput": computed_blocks.pop() num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 + + # Schedule encoder inputs. + (encoder_inputs_to_schedule, num_new_tokens, + new_encoder_budget) = self._try_schedule_encoder_inputs( + request, num_computed_tokens, num_new_tokens, + encoder_budget) + if num_new_tokens == 0: + # The request cannot be scheduled. + break + new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, computed_blocks) if new_blocks is None: # The request cannot be scheduled. break - request.num_computed_tokens = num_computed_tokens self.waiting.popleft() self.running.append(request) @@ -172,6 +234,18 @@ def schedule(self) -> "SchedulerOutput": num_scheduled_tokens[request.request_id] = num_new_tokens token_budget -= num_new_tokens request.status = RequestStatus.RUNNING + request.num_computed_tokens = num_computed_tokens + has_partial_request = (num_computed_tokens + num_new_tokens < + request.num_tokens) + + # Encoder-related. + if encoder_inputs_to_schedule: + scheduled_encoder_inputs[request.request_id] = ( + encoder_inputs_to_schedule) + # Allocate the encoder cache. + for i in encoder_inputs_to_schedule: + self.encoder_cache_manager.allocate(request, i) + encoder_budget = new_encoder_budget # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) @@ -205,12 +279,14 @@ def schedule(self) -> "SchedulerOutput": scheduled_running_reqs=running_reqs_data, num_scheduled_tokens=num_scheduled_tokens, total_num_scheduled_tokens=total_num_scheduled_tokens, + scheduled_encoder_inputs=scheduled_encoder_inputs, preempted_req_ids=preempted_req_ids, # finished_req_ids is an existing state in the scheduler, # instead of being newly scheduled in this step. # It contains the request IDs that are finished in between # the previous and the current steps. finished_req_ids=self.finished_req_ids, + free_encoder_input_ids=self.encoder_cache_manager.get_freed_ids(), ) self.finished_req_ids = set() @@ -234,6 +310,72 @@ def _make_running_request_data( self.running_reqs_data[request.request_id] = req_data return req_data + def _try_schedule_encoder_inputs( + self, + request: Request, + num_computed_tokens: int, + num_new_tokens: int, + encoder_budget: int, + ) -> Tuple[List[int], int, int]: + """ + Determine which encoder inputs need to be scheduled in the current step, + and update `num_new_tokens` and encoder token budget accordingly. + + An encoder input will be scheduled if: + - Its output tokens overlap with the range of tokens being computed + in this step, i.e., + [num_computed_tokens, num_computed_tokens + num_new_tokens). + - It is not already computed and stored in the encoder cache. + - There is sufficient encoder token budget to process it. + - The encoder cache has space to store it. + + If an encoder input cannot be scheduled due to cache or budget + limitations, the method adjusts `num_new_tokens` to schedule only the + decoder tokens up to just before the unschedulable encoder input. + """ + if not request.has_encoder_inputs(): + return [], num_new_tokens, encoder_budget + + encoder_inputs_to_schedule: List[int] = [] + mm_positions = request.mm_positions + assert mm_positions is not None + assert len(mm_positions) > 0 + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, num_computed_tokens + num_new_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_new_tokens: + # The encoder input is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder input is already computed and stored + # in the decoder's KV cache. + continue + + if self.encoder_cache_manager.has_cache(request, i): + # The encoder input is already computed and cached. + continue + if not self.encoder_cache_manager.can_allocate(request, i): + # The encoder cache is full. We can only schedule the decoder + # tokens just before the encoder input. + num_new_tokens = start_pos - num_computed_tokens + break + if num_encoder_tokens > encoder_budget: + # The encoder budget is exhausted. We can only schedule the + # decoder tokens up until the encoder input. + # NOTE(woosuk): We assume that the encoder tokens should be + # processed altogether, as the encoder usually uses + # bidirectional attention. + num_new_tokens = start_pos - num_computed_tokens + break + + encoder_budget -= num_encoder_tokens + encoder_inputs_to_schedule.append(i) + return encoder_inputs_to_schedule, num_new_tokens, encoder_budget + def update_from_output( self, scheduler_output: "SchedulerOutput", @@ -251,6 +393,17 @@ def update_from_output( # the request generates output tokens. Otherwise, we ignore the # sampler output for the request. assert request.num_computed_tokens <= request.num_tokens + + cached_encoder_input_ids = ( + self.encoder_cache_manager.get_cached_input_ids(request)) + for input_id in list(cached_encoder_input_ids): + start_pos = request.mm_positions[input_id]["offset"] + num_tokens = request.mm_positions[input_id]["length"] + if start_pos + num_tokens <= request.num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + self.encoder_cache_manager.free(request, input_id) + if request.num_computed_tokens == request.num_tokens: req_index = model_runner_output.req_id_to_index[req_id] # NOTE(woosuk): Currently, we assume that each request @@ -355,7 +508,8 @@ class NewRequestData: req_id: str prompt_token_ids: List[int] prompt: Optional[str] - multi_modal_data: Optional[MultiModalDataDict] + mm_inputs: List["MultiModalKwargs"] + mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int @@ -369,9 +523,10 @@ def from_request( ) -> "NewRequestData": return cls( req_id=request.request_id, - prompt_token_ids=request.inputs["prompt_token_ids"], - prompt=request.inputs.get("prompt"), - multi_modal_data=request.inputs.get("multi_modal_data"), + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + mm_inputs=request.mm_inputs, + mm_positions=request.mm_positions, sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, @@ -429,6 +584,8 @@ class SchedulerOutput: num_scheduled_tokens: Dict[str, int] total_num_scheduled_tokens: int + scheduled_encoder_inputs: Dict[str, List[int]] preempted_req_ids: Set[str] finished_req_ids: Set[str] + free_encoder_input_ids: List[Tuple[str, int]] diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 428483bdb29cb..35ed131d50de9 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -17,6 +17,7 @@ from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, EngineCoreRequest, EngineCoreRequestType) +from vllm.v1.engine.mm_input_mapper import MMInputMapper from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.request import Request, RequestStatus from vllm.v1.serial_utils import PickleEncoder @@ -65,6 +66,9 @@ def __init__( vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks + # Set up multimodal input mapper (e.g., convert PIL images to tensors). + self.mm_input_mapper = MMInputMapper(vllm_config.model_config) + # Setup scheduler. self.scheduler = Scheduler(vllm_config.scheduler_config, vllm_config.cache_config, @@ -93,6 +97,12 @@ def add_request(self, request: EngineCoreRequest): """Add request to the scheduler.""" req = Request.from_engine_core_request(request) + # FIXME(woosuk): The input mapping (e.g., PIL images to tensors) may + # take 10-50 ms, which can cause a spike in the latency. We should + # consider moving this to a separate thread. + if req.mm_data: + req.mm_inputs = self.mm_input_mapper.process_inputs( + req.mm_data, req.mm_processor_kwargs) self.scheduler.add_request(req) def abort_requests(self, request_ids: List[str]): diff --git a/vllm/v1/engine/mm_input_mapper.py b/vllm/v1/engine/mm_input_mapper.py new file mode 100644 index 0000000000000..594c973678235 --- /dev/null +++ b/vllm/v1/engine/mm_input_mapper.py @@ -0,0 +1,39 @@ +from typing import Any, Dict, List, Optional + +from vllm.config import ModelConfig +from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, + MultiModalKwargs, MultiModalRegistry) + + +class MMInputMapper: + + def __init__( + self, + model_config: ModelConfig, + mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, + ): + self.mm_registry = mm_registry + self.multi_modal_input_mapper = mm_registry.create_input_mapper( + model_config) + self.mm_registry.init_mm_limits_per_prompt(model_config) + + def process_inputs( + self, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Optional[Dict[str, Any]], + ) -> List[MultiModalKwargs]: + image_inputs = mm_data["image"] + if not isinstance(image_inputs, list): + image_inputs = [image_inputs] + + # Process each image input separately so that later we can schedule + # them in a fine-grained manner. + mm_inputs: List[MultiModalKwargs] = [] + num_images = len(image_inputs) + for i in range(num_images): + mm_input = self.multi_modal_input_mapper( + {"image": [image_inputs[i]]}, + mm_processor_kwargs=mm_processor_kwargs, + ) + mm_inputs.append(mm_input) + return mm_inputs diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 00e5aea92a8df..f35cf738c89bf 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -3,6 +3,7 @@ from vllm.inputs.data import DecoderOnlyInputs from vllm.lora.request import LoRARequest +from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams from vllm.sequence import RequestMetrics from vllm.v1.engine import EngineCoreRequest @@ -47,14 +48,30 @@ def __init__( self._all_token_ids: List[int] = self.prompt_token_ids.copy() self.num_computed_tokens = 0 + # Raw multimodal data before the mm input mapper (e.g., PIL images). + self.mm_data = inputs.get("multi_modal_data") + self.mm_processor_kwargs = inputs.get("mm_processor_kwargs") + mm_positions = inputs.get("multi_modal_placeholders") + if mm_positions: + # FIXME(woosuk): Support other modalities. + self.mm_positions = mm_positions.get("image", []) + else: + self.mm_positions = [] + # Output of the mm input mapper (e.g., image tensors). + self.mm_inputs: List[MultiModalKwargs] = [] + @classmethod def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": - return cls( request_id=request.request_id, - inputs=DecoderOnlyInputs(type="token", - prompt_token_ids=request.prompt_token_ids, - prompt=request.prompt), + inputs=DecoderOnlyInputs( + type="token", + prompt_token_ids=request.prompt_token_ids, + prompt=request.prompt, + multi_modal_data=request.mm_data, + multi_modal_placeholders=request.mm_placeholders, + mm_processor_kwargs=request.mm_processor_kwargs, + ), sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, arrival_time=request.arrival_time, @@ -96,9 +113,21 @@ def is_finished(self) -> bool: def get_finished_reason(self) -> Union[str, None]: return RequestStatus.get_finished_reason(self.status) + def has_encoder_inputs(self) -> bool: + return self.mm_data is not None + + @property + def num_encoder_inputs(self) -> int: + return len(self.mm_positions) + + def get_num_encoder_tokens(self, input_id: int) -> int: + assert input_id < len(self.mm_positions) + num_tokens = self.mm_positions[input_id]["length"] + return num_tokens + class RequestStatus(enum.IntEnum): - """Status of a sequence.""" + """Status of a request.""" WAITING = 0 RUNNING = 1 PREEMPTED = 2 @@ -119,7 +148,7 @@ def get_finished_reason(status: "RequestStatus") -> Union[str, None]: # Mapping of finished statuses to their finish reasons. -# NOTE: The ignored sequences are the sequences whose prompt lengths +# NOTE: The ignored requests are the requests whose prompt lengths # are longer than the model's length cap. Therefore, the stop # reason should also be "length" as in OpenAI API. _FINISHED_REASON_MAP = { diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index db676e2819bf4..81480786a09e1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,7 @@ import os import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch @@ -14,9 +14,10 @@ from vllm.compilation.levels import CompilationLevel from vllm.config import VllmConfig from vllm.forward_context import set_forward_context +from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model -from vllm.multimodal import MultiModalDataDict +from vllm.multimodal import MultiModalKwargs from vllm.plugins import set_compilation_config from vllm.sampling_params import SamplingParams, SamplingType from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv, @@ -27,6 +28,7 @@ from vllm.v1.sample.metadata import SamplingMetadata if TYPE_CHECKING: + from vllm.multimodal.base import PlaceholderRange from vllm.v1.core.scheduler import SchedulerOutput logger = init_logger(__name__) @@ -37,8 +39,8 @@ class GPUModelRunner: def __init__( self, vllm_config: VllmConfig, + input_registry: InputRegistry = INPUT_REGISTRY, ): - # TODO: use ModelRunnerBase.__init__(self, vllm_config=vllm_config) self.vllm_config = vllm_config self.model_config = vllm_config.model_config self.cache_config = vllm_config.cache_config @@ -75,10 +77,16 @@ def __init__( parallel_config) self.num_kv_heads = model_config.get_num_kv_heads(parallel_config) self.head_size = model_config.get_head_size() + self.hidden_size = model_config.get_hidden_size() + + # Multi-modal data support + self.input_registry = input_registry # Lazy initialization # self.model: nn.Module # Set after load_model self.kv_caches: List[torch.Tensor] = [] + # req_id -> (input_id -> encoder_output) + self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {} # Request states. self.requests: Dict[str, CachedRequestState] = {} @@ -96,18 +104,28 @@ def __init__( and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] - self.input_ids = torch.zeros(self.max_num_tokens, - dtype=torch.int32, - device=self.device) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + self.inputs_embeds = torch.zeros( + (self.max_num_tokens, self.hidden_size), + dtype=self.dtype, + device=self.device) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: # Remove stopped requests from the cached states. # Keep the states of the pre-empted requests. for req_id in scheduler_output.finished_req_ids: self.requests.pop(req_id, None) + self.encoder_cache.pop(req_id, None) + + # Free the cached encoder outputs. + for req_id, input_id in scheduler_output.free_encoder_input_ids: + encoder_outputs = self.encoder_cache.get(req_id) + if encoder_outputs is not None: + encoder_outputs.pop(input_id, None) + if not encoder_outputs: + self.encoder_cache.pop(req_id, None) # Remove the requests from the persistent batch. stopped_req_ids = set().union( @@ -156,7 +174,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: req_id=req_id, prompt_token_ids=req_data.prompt_token_ids, prompt=req_data.prompt, - multi_modal_data=req_data.multi_modal_data, + mm_inputs=req_data.mm_inputs, + mm_positions=req_data.mm_positions, sampling_params=sampling_params, generator=generator, block_ids=req_data.block_ids, @@ -285,11 +304,9 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): seq_start_loc_np[0] = 0 np.cumsum(seq_lens, out=seq_start_loc_np[1:]) - self.input_ids[:total_num_scheduled_tokens].copy_(input_ids, - non_blocking=True) + input_ids = input_ids.to(self.device, non_blocking=True) self.positions[:total_num_scheduled_tokens].copy_(positions, non_blocking=True) - query_start_loc = query_start_loc.to(self.device, non_blocking=True) seq_start_loc = seq_start_loc.to(self.device, non_blocking=True) slot_mapping = slot_mapping.to(self.device, non_blocking=True).long() @@ -308,7 +325,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # token from the partial request. # TODO: Support prompt logprobs. logits_indices = query_start_loc[1:] - 1 - return attn_metadata, logits_indices + return input_ids, attn_metadata, logits_indices def _prepare_sampling( self, @@ -325,13 +342,91 @@ def _prepare_sampling( sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy) return sampling_metadata + def _execute_encoder(self, scheduler_output: "SchedulerOutput"): + scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs + if not scheduled_encoder_inputs: + return + + # Batch the multi-modal inputs. + mm_inputs: List[MultiModalKwargs] = [] + req_input_ids: List[Tuple[int, int]] = [] + for req_id, encoder_input_ids in scheduled_encoder_inputs.items(): + req_state = self.requests[req_id] + for input_id in encoder_input_ids: + mm_inputs.append(req_state.mm_inputs[input_id]) + req_input_ids.append((req_id, input_id)) + batched_mm_inputs = MultiModalKwargs.batch(mm_inputs) + batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs, + device=self.device) + + # Run the encoder. + # `encoder_outputs` is either of the following: + # 1. A tensor of shape [num_images, feature_size, hidden_size] + # in case when feature_size is fixed across all images. + # 2. A list (length: num_images) of tensors, each of shape + # [feature_size, hidden_size] in case when the feature size is + # dynamic depending on input images. + encoder_outputs = self.model.process_mm_inputs(**batched_mm_inputs) + + # Cache the encoder outputs. + for (req_id, input_id), output in zip(req_input_ids, encoder_outputs): + if req_id not in self.encoder_cache: + self.encoder_cache[req_id] = {} + self.encoder_cache[req_id][input_id] = output + + def _gather_encoder_outputs( + self, + scheduler_output: "SchedulerOutput", + ) -> List[torch.Tensor]: + encoder_outputs: List[torch.Tensor] = [] + num_reqs = self.input_batch.num_reqs + for req_id in self.input_batch.req_ids[:num_reqs]: + num_scheduled_tokens = scheduler_output.num_scheduled_tokens[ + req_id] + req_state = self.requests[req_id] + num_computed_tokens = req_state.num_computed_tokens + mm_positions = req_state.mm_positions + for i, pos_info in enumerate(mm_positions): + start_pos = pos_info["offset"] + num_encoder_tokens = pos_info["length"] + + # The encoder output is needed if the two ranges overlap: + # [num_computed_tokens, + # num_computed_tokens + num_scheduled_tokens) and + # [start_pos, start_pos + num_encoder_tokens) + if start_pos >= num_computed_tokens + num_scheduled_tokens: + # The encoder output is not needed in this step. + break + if start_pos + num_encoder_tokens <= num_computed_tokens: + # The encoder output is already processed and stored + # in the decoder's KV cache. + continue + + start_idx = max(num_computed_tokens - start_pos, 0) + end_idx = min( + num_computed_tokens - start_pos + num_scheduled_tokens, + num_encoder_tokens) + assert start_idx < end_idx + assert req_id in self.encoder_cache + assert i in self.encoder_cache[req_id] + encoder_output = self.encoder_cache[req_id][i] + encoder_outputs.append(encoder_output[start_idx:end_idx]) + return encoder_outputs + @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", ) -> ModelRunnerOutput: self._update_states(scheduler_output) - attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + + # Run the encoder. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + + # Prepare the decoder inputs. + input_ids, attn_metadata, logits_indices = self._prepare_inputs( + scheduler_output) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if (self.use_cuda_graph and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]): @@ -343,12 +438,26 @@ def execute_model( # Eager mode. num_input_tokens = num_scheduled_tokens + # Get the inputs embeds. + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # NOTE(woosuk): To unify token ids and soft tokens (vision embeddings), + # always use embeddings (rather than token ids) as input to the model. + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + + # Run the decoder. + # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata): hidden_states = self.model( - input_ids=self.input_ids[:num_input_tokens], + input_ids=None, positions=self.positions[:num_input_tokens], kv_caches=self.kv_caches, attn_metadata=None, + inputs_embeds=self.inputs_embeds[:num_input_tokens], ) hidden_states = hidden_states[:num_scheduled_tokens] hidden_states = hidden_states[logits_indices] @@ -440,13 +549,16 @@ def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: with set_forward_context(None): # noqa: SIM117 with set_compile_context(self.cudagraph_batch_sizes): # Trigger compilation for general shape. - model(self.input_ids, - self.positions, - dummy_kv_caches, - attn_metadata=None) + model(input_ids=None, + positions=self.positions, + kv_caches=dummy_kv_caches, + attn_metadata=None, + inputs_embeds=self.inputs_embeds) @torch.inference_mode() def profile_run(self) -> None: + # TODO(woosuk): Profile the max memory usage of the encoder and + # the encoder cache. self._dummy_run(self.model, self.max_num_tokens) torch.cuda.synchronize() @@ -468,10 +580,11 @@ def capture_model(self) -> None: # can reuse the memory pool allocated for the large shapes. for num_tokens in reversed(self.cudagraph_batch_sizes): self.model( - self.input_ids[:num_tokens], - self.positions[:num_tokens], + input_ids=None, + positions=self.positions[:num_tokens], kv_caches=self.kv_caches, attn_metadata=None, + inputs_embeds=self.inputs_embeds[:num_tokens], ) end_time = time.perf_counter() @@ -506,7 +619,8 @@ class CachedRequestState: req_id: str prompt_token_ids: List[int] prompt: Optional[str] - multi_modal_data: Optional["MultiModalDataDict"] + mm_inputs: List[MultiModalKwargs] + mm_positions: List["PlaceholderRange"] sampling_params: SamplingParams generator: Optional[torch.Generator]