From 20587d8ede0ab0339f81b9fdf48eb4c024382fa0 Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Fri, 4 Oct 2024 23:57:05 -0700 Subject: [PATCH] [Model] Support Gemma2 embedding model (#9004) Signed-off-by: Sumit Dubey --- tests/conftest.py | 1 + .../embedding/language/test_embedding.py | 11 ++- vllm/model_executor/models/gemma2.py | 7 +- .../model_executor/models/gemma2_embedding.py | 82 +++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 5 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 vllm/model_executor/models/gemma2_embedding.py diff --git a/tests/conftest.py b/tests/conftest.py index b1833fdae5347..177b8a0640278 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -277,6 +277,7 @@ def __init__( SentenceTransformer( model_name, device="cpu", + trust_remote_code=True, ).to(dtype=torch_dtype)) else: model_kwargs = model_kwargs if model_kwargs is not None else {} diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 6556998b68a74..be316c6e12da1 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -1,6 +1,6 @@ """Compare the outputs of HF and vLLM for Mistral models using greedy sampling. -Run `pytest tests/models/test_llama_embedding.py`. +Run `pytest tests/models/embedding/language/test_embedding.py`. """ import pytest import torch @@ -8,6 +8,7 @@ MODELS = [ "intfloat/e5-mistral-7b-instruct", + "BAAI/bge-multilingual-gemma2", ] @@ -28,6 +29,14 @@ def test_models( model: str, dtype: str, ) -> None: + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + with hf_runner(model, dtype=dtype, is_embedding_model=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 9fddaac3a0837..ddeaa0fbfc276 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -278,11 +278,14 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: - hidden_states = self.embed_tokens(input_ids) + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) hidden_states *= self.normalizer - residual = None else: assert intermediate_tensors is not None diff --git a/vllm/model_executor/models/gemma2_embedding.py b/vllm/model_executor/models/gemma2_embedding.py new file mode 100644 index 0000000000000..1bcdaea93410f --- /dev/null +++ b/vllm/model_executor/models/gemma2_embedding.py @@ -0,0 +1,82 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.attention import AttentionMetadata +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.gemma2 import Gemma2Model +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + + +class Gemma2EmbeddingModel(nn.Module): + """A model that uses Gemma2 with additional embedding functionalities. + + This class encapsulates the Gemma2Model and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of Gemma2Model used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = Gemma2Model(**kwargs) + self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index a72b9e8909db2..ccb0e155ff4aa 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -83,6 +83,7 @@ _EMBEDDING_MODELS = { "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Gemma2Model": ("gemma2_embedding", "Gemma2EmbeddingModel"), } _MULTIMODAL_MODELS = {