Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about training RDM with main_rdm.py #18

Open
wyyy04 opened this issue Dec 19, 2023 · 2 comments
Open

Question about training RDM with main_rdm.py #18

wyyy04 opened this issue Dec 19, 2023 · 2 comments

Comments

@wyyy04
Copy link

wyyy04 commented Dec 19, 2023

Thanks for your excellent work! I am facing difficulties while training RDM, and I hope to receive your assistance.

When training RDM, in ddpm.py, at line 564 in the get_input function, the input x (32, 256, 256, 3) after feature extraction has a shape of (32, 197, 768), where 32 is the batch size. However, an error occurs at line 578 ”rep = self.pretrained_encoder.head(rep)“ with the following traceback:

File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/user/Diffusion_FSAR/rcg-main/rdm/models/diffusion/ddpm.py", line 578, in get_input
rep = self.pretrained_encoder.head(rep)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/container.py", line 217, in forward
input = module(input)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 171, in forward
return F.batch_norm(
File "/home/user/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/functional.py", line 2450, in batch_norm
return torch.batch_norm(
RuntimeError: running_mean should contain 197 elements not 4096

This appears to be a mismatch between the dimensions of the input x and the model "self.pretrained_encoder.head".
I am uncertain about the cause, and I am hopeful to receive your clarification and support. Thank you!

@LTH14
Copy link
Owner

LTH14 commented Dec 19, 2023

Thanks for your interest! Please make sure your timm version is 0.3.2, as later versions use a different forward_features implementation. #9 Please check this issue for a similar problem and solution.

@wyyy04
Copy link
Author

wyyy04 commented Dec 20, 2023

Thank you very much for your guidance. I have resolved the current issue by updating the “timm” to 0.3.2. Best wishes!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants