-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix backward pass for cached losses #3114
Conversation
@@ -230,14 +230,15 @@ def embed_minibatch_iter( | |||
def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor]]) -> Tensor: | |||
"""Generalized function to calculate the cross-entropy loss and cache the gradients wrt. the embeddings.""" | |||
loss = self.calculate_loss(reps, reps_guided) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss = self.calculate_loss(reps, reps_guided) | |
loss = self.calculate_loss(reps, reps_guided, with_backward=True) |
I believe this is missing for CGIST.
My initial tests with CMNRL and CMNRL + Matryoshka are promising, and I can indeed also reproduce the high memory usage on Matryoshka on top of CMNRL only seems to add about 20%-25% training time, which seems fine. CMNRL does seem a decent bit slower than pure MNRL (more than the 10-20% that I thought it was), but based on your changes in this PR and the previous, that shouldn't be related to you at all.
|
Nice! I assume that 0292b9b means that we perform evaluations more efficiently? I think this is ready to go. |
Yes, with 0292b9b the loss can now also be called with torch.no_grad(). |
Thanks a bunch for having another look at this @Marcel256 🤗
|
Here is a draft for a fix of the backward pass in the cached losses, while still maintaining the compatibility with the matryoshka loss. The problem is a bit more difficult than I originally thought because we need to detach the tensor before doing the minibatch loss computation.
If anyone has suggestions for improvements, let me know 😀