Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify existing resnet API for pretrained flag #7047

Closed
surajpaib opened this issue Sep 25, 2023 · 2 comments · Fixed by #7095
Closed

Simplify existing resnet API for pretrained flag #7047

surajpaib opened this issue Sep 25, 2023 · 2 comments · Fixed by #7095

Comments

@surajpaib
Copy link
Contributor

Is your feature request related to a problem? Please describe.
The pretrained flag in the resnet50 function currently only accepts a boolean value. If it's set to True, the function throws an error pointing the user to download MedicalNet weights. Then, the user has to download these and manually load the state_dict into the model.

Describe the solution you'd like
This process can be simplified largely by allowing the pretrained flag to take str values that point to a path and automatically loading state_dict from these paths.

model = resnet50_monai(pretrained=False, n_input_channels=n_input_channels, widen_factor=widen_factor, conv1_t_stride=conv1_t_stride, feed_forward=feed_forward, bias_downsample=bias_downsample)
model = model.to(device)
if pretrained:
    if Path(pretrained).exists():
        logger.info(f"Loading weights from {weights_path}...")
        checkpoint = torch.load(pretrained, map_location=device)
    else:
         ### Throw error

     if "state_dict" in checkpoint:
        model_state_dict = checkpoint["state_dict"]
        model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
        
    model.load_state_dict(model_state_dict, strict=True)

This can help with loading the MedicalNet weights more easily and potentially open up this API to loading other models that have been trained with monai as well. If we don't want to support duck typing, then we could also add another flag to the function.

Additional context

I do understand that monai is moving in a different direction with loading bundles etc. but since this pretrained flag still exists, it might be useful to simplify it.

@surajpaib surajpaib changed the title Simplify existing Resn Simplify existing resnet API for pretrained flag Sep 25, 2023
vgrau98 added a commit to vgrau98/MONAI that referenced this issue Oct 6, 2023
Fixes: Project-MONAI#7047

Original behaviour did not support True pretrained flag.

Signed-off-by: vgrau98 <[email protected]>
@vgrau98
Copy link
Contributor

vgrau98 commented Oct 6, 2023

Using gdwon it is possible to download pretrained models developed by https://github.com/Tencent/MedicalNet and published on google drive: https://drive.google.com/u/0/uc?id=13tnSvXY7oDIEloNFiGTsjUIYfS3g3BfG&export=download.
However, there is only one link to all the pre-trained models and other data (abour 2 GB).
Would it be ok if the flag is set to True to download the weights from the drive link even if other useless data is downloaded too?

@vgrau98
Copy link
Contributor

vgrau98 commented Oct 6, 2023

Finally pretrained models seem to be on hugging face: https://huggingface.co/TencentMedicalNet

wyli pushed a commit that referenced this issue Oct 18, 2023
Fixes #7047

### Description

Resnet did not support `True` value (not implemented ) for its
pretrained flag.
2 implemented behavior: 
- When pretrained is True, download weights from
https://huggingface.co/TencentMedicalNet
- When pretrained is a string, loads weights from the path defined by
the string

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: vgrau98 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants