From bb4df37ab17e2e1fc136fd3a28ce3178506db654 Mon Sep 17 00:00:00 2001 From: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Date: Tue, 4 Apr 2023 05:38:26 +0800 Subject: [PATCH] 6268 enhance hovernet load pretrained function (#6269) Fixes #6268 . ### Description This PR enhances Hovernet's load pretrained function. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yiheng Wang --- monai/networks/nets/hovernet.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/monai/networks/nets/hovernet.py b/monai/networks/nets/hovernet.py index 323e107fd7..3ec1cea37e 100644 --- a/monai/networks/nets/hovernet.py +++ b/monai/networks/nets/hovernet.py @@ -443,6 +443,8 @@ class HoVerNet(nn.Module): adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format, this value should be `True`. + pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True. + It is used to extract the expected state dict. freeze_encoder: whether to freeze the encoder of the network. """ @@ -461,6 +463,7 @@ def __init__( dropout_prob: float = 0.0, pretrained_url: str | None = None, adapt_standard_resnet: bool = False, + pretrained_state_dict_key: str | None = None, freeze_encoder: bool = False, ) -> None: super().__init__() @@ -566,7 +569,7 @@ def __init__( if pretrained_url is not None: if adapt_standard_resnet: - weights = _remap_standard_resnet_model(pretrained_url) + weights = _remap_standard_resnet_model(pretrained_url, state_dict_key=pretrained_state_dict_key) else: weights = _remap_preact_resnet_model(pretrained_url) _load_pretrained_encoder(self, weights) @@ -609,6 +612,12 @@ def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict): model_dict.update(state_dict) model.load_state_dict(model_dict) + if len(state_dict.keys()) == 0: + warnings.warn( + "no key will be updated. Please check if 'pretrained_url' or `pretrained_state_dict_key` is correct." + ) + else: + print(f"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.") def _remap_preact_resnet_model(model_url: str): @@ -619,7 +628,9 @@ def _remap_preact_resnet_model(model_url: str): # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None)["desc"] + state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[ + "desc" + ] for key in list(state_dict.keys()): new_key = None if pattern_conv0.match(key): @@ -639,7 +650,7 @@ def _remap_preact_resnet_model(model_url: str): return state_dict -def _remap_standard_resnet_model(model_url: str): +def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = None): pattern_conv0 = re.compile(r"^conv1\.(.+)$") pattern_bn1 = re.compile(r"^bn1\.(.+)$") pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$") @@ -652,7 +663,9 @@ def _remap_standard_resnet_model(model_url: str): # download the pretrained weights into torch hub's default dir weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth") download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False) - state_dict = torch.load(weights_dir, map_location=None) + state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu")) + if state_dict_key is not None: + state_dict = state_dict[state_dict_key] for key in list(state_dict.keys()): new_key = None