Skip to content

Commit

Permalink
Update mednext implementations
Browse files Browse the repository at this point in the history
Signed-off-by: Suraj Pai <[email protected]>
  • Loading branch information
surajpaib committed Sep 26, 2024
1 parent 93e782f commit e146aaf
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 94 deletions.
147 changes: 54 additions & 93 deletions monai/networks/nets/mednext.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,128 +266,89 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:


# Define the MedNeXt variants as reported in 10.48550/arXiv.2303.09975
class MedNeXtSmall(MedNeXt):
"""MedNeXt Small (S) configuration"""
def create_mednext(
variant: str,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 2,
kernel_size: int = 3,
deep_supervision: bool = False,
) -> MedNeXt:
"""
Factory method to create MedNeXt variants.
def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 2,
kernel_size: int = 3,
deep_supervision: bool = False,
):
super().__init__(
spatial_dims=spatial_dims,
init_filters=32,
in_channels=in_channels,
out_channels=out_channels,
Args:
variant (str): The MedNeXt variant to create ('S', 'B', 'M', or 'L').
spatial_dims (int): Number of spatial dimensions. Defaults to 3.
in_channels (int): Number of input channels. Defaults to 1.
out_channels (int): Number of output channels. Defaults to 2.
kernel_size (int): Kernel size for convolutions. Defaults to 3.
deep_supervision (bool): Whether to use deep supervision. Defaults to False.
Returns:
MedNeXt: The specified MedNeXt variant.
Raises:
ValueError: If an invalid variant is specified.
"""
common_args = {
"spatial_dims": spatial_dims,
"in_channels": in_channels,
"out_channels": out_channels,
"kernel_size": kernel_size,
"deep_supervision": deep_supervision,
"use_residual_connection": True,
"norm_type": "group",
"grn": False,
"init_filters": 32,
}

if variant.upper() == "S":
return MedNeXt(
encoder_expansion_ratio=2,
decoder_expansion_ratio=2,
bottleneck_expansion_ratio=2,
kernel_size=kernel_size,
deep_supervision=deep_supervision,
use_residual_connection=True,
blocks_down=(2, 2, 2, 2),
blocks_bottleneck=2,
blocks_up=(2, 2, 2, 2),
norm_type="group",
grn=False,
**common_args,
)


class MedNeXtBase(MedNeXt):
"""MedNeXt Base (B) configuration"""

def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 2,
kernel_size: int = 3,
deep_supervision: bool = False,
):
super().__init__(
spatial_dims=spatial_dims,
init_filters=32,
in_channels=in_channels,
out_channels=out_channels,
elif variant.upper() == "B":
return MedNeXt(
encoder_expansion_ratio=(2, 3, 4, 4),
decoder_expansion_ratio=(4, 4, 3, 2),
bottleneck_expansion_ratio=4,
kernel_size=kernel_size,
deep_supervision=deep_supervision,
use_residual_connection=True,
blocks_down=(2, 2, 2, 2),
blocks_bottleneck=2,
blocks_up=(2, 2, 2, 2),
norm_type="group",
grn=False,
**common_args,
)


class MedNeXtMedium(MedNeXt):
"""MedNeXt Medium (M)"""

def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 2,
kernel_size: int = 3,
deep_supervision: bool = False,
):
super().__init__(
spatial_dims=spatial_dims,
init_filters=32,
in_channels=in_channels,
out_channels=out_channels,
elif variant.upper() == "M":
return MedNeXt(
encoder_expansion_ratio=(2, 3, 4, 4),
decoder_expansion_ratio=(4, 4, 3, 2),
bottleneck_expansion_ratio=4,
kernel_size=kernel_size,
deep_supervision=deep_supervision,
use_residual_connection=True,
blocks_down=(3, 4, 4, 4),
blocks_bottleneck=4,
blocks_up=(4, 4, 4, 3),
norm_type="group",
grn=False,
**common_args,
)


class MedNeXtLarge(MedNeXt):
"""MedNeXt Large (L)"""

def __init__(
self,
spatial_dims: int = 3,
in_channels: int = 1,
out_channels: int = 2,
kernel_size: int = 3,
deep_supervision: bool = False,
):
super().__init__(
spatial_dims=spatial_dims,
init_filters=32,
in_channels=in_channels,
out_channels=out_channels,
elif variant.upper() == "L":
return MedNeXt(
encoder_expansion_ratio=(3, 4, 8, 8),
decoder_expansion_ratio=(8, 8, 4, 3),
bottleneck_expansion_ratio=8,
kernel_size=kernel_size,
deep_supervision=deep_supervision,
use_residual_connection=True,
blocks_down=(3, 4, 8, 8),
blocks_bottleneck=8,
blocks_up=(8, 8, 4, 3),
norm_type="group",
grn=False,
**common_args,
)
else:
raise ValueError(f"Invalid MedNeXt variant: {variant}")


MedNext = MedNeXt
MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall
MedNextB = MedNeXtB = MedNextBase = MedNeXtBase
MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium
MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge
MedNextS = MedNeXtS = MedNextSmall = MedNeXtSmall = lambda **kwargs: create_mednext("S", **kwargs)
MedNextB = MedNeXtB = MedNextBase = MedNeXtBase = lambda **kwargs: create_mednext("B", **kwargs)
MedNextM = MedNeXtM = MedNextMedium = MedNeXtMedium = lambda **kwargs: create_mednext("M", **kwargs)
MedNextL = MedNeXtL = MedNextLarge = MedNeXtLarge = lambda **kwargs: create_mednext("L", **kwargs)
29 changes: 28 additions & 1 deletion tests/test_mednext.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from parameterized import parameterized

from monai.networks import eval_mode
from monai.networks.nets import MedNeXt
from monai.networks.nets import MedNeXt, MedNeXtL, MedNeXtM, MedNeXtS
from tests.utils import SkipIfBeforePyTorchVersion, test_script_save

device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -55,6 +55,18 @@
]
TEST_CASE_MEDNEXT_2.append(test_case)

TEST_CASE_MEDNEXT_VARIANTS = []
for model in [MedNeXtS, MedNeXtM, MedNeXtL]:
for spatial_dims in range(2, 4):
for out_channels in [1, 2]:
test_case = [
model,
{"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels},
(2, 1, *([16] * spatial_dims)),
(2, out_channels, *([16] * spatial_dims)),
]
TEST_CASE_MEDNEXT_VARIANTS.append(test_case)


class TestMedNeXt(unittest.TestCase):

Expand Down Expand Up @@ -91,6 +103,21 @@ def test_ill_arg(self):
with self.assertRaises(AssertionError):
MedNeXt(spatial_dims=4)

@parameterized.expand(TEST_CASE_MEDNEXT_VARIANTS)
def test_mednext_variants(self, model, input_param, input_shape, expected_shape):
net = model(**input_param).to(device)

net.train()
result = net(torch.randn(input_shape).to(device))
assert isinstance(result, torch.Tensor)
self.assertEqual(result.shape, expected_shape, msg=str(input_param))

net.eval()
with torch.no_grad():
result = net(torch.randn(input_shape).to(device))
assert isinstance(result, torch.Tensor)
self.assertEqual(result.shape, expected_shape, msg=str(input_param))


if __name__ == "__main__":
unittest.main()

0 comments on commit e146aaf

Please sign in to comment.