Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Gradient Accumulation issue #34191

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open

Fix Gradient Accumulation issue #34191

wants to merge 49 commits into from

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 16, 2024

What does this PR do?

First draft

End goal is to make it easy for anyone to:

  • change the loss for his model
  • contribute a new loss for a model (like vision model, ENCODEC etc)
  • allow passing arbitrary kwargs, interfacing

TODO:

  • Fix deformable detr loss computation

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for coming forward with this fix so quickly. There is probably not much I can help with, but I took a look and added some comments.

if loss_type is None:
raise ValueError(
"We could not determine which loss function to use."
f"based on the the class name. Make sure you add `{ self.__class__.__name__}` to the `LOSS_MAPPING`"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think users could be really confused when they read this message. They don't know what and where LOSS_MAPPING is and they don't know what value they should add there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep will update this

)
if loss_type not in LOSS_MAPPING and getattr(self.config, "loss_type", None) is not None:
raise ValueError(
f"`loss_type={loss_type}` was set in the config but it is unrecognised"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar issue with potential confusion.



LOSS_MAPPING = {
"ForCausalLM": DefaultCrossEntropyLoss,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just wondering aloud: Instead of matching based on class name, could we do a mapping from class to loss, and then do something like:

for key, val in LOSS_MAPPING.items():
    if isinstance(self, key):
        loss = val
        break
else:  # no break
    # raise error

I assume the matching exists for custom classes that are out there in the wild. If name is a more reliable predictor than inheritance or if I'm misunderstanding, please disregard my comment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will look into improving, but this looks super slow

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would only be run once because of the LRU cache, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we get the classname from the class itself and want to have good defaults instead of matching against full name!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good! I'll ask Daniel if he's down to review, it would be very useful to have his opinion. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants