Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GISTEmbedLoss #2535

Merged
merged 16 commits into from
Mar 15, 2024
Merged

Add GISTEmbedLoss #2535

merged 16 commits into from
Mar 15, 2024

Conversation

avsolatorio
Copy link
Contributor

This is the implementation for the GISTEmbed loss detailed in https://arxiv.org/abs/2402.16829. This implementation supports both anchor-positive pair inputs or triplets.

avsolatorio and others added 16 commits March 9, 2024 20:31
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Signed-off-by: Aivin V. Solatorio <[email protected]>
Also call the guide model with no_grad for better memory efficiency.
@tomaarsen
Copy link
Collaborator

Hello!

Apologies for simply pushing some commits, I should have asked beforehand.
I've slightly updated the loss function in a few ways in 00c6d8b:

  1. Add must_retokenize flag that is automatically set based on the tokenizer vocabs & max lengths. If the tokenizers differ, then in the forward I recompute the sentence features by decoding & retokenizing. This is quite unfortunate, as it introduces a slight unnecessary overhead and may even strip tokens that normally wouldn't be stripped (if the model has a smaller max token length than guide), but there's no convenient way to fix it without affecting the SentenceTransformer training code too heavily.
  2. Throw an error if the models are not both transformers-based. ST also supports e.g. word2vec models, but that won't work here. This is no big deal at all, as most people use transformers-based models.
  3. Use torch.no_grad() for the guide inference, this should help with memory somewhat, as it won't have to store the gradients for that inference.
  4. Add GISTEmbedLoss in the documentation in various places.
  5. Add a training_nli_v3.py training script that uses GISTEmbedLoss

I think this is just about ready to be merged. I'm happy for you to give this another once over, if you'd like. I'm quite excited to see this work implemented 🎉

Some experiments that I ran (note: I used the new WIP refactor, that's why I have W&B logs):

  • Spearman correlation based on cosine similarity using the STS Benchmark dev set, with distil-RoBERTa-base finetuned on AllNLI using MultipleNegativeRankingLoss vs GistEmbedLoss with all-MiniLM-L6-v2 as the guide model:
    image
  • Spearman correlation based on cosine similarity using the STS Benchmark dev set, with MPNet-base finetuned on AllNLI with MultipleNegativeRankingLoss vs GistEmbedLoss with all-MiniLM-L6-v2 as the guide model:
    image

Note: all-MiniLM-L6-v2 was not trained on the STS Benchmark development set (nor the training set, I believe), and it itself scores a ~82.03 on the test set (& probably like a 84 on the dev set). With other words: the improved performance is not because we're doing knowledge distillation by using some model that was trained on my development set.

  • Tom Aarsen

@avsolatorio
Copy link
Contributor Author

avsolatorio commented Mar 12, 2024

Hello @tomaarsen , thank you so much for doing all this work!

I didn't want to change much, but I totally missed the part of handling the different tokenization. Good catch there!

In my original implementation, I passed the raw text as part of the features so I could use the .encode method directly on it, which also covers the inference without the gradients.

Indeed, I think your approach of decoding and tokenizing strikes a reasonable balance between the size of the change needed. I'd be very happy if this gets merged! 🤗

@tomaarsen
Copy link
Collaborator

In my original implementation, I passed the raw text as part of the features so I could use the .encode method directly on it, which also covers the inference without the gradients.

That's indeed a more convenient approach, though I'm fairly satisfied with the one that we've got here, too. I'll have another look at this tomorrow & prepare to merge it!

  • Tom Aarsen

@tomaarsen tomaarsen merged commit 465d4f0 into UKPLab:master Mar 15, 2024
9 checks passed
@tomaarsen
Copy link
Collaborator

Thanks a bunch for this! I'm excited to see whether this will be picked up by the community, I hope so.

  • Tom Aarsen

@litvinovsky
Copy link

@avsolatorio @tomaarsen guys, could you please help me understand if it makes sense to use the model we are training on as a guided model for loss function? for example, I am finetuning distiluse-base-multilingual-cased-v2 like that

model = SentenceTransformer("distiluse-base-multilingual-cased-v2")

and then I am using gistembed loss like this

loss = losses.GISTEmbedLoss(model=model, guide=model, temperature=0.01)

Ive tested it and it seems work much better than if I use the same distiluse-base-multilingual-cased-v2 model, but not the one we are training

What are your thoughts on this? Could it be right approach? Thanks

@tomaarsen
Copy link
Collaborator

That's a very interesting approach! I haven't tried that myself. Out of curiosity, based on what metric does it work much better than using an "unchanging" version of the base model?
If it's just based on the loss being lower, then it's possible that the model is learning to ignore difficult samples in the loss. This might give better losses, but worse performance on actual benchmarks.
If it also works better on benchmarks like STSb during training, then this could be quite promising.

I'm also curious if @avsolatorio experimented with this.

  • Tom Aarsen

@avsolatorio
Copy link
Contributor Author

I'm also curious if @avsolatorio experimented with this.

@tomaarsen , I have not actually tried this, but this approach reminds me of something you mentioned to me last time 🤔: https://sbert.net/docs/package_reference/sentence_transformer/losses.html#onlinecontrastiveloss .

Since the guide model updates with the model being trained, I think that there could be some overfitting happening as you mentioned. So, the loss looks better but the actual performance may not.

I am keen on understanding how the performance was measured, @litvinovsky . That could give us a hint on what is happening here. 🙂

@litvinovsky
Copy link

@avsolatorio @tomaarsen thank you for your response. That's very helpful.

How I come up with the idea that it increases performance, because of two things

  1. using some batch (it could be any, let's say 8) the training loss going down more smoothly with less dispersion, i.e. there are no spikes in training loss during training and as result training loss seems to be going down quicker and more smoothly. In compare with using guided model as separate model (i.e. not model I am training) first 5-10 batches have a lot of training loss spikes which could be quite high in compare with using self model as guided model.

  2. The actual performance for my task increased and I found it during post training testing (manually). i.e. some anchor-positive pairs were not trained properly until I used the model I am training as guided model. In my opinion it happens because I have pretty much tricky search logic, like that

anchor, positive
[QUERY] buy, sell blablabla1
[QUERY] sell, buy blablabla1

and because of batch mechanism and negative samples selection using guide model, it skips this negative pair

[QUERY] buy, buy blablabla1

which I would like to not be skipped. It happens because this negative (buy, buy blablabla1) is closer than positive one (buy,sell blablabla1) but I do expect this negative pair to be not skipped, which not happens if I use separate model as guided model. If I use training model as guide model it helps to resolve this issue, so negative example above get applied properly. Why it helps? probably because most likely weights were already changed and buy, sell blablabla1 is closer than buy, buy blablabla1

not sure how this approach affects standard benchmarks.

@avsolatorio
Copy link
Contributor Author

This is fascinating! One idea I had was to use a separate model for the first epoch or N-timesteps, create a checkpoint of this model, then use the checkpoint for the trained model as the guide model for the next epoch, then repeat. Unfortunately, I didn't have the time to test this. But, your explanation seems to suggest that this could work!

And your point that the model learning from the data itself could result to outperforming a static guide model sounds quite sensible to me! 😀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants