-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[modeling_utils] torch_dtype/auto floating dtype fixes #17614
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't start throwing errors in those function but would return the last dtype in case everything is an int (for the unlikely case we get a quantized model).
Then I'd use this new get_parameter_first_float_dtype
instead of the next parameter hack (for instance when we set the self.config.torch_dtype
.
Thanks a lot for working on this!
Probably good to merge now, right? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
YEs, good for me if it's good for you :-)
…7614) * [modeling_utils] torch_dtype/auto fixes * add test * apply suggestions * add missing fallback * Renaming things * Use for else Co-authored-by: Sylvain Gugger <[email protected]>
…7614) * [modeling_utils] torch_dtype/auto fixes * add test * apply suggestions * add missing fallback * Renaming things * Use for else Co-authored-by: Sylvain Gugger <[email protected]>
As reported in #17583 not all model's have their first param of floating dtype, which lead to failures like:
Fixes: #17583
Possible additional TODO that wasn't part of the original report
@sgugger, we can sort out the saving side of things here as well - I already added an alternative
get_parameter_dtype
=>get_parameter_first_float_dtype
- but I wanted to check in with you if we replace all instances ofget_parameter_dtype
or only some.I didn't go ahead with doing that since we have a method called
dtype
which probably should callget_parameter_dtype
and addfloat_dtype
? Not sure - let's see what you think is the best way to proceed.