Skip to content

Commit

Permalink
6268 enhance hovernet load pretrained function (#6269)
Browse files Browse the repository at this point in the history
Fixes #6268 .

### Description

This PR enhances Hovernet's load pretrained function.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Yiheng Wang <[email protected]>
  • Loading branch information
yiheng-wang-nv authored Apr 3, 2023
1 parent 129c097 commit bb4df37
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions monai/networks/nets/hovernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ class HoVerNet(nn.Module):
adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this
value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format,
this value should be `True`.
pretrained_state_dict_key: this arg is used when `pretrained_url` is provided and `adapt_standard_resnet` is True.
It is used to extract the expected state dict.
freeze_encoder: whether to freeze the encoder of the network.
"""

Expand All @@ -461,6 +463,7 @@ def __init__(
dropout_prob: float = 0.0,
pretrained_url: str | None = None,
adapt_standard_resnet: bool = False,
pretrained_state_dict_key: str | None = None,
freeze_encoder: bool = False,
) -> None:
super().__init__()
Expand Down Expand Up @@ -566,7 +569,7 @@ def __init__(

if pretrained_url is not None:
if adapt_standard_resnet:
weights = _remap_standard_resnet_model(pretrained_url)
weights = _remap_standard_resnet_model(pretrained_url, state_dict_key=pretrained_state_dict_key)
else:
weights = _remap_preact_resnet_model(pretrained_url)
_load_pretrained_encoder(self, weights)
Expand Down Expand Up @@ -609,6 +612,12 @@ def _load_pretrained_encoder(model: nn.Module, state_dict: OrderedDict | dict):

model_dict.update(state_dict)
model.load_state_dict(model_dict)
if len(state_dict.keys()) == 0:
warnings.warn(
"no key will be updated. Please check if 'pretrained_url' or `pretrained_state_dict_key` is correct."
)
else:
print(f"{len(state_dict)} out of {len(model_dict)} keys are updated with pretrained weights.")


def _remap_preact_resnet_model(model_url: str):
Expand All @@ -619,7 +628,9 @@ def _remap_preact_resnet_model(model_url: str):
# download the pretrained weights into torch hub's default dir
weights_dir = os.path.join(torch.hub.get_dir(), "preact-resnet50.pth")
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
state_dict = torch.load(weights_dir, map_location=None)["desc"]
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))[
"desc"
]
for key in list(state_dict.keys()):
new_key = None
if pattern_conv0.match(key):
Expand All @@ -639,7 +650,7 @@ def _remap_preact_resnet_model(model_url: str):
return state_dict


def _remap_standard_resnet_model(model_url: str):
def _remap_standard_resnet_model(model_url: str, state_dict_key: str | None = None):
pattern_conv0 = re.compile(r"^conv1\.(.+)$")
pattern_bn1 = re.compile(r"^bn1\.(.+)$")
pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$")
Expand All @@ -652,7 +663,9 @@ def _remap_standard_resnet_model(model_url: str):
# download the pretrained weights into torch hub's default dir
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
state_dict = torch.load(weights_dir, map_location=None)
state_dict = torch.load(weights_dir, map_location=None if torch.cuda.is_available() else torch.device("cpu"))
if state_dict_key is not None:
state_dict = state_dict[state_dict_key]

for key in list(state_dict.keys()):
new_key = None
Expand Down

0 comments on commit bb4df37

Please sign in to comment.