-
Notifications
You must be signed in to change notification settings - Fork 13
/
vision_retriever.py
167 lines (138 loc) · 5.37 KB
/
vision_retriever.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from colpali_engine.trainer.eval_utils import CustomRetrievalEvaluator
from datasets import Dataset
logger = logging.getLogger(__name__)
class VisionRetriever(ABC):
"""
Abstract class for vision retrievers used in the ViDoRe benchmark.
"""
@abstractmethod
def __init__(self, **kwargs):
"""
Initialize the VisionRetriever.
"""
pass
@property
@abstractmethod
def use_visual_embedding(self) -> bool:
"""
The child class should instantiate the `use_visual_embedding` property:
- True if the retriever uses native visual embeddings (e.g. JINA-Clip, ColPali)
- False if the retriever uses text embeddings and possibly VLM-generated captions (e.g. BM25).
"""
pass
@abstractmethod
def forward_queries(
self,
queries: Any,
batch_size: int,
**kwargs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Preprocess and forward pass the queries through the model.
NOTE: This method can either:
- return a single tensor where the first dimension corresponds to the number of queries.
- return a list of tensors where each tensor corresponds to a query.
"""
pass
@abstractmethod
def forward_passages(
self,
passages: Any,
batch_size: int,
**kwargs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Preprocess and forward pass the passages through the model. A passage can a text chunk (e.g. BM25) or
an image of a document page (e.g. ColPali).
NOTE: This method can either:
- return a single tensor where the first dimension corresponds to the number of passages.
- return a list of tensors where each tensor corresponds to a passage.
"""
pass
@abstractmethod
def get_scores(
self,
query_embeddings: Union[torch.Tensor, List[torch.Tensor]],
passage_embeddings: Union[torch.Tensor, List[torch.Tensor]],
batch_size: Optional[int] = None,
) -> torch.Tensor:
"""
Get the scores between queries and passages.
Inputs:
- query_embeddings: torch.Tensor (n_queries, emb_dim_query) or List[torch.Tensor] (emb_dim_query)
- passage_embeddings: torch.Tensor (n_passages, emb_dim_doc) or List[torch.Tensor] (emb_dim_doc)
- batch_size: Optional[int]
Output:
- scores: torch.Tensor (n_queries, n_passages)
"""
pass
def get_relevant_docs_results(
self,
ds: Dataset,
queries: List[str],
scores: torch.Tensor,
**kwargs,
) -> Tuple[Dict[str, float], Dict[str, Dict[str, float]]]:
"""
Get the relevant passages and the results from the scores.
NOTE: Override this method if the retriever has a different output format.
Outputs:
- relevant_docs: Dict[str, float]
{
"query_0": {"doc_0": 1},
"query_1": {"doc_1": 1},
...
}
- results: Dict[str, Dict[str, float]] with shape:
{
"query_0": {"doc_i": 19.125, "doc_1": 18.75, ...},
"query_1": {"doc_j": 17.25, "doc_1": 16.75, ...},
...
}
"""
relevant_docs = {}
results = {}
queries2filename = {query: image_filename for query, image_filename in zip(ds["query"], ds["image_filename"])}
passages2filename = {docidx: image_filename for docidx, image_filename in enumerate(ds["image_filename"])}
for query, score_per_query in zip(queries, scores):
relevant_docs[query] = {queries2filename[query]: 1}
for docidx, score in enumerate(score_per_query):
filename = passages2filename[docidx]
score_passage = float(score.item())
if query in results:
results[query][filename] = max(results[query].get(filename, 0), score_passage)
else:
results[query] = {filename: score_passage}
return relevant_docs, results
def compute_metrics(
self,
relevant_docs: Any,
results: Any,
**kwargs,
):
"""
Compute the MTEB metrics.
NOTE: Override this method if the retriever has a different evaluation metric.
"""
mteb_evaluator = CustomRetrievalEvaluator()
ndcg, _map, recall, precision, naucs = mteb_evaluator.evaluate(
relevant_docs,
results,
mteb_evaluator.k_values,
ignore_identical_ids=kwargs.get("ignore_identical_ids", True),
)
mrr = mteb_evaluator.evaluate_custom(relevant_docs, results, mteb_evaluator.k_values, "mrr")
scores = {
**{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()},
**{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()},
**{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()},
**{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()},
**{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr[0].items()},
**{f"naucs_at_{k.split('@')[1]}": v for (k, v) in naucs.items()},
}
return scores