Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For model long-t5-tglobal-x, fix 'float' object cannot be interpreted as an integer #17777

Merged
merged 1 commit into from
Jun 20, 2022

Conversation

bjascob
Copy link
Contributor

@bjascob bjascob commented Jun 19, 2022

On line 180, torch.tensor(-1.0, dtype=global_block_ids.dtype) gives the error TypeError: 'float' object cannot be interpreted as an integer . This is because the dtype here is int64. For dtype=int64, this needs to simply be -1.

This impacts the long-t5-tglogbal-x model. It does not impact the long-t5-local-x version which does not appear to call this line in the code.

The torch version where I see this is 1.11.0+cu113. I'm not certain if older, or non-gpu versions of torch allowed this but 1.11.0+cu113 does not.

Note that torch does not complain when casting an int to a float so it should be safe to change this to -1 even if there are occasions where global_block_ids.dtype is a float.

What does this PR do?

Fixes # (no issue # created).
There is a simple error in the code where torch fails when trying to create a constant int64 tensor using -1.0 instead of -1.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

This model is new. I would suggest someone from the original upload team review this. Here are the first 3 in the file history..
@stancld, @PhungVanDuy, @sgugger

On line 180, `torch.tensor(-1.0, xxx)` gives the error "TypeError: 'float' object cannot be interpreted as an integer" 
This is because the dtype here is `int64`.  For `dtype=int64`, this needs to simply be `-1`.  
This impacts the long-t5-tglogbal-x model.  It does not impact the long-t5-local-x version which does not appear to call this line.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 19, 2022

The documentation is not available anymore as the PR was closed or merged.

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented Jun 20, 2022

@patil-suraj could you maybe take a look here? :-)

Also cc @stancld in case you're interested and have an idea what the problem could be

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried this torch==1.11.0+cu113 but it does not give any error. However it does throw a DeprecationWarning

DeprecationWarning: an integer is required (got type float).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.
    _global_block_ids_lower_bound = torch.tensor(-1.0, dtype=global_block_ids.dtype, device=global_block_ids.device)

And you are right that it should actually be an int, so this looks good to me. I just ran slow tests with and they pass. Good for you @patrickvonplaten ?

@bjascob
Copy link
Contributor Author

bjascob commented Jun 20, 2022

Interesting. Looks like this is a change in python, not torch. UBT 22.04 uses Python 3.10.4 and this is fully broken for that version.

@patrickvonplaten
Copy link
Contributor

Great looks good to me than as well!

@patrickvonplaten patrickvonplaten merged commit da27c4b into huggingface:main Jun 20, 2022
younesbelkada pushed a commit to younesbelkada/transformers that referenced this pull request Jun 25, 2022
On line 180, `torch.tensor(-1.0, xxx)` gives the error "TypeError: 'float' object cannot be interpreted as an integer" 
This is because the dtype here is `int64`.  For `dtype=int64`, this needs to simply be `-1`.  
This impacts the long-t5-tglogbal-x model.  It does not impact the long-t5-local-x version which does not appear to call this line.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants