Skip to content

Commit

Permalink
Bug fixes + hard negative (#23)
Browse files Browse the repository at this point in the history
## [0.2.0] - 2024-08-29
 
Large refactoring to adress several issues and add features. This release is not backward compatible with previous versions.
The models trained under this version will exhibit degraded performance if used with the previous version of the code and vice versa.

[Branch](#23)
 

### Added
- Added multiple training options for training with hard negatives. This leads to better model performance !
- Added options for restarting training from a checkpoint.

### Changed

- Optionally load ColPali models from pre-initialized backbones of the same shape to remove any stochastic initialization when loading adapters. This fixes [11](#11) and [17](#17).
 
### Fixed
- Set padding side to right in the tokenizer to fix misalignement issue between different query lengths in the same batch. Fixes [12](#12)
- Add 10 extra pad token by default to the query to act as reasoning buffers. This enables the above fix to be made without degrading performance and cleans up the old technique of using <unused> tokens.
  • Loading branch information
ManuelFay authored Aug 29, 2024
1 parent 270954a commit f961263
Show file tree
Hide file tree
Showing 30 changed files with 1,041 additions and 84 deletions.
40 changes: 40 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

# Change Log
All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).


## [0.2.0] - 2024-08-29

Large refactoring to adress several issues and add features. This release is not backward compatible with previous versions.
The models trained under this version will exhibit degraded performance if used with the previous version of the code and vice versa.

[Branch](https://github.com/illuin-tech/colpali/pull/23)


### Added
- Added multiple training options for training with hard negatives. This leads to better model performance !
- Added options for restarting training from a checkpoint.

### Changed

- Optionally load ColPali models from pre-initialized backbones of the same shape to remove any stochastic initialization when loading adapters. This fixes [11](https://github.com/illuin-tech/colpali/issues/11) and [17](https://github.com/illuin-tech/colpali/issues/17).

### Fixed
- Set padding side to right in the tokenizer to fix misalignement issue between different query lengths in the same batch. Fixes [12](https://github.com/illuin-tech/colpali/issues/12)
- Add 10 extra pad token by default to the query to act as reasoning buffers. This enables the above fix to be made without degrading performance and cleans up the old technique of using <unused> tokens.

## [0.1.1] - 2024-08-28

Minor patch release to fix packaging issues.

### Fixed

- [Branch](https://github.com/illuin-tech/colpali/commit/bd55e88c7af7069dde943f00665181fb94631cdd
Fix .gitignore to include all necessary files in the package.

## [0.1.0] - 2024-08-28

Initial code release corresponding to the paper.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
[[Blog Post]](https://huggingface.co/blog/manu/colpali)

> [!TIP]
> If you want to try the pre-trained ColPali on your own documents, you should use the [`vidore-benchmark`](https://github.com/illuin-tech/vidore-benchmark) repository. It comes with a Python package and a CLI tool for convenient evaluation.
> If you want to try the pre-trained ColPali on your own documents, you can use the [`vidore-benchmark`](https://github.com/illuin-tech/vidore-benchmark) repository. It comes with a Python package and a CLI tool for convenient evaluation. You can also use code provided in the model cards on the hub.
## Associated Paper

Expand All @@ -36,6 +36,11 @@ To keep a lightweight repository, only the essential packages were installed. In
pip install "colpali-engine[train]"
```


> [!TIP]
> For ColPali versions above v1.0, make sure to install the `colpali-engine` package from source or with a version above v0.2.0.

## Usage

The `scripts/` directory contains scripts to run training and inference.
Expand Down
42 changes: 36 additions & 6 deletions colpali_engine/dataset/custom_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ def __init__(
processor: ProcessorMixin = None,
tokenizer: PreTrainedTokenizer = None,
max_length: int = 2048,
add_suffix: bool = False,
add_suffix: bool = True,
):
self.processor = processor
self.tokenizer = tokenizer
self.image_token_id = None
self.max_length = max_length
self.suffix = ""
if add_suffix:
self.suffix = "\n" * 10

if tokenizer is None and processor is None:
raise ValueError("Either processor or tokenizer should be provided.")
Expand All @@ -32,6 +30,17 @@ def __init__(
if self.tokenizer and self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token

if self.processor.__class__.__name__ == "PaliGemmaProcessor":
if self.processor.tokenizer.padding_side != "right":
print("Setting padding side to right")
self.processor.tokenizer.padding_side = "right"

if add_suffix:
if self.tokenizer:
self.suffix = self.tokenizer.pad_token * 10
else:
self.suffix = self.processor.tokenizer.pad_token * 10

def __call__(self, examples):
if self.processor is None:
return self.forward_text(examples)
Expand Down Expand Up @@ -76,14 +85,14 @@ def forward_vision_idefics(self, examples):

text_query = None
if example["query"] is not None:
query = example["query"]
query = example["query"] + self.suffix
messages_query = [
{
"role": "user",
"content": [
{
"type": "text",
"text": f"Question: {query}<end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance><end_of_utterance>",
"text": f"Question: {query}",
},
],
},
Expand Down Expand Up @@ -133,6 +142,7 @@ def forward_vision_pali(self, examples):
texts_doc = []
texts_query = []
images = []
neg_images = []
for example in examples:

if example["image"] is None:
Expand All @@ -142,11 +152,17 @@ def forward_vision_pali(self, examples):
images.append(image)
texts_doc.append("Describe the image.")

if "neg_image" in example and example["neg_image"] is not None:
neg_image = example["neg_image"].convert("RGB")
neg_images.append(neg_image)

if example["query"] is None:
texts_query.append(None)
else:
query = example["query"]
query = f"Question: {query}<unused0><unused0><unused0><unused0><unused0>"
query = f"Question: {query}"
# add pad tokens
query += self.suffix
texts_query.append(query)

batch_doc = self.processor(
Expand All @@ -157,11 +173,22 @@ def forward_vision_pali(self, examples):
max_length=self.max_length + self.processor.image_seq_length,
)

batch_neg_doc = None
if len(neg_images) > 0:
batch_neg_doc = self.processor(
text=texts_doc,
images=neg_images,
return_tensors="pt",
padding="longest",
max_length=self.max_length + self.processor.image_seq_length,
)

batch_query = None
# check if some but not all queries are None
if all([t is None for t in texts_query]):
print("All queries are None. Returning None for all queries.")
elif any([t is None for t in texts_query]):
# if it's the first query that is not None but the rest are None, then it's hard negatives
raise ValueError("Some queries are None. This collator does not support None queries yet.")
else:
batch_query = self.processor(
Expand All @@ -181,6 +208,9 @@ def forward_vision_pali(self, examples):
if batch_query is not None:
batch_query = {f"query_{k}": v for k, v in batch_query.items()}
batch_doc.update(batch_query)
if batch_neg_doc is not None:
batch_neg_doc = {f"neg_doc_{k}": v for k, v in batch_neg_doc.items()}
batch_doc.update(batch_neg_doc)

return batch_doc

Expand Down
46 changes: 46 additions & 0 deletions colpali_engine/dataset/hard_neg_collator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from random import randint

from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer, ProcessorMixin

from .custom_collator import CustomCollator


class HardNegCollator(CustomCollator):
def __init__(
self,
processor: ProcessorMixin = None,
tokenizer: PreTrainedTokenizer = None,
max_length: int = 2048,
add_suffix: bool = True,
image_dataset: Dataset = None,
):
super().__init__(processor, tokenizer, max_length, add_suffix)
self.image_dataset = image_dataset
assert self.image_dataset is not None, "image_dataset must be provided"

def get_image_from_image_dataset(self, image_idx):
return self.image_dataset[int(image_idx)]["image"]

def __call__(self, examples):
# assert len(examples) == 1, "HardNegCollator only supports a single example at at time"

tmp_examples = examples
examples = []
for example in tmp_examples:
pos_image = self.get_image_from_image_dataset(example["gold_index"])
pos_query = example["query"]
# randomly sample a negative image amongst the top 10
neg_image = self.get_image_from_image_dataset(example["negs"][randint(0, 9)])
examples += [{"image": pos_image, "query": pos_query, "neg_image": neg_image}]

# reorder examples
if self.processor is None:
return self.forward_text(examples)
if self.processor.__class__.__name__ == "Idefics2Processor":
return self.forward_vision_idefics(examples)
if self.processor.__class__.__name__ == "PaliGemmaProcessor":
return self.forward_vision_pali(examples)
if self.processor.__class__.__name__ == "SiglipProcessor":
return self.forward_vision_siglip(examples)
raise ValueError("Processor not supported")
44 changes: 44 additions & 0 deletions colpali_engine/dataset/hard_neg_collator_docmatix_ir.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from datasets import Dataset, DatasetDict
from transformers import PreTrainedTokenizer, ProcessorMixin

from .custom_collator import CustomCollator


class HardNegCollator(CustomCollator):
def __init__(
self,
processor: ProcessorMixin = None,
tokenizer: PreTrainedTokenizer = None,
max_length: int = 2048,
add_suffix: bool = True,
image_dataset: Dataset = None,
):
super().__init__(processor, tokenizer, max_length, add_suffix)
self.image_dataset = image_dataset
assert self.image_dataset is not None, "image_dataset must be provided"

def get_image_from_docid(self, docid):
example_idx, image_idx = docid.split("_")
target_image = self.image_dataset[int(example_idx)]["images"][int(image_idx)]
return target_image

def __call__(self, examples):
tmp_examples = examples
examples = []
for example in tmp_examples:
pos_image = self.get_image_from_docid(example["positive_passages"][0]["docid"])
pos_query = example["query"]
neg_images_ids = [doc["docid"] for doc in example["negative_passages"][:1]]
neg_images = [self.get_image_from_docid(docid) for docid in neg_images_ids]

examples += [{"image": pos_image, "query": pos_query, "neg_image": neg_images[0]}]

if self.processor is None:
return self.forward_text(examples)
if self.processor.__class__.__name__ == "Idefics2Processor":
return self.forward_vision_idefics(examples)
if self.processor.__class__.__name__ == "PaliGemmaProcessor":
return self.forward_vision_pali(examples)
if self.processor.__class__.__name__ == "SiglipProcessor":
return self.forward_vision_siglip(examples)
raise ValueError("Processor not supported")
65 changes: 65 additions & 0 deletions colpali_engine/loss/colbert_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,40 @@ def forward(self, query_embeddings, doc_embeddings):
return loss


class ColbertPairwiseNegativeCELoss(torch.nn.Module):
def __init__(self, in_batch_term=False):
super().__init__()
self.ce_loss = CrossEntropyLoss()
self.in_batch_term = in_batch_term

def forward(self, query_embeddings, doc_embeddings, neg_doc_embeddings):
"""
query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)
neg_doc_embeddings: (batch_size, num_neg_doc_tokens, dim)
"""

# Compute the ColBERT scores
pos_scores = torch.einsum("bnd,bsd->bns", query_embeddings, doc_embeddings).max(dim=2)[0].sum(dim=1)
neg_scores = torch.einsum("bnd,bsd->bns", query_embeddings, neg_doc_embeddings).max(dim=2)[0].sum(dim=1)

loss = F.softplus(neg_scores - pos_scores).mean()

if self.in_batch_term:
scores = (
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
) # (batch_size, batch_size)

# Positive scores are the diagonal of the scores matrix.
pos_scores = scores.diagonal() # (batch_size,)
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)

loss += F.softplus(neg_scores - pos_scores).mean()

return loss / 2


class BiPairwiseCELoss(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -120,3 +154,34 @@ def forward(self, query_embeddings, doc_embeddings):
loss = F.softplus(neg_scores - pos_scores).mean()

return loss


class BiPairwiseNegativeCELoss(torch.nn.Module):
def __init__(self, in_batch_term=False):
super().__init__()
self.ce_loss = CrossEntropyLoss()
self.in_batch_term = in_batch_term

def forward(self, query_embeddings, doc_embeddings, neg_doc_embeddings):
"""
query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
neg_doc_embeddings: (batch_size, dim)
"""

# Compute the ColBERT scores
pos_scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings).diagonal()
neg_scores = torch.einsum("bd,cd->bc", query_embeddings, neg_doc_embeddings).diagonal()

loss = F.softplus(neg_scores - pos_scores).mean()

if self.in_batch_term:
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
# Positive scores are the diagonal of the scores matrix.
pos_scores = scores.diagonal() # (batch_size,)
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)

loss += F.softplus(neg_scores - pos_scores).mean()

return loss / 2
Loading

0 comments on commit f961263

Please sign in to comment.