Skip to content

Commit

Permalink
[mieb] add Eva CLIP models (#1369)
Browse files Browse the repository at this point in the history
* add Eva CLIP models

* make lint
  • Loading branch information
isaac-chung authored Oct 31, 2024
1 parent cf8ea1f commit 6652e56
Show file tree
Hide file tree
Showing 8 changed files with 603 additions and 0 deletions.
198 changes: 198 additions & 0 deletions mteb/models/evaclip_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
from __future__ import annotations

from functools import partial
from typing import Any

import torch
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm

from mteb.model_meta import ModelMeta


def evaclip_loader(**kwargs):
try:
import sys
import os

sys.path.insert(0, os.path.join(os.getcwd(), "EVA/EVA-CLIP/rei"))

from eva_clip import create_model_and_transforms, get_tokenizer
except ImportError:
# https://github.com/baaivision/EVA/tree/master/EVA-CLIP#setup
raise ImportError(
"Please run `git clone [email protected]:baaivision/EVA.git`,"
"`pip install ninja`"
"`pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers`"
"`git clone https://github.com/NVIDIA/apex && cd apex && pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./`"
)

class EvaCLIPWrapper:
def __init__(
self,
model_name: str = "EVA02-CLIP-B-16",
device: str = "cuda" if torch.cuda.is_available() else "cpu",
**kwargs: Any,
):
self.model_name = model_name
self.device = device
pretrained = "eva_clip" # or "/path/to/EVA02_CLIP_B_psz16_s8B.pt"
self.model, _, self.img_preprocess = create_model_and_transforms(
model_name, pretrained, force_custom_clip=True, device=device
)
self.model.eval()
self.tokenizer = get_tokenizer(model_name)

def encode( # type: ignore
self,
sentences: list[str],
*,
batch_size: int = 32,
**kwargs: Any,
):
return self.get_text_embeddings(texts=sentences, batch_size=batch_size)

def get_text_embeddings(self, texts: list[str], batch_size: int = 32):
all_text_embeddings = []

with torch.no_grad(), torch.cuda.amp.autocast():
for i in tqdm(range(0, len(texts), batch_size)):
batch_texts = texts[i : i + batch_size]
inputs = self.tokenizer(batch_texts)
text_outputs = self.model.encode_text(inputs.to(self.device))
all_text_embeddings.append(text_outputs.cpu())

all_text_embeddings = torch.cat(all_text_embeddings, dim=0)
return all_text_embeddings

def get_image_embeddings(
self, images: list[Image.Image] | DataLoader, batch_size: int = 32
):
all_image_embeddings = []
if isinstance(images, DataLoader):
import torchvision.transforms.functional as F

with torch.no_grad(), torch.cuda.amp.autocast():
for batch in tqdm(images):
# import pdb; pdb.set_trace()
inputs = torch.vstack(
[
self.img_preprocess(F.to_pil_image(b)).unsqueeze(0)
for b in batch
]
)
image_outputs = self.model.encode_image(inputs.to(self.device))
all_image_embeddings.append(image_outputs.cpu())
else:
with torch.no_grad(), torch.cuda.amp.autocast():
for i in tqdm(range(0, len(images), batch_size)):
batch_images = images[i : i + batch_size]
inputs = torch.vstack(
[self.img_preprocess(b) for b in batch_images]
)
image_outputs = self.model.encode_image(inputs.to(self.device))
all_image_embeddings.append(image_outputs.cpu())

all_image_embeddings = torch.cat(all_image_embeddings, dim=0)
return all_image_embeddings

def calculate_probs(self, text_embeddings, image_embeddings):
text_embeddings = text_embeddings / text_embeddings.norm(
dim=-1, keepdim=True
)
image_embeddings = image_embeddings / image_embeddings.norm(
dim=-1, keepdim=True
)
logits = torch.matmul(image_embeddings, text_embeddings.T)
probs = (logits * 100).softmax(dim=-1)
return probs

def get_fused_embeddings(
self,
texts: list[str] = None,
images: list[Image.Image] | DataLoader = None,
fusion_mode="sum",
batch_size: int = 32,
):
if texts is None and images is None:
raise ValueError("Either texts or images must be provided")

text_embeddings = None
image_embeddings = None

if texts is not None:
text_embeddings = self.get_text_embeddings(texts, batch_size)

if images is not None:
image_embeddings = self.get_image_embeddings(images, batch_size)

if text_embeddings is not None and image_embeddings is not None:
if len(text_embeddings) != len(image_embeddings):
raise ValueError(
"The number of texts and images must have the same length"
)
if fusion_mode == "sum":
fused_embeddings = text_embeddings + image_embeddings
else:
# to do: add other fusion mode
raise ValueError(
f"fusion mode {fusion_mode} hasn't been implemented"
)
return fused_embeddings
elif text_embeddings is not None:
return text_embeddings
elif image_embeddings is not None:
return image_embeddings

return EvaCLIPWrapper(**kwargs)


EVA02_CLIP_B_16 = ModelMeta(
loader=partial(
evaclip_loader,
model_name="EVA02-CLIP-B-16",
),
name="EVA02-CLIP-B-16",
languages=["eng_Latn"],
open_source=True,
revision="11afd202f2ae80869d6cef18b1ec775e79bd8d12",
release_date="2023-04-26",
)

EVA02_CLIP_L_14 = ModelMeta(
loader=partial(
evaclip_loader,
model_name="EVA02-CLIP-L-14",
),
name="EVA02-CLIP-L-14",
languages=["eng_Latn"],
open_source=True,
revision="11afd202f2ae80869d6cef18b1ec775e79bd8d12",
release_date="2023-04-26",
)

EVA02_CLIP_bigE_14 = ModelMeta(
loader=partial(
evaclip_loader,
model_name="EVA02-CLIP-bigE-14",
),
name="EVA02-CLIP-bigE-14",
languages=["eng_Latn"],
open_source=True,
revision="11afd202f2ae80869d6cef18b1ec775e79bd8d12",
release_date="2023-04-26",
)


EVA02_CLIP_bigE_14_plus = ModelMeta(
loader=partial(
evaclip_loader,
model_name="EVA02-CLIP-bigE-14-plus",
),
name="EVA02-CLIP-bigE-14-plus",
languages=["eng_Latn"],
open_source=True,
revision="11afd202f2ae80869d6cef18b1ec775e79bd8d12",
release_date="2023-04-26",
)
2 changes: 2 additions & 0 deletions mteb/models/overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
e5_instruct,
e5_models,
e5_v,
evaclip_models,
google_models,
gritlm_models,
gte_models,
Expand Down Expand Up @@ -57,6 +58,7 @@
e5_instruct,
e5_models,
e5_v,
evaclip_models,
google_models,
gritlm_models,
gte_models,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"dataset_revision": "77f3279092a1c1579b2250db8eafed0ad422088c",
"evaluation_time": 36.933613777160645,
"kg_co2_emissions": null,
"mteb_version": "1.14.21",
"scores": {
"test": [
{
"accuracy": 0.7848,
"f1": 0.7815922902217035,
"f1_weighted": 0.7830608860261875,
"hf_subset": "default",
"languages": [
"eng-Latn"
],
"main_score": 0.7848,
"scores_per_experiment": [
{
"accuracy": 0.7848,
"f1": 0.7815922902217035,
"f1_weighted": 0.7830608860261875
}
]
}
]
},
"task_name": "MNIST"
}
Loading

0 comments on commit 6652e56

Please sign in to comment.