Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7047 simplify resnet pretrained #7095

Merged
merged 39 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
0613c2f
Simplify resnet pretrained flag
vgrau98 Oct 4, 2023
1005fe8
add tests + typos
vgrau98 Oct 6, 2023
8b75095
add MedicalNet resnet 3D pretrained models support
vgrau98 Oct 6, 2023
00ec022
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2023
f5e09b1
add optional import
vgrau98 Oct 7, 2023
c7a827b
simplify user pretrained weights loading
vgrau98 Oct 7, 2023
fa60fad
Manage MedicalNet resnet model validation with pretrained flag
vgrau98 Oct 7, 2023
e9ba99d
update resnet tests
vgrau98 Oct 7, 2023
0ddfb04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2023
06ed8b0
update resnet unit tests
vgrau98 Oct 8, 2023
955dcf1
fix incorrect optional import
vgrau98 Oct 8, 2023
ff7f6d3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2023
a707d1c
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 8, 2023
bb6830f
Line shortening
vgrau98 Oct 8, 2023
d21a022
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2023
5ac3627
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 9, 2023
3735733
update resnet tests and deployment files
vgrau98 Oct 9, 2023
8b6782a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2023
801edf6
Merge branch 'dev' into 7047-simplify-resnet-pretrained
wyli Oct 10, 2023
02a360a
[MONAI] code formatting
monai-bot Oct 10, 2023
1993f04
Update utils.py
vgrau98 Oct 10, 2023
e344771
Update utils.py
vgrau98 Oct 10, 2023
0ea682b
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 10, 2023
be516c0
Update resnet.py
vgrau98 Oct 10, 2023
3dd89de
Update utils.py
vgrau98 Oct 10, 2023
b74d48a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 10, 2023
46e7acc
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 11, 2023
6896216
fix lint error
vgrau98 Oct 16, 2023
b30ea73
minor refactos
vgrau98 Oct 16, 2023
9ee9e45
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 16, 2023
2945704
fix lint error
vgrau98 Oct 16, 2023
7a01bb5
fix typo
vgrau98 Oct 16, 2023
1a29bb2
Merge branch 'dev' into 7047-simplify-resnet-pretrained
vgrau98 Oct 16, 2023
8198203
fix mypy error
vgrau98 Oct 17, 2023
2930f97
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 17, 2023
7a0a2c3
fix lint error
vgrau98 Oct 17, 2023
5b279b6
Merge branch 'dev' into 7047-simplify-resnet-pretrained
wyli Oct 18, 2023
741970d
update unit test
wyli Oct 18, 2023
89eda2c
local torch.cuda check
wyli Oct 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ opencv-python-headless
onnx>=1.13.0
onnxruntime; python_version <= '3.10'
zarr
huggingface_hub
2 changes: 2 additions & 0 deletions monai/networks/nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
ResNet,
ResNetBlock,
ResNetBottleneck,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
resnet18,
resnet34,
Expand Down
134 changes: 123 additions & 11 deletions monai/networks/nets/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@

from __future__ import annotations

import logging
import re
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import Any

import torch
Expand All @@ -21,7 +24,13 @@
from monai.networks.layers.factories import Conv, Norm, Pool
from monai.networks.layers.utils import get_pool_layer
from monai.utils import ensure_tuple_rep
from monai.utils.module import look_up_option
from monai.utils.module import look_up_option, optional_import

hf_hub_download, _ = optional_import("huggingface_hub", name="hf_hub_download")
EntryNotFoundError, _ = optional_import("huggingface_hub.utils._errors", name="EntryNotFoundError")

MEDICALNET_HUGGINGFACE_REPO_BASENAME = "TencentMedicalNet/MedicalNet-Resnet"
MEDICALNET_HUGGINGFACE_FILES_BASENAME = "resnet_"

__all__ = [
"ResNet",
Expand All @@ -36,6 +45,8 @@
"resnet200",
]

logger = logging.getLogger(__name__)


def get_inplanes():
return [64, 128, 256, 512]
Expand Down Expand Up @@ -329,21 +340,54 @@ def _resnet(
block: type[ResNetBlock | ResNetBottleneck],
layers: list[int],
block_inplanes: list[int],
pretrained: bool,
pretrained: bool | str,
progress: bool,
**kwargs: Any,
) -> ResNet:
model: ResNet = ResNet(block, layers, block_inplanes, **kwargs)
if pretrained:
# Author of paper zipped the state_dict on googledrive,
# so would need to download, unzip and read (2.8gb file for a ~150mb state dict).
# Would like to load dict from url but need somewhere to save the state dicts.
raise NotImplementedError(
"Currently not implemented. You need to manually download weights provided by the paper's author"
" and load then to the model with `state_dict`. See https://github.com/Tencent/MedicalNet"
"Please ensure you pass the appropriate `shortcut_type` and `bias_downsample` args. as specified"
"here: https://github.com/Tencent/MedicalNet/tree/18c8bb6cd564eb1b964bffef1f4c2283f1ae6e7b#update20190730"
)
device = "cuda" if torch.cuda.is_available() else "cpu"
if isinstance(pretrained, str):
if Path(pretrained).exists():
logger.info(f"Loading weights from {pretrained}...")
model_state_dict = torch.load(pretrained, map_location=device)
else:
# Throw error
raise FileNotFoundError("The pretrained checkpoint file is not found")
else:
# Also check bias downsample and shortcut.
if kwargs.get("spatial_dims", 3) == 3:
if kwargs.get("n_input_channels", 3) == 1 and kwargs.get("feed_forward", True) is False:
search_res = re.search(r"resnet(\d+)", arch)
if search_res:
resnet_depth = int(search_res.group(1))
else:
raise ValueError("arch argument should be as 'resnet_{resnet_depth}")

# Check model bias_downsample and shortcut_type
bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)
if shortcut_type == kwargs.get("shortcut_type", "B") and (
bool(bias_downsample) == kwargs.get("bias_downsample", False) if bias_downsample != -1 else True
):
# Download the MedicalNet pretrained model
model_state_dict = get_pretrained_resnet_medicalnet(
resnet_depth, device=device, datasets23=True
)
else:
raise NotImplementedError(
f"Please set shortcut_type to {shortcut_type} and bias_downsample to"
f"{bool(bias_downsample) if bias_downsample!=-1 else 'True or False'}"
f"when using pretrained MedicalNet resnet{resnet_depth}"
)
else:
raise NotImplementedError(
"Please set n_input_channels to 1"
"and feed_forward to False in order to use MedicalNet pretrained weights"
)
else:
raise NotImplementedError("MedicalNet pretrained weights are only avalaible for 3D models")
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
model.load_state_dict(model_state_dict, strict=True)
return model


Expand Down Expand Up @@ -429,3 +473,71 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet("resnet200", ResNetBottleneck, [3, 24, 36, 3], get_inplanes(), pretrained, progress, **kwargs)


def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True):
"""
Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet

Args:
resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example.
datasets23: if True, get the weights trained on more datasets (23).
Not all depths are available. If not, standard weights are returned.

Returns:
Pretrained state dict

Raises:
huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub
NotImplementedError: if `resnet_depth` is not supported
"""

medicalnet_huggingface_repo_basename = "TencentMedicalNet/MedicalNet-Resnet"
medicalnet_huggingface_files_basename = "resnet_"
supported_depth = [10, 18, 34, 50, 101, 152, 200]

logger.info(
f"Loading MedicalNet pretrained model from https://huggingface.co/{medicalnet_huggingface_repo_basename}{resnet_depth}"
)

if resnet_depth in supported_depth:
filename = (
f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth"
if not datasets23
else f"{medicalnet_huggingface_files_basename}{resnet_depth}_23dataset.pth"
)
try:
pretrained_path = hf_hub_download(
repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename
)
except Exception:
if datasets23:
logger.info(f"{filename} not available for resnet{resnet_depth}")
filename = f"{medicalnet_huggingface_files_basename}{resnet_depth}.pth"
logger.info(f"Trying with {filename}")
pretrained_path = hf_hub_download(
repo_id=f"{medicalnet_huggingface_repo_basename}{resnet_depth}", filename=filename
)
else:
raise EntryNotFoundError(
f"{filename} not found on {medicalnet_huggingface_repo_basename}{resnet_depth}"
) from None
checkpoint = torch.load(pretrained_path, map_location=torch.device(device))
else:
raise NotImplementedError("Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]")
logger.info(f"{filename} downloaded")
return checkpoint.get("state_dict")


def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
"""
Return correct shortcut_type and bias_downsample
for pretrained MedicalNet weights according to resnet depth
"""
# After testing
# False: 10, 50, 101, 152, 200
# Any: 18, 34
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
return bias_downsample, shortcut_type
1 change: 1 addition & 0 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
onnxreference, _ = optional_import("onnx.reference")
onnxruntime, _ = optional_import("onnxruntime")


__all__ = [
"one_hot",
"predict_segmentation",
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ filelock!=3.12.0 # https://github.com/microsoft/nni/issues/5523
zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ all =
zarr
lpips==0.1.4
nvidia-ml-py
huggingface_hub
nibabel =
nibabel
ninja =
Expand Down
85 changes: 83 additions & 2 deletions tests/test_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,32 @@

from __future__ import annotations

import copy
import os
import re
import sys
import unittest
from typing import TYPE_CHECKING

import torch
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import ResNet, resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
from monai.networks.nets import (
ResNet,
get_medicalnet_pretrained_resnet_args,
get_pretrained_resnet_medicalnet,
resnet10,
resnet18,
resnet34,
resnet50,
resnet101,
resnet152,
resnet200,
)
from monai.networks.nets.resnet import ResNetBlock
from monai.utils import optional_import
from tests.utils import test_script_save
from tests.utils import equal_state_dict, skip_if_downloading_fails, skip_if_no_cuda, skip_if_quick, test_script_save

if TYPE_CHECKING:
import torchvision
Expand All @@ -30,6 +45,10 @@
else:
torchvision, has_torchvision = optional_import("torchvision")

has_hf_modules = "huggingface_hub" in sys.modules and "huggingface_hub.utils._errors" in sys.modules

# from torchvision.models import ResNet50_Weights, resnet50

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASE_1 = [ # 3D, batch 3, 2 input channel
Expand Down Expand Up @@ -159,9 +178,11 @@
]

TEST_CASES = []
PRETRAINED_TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
PRETRAINED_TEST_CASES.append([model, *case])
for case in [TEST_CASE_5, TEST_CASE_5_A, TEST_CASE_6, TEST_CASE_7]:
TEST_CASES.append([ResNet, *case])

Expand All @@ -171,6 +192,16 @@


class TestResNet(unittest.TestCase):
def setUp(self):
self.tmp_ckpt_filename = os.path.join("tests", "monai_unittest_tmp_ckpt.pth")

def tearDown(self):
if os.path.exists(self.tmp_ckpt_filename):
try:
os.remove(self.tmp_ckpt_filename)
except BaseException:
pass

@parameterized.expand(TEST_CASES)
def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
Expand All @@ -181,6 +212,56 @@ def test_resnet_shape(self, model, input_param, input_shape, expected_shape):
else:
self.assertTrue(result.shape in expected_shape)

@parameterized.expand(PRETRAINED_TEST_CASES)
@skip_if_quick
@skip_if_no_cuda
def test_resnet_pretrained(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)
# Save ckpt
torch.save(net.state_dict(), self.tmp_ckpt_filename)

cp_input_param = copy.copy(input_param)
# Custom pretrained weights
cp_input_param["pretrained"] = self.tmp_ckpt_filename
pretrained_net = model(**cp_input_param)
self.assertTrue(equal_state_dict(net.state_dict(), pretrained_net.state_dict()))

if has_hf_modules:
# True flag
cp_input_param["pretrained"] = True
resnet_depth = int(re.search(r"resnet(\d+)", model.__name__).group(1))

bias_downsample, shortcut_type = get_medicalnet_pretrained_resnet_args(resnet_depth)

# With orig. test cases
if (
input_param.get("spatial_dims", 3) == 3
and input_param.get("n_input_channels", 3) == 1
and input_param.get("feed_forward", True) is False
and input_param.get("shortcut_type", "B") == shortcut_type
and (
input_param.get("bias_downsample", True) == bool(bias_downsample) if bias_downsample != -1 else True
)
):
model(**cp_input_param)
else:
with self.assertRaises(NotImplementedError):
model(**cp_input_param)

# forcing MedicalNet pretrained download for 3D tests cases
cp_input_param["n_input_channels"] = 1
cp_input_param["feed_forward"] = False
cp_input_param["shortcut_type"] = shortcut_type
cp_input_param["bias_downsample"] = bool(bias_downsample) if bias_downsample != -1 else True
if cp_input_param.get("spatial_dims", 3) == 3:
with skip_if_downloading_fails():
pretrained_net = model(**cp_input_param).to(device)
medicalnet_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device=device)
medicalnet_state_dict = {
key.replace("module.", ""): value for key, value in medicalnet_state_dict.items()
}
self.assertTrue(equal_state_dict(pretrained_net.state_dict(), medicalnet_state_dict))

@parameterized.expand(TEST_SCRIPT_CASES)
def test_script(self, model, input_param, input_shape, expected_shape):
net = model(**input_param)
Expand Down
17 changes: 17 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,23 @@ def command_line_tests(cmd, copy_env=True):
raise RuntimeError(f"subprocess call error {e.returncode}: {errors}, {output}") from e


def equal_state_dict(st_1, st_2):
"""
Compare 2 torch state dicts.
"""
r = True
for key_st_1, val_st_1 in st_1.items():
if key_st_1 in st_2:
val_st_2 = st_2.get(key_st_1)
if not torch.equal(val_st_1, val_st_2):
r = False
break
else:
r = False
break
return r


TEST_TORCH_TENSORS: tuple = (torch.as_tensor,)
if torch.cuda.is_available():
gpu_tensor: Callable = partial(torch.as_tensor, device="cuda")
Expand Down