Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][CPU] Add embedding models support for CPU backend #10193

Merged
merged 7 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .buildkite/run-cpu-test-ppc64le.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
# Embedding models are not supported for CPU yet
# pytest -v -s tests/models/embedding/language
pytest -v -s tests/models/embedding/language
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language/test_models.py
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
Expand Down
3 changes: 1 addition & 2 deletions .buildkite/run-cpu-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ function cpu_tests() {
decord einops librosa peft Pillow sentence-transformers soundfile \
transformers_stream_generator matplotlib datamodel_code_generator
pip install torchvision --index-url https://download.pytorch.org/whl/cpu
# Embedding models are not supported for CPU yet
# pytest -v -s tests/models/embedding/language
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved
pytest -v -s tests/models/embedding/language
pytest -v -s tests/models/encoder_decoder/language
pytest -v -s tests/models/decoder_only/language/test_models.py
pytest -v -s tests/models/decoder_only/audio_language -m cpu_model
Expand Down
7 changes: 4 additions & 3 deletions tests/models/embedding/language/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
"""
import pytest

from vllm.utils import current_platform

from ..utils import check_embeddings_close

# Model, Guard
Expand All @@ -21,15 +23,14 @@
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_models(
monkeypatch,
hf_runner,
vllm_runner,
example_prompts,
model,
dtype: str,
) -> None:
if model in ENCODER_ONLY:
monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
if model not in ENCODER_ONLY and current_platform.is_cpu():
pytest.skip("Skip large embedding models test on CPU.")
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved

# The example_prompts has ending "\n", for example:
# "Write a short story about a robot that dreams for the first time.\n"
Expand Down
14 changes: 10 additions & 4 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ def get_seq_lens(
* Appropriate sequence lengths tensor for key & value
'''

if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
seq_lens_q = self.seq_lens
seq_lens_kv = self.seq_lens
elif attn_type == AttentionType.ENCODER:
Expand Down Expand Up @@ -189,7 +190,8 @@ def get_attn_bias(
* Appropriate attention bias value given the attention type
'''

if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return self.attn_bias
elif attn_type == AttentionType.ENCODER:
return self.encoder_attn_bias
Expand All @@ -215,7 +217,8 @@ def set_attn_bias(
encoder/decoder cross-attention
'''

if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
self.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
self.encoder_attn_bias = attn_bias
Expand Down Expand Up @@ -252,7 +255,8 @@ def get_seq_len_block_table_args(
* Appropriate block tables (or None)
'''

if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
# Decoder self-attention
# Choose max_seq_len based on whether we are in prompt_run
return (self.seq_lens_tensor, self.max_decode_seq_len,
Expand Down Expand Up @@ -420,6 +424,8 @@ def forward(
"Torch SDPA backend doesn't support prefix decoding.")

if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
# Decoding run.
(
seq_lens_arg,
Expand Down
6 changes: 0 additions & 6 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from transformers import BertConfig

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.backends.xformers import XFormersImpl
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -216,11 +215,6 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.attn")

if not isinstance(self.attn.impl, XFormersImpl):
raise ValueError(
"Encoder-only models currently require XFORMERS attention "
"backend. Set VLLM_ATTENTION_BACKEND=XFORMERS to use BERT.")
Isotr0py marked this conversation as resolved.
Show resolved Hide resolved

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
122 changes: 122 additions & 0 deletions vllm/worker/cpu_embedding_model_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch

from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
SequenceGroupMetadata)
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
ModelInputForCPUBuilder)


@dataclasses.dataclass(frozen=True)
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
"""
Used by the CPUEmbeddingModelRunner.
"""
pooling_metadata: Optional["PoolingMetadata"] = None


class CPUEmbeddingModelRunner(
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
ModelInputForCPUWithPoolingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

@torch.inference_mode()
def execute_model(
self,
model_input: ModelInputForCPUWithPoolingMetadata,
kv_caches: List[torch.Tensor],
intermediate_tensors: Optional[IntermediateTensors] = None,
num_steps: int = 1,
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
if num_steps > 1:
raise ValueError(
"CPU worker does not support multi-step execution.")

num_layers = self.model_config.get_num_layers(self.parallel_config)
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value ``None``.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(num_layers)
]

model_executable = self.model
execute_model_kwargs = {
"input_ids":
model_input.input_tokens,
"positions":
model_input.input_positions,
"kv_caches":
kv_caches,
"attn_metadata":
model_input.attn_metadata,
**MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {},
device=self.device),
"intermediate_tensors":
intermediate_tensors,
}

hidden_states = model_executable(**execute_model_kwargs)

return [
self.model.pooler(hidden_states=hidden_states,
pooling_metadata=model_input.pooling_metadata)
]

def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str,
Any]) -> ModelInputForCPUWithPoolingMetadata:
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
tensor_dict,
attn_backend=self.attn_backend,
)

def prepare_model_input(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithPoolingMetadata:
assert seq_group_metadata_list is not None
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
# Prepare PoolingMetadata.
assert model_input.seq_lens is not None
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
model_input.seq_lens)

return dataclasses.replace(model_input,
pooling_metadata=pooling_metadata)

def _prepare_pooling(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
prompt_lens: List[int],
) -> PoolingMetadata:
"""Prepare PoolingMetadata for the sequence group metadata list."""
seq_groups: List[Tuple[List[int], PoolingParams]] = []
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
seq_ids = list(seq_group_metadata.seq_data.keys())
pooling_params = seq_group_metadata.pooling_params
seq_groups.append((seq_ids, pooling_params))

seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)

pooling_metadata = PoolingMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
)

return pooling_metadata
11 changes: 5 additions & 6 deletions vllm/worker/cpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.cpu_model_runner import (CPUModelRunner,
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
ModelInputForCPUBuilder,
ModelInputForCPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
Expand Down Expand Up @@ -50,7 +50,8 @@ def from_broadcasted_tensor_dict(
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))


class CPUEncoderDecoderModelRunner(CPUModelRunner):
class CPUEncoderDecoderModelRunner(
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
EncoderDecoderModelInputForCPU)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
Expand Down Expand Up @@ -87,10 +88,8 @@ def prepare_model_input(
virtual_engine: int = 0,
finished_requests_ids: Optional[List[str]] = None
) -> EncoderDecoderModelInputForCPU:
model_input = super().prepare_model_input(seq_group_metadata_list,
virtual_engine,
finished_requests_ids)
model_input = cast(EncoderDecoderModelInputForCPU, model_input)
model_input = self._prepare_model_input_tensors(
seq_group_metadata_list, finished_requests_ids)
(
attn_metadata,
encoder_input_tokens_tensor,
Expand Down
57 changes: 35 additions & 22 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import weakref
from collections import defaultdict
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)

import torch
from torch import nn
Expand Down Expand Up @@ -31,6 +32,7 @@

logger = init_logger(__name__)

TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
_PAD_SLOT_ID = -1


Expand Down Expand Up @@ -60,10 +62,10 @@ def as_broadcastable_tensor_dict(

@classmethod
def from_broadcasted_tensor_dict(
cls: Type["ModelInputForCPU"],
cls: Type[TModelInputForCPU],
tensor_dict: Dict[str, Any],
attn_backend: Optional["AttentionBackend"] = None
) -> "ModelInputForCPU":
) -> TModelInputForCPU:
if attn_backend is not None:
tensor_dict = _init_attn_metadata_from_tensor_dict(
attn_backend, tensor_dict)
Expand Down Expand Up @@ -255,11 +257,14 @@ def _prepare_prompt(
slot_mapping.append(_PAD_SLOT_ID)
continue

block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# For encoder-only models, the block_table is None,
# and there is no need to initialize the slot_mapping.
if block_table is not None:
block_number = block_table[i //
self.block_size] # type: ignore
block_offset = i % self.block_size # type: ignore
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)

if any(input_mrope_positions):
input_positions = None # type: ignore
Expand Down Expand Up @@ -402,10 +407,12 @@ def _prepare_decode(
)


class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
"""
Helper class for shared methods between CPU model runners.
"""
_model_input_cls: Type[TModelInputForCPU]
_builder_cls: Type[ModelInputForCPUBuilder]

def __init__(
self,
Expand Down Expand Up @@ -448,20 +455,11 @@ def __init__(
def load_model(self) -> None:
self.model = get_model(vllm_config=self.vllm_config)

def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)

def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
finished_requests_ids: Optional[List[str]] = None
) -> ModelInputForCPUWithSamplingMetadata:
) -> TModelInputForCPU:
"""Helper method to prepare the model input based on a given sequence
group. Prepares metadata needed for the base model forward pass but not
metadata for possible additional steps, e.g., sampling.
Expand All @@ -473,6 +471,21 @@ def _prepare_model_input_tensors(

return builder.build() # type: ignore


class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
ModelInputForCPUWithSamplingMetadata)
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder

def make_model_input_from_broadcasted_tensor_dict(
self,
tensor_dict: Dict[str, Any],
) -> ModelInputForCPUWithSamplingMetadata:
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
tensor_dict,
attn_backend=self.attn_backend,
)

def prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
Expand Down
Loading