-
Notifications
You must be signed in to change notification settings - Fork 26.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix Gradient Accumulation issue #34191
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for coming forward with this fix so quickly. There is probably not much I can help with, but I took a look and added some comments.
if loss_type is None: | ||
raise ValueError( | ||
"We could not determine which loss function to use." | ||
f"based on the the class name. Make sure you add `{ self.__class__.__name__}` to the `LOSS_MAPPING`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think users could be really confused when they read this message. They don't know what and where LOSS_MAPPING is and they don't know what value they should add there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep will update this
) | ||
if loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None: | ||
raise ValueError( | ||
f"`loss_type={loss_type}` was set in the config but it is unrecognised" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar issue with potential confusion.
src/transformers/loss_utils.py
Outdated
|
||
|
||
LOSS_MAPPING = { | ||
"ForCausalLM": DefaultCrossEntropyLoss, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just wondering aloud: Instead of matching based on class name, could we do a mapping from class to loss, and then do something like:
for key, val in LOSS_MAPPING.items():
if isinstance(self, key):
loss = val
break
else: # no break
# raise error
I assume the matching exists for custom classes that are out there in the wild. If name is a more reliable predictor than inheritance or if I'm misunderstanding, please disregard my comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will look into improving, but this looks super slow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would only be run once because of the LRU cache, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also we get the classname from the class itself and want to have good defaults instead of matching against full name!
Co-authored-by: Kashif Rasul <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very good! I'll ask Daniel if he's down to review, it would be very useful to have his opinion. Thanks
What does this PR do?
First draft
End goal is to make it easy for anyone to:
TODO: