Skip to content

Commit

Permalink
Fix exceptions and docstrings (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
HangJung97 authored Nov 15, 2023
1 parent 4903d9a commit 46361af
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 86 deletions.
5 changes: 4 additions & 1 deletion ascent/models/components/decoders/unet_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolutions.
Args:
input_size: Size of the input image of the encoder.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.encoder_strides[0]):
raise ValueError(
Expand Down
12 changes: 11 additions & 1 deletion ascent/models/components/encoders/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
Raises:
ValueError: When `len(kernels)` is not equal to `num_stages`.
ValueError: When `len(strides)` is not equal to `num_stages`.
ValueError: When `len(expansion_rate)` is not equal to `num_stages`.
ValueError: When `len(n_conv_per_stage)` is not equal to `num_stages`.
ValueError: When `len(num_features_per_stage)` is not equal to `num_stages`.
"""
Expand All @@ -95,6 +96,9 @@ def __init__(
if isinstance(stem_kernel, int):
stem_kernel = (stem_kernel,) * dim

if isinstance(kernels, int):
kernels = (kernels,) * num_stages

if isinstance(strides, int):
strides = (strides,) * num_stages

Expand All @@ -113,6 +117,9 @@ def __init__(
if not len(strides) == num_stages:
raise ValueError(f"len(strides) must be equal to num_stages: {num_stages}")

if not len(expansion_rate) == num_stages:
raise ValueError(f"len(expansion_rate) must be equal to num_stages: {num_stages}")

if not len(num_conv_per_stage) == num_stages:
raise ValueError(f"len(num_conv_per_stage) must be equal to num_stages: {num_stages}")

Expand Down Expand Up @@ -250,10 +257,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolutions.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.strides[0]):
raise ValueError(
Expand Down
5 changes: 4 additions & 1 deletion ascent/models/components/encoders/unet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolutions.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.strides[0]):
raise ValueError(
Expand Down
112 changes: 34 additions & 78 deletions ascent/models/components/utils/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import torch
from einops.layers.torch import Rearrange
from omegaconf.listconfig import ListConfig
from torch import Tensor, nn
from torchvision.ops import StochasticDepth

Expand Down Expand Up @@ -67,15 +66,15 @@ def __init__(
if not num_conv > 0:
raise ValueError("`num_conv` must be strictly greater than 0!")

if isinstance(out_channels, (tuple, list, ListConfig)):
if not len(out_channels) == num_conv:
raise ValueError(
f"`out_channels` {out_channels} must be an integer or a tuple/list of length "
f"{num_conv}!"
)
if isinstance(out_channels, int):
out_channels = (out_channels,) * num_conv

if not len(out_channels) == num_conv:
raise ValueError(
f"`out_channels` {out_channels} must be an integer or a tuple/list of length "
f"{num_conv}!"
)

# Store stride, out_channels, and num_conv for computing total number of pixels/voxels in
# the output feature map
if isinstance(stride, int):
Expand Down Expand Up @@ -192,15 +191,15 @@ def __init__(
if not num_conv > 1:
raise ValueError("`num_conv` must be greater than 1!")

if isinstance(out_channels, (tuple, list, ListConfig)):
if not len(out_channels) == num_conv:
raise ValueError(
f"`out_channels` {out_channels} must be an integer or a tuple/list of length "
f"{num_conv}!"
)
if isinstance(out_channels, int):
out_channels = (out_channels,) * num_conv

if not len(out_channels) == num_conv:
raise ValueError(
f"`out_channels` {out_channels} must be an integer or a tuple/list of length "
f"{num_conv}!"
)

# Store stride, out_channels, and num_convs for computing total number of pixels/voxels in
# the output feature map
if isinstance(stride, int):
Expand Down Expand Up @@ -281,10 +280,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolutions.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.stride):
raise ValueError(
Expand Down Expand Up @@ -455,10 +457,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolutions.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.stride):
raise ValueError(
Expand Down Expand Up @@ -528,15 +533,15 @@ def __init__(
if not num_conv > 0:
raise ValueError("`num_conv` must be strictly greater than 0!")

if isinstance(out_channels, (tuple, list, ListConfig)):
if not len(out_channels) == num_conv:
raise ValueError(
f"`out_channels` {out_channels} must be an integer or a tuple/list of length "
f"{num_conv}!"
)
if isinstance(out_channels, int):
out_channels = (out_channels,) * num_conv

if not len(out_channels) == num_conv:
raise ValueError(
f"`out_channels` {out_channels} must be an integer or a tuple/list of length "
f"{num_conv}!"
)

# Store stride, in_channels, out_channels, and num_conv for computing total number of
# pixels/voxels in the output feature map
if isinstance(stride, int):
Expand Down Expand Up @@ -605,10 +610,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolutions.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.stride):
raise ValueError(
Expand Down Expand Up @@ -669,69 +677,17 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolutions.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.stride):
raise ValueError(
"`input_size` should be (H, W(, D)) without channel or batch dimension!"
)

return np.prod([self.out_channels, *input_size], dtype=np.int64)


if __name__ == "__main__":
from torchinfo import summary

kernels = 3
strides = 2
num_conv = 4
# conv = ResidBlock(
# num_conv=num_conv,
# in_channels=1,
# out_channels=[16, 32, 64, 128],
# kernel_size=kernels,
# stride=strides,
# dim=2,
# norm_layer="instance",
# activation="leakyrelu",
# conv_kwargs=None,
# norm_kwargs=None,
# activation_kwargs={"inplace": True},
# conv_bias=True,
# drop_block=False,
# )
# conv = ConvNeXtBlock(
# in_channels=16,
# kernel_size=7,
# stride=1,
# dim=2,
# norm_layer="layer",
# activation="gelu",
# conv_bias=True,
# drop_block=False,
# stochastic_depth_p=0.0,
# layer_scale_init_value=1e-6,
# )
conv = OutputBlock(
in_channels=16,
out_channels=1,
dim=2,
conv_bias=True,
conv_kwargs=None,
)
print(conv)
dummy_input = torch.rand((2, 16, 640, 512))
out = conv(dummy_input)
print(conv.compute_pixels_in_output_feature_map((640, 512)))
print(
summary(
conv,
input_size=(2, 16, 640, 512),
device="cpu",
depth=4,
col_names=("input_size", "output_size", "num_params"),
)
)
15 changes: 10 additions & 5 deletions ascent/models/components/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
from monai.utils import has_option
from omegaconf.listconfig import ListConfig
from torch import Tensor, nn

from ascent.models.components.utils.normalization import LayerNorm
Expand Down Expand Up @@ -334,7 +333,7 @@ def __init__(
if drop_kwargs is None:
drop_kwargs = {}

if not isinstance(kernel_size, (tuple, list, ListConfig)):
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * dim

if not len(kernel_size) == dim:
Expand All @@ -343,7 +342,7 @@ def __init__(
f"{dim}!"
)

if not isinstance(stride, (tuple, list, ListConfig)):
if isinstance(stride, int):
stride = (stride,) * dim

if not len(stride) == dim:
Expand Down Expand Up @@ -383,10 +382,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolution.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.stride):
raise ValueError(
Expand Down Expand Up @@ -462,10 +464,13 @@ def compute_pixels_in_output_feature_map(self, input_size: tuple[int, ...]) -> i
"""Compute total number of pixels/voxels in the output feature map after convolution.
Args:
input_size: Size of the input image.
input_size: Size of the input image. (H, W(, D))
Returns:
Number of pixels/voxels in the output feature map after convolution.
Raises:
ValueError: If length of `input_size` is not equal to `dim`.
"""
if not len(input_size) == len(self.stride):
raise ValueError(
Expand Down

0 comments on commit 46361af

Please sign in to comment.