Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[modeling_utils] torch_dtype/auto floating dtype fixes #17614

Merged
merged 6 commits into from
Jun 9, 2022
Merged

Conversation

stas00
Copy link
Contributor

@stas00 stas00 commented Jun 8, 2022

As reported in #17583 not all model's have their first param of floating dtype, which lead to failures like:

$ python -c 'from transformers import AutoModel; AutoModel.from_pretrained("hf-internal-testing/tiny-bert-for-token-classification", torch_dtype="auto")'
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/models/auto/auto_factory.py", line 446, in from_pretrained
    return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/modeling_utils.py", line 2004, in from_pretrained
    dtype_orig = cls._set_default_torch_dtype(torch_dtype)
  File "/mnt/nvme0/code/huggingface/transformers-master/src/transformers/modeling_utils.py", line 980, in _set_default_torch_dtype
    raise ValueError(
ValueError: Can't instantiate BertModel model under dtype=torch.int64 since it is not a floating point dtype
  1. This PR fixes that by searching for the first floating dtype instead.
  2. adds test that failed before this PR

Fixes: #17583


Possible additional TODO that wasn't part of the original report

@sgugger, we can sort out the saving side of things here as well - I already added an alternative get_parameter_dtype => get_parameter_first_float_dtype - but I wanted to check in with you if we replace all instances of get_parameter_dtype or only some.

I didn't go ahead with doing that since we have a method called dtype which probably should call get_parameter_dtype and add float_dtype? Not sure - let's see what you think is the best way to proceed.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 8, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't start throwing errors in those function but would return the last dtype in case everything is an int (for the unlikely case we get a quantized model).

Then I'd use this new get_parameter_first_float_dtype instead of the next parameter hack (for instance when we set the self.config.torch_dtype.

Thanks a lot for working on this!

src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
tests/test_modeling_common.py Show resolved Hide resolved
@stas00
Copy link
Contributor Author

stas00 commented Jun 9, 2022

Probably good to merge now, right?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

YEs, good for me if it's good for you :-)

@stas00 stas00 merged commit 75343de into main Jun 9, 2022
@stas00 stas00 deleted the torch_dtype_auto2 branch June 9, 2022 17:18
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
…7614)

* [modeling_utils] torch_dtype/auto fixes

* add test

* apply suggestions

* add missing fallback

* Renaming things

* Use for else

Co-authored-by: Sylvain Gugger <[email protected]>
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jun 16, 2022
…7614)

* [modeling_utils] torch_dtype/auto fixes

* add test

* apply suggestions

* add missing fallback

* Renaming things

* Use for else

Co-authored-by: Sylvain Gugger <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Probable bug with torch_dtype="auto"
3 participants