diff --git a/colpali_engine/loss/colbert_loss.py b/colpali_engine/loss/colbert_loss.py index c0dad1eb..655237f3 100644 --- a/colpali_engine/loss/colbert_loss.py +++ b/colpali_engine/loss/colbert_loss.py @@ -154,3 +154,36 @@ def forward(self, query_embeddings, doc_embeddings): loss = F.softplus(neg_scores - pos_scores).mean() return loss + + +class BiPairwiseNegativeCELoss(torch.nn.Module): + def __init__(self, in_batch_term=False): + super().__init__() + self.ce_loss = CrossEntropyLoss() + self.in_batch_term = in_batch_term + + def forward(self, query_embeddings, doc_embeddings, neg_doc_embeddings): + """ + query_embeddings: (batch_size, dim) + doc_embeddings: (batch_size, dim) + neg_doc_embeddings: (batch_size, dim) + """ + + # Compute the ColBERT scores + pos_scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings) + neg_scores = torch.einsum("bd,cd->bc", query_embeddings, neg_doc_embeddings) + loss = F.softplus(neg_scores - pos_scores).mean() + + if self.in_batch_term: + scores = ( + torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2) + ) # (batch_size, batch_size) + + # Positive scores are the diagonal of the scores matrix. + pos_scores = scores.diagonal() # (batch_size,) + neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size) + neg_scores = neg_scores.max(dim=1)[0] # (batch_size,) + + loss += F.softplus(neg_scores - pos_scores).mean() + + return loss / 2 \ No newline at end of file diff --git a/scripts/configs/pali/train_bipali_pairwise_model.yaml b/scripts/configs/pali/train_bipali_pairwise_model.yaml new file mode 100644 index 00000000..9bb85442 --- /dev/null +++ b/scripts/configs/pali/train_bipali_pairwise_model.yaml @@ -0,0 +1,41 @@ +config: + (): colpali_engine.utils.train_colpali_engine_models.ColModelTrainingConfig + output_dir: !path ../../../models/right_pad/train_bipali_pairwise + processor: + () : colpali_engine.utils.wrapper.AutoProcessorWrapper + pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" + max_length: 50 + model: + (): colpali_engine.utils.wrapper.AutoColModelWrapper + pretrained_model_name_or_path: "./models/paligemma-3b-mix-448" + training_objective: "biencoder_mean" + # attn_implementation: "eager" + torch_dtype: !ext torch.bfloat16 +# device_map: "auto" +# quantization_config: +# (): transformers.BitsAndBytesConfig +# load_in_4bit: true +# bnb_4bit_quant_type: "nf4" +# bnb_4bit_compute_dtype: "bfloat16" +# bnb_4bit_use_double_quant: true + + dataset_loading_func: !ext colpali_engine.utils.dataset_transformation.load_train_set + eval_dataset_loader: !import ../data/test_data.yaml + + max_length: 50 + run_eval: true + add_suffix: true + loss_func: + (): colpali_engine.loss.colbert_loss.BiPairwiseCELoss + tr_args: !import ../tr_args/default_tr_args.yaml + peft_config: + (): peft.LoraConfig + r: 32 + lora_alpha: 32 + lora_dropout: 0.1 + init_lora_weights: "gaussian" + bias: "none" + task_type: "FEATURE_EXTRACTION" + target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$)' + # target_modules: '(.*(language_model).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$' +