Skip to content

Commit

Permalink
[Misc] Sort the list of embedding models (vllm-project#10037)
Browse files Browse the repository at this point in the history
Signed-off-by: DarkLight1337 <[email protected]>
Signed-off-by: Loc Huynh <[email protected]>
  • Loading branch information
DarkLight1337 authored and JC1DA committed Nov 11, 2024
1 parent 9c207af commit 2f5095b
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,33 +94,23 @@
_EMBEDDING_MODELS = {
# [Text-only]
"BertModel": ("bert", "BertEmbeddingModel"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"),
"LlamaModel": ("llama", "LlamaEmbeddingModel"),
**{
# Multiple models share the same architecture, so we include them all
k: (mod, arch) for k, (mod, arch) in _TEXT_GENERATION_MODELS.items()
if arch == "LlamaForCausalLM"
},
"MistralModel": ("llama", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": (
"qwen2_cls", "Qwen2ForSequenceClassification"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForSequenceClassification": ("qwen2_cls", "Qwen2ForSequenceClassification"), # noqa: E501
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
}

def add_embedding_models(base_models, embedding_models):
with_pooler_method_models = {}
embedding_models_name = embedding_models.keys()
for name, (path, arch) in base_models.items():
if arch in embedding_models_name:
with_pooler_method_models[name] = (path, arch)
return with_pooler_method_models

_EMBEDDING_MODELS = {
**add_embedding_models(_TEXT_GENERATION_MODELS, _EMBEDDING_MODELS),
**_EMBEDDING_MODELS,
}

_MULTIMODAL_MODELS = {
# [Decoder-only]
"Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"),
Expand Down

0 comments on commit 2f5095b

Please sign in to comment.