Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting "RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor" #240

Closed
Jeriousman opened this issue Mar 17, 2022 · 8 comments · Fixed by #260
Labels
bug Something isn't working

Comments

@Jeriousman
Copy link

Describe the bug(问题描述)
history = model.fit(x, y, batch_size=256, epochs=20, verbose=1, validation_split=0.4, shuffle=True)
When I try model.fit for DIEN model with run_dien.py of your default example, it works when I set device to cpu but with cuda, I get this error below.

cuda ready...
0it [00:00, ?it/s]cuda:0
Train on 4 samples, validate on 0 samples, 2 steps per epoch

Traceback (most recent call last):

  File "<ipython-input-1-e985ce1c0aa2>", line 69, in <module>
    history = model.fit(x, y, batch_size=2, epochs=10, verbose=1, validation_split=0, shuffle=False)

  File "/home/hojun/anaconda3/envs/ai/lib/python3.6/site-packages/deepctr_torch/models/basemodel.py", line 244, in fit
    y_pred = model(x).squeeze()

  File "/home/hojun/anaconda3/envs/ai/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)

  File "/home/hojun/anaconda3/envs/ai/lib/python3.6/site-packages/deepctr_torch/models/dien.py", line 92, in forward
    masked_interest, aux_loss = self.interest_extractor(keys_emb, keys_length, neg_keys_emb)

  File "/home/hojun/anaconda3/envs/ai/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)

  File "/home/hojun/anaconda3/envs/ai/lib/python3.6/site-packages/deepctr_torch/models/dien.py", line 221, in forward
    enforce_sorted=False)

  File "/home/hojun/anaconda3/envs/ai/lib/python3.6/site-packages/torch/nn/utils/rnn.py", line 244, in pack_padded_sequence
    _VF._pack_padded_sequence(input, lengths, batch_first)

RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor

So I tried lengths.cpu(), lengths.to('cpu') and all of them couldnt solve the problem. Can you please provide a solution?

Operating environment(运行环境):

  • python version 3.6
  • torch version 1.7.1
  • deepctr-torch version 0.2.7
@zanshuxun
Copy link
Collaborator

zanshuxun commented Apr 4, 2022

In the new version of PyTorch, the input parameter lengths of torch.nn.utils.rnn.pack_padded_sequence has been changed: (Details can be found in pytorch/pytorch#43227)

image

image

@zanshuxun zanshuxun added the bug Something isn't working label Apr 4, 2022
@Jeriousman
Copy link
Author

Obviously I tried. But as I said, none of them worked. But I had to get way down to torch 1.4.0 to get it done.

@zanshuxun
Copy link
Collaborator

Obviously I tried. But as I said, none of them worked. But I had to get way down to torch 1.4.0 to get it done.

Where did you use .cpu()? Did the device of the tensor change after you use .cpu()?

@Jeriousman
Copy link
Author

Jeriousman commented Apr 6, 2022

Yes. I did. as I mentioned below.

So I tried lengths.cpu(), lengths.to('cpu') and all of them couldnt solve the problem

The length part is the one I tried to put into cpu as the exact same persons mruberry and ngimel suggested. That was the first web page I found as well when I was trying to fix the problem.

@zanshuxun
Copy link
Collaborator

  1. Where did you use .cpu()?

Could you tell me the corresponding line number in the code? for example:

packed_keys = pack_padded_sequence(masked_keys, lengths=masked_keys_length, batch_first=True,
enforce_sorted=False)

Did you set masked_keys_length.cpu() here?

or other places like

packed_keys = pack_padded_sequence(keys, lengths=keys_length, batch_first=True, enforce_sorted=False)

or

packed_interests = pack_padded_sequence(interests, lengths=keys_length, batch_first=True,

  1. Did the device of the tensor change after you use .cpu()?

Could you print the device of the tensor before and after your .cpu()? To figure out whether it works. If it works, there should not be the error "RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor, but got 1D cuda:0 Long tensor"

@Jeriousman
Copy link
Author

Jeriousman commented Apr 7, 2022

Hello. I have done for all the pack_padded_sequences for example, masked_keys_length.cpu(). When I did this, it was converted to cpu one. But the error was still there. For me, only downgrading torch version worked. It is strange tho. That was the whole point of the question. It became CPU tensor, but it didnt work. Is it working on your side?

@zanshuxun zanshuxun pinned this issue Jun 20, 2022
@zanshuxun
Copy link
Collaborator

@Jeriousman I add .cpu() in all the pack_padded_sequence(...) in dien.py, then it works. Maybe you missed something. Could you paste the traceback info and your dien.py file?

shenweichen pushed a commit that referenced this issue Aug 15, 2022
1. Add multi-task models: SharedBottom, ESMM, MMOE, PLE
2. Bugfix:
#240
#232
shenweichen added a commit that referenced this issue Oct 21, 2022
* add multitask mdoels

1. Add multi-task models: SharedBottom, ESMM, MMOE, PLE
2. Bugfix:
#240
#232

* support python 3.9/3.10 (#259)
* fix: variable name typo (#257)
Co-authored-by: Jason Zan <[email protected]>
Co-authored-by: Yi-Xuan Xu <[email protected]>
@shenweichen shenweichen unpinned this issue Oct 22, 2022
@umanniyaz
Copy link

hi any one tell me same error on torch==1.8.0 , how to handle this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants