Skip to content

Commit

Permalink
fixed torch 0.4.0 , "RuntimeError: Expected object of type torch.cuda… (
Browse files Browse the repository at this point in the history
facebookresearch#393)

Summary:
….LongTensor but found type torch.cuda.FloatTensor for argument facebookresearch#3 'index' " error

in the torch.__version__ == 0.4.0 ,
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
will return a float dtype Tensor, when exec the "line 321: fairseq/fairseq/models/fconv.py " will throw a RuntimeError
Pull Request resolved: facebookresearch#393

Differential Revision: D13276496

Pulled By: myleott

fbshipit-source-id: e7986246fbe2c79fff61bcab0e5bec9dd63e0afd
  • Loading branch information
linkerr authored and facebook-github-bot committed Nov 30, 2018
1 parent 7bbe528 commit 9dd8724
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion fairseq/sequence_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _generate(self, encoder_input, beam_size=None, maxlen=None, prefix_tokens=No
# compute the encoder output for each beam
encoder_out = model.encoder(**encoder_input)
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(src_tokens.device)
new_order = new_order.to(src_tokens.device).long()
encoder_out = model.encoder.reorder_encoder_out(encoder_out, new_order)
encoder_outs.append(encoder_out)

Expand Down

0 comments on commit 9dd8724

Please sign in to comment.