forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model] Support Gemma2 embedding model (vllm-project#9004)
Signed-off-by: Sumit Dubey <[email protected]>
- Loading branch information
Showing
5 changed files
with
99 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from typing import Iterable, List, Optional, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from vllm.attention import AttentionMetadata | ||
from vllm.model_executor.layers.pooler import Pooler, PoolingType | ||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader | ||
from vllm.model_executor.models.gemma2 import Gemma2Model | ||
from vllm.model_executor.pooling_metadata import PoolingMetadata | ||
from vllm.sequence import IntermediateTensors, PoolerOutput | ||
|
||
|
||
class Gemma2EmbeddingModel(nn.Module): | ||
"""A model that uses Gemma2 with additional embedding functionalities. | ||
This class encapsulates the Gemma2Model and provides an interface for | ||
embedding operations and customized pooling functions. | ||
Attributes: | ||
model: An instance of Gemma2Model used for forward operations. | ||
_pooler: An instance of Pooler used for pooling operations. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
**kwargs, | ||
) -> None: | ||
super().__init__() | ||
self.model = Gemma2Model(**kwargs) | ||
self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) | ||
|
||
def forward( | ||
self, | ||
input_ids: Optional[torch.Tensor], | ||
positions: torch.Tensor, | ||
kv_caches: List[torch.Tensor], | ||
attn_metadata: AttentionMetadata, | ||
intermediate_tensors: Optional[IntermediateTensors] = None, | ||
inputs_embeds: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
return self.model.forward(input_ids, positions, kv_caches, | ||
attn_metadata, intermediate_tensors, | ||
inputs_embeds) | ||
|
||
def pooler( | ||
self, | ||
hidden_states: torch.Tensor, | ||
pooling_metadata: PoolingMetadata, | ||
) -> Optional[PoolerOutput]: | ||
return self._pooler(hidden_states, pooling_metadata) | ||
|
||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | ||
stacked_params_mapping = [ | ||
# (param_name, shard_name, shard_id) | ||
("qkv_proj", "q_proj", "q"), | ||
("qkv_proj", "k_proj", "k"), | ||
("qkv_proj", "v_proj", "v"), | ||
("gate_up_proj", "gate_proj", 0), | ||
("gate_up_proj", "up_proj", 1), | ||
] | ||
params_dict = dict(self.model.named_parameters()) | ||
for name, loaded_weight in weights: | ||
for (param_name, weight_name, shard_id) in stacked_params_mapping: | ||
if weight_name not in name: | ||
continue | ||
name = name.replace(weight_name, param_name) | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
param = params_dict[name] | ||
weight_loader = param.weight_loader | ||
weight_loader(param, loaded_weight, shard_id) | ||
break | ||
else: | ||
# Skip loading extra bias for GPTQ models. | ||
if name.endswith(".bias") and name not in params_dict: | ||
continue | ||
param = params_dict[name] | ||
weight_loader = getattr(param, "weight_loader", | ||
default_weight_loader) | ||
weight_loader(param, loaded_weight) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters