-
Notifications
You must be signed in to change notification settings - Fork 291
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* wip vlm2vec model * making i2t classification work wit Calteh101 * test vlm2vec on other task types * move peft into class
- Loading branch information
1 parent
8065568
commit 2011aa1
Showing
6 changed files
with
363 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
from functools import partial | ||
from typing import Any, Literal | ||
|
||
import torch | ||
from PIL import Image | ||
from torch.utils.data import DataLoader | ||
from tqdm import tqdm | ||
from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor | ||
|
||
from mteb.model_meta import ModelMeta | ||
|
||
logging.basicConfig(level=logging.WARNING) | ||
logger = logging.getLogger(__name__) | ||
|
||
EncodeTypes = Literal["query", "passage"] | ||
|
||
|
||
class VLM2VecWrapper: | ||
"""Adapted from https://github.com/TIGER-AI-Lab/VLM2Vec/blob/main/src/model.py""" | ||
|
||
def __init__( | ||
self, | ||
model_name: str = "TIGER-Lab/VLM2Vec-LoRA", | ||
device: str = "cuda" if torch.cuda.is_available() else "cpu", | ||
**kwargs, | ||
): | ||
try: | ||
import flash_attn # noqa | ||
from peft import LoraConfig, PeftModel # noqa | ||
except ImportError: | ||
logger.warning( | ||
"VLM2Vec models were trained with flash attention enabled. For optimal performance, please install the `flash_attn` package with `pip install flash-attn --no-build-isolation`." | ||
) | ||
|
||
self.pooling = "last" | ||
self.normalize = True | ||
self.temperature = 1.0 | ||
self.hidden_size = 4096 | ||
self.device = device | ||
|
||
# Loading the base model | ||
base_model_name = "microsoft/Phi-3.5-vision-instruct" | ||
config = AutoConfig.from_pretrained(base_model_name, trust_remote_code=True) | ||
config.use_cache = False | ||
config.padding_side = "right" | ||
|
||
checkpoint_path = model_name if model_name else base_model_name | ||
base_model = AutoModelForCausalLM.from_pretrained( | ||
checkpoint_path, | ||
config=config, | ||
attn_implementation="flash_attention_2", | ||
torch_dtype=torch.bfloat16, | ||
trust_remote_code=True, | ||
) | ||
base_model.padding_side = "right" | ||
|
||
# Building the model on top of the base | ||
if "LoRA" in model_name: | ||
lora_config = LoraConfig.from_pretrained(checkpoint_path) | ||
lora_model = PeftModel.from_pretrained( | ||
base_model, checkpoint_path, config=lora_config | ||
) | ||
lora_model = lora_model.merge_and_unload() | ||
model = lora_model | ||
else: | ||
model = base_model | ||
|
||
model.eval() | ||
model.to(device) | ||
self.mdl = model | ||
|
||
self.processor = AutoProcessor.from_pretrained( | ||
base_model_name, | ||
trust_remote_code=True, | ||
num_crops=4, | ||
) | ||
|
||
def encode( | ||
self, | ||
sentences: list[str], | ||
*, | ||
prompt_name: str = None, | ||
**kwargs: Any, # noqa | ||
): | ||
return self.get_text_embeddings(texts=sentences) | ||
|
||
def encode_input(self, input): | ||
hidden_states = self.mdl(**input, return_dict=True, output_hidden_states=True) | ||
hidden_states = hidden_states.hidden_states[-1] | ||
pooled_output = self._pooling(hidden_states, input["attention_mask"]) | ||
return pooled_output | ||
|
||
def _pooling(self, last_hidden_state, attention_mask): | ||
if self.pooling == "last": | ||
sequence_lengths = attention_mask.sum(dim=1) - 1 | ||
batch_size = last_hidden_state.shape[0] | ||
reps = last_hidden_state[ | ||
torch.arange(batch_size, device=last_hidden_state.device), | ||
sequence_lengths, | ||
] | ||
else: | ||
raise NotImplementedError | ||
if self.normalize: | ||
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) | ||
return reps | ||
|
||
# reference: https://github.com/TIGER-AI-Lab/VLM2Vec/blob/main/src/collator.py | ||
def get_image_embeddings( | ||
self, images: list[Image.Image] | DataLoader, batch_size: int = 32 | ||
): | ||
text = "<|image_1|> Represent the given image." | ||
all_image_embeddings = [] | ||
if isinstance(images, DataLoader): | ||
import torchvision.transforms.functional as F | ||
|
||
with torch.no_grad(): | ||
for batch in tqdm(images): | ||
input_ids, pixel_values, image_sizes = [], [], [] | ||
for b in batch: | ||
inputs = self.processor( | ||
text, | ||
[F.to_pil_image(b.to("cpu"))], | ||
return_tensors="pt", | ||
max_length=256, | ||
truncation=True, | ||
) | ||
inputs = {k: v.to(self.device) for k, v in inputs.items()} | ||
input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) | ||
pixel_values.append(inputs["pixel_values"]) | ||
image_sizes.append(inputs["image_sizes"]) | ||
|
||
input_ids = torch._C._nn.pad_sequence( | ||
input_ids, | ||
batch_first=True, | ||
padding_value=self.processor.tokenizer.pad_token_id, | ||
).squeeze(2) | ||
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) | ||
|
||
pixel_values = torch.cat(pixel_values, dim=0) | ||
image_sizes = torch.cat(image_sizes, dim=0) | ||
inputs = { | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"pixel_values": pixel_values, | ||
"image_sizes": image_sizes, | ||
} | ||
|
||
image_outputs = self.encode_input(inputs) | ||
all_image_embeddings.append(image_outputs.cpu()) | ||
|
||
else: | ||
with torch.no_grad(): | ||
for i in tqdm(range(0, len(images), batch_size)): | ||
batch_images = images[i : i + batch_size] | ||
input_ids, pixel_values, image_sizes = [], [], [] | ||
for b in batch_images: | ||
inputs = self.processor( | ||
text, | ||
[b], | ||
return_tensors="pt", | ||
max_length=256, | ||
truncation=True, | ||
) | ||
inputs = {k: v.to(self.device) for k, v in inputs.items()} | ||
input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) | ||
pixel_values.append(inputs["pixel_values"]) | ||
image_sizes.append(inputs["image_sizes"]) | ||
|
||
input_ids = torch._C._nn.pad_sequence( | ||
input_ids, | ||
batch_first=True, | ||
padding_value=self.processor.tokenizer.pad_token_id, | ||
).squeeze(2) | ||
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) | ||
|
||
pixel_values = torch.cat(pixel_values, dim=0) | ||
image_sizes = torch.cat(image_sizes, dim=0) | ||
inputs = { | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
"pixel_values": pixel_values, | ||
"image_sizes": image_sizes, | ||
} | ||
|
||
image_outputs = self.encode_input(inputs) | ||
all_image_embeddings.append(image_outputs.cpu()) | ||
|
||
all_image_embeddings = torch.cat(all_image_embeddings, dim=0) | ||
return all_image_embeddings | ||
|
||
def get_text_embeddings(self, texts: list[str], batch_size: int = 32): | ||
all_text_embeddings = [] | ||
|
||
with torch.no_grad(): | ||
for i in tqdm(range(0, len(texts), batch_size)): | ||
input_ids = [] | ||
batch_texts = texts[i : i + batch_size] | ||
for text in batch_texts: | ||
inputs = self.processor( | ||
text, | ||
None, | ||
return_tensors="pt", | ||
max_length=256, | ||
truncation=True, | ||
) | ||
inputs = {k: v.to(self.device) for k, v in inputs.items()} | ||
input_ids.append(inputs["input_ids"].squeeze(0).unsqueeze(1)) | ||
|
||
input_ids = torch._C._nn.pad_sequence( | ||
input_ids, | ||
batch_first=True, | ||
padding_value=self.processor.tokenizer.pad_token_id, | ||
).squeeze(2) | ||
attention_mask = input_ids.ne(self.processor.tokenizer.pad_token_id) | ||
inputs = { | ||
"input_ids": input_ids, | ||
"attention_mask": attention_mask, | ||
} | ||
|
||
text_outputs = self.encode_input(inputs) | ||
all_text_embeddings.append(text_outputs.cpu()) | ||
|
||
all_text_embeddings = torch.cat(all_text_embeddings, dim=0) | ||
return all_text_embeddings | ||
|
||
def get_fused_embeddings( | ||
self, | ||
texts: list[str] = None, | ||
images: list[Image.Image] | DataLoader = None, | ||
fusion_mode="sum", | ||
batch_size: int = 32, | ||
): | ||
if texts is None and images is None: | ||
raise ValueError("Either texts or images must be provided") | ||
|
||
text_embeddings = None | ||
image_embeddings = None | ||
|
||
if texts is not None: | ||
text_embeddings = self.get_text_embeddings(texts, batch_size) | ||
|
||
if images is not None: | ||
image_embeddings = self.get_image_embeddings(images, batch_size) | ||
|
||
if text_embeddings is not None and image_embeddings is not None: | ||
if len(text_embeddings) != len(image_embeddings): | ||
raise ValueError( | ||
"The number of texts and images must have the same length" | ||
) | ||
texts = iter(texts) | ||
all_fused_embeddings = [] | ||
if isinstance(images, DataLoader): | ||
import torchvision.transforms.functional as F | ||
|
||
for batch in images: | ||
for b in batch: | ||
text = next(texts) | ||
inputs = self.processor( | ||
f"<|image_1|> Represent the given image with the following question: {text}", | ||
[F.to_pil_image(b.to("cpu"))], | ||
) | ||
inputs = { | ||
key: value.to(self.device) for key, value in inputs.items() | ||
} | ||
outputs = self.encode_input(inputs) | ||
all_fused_embeddings.append(outputs.cpu()) | ||
|
||
fused_embeddings = torch.cat(all_fused_embeddings, dim=0) | ||
|
||
return fused_embeddings | ||
elif text_embeddings is not None: | ||
return text_embeddings | ||
elif image_embeddings is not None: | ||
return image_embeddings | ||
|
||
|
||
vlm2vec_lora = ModelMeta( | ||
loader=partial( | ||
VLM2VecWrapper, | ||
model_name="TIGER-Lab/VLM2Vec-LoRA", | ||
), | ||
name="TIGER-Lab/VLM2Vec-LoRA", | ||
languages=["eng_Latn"], | ||
open_source=True, | ||
revision="7403b6327958071c1e33c822c7453adadccc7298", | ||
release_date="2024-10-08", | ||
) | ||
|
||
vlm2vec_full = ModelMeta( | ||
loader=partial( | ||
VLM2VecWrapper, | ||
model_name="TIGER-Lab/VLM2Vec-Full", | ||
), | ||
name="TIGER-Lab/VLM2Vec-Full", | ||
languages=["eng_Latn"], | ||
open_source=True, | ||
revision="e9afa98002097ac2471827ba23ea1f2ddd229480", | ||
release_date="2024-10-08", | ||
) |
28 changes: 28 additions & 0 deletions
28
...lts-mieb/TIGER-Lab__VLM2Vec-LoRA/7403b6327958071c1e33c822c7453adadccc7298/Caltech101.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
{ | ||
"dataset_revision": "851374102055782c84f89b1b4e9d128a6568847b", | ||
"evaluation_time": 1317.9743084907532, | ||
"kg_co2_emissions": null, | ||
"mteb_version": "1.14.21", | ||
"scores": { | ||
"test": [ | ||
{ | ||
"accuracy": 0.9301446416831032, | ||
"f1": 0.8863632422649081, | ||
"f1_weighted": 0.9270094006117223, | ||
"hf_subset": "default", | ||
"languages": [ | ||
"eng-Latn" | ||
], | ||
"main_score": 0.9301446416831032, | ||
"scores_per_experiment": [ | ||
{ | ||
"accuracy": 0.9301446416831032, | ||
"f1": 0.8863632422649081, | ||
"f1_weighted": 0.9270094006117223 | ||
} | ||
] | ||
} | ||
] | ||
}, | ||
"task_name": "Caltech101" | ||
} |
26 changes: 26 additions & 0 deletions
26
results-mieb/TIGER-Lab__VLM2Vec-LoRA/7403b6327958071c1e33c822c7453adadccc7298/STS12.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
{ | ||
"dataset_revision": "a0d554a64d88156834ff5ae9920b964011b16384", | ||
"evaluation_time": 33.679136514663696, | ||
"kg_co2_emissions": null, | ||
"mteb_version": "1.14.21", | ||
"scores": { | ||
"test": [ | ||
{ | ||
"cosine_pearson": 0.6128856131150828, | ||
"cosine_spearman": 0.5375376750091784, | ||
"euclidean_pearson": 0.5866571133163221, | ||
"euclidean_spearman": 0.5376001641683719, | ||
"hf_subset": "default", | ||
"languages": [ | ||
"eng-Latn" | ||
], | ||
"main_score": 0.5375376750091784, | ||
"manhattan_pearson": 0.5912422177023093, | ||
"manhattan_spearman": 0.5413588869937086, | ||
"pearson": 0.6128856131150828, | ||
"spearman": 0.5375376750091784 | ||
} | ||
] | ||
}, | ||
"task_name": "STS12" | ||
} |
1 change: 1 addition & 0 deletions
1
...lts-mieb/TIGER-Lab__VLM2Vec-LoRA/7403b6327958071c1e33c822c7453adadccc7298/model_meta.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
{"name": "TIGER-Lab/VLM2Vec-LoRA", "revision": "7403b6327958071c1e33c822c7453adadccc7298", "release_date": "2024-10-08", "languages": ["eng_Latn"], "n_parameters": null, "memory_usage": null, "max_tokens": null, "embed_dim": null, "license": null, "open_source": true, "similarity_fn_name": null, "framework": [], "loader": "VLM2VecWrapper"} |