Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Adding Support for Qwen2VL as an Embedding Model. Using MrLight/dse-qwen2-2b-mrl-v1 #9944

Merged
merged 10 commits into from
Nov 13, 2024
7 changes: 7 additions & 0 deletions examples/template_dse_qwen2_vl.jinja
DarkLight1337 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{% raw %}<|im_start|>system
You are a helpful assistant.<|im_end|>
{% endraw %}{% endif %}<|im_start|>{{ message['role'] }}{% raw %}
{% endraw %}{% if message['content'] is string %}{{ message['content'] }}<|im_end|>{% raw %}
{% endraw %}{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>{% raw %}
{% endraw %}{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant{% raw %}
{% endraw %}{% endif %}<|endoftext|>
228 changes: 228 additions & 0 deletions tests/models/embedding/vision_language/test_dse_qwen2_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
import os
from typing import List, Type

import pytest
import torch
from PIL import Image
from qwen_vl_utils import process_vision_info
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration

from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
from ....utils import large_gpu_test
from ..utils import check_embeddings_close

HF_TEXT_PROMPTS = [
# T -> X
(
"Query: Find me an everyday image that matches the given caption: The label of the object is stop sign", # noqa: E501,
Image.new("RGB", (56, 56))),
# T -> X
("Query: Retrieve an image of this caption: cherry blossom",
Image.new("RGB", (56, 56))),
]

HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
"stop_sign":
"What is shown in this image?",
"cherry_blossom":
"What is shown in this image?"
})

MODELS = ["MrLight/dse-qwen2-2b-mrl-v1"]


class QwenVLEncoder:

def __init__(self):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
attn = "flash_attention_2" if self.device == "cuda" else None

os.environ["TOKENIZERS_PARALLELISM"] = "true"
self.processor = AutoProcessor.from_pretrained(MODELS[0])
self.model = Qwen2VLForConditionalGeneration.from_pretrained(
MODELS[0], attn_implementation=attn,
torch_dtype=torch.bfloat16).to(self.device).eval()
self.processor.tokenizer.padding_side = "left"
self.model.padding_side = "left"
self.base_embed_dim = 1536

def _get_embedding(self, last_hidden_state: torch.Tensor,
dimension: int) -> torch.Tensor:
reps = last_hidden_state[:, -1]
reps = torch.nn.functional.normalize(reps[0, :dimension], p=2, dim=-1)
return reps

def embed(self, inp: dict, embed_dim: int = 1536) -> torch.Tensor:
"""
inp: dict
{
"dtype": "image",
"image": PIL.Image,
}
or
{
"dtype": "text",
"question": (str) the question to embed,
}
embed_dim: int
Will slice embeddings like emb[:embed_dim]
"""
if inp["dtype"] == "image":
messages = [[{
"role":
"user",
"content": [{
"type": "image",
"image": inp["image"]
}, {
"type": "text",
"text": "What is shown in this image?"
}]
}]]
else:
messages = [[{
"role":
"user",
"content": [
{
"type": "image",
"image": Image.new("RGB", (28, 28)),
"resized_height": 1,
"resized_width": 1
}, # need a dummy image here for an easier process.
{
"type": "text",
"text": f"{inp['question']}"
},
]
}]]
image_inputs, _ = process_vision_info(messages)

texts = [
self.processor.apply_chat_template(
msg, tokenize=False, add_generation_prompt=True) +
"<|endoftext|>" for msg in messages
]
inputs = self.processor(text=texts,
images=image_inputs,
padding="longest",
return_tensors="pt").to(self.device)
inputs = self.model.prepare_inputs_for_generation(**inputs,
use_cache=False)

with torch.no_grad():
output = self.model(**inputs,
return_dict=True,
output_hidden_states=True)

embeddings = self._get_embedding(output.hidden_states[-1], embed_dim)
return embeddings


def _run_test(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
input_texts: List[str],
input_images: PromptImageInput,
model: str,
*,
dtype: str,
) -> None:
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
processor = AutoProcessor.from_pretrained(MODELS[0])
with vllm_runner(model,
task="embedding",
dtype=dtype,
enforce_eager=True,
max_model_len=8192) as vllm_model:
texts = [
processor.apply_chat_template([{
"role":
"user",
"content": [
{
"type": "image",
"image": Image.new("RGB", (28, 28)),
"resized_height": 1,
"resized_width": 1
},
{
"type": "text",
"text": text
},
]
}],
tokenize=False,
add_generation_prompt=True) +
"<|endoftext|>" for text in input_texts
]
vllm_outputs = vllm_model.encode(texts, images=input_images)

hf_model = QwenVLEncoder()
hf_outputs = []
for text, image in zip(input_texts, input_images):
if text.startswith("Query:"):
inp = {"dtype": "text", "question": text}
else:
inp = {"dtype": "image", "image": image}
hf_outputs.append(hf_model.embed(inp).tolist())
check_embeddings_close(
embeddings_0_lst=hf_outputs,
embeddings_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_text(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [(text, image_placeholder)
for text, image_placeholder in HF_TEXT_PROMPTS]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]

_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images, # type: ignore
model,
dtype=dtype,
)


@large_gpu_test(min_gb=48)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_models_image(
hf_runner,
vllm_runner,
image_assets,
model: str,
dtype: str,
) -> None:
input_texts_images = [
(text, asset.pil_image)
for text, asset in zip(HF_IMAGE_PROMPTS, image_assets)
]
input_texts = [text for text, _ in input_texts_images]
input_images = [image for _, image in input_texts_images]

_run_test(
hf_runner,
vllm_runner,
input_texts,
input_images,
model,
dtype=dtype,
)
21 changes: 18 additions & 3 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.config import CacheConfig, MultiModalConfig
from vllm.config import CacheConfig, MultiModalConfig, PoolerConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, InputContext,
Expand All @@ -52,17 +52,19 @@
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.qwen2 import Qwen2Model
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
MultiModalInputs)
from vllm.multimodal.base import MultiModalData
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.sequence import IntermediateTensors, SequenceData, PoolerOutput
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor

Expand Down Expand Up @@ -934,7 +936,8 @@ def __init__(self,
config: Qwen2VLConfig,
multimodal_config: MultiModalConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
quant_config: Optional[QuantizationConfig] = None,
pooler_config: Optional[PoolerConfig] = None) -> None:
super().__init__()

assert not cache_config.enable_prefix_caching, \
Expand Down Expand Up @@ -968,6 +971,11 @@ def __init__(self,

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.LAST,
normalize=True,
softmax=False)
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
Expand Down Expand Up @@ -1158,6 +1166,13 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
# [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501,
}

def add_embedding_models(base_models, embedding_models):
Expand Down
Loading