Skip to content

Commit

Permalink
Merge branch 'dev' into 5852-callable-partial
Browse files Browse the repository at this point in the history
  • Loading branch information
wyli authored Jan 14, 2023
2 parents 4093707 + 6cb4ced commit 88babdd
Showing 1 changed file with 51 additions and 7 deletions.
58 changes: 51 additions & 7 deletions monai/networks/nets/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,16 +296,27 @@ class DenseNet121(DenseNet):

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 24, 16),
pretrained: bool = False,
progress: bool = True,
**kwargs,
) -> None:
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
if kwargs["spatial_dims"] > 2:
if spatial_dims > 2:
raise NotImplementedError(
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
"provide pretrained models for more than two spatial dimensions."
Expand All @@ -318,16 +329,27 @@ class DenseNet169(DenseNet):

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 32, 32),
pretrained: bool = False,
progress: bool = True,
**kwargs,
) -> None:
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
if kwargs["spatial_dims"] > 2:
if spatial_dims > 2:
raise NotImplementedError(
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
"provide pretrained models for more than two spatial dimensions."
Expand All @@ -340,16 +362,27 @@ class DenseNet201(DenseNet):

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 48, 32),
pretrained: bool = False,
progress: bool = True,
**kwargs,
) -> None:
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
if kwargs["spatial_dims"] > 2:
if spatial_dims > 2:
raise NotImplementedError(
"Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does not"
"provide pretrained models for more than two spatial dimensions."
Expand All @@ -362,14 +395,25 @@ class DenseNet264(DenseNet):

def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
init_features: int = 64,
growth_rate: int = 32,
block_config: Sequence[int] = (6, 12, 64, 48),
pretrained: bool = False,
progress: bool = True,
**kwargs,
) -> None:
super().__init__(init_features=init_features, growth_rate=growth_rate, block_config=block_config, **kwargs)
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
init_features=init_features,
growth_rate=growth_rate,
block_config=block_config,
**kwargs,
)
if pretrained:
raise NotImplementedError("Currently PyTorch Hub does not provide densenet264 pretrained models.")

Expand Down

0 comments on commit 88babdd

Please sign in to comment.