forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Hardware][CPU] Add embedding models support for CPU backend (vllm-pr…
…oject#10193) Signed-off-by: Isotr0py <[email protected]>
- Loading branch information
Showing
9 changed files
with
185 additions
and
52 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import dataclasses | ||
from typing import Any, Dict, List, Optional, Tuple, Type, Union | ||
|
||
import torch | ||
|
||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
from vllm.multimodal import MultiModalKwargs | ||
from vllm.pooling_params import PoolingParams | ||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData, | ||
SequenceGroupMetadata) | ||
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU, | ||
ModelInputForCPUBuilder) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU): | ||
""" | ||
Used by the CPUEmbeddingModelRunner. | ||
""" | ||
pooling_metadata: Optional["PoolingMetadata"] = None | ||
|
||
|
||
class CPUEmbeddingModelRunner( | ||
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]): | ||
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = ( | ||
ModelInputForCPUWithPoolingMetadata) | ||
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder | ||
|
||
@torch.inference_mode() | ||
def execute_model( | ||
self, | ||
model_input: ModelInputForCPUWithPoolingMetadata, | ||
kv_caches: List[torch.Tensor], | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
num_steps: int = 1, | ||
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]: | ||
if num_steps > 1: | ||
raise ValueError( | ||
"CPU worker does not support multi-step execution.") | ||
|
||
num_layers = self.model_config.get_num_layers(self.parallel_config) | ||
# use an empty tensor instead of `None`` to force Dynamo to pass | ||
# it by reference, rather by specializing on the value ``None``. | ||
# the `dtype` argument does not matter, and we use `float32` as | ||
# a placeholder (it has wide hardware support). | ||
kv_caches = [ | ||
torch.tensor([], dtype=torch.float32, device=self.device) | ||
for _ in range(num_layers) | ||
] | ||
|
||
model_executable = self.model | ||
execute_model_kwargs = { | ||
"input_ids": | ||
model_input.input_tokens, | ||
"positions": | ||
model_input.input_positions, | ||
"kv_caches": | ||
kv_caches, | ||
"attn_metadata": | ||
model_input.attn_metadata, | ||
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, | ||
device=self.device), | ||
"intermediate_tensors": | ||
intermediate_tensors, | ||
} | ||
|
||
hidden_states = model_executable(**execute_model_kwargs) | ||
|
||
return [ | ||
self.model.pooler(hidden_states=hidden_states, | ||
pooling_metadata=model_input.pooling_metadata) | ||
] | ||
|
||
def make_model_input_from_broadcasted_tensor_dict( | ||
self, | ||
tensor_dict: Dict[str, | ||
Any]) -> ModelInputForCPUWithPoolingMetadata: | ||
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict( | ||
tensor_dict, | ||
attn_backend=self.attn_backend, | ||
) | ||
|
||
def prepare_model_input( | ||
self, | ||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], | ||
virtual_engine: int = 0, | ||
finished_requests_ids: Optional[List[str]] = None | ||
) -> ModelInputForCPUWithPoolingMetadata: | ||
assert seq_group_metadata_list is not None | ||
model_input = self._prepare_model_input_tensors( | ||
seq_group_metadata_list, finished_requests_ids) | ||
# Prepare PoolingMetadata. | ||
assert model_input.seq_lens is not None | ||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list, | ||
model_input.seq_lens) | ||
|
||
return dataclasses.replace(model_input, | ||
pooling_metadata=pooling_metadata) | ||
|
||
def _prepare_pooling( | ||
self, | ||
seq_group_metadata_list: List[SequenceGroupMetadata], | ||
prompt_lens: List[int], | ||
) -> PoolingMetadata: | ||
"""Prepare PoolingMetadata for the sequence group metadata list.""" | ||
seq_groups: List[Tuple[List[int], PoolingParams]] = [] | ||
for i, seq_group_metadata in enumerate(seq_group_metadata_list): | ||
seq_ids = list(seq_group_metadata.seq_data.keys()) | ||
pooling_params = seq_group_metadata.pooling_params | ||
seq_groups.append((seq_ids, pooling_params)) | ||
|
||
seq_data: Dict[int, SequenceData] = {} | ||
for seq_group_metadata in seq_group_metadata_list: | ||
seq_data.update(seq_group_metadata.seq_data) | ||
|
||
pooling_metadata = PoolingMetadata( | ||
seq_groups=seq_groups, | ||
seq_data=seq_data, | ||
prompt_lens=prompt_lens, | ||
) | ||
|
||
return pooling_metadata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.