Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu) #38

Open
Paperbackpeople opened this issue Sep 19, 2023 · 1 comment

Comments

@Paperbackpeople
Copy link

Paperbackpeople commented Sep 19, 2023

When i ran model on imagenet: one error happened:
How can i tackle the problem about tensor? Dashen / masters help me , thanks!

Model unexpected keys:
['transformer.log_alpha', 'transformer.log_1_min_alpha', 'transformer.log_cumprod_alpha', 'transformer.log_1_min_cumprod_alpha']
Evaluate EMA model
/home/zvwang/miniconda3/envs/myenv/lib/python3.7/site-packages/torch/nn/functional.py:1967: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
File "test_imagenet.py", line 5, in
VQ_Diffusion_model.inference_generate_sample_with_class(407, truncation_rate=0.86, save_root="RESULT", batch_size=4)
File "/home/zvwang/VQ-Diffusion/inference_VQ_Diffusion.py", line 93, in inference_generate_sample_with_class
sample_type="top"+str(truncation_rate)+'r',
File "/home/zvwang/miniconda3/envs/myenv/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/zvwang/VQ-Diffusion/image_synthesis/modeling/models/conditional_dalle.py", line 184, in generate_content
content = self.content_codec.decode(trans_out['content_token']) #(8,1024)->(8,3,256,256)
File "/home/zvwang/VQ-Diffusion/image_synthesis/modeling/codecs/image_codec/taming_gumbel_vqvae.py", line 202, in decode
img_seq=self.quantize_to_full[img_seq].type_as(img_seq)
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)

@Paperbackpeople
Copy link
Author

Paperbackpeople commented Sep 19, 2023

I got it!
Change the method decode in the taming_gumbel_vqvae.py, line 200.

`
def decode(self, img_seq):
if self.quantize_number != 0:
self.quantize_to_full = self.quantize_to_full.to(img_seq.device)
img_seq=self.quantize_to_full[img_seq].type_as(img_seq)
b, n = img_seq.shape
img_seq = rearrange(img_seq, 'b (h w) -> b h w', h = int(math.sqrt(n)))

    x_rec = self.dec(img_seq)
    x_rec = self.postprocess(x_rec)
    return x_rec

`

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant