diff --git a/viscy/light/engine.py b/viscy/light/engine.py index f3886147..08b85319 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -29,7 +29,7 @@ from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d -from viscy.unet.networks.Unet21D import Unet21d +from viscy.unet.networks.Unet22D import Unet22d from viscy.unet.networks.Unet25D import Unet25d try: @@ -40,9 +40,7 @@ _UNET_ARCHITECTURE = { "2D": Unet2d, - "2.1D": Unet21d, - # same class with out_stack_depth > 1 - "2.2D": Unet21d, + "2.2D": Unet22d, "2.5D": Unet25d, "fcmae": FullyConvolutionalMAE, } @@ -139,8 +137,6 @@ def __init__( raise ValueError( f"Architecture {architecture} not in {_UNET_ARCHITECTURE.keys()}" ) - if architecture == "2.2D": - model_config["out_stack_depth"] = model_config["in_stack_depth"] self.model = net_class(**model_config) # TODO: handle num_outputs in metrics # self.out_channels = self.model.terminal_block.out_filters diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index 419d69a4..dc436cdf 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -44,29 +44,30 @@ graph25d # %% -# 2.1D UNet without upsampling in Z. +# 3D->2D model = VSUNet( - architecture="2.1D", + architecture="2.2D", model_config={ "in_channels": 2, - "out_channels": 1, - "in_stack_depth": 9, + "out_channels": 3, + "in_stack_depth": 5, + "out_stack_depth": 1, "backbone": "convnextv2_tiny", - "stem_kernel_size": (3, 1, 1), "decoder_mode": "pixelshuffle", + "stem_kernel_size": (5, 4, 4), }, ) model_graph = draw_graph( model, model.example_input_array, - graph_name="2.1D UNet", + graph_name="2.2D UNet", roll=True, depth=3, ) -graph21d = model_graph.visual_graph -graph21d +model_graph.visual_graph + # %% # 2.1D UNet with upsampling in Z. model = VSUNet( diff --git a/viscy/unet/networks/Unet21D.py b/viscy/unet/networks/Unet22D.py similarity index 90% rename from viscy/unet/networks/Unet21D.py rename to viscy/unet/networks/Unet22D.py index c4320240..a215ff0e 100644 --- a/viscy/unet/networks/Unet21D.py +++ b/viscy/unet/networks/Unet22D.py @@ -64,7 +64,7 @@ def _get_convnext_stage( return stage -class Conv21dStem(nn.Module): +class Conv22dStem(nn.Module): """Stem for 2.1D networks.""" def __init__( @@ -249,13 +249,13 @@ def forward(self, features: Sequence[Tensor]) -> Tensor: return feat -class Unet21d(nn.Module): +class Unet22d(nn.Module): def __init__( self, in_channels: int = 1, out_channels: int = 1, in_stack_depth: int = 5, - out_stack_depth: int = 1, + out_stack_depth: int = None, backbone: str = "convnextv2_tiny", pretrained: bool = False, stem_kernel_size: tuple[int, int, int] = (5, 4, 4), @@ -273,18 +273,8 @@ def __init__( f"Input stack depth {in_stack_depth} is not divisible " f"by stem kernel depth {stem_kernel_size[0]}." ) - if not (in_stack_depth == out_stack_depth or out_stack_depth == 1): - raise ValueError( - "`out_stack_depth` must be either 1 or " - f"the same as `input_stack_depth` ({in_stack_depth}), " - f"but got {out_stack_depth}." - ) - if not (in_stack_depth == out_stack_depth or out_stack_depth == 1): - raise ValueError( - "`out_stack_depth` must be either 1 or " - f"the same as `input_stack_depth` ({in_stack_depth}), " - f"but got {out_stack_depth}." - ) + if out_stack_depth is None: + out_stack_depth = in_stack_depth multi_scale_encoder = timm.create_model( backbone, pretrained=pretrained, @@ -295,7 +285,7 @@ def __init__( # replace first convolution layer with a projection tokenizer multi_scale_encoder.stem_0 = nn.Identity() self.encoder_stages = multi_scale_encoder - self.stem = Conv21dStem( + self.stem = Conv22dStem( in_channels, num_channels[0], stem_kernel_size, in_stack_depth ) decoder_channels = num_channels @@ -311,16 +301,13 @@ def __init__( strides=[2] * (len(num_channels) - 1) + [stem_kernel_size[-1]], upsample_pre_conv="default" if decoder_upsample_pre_conv else None, ) - if out_stack_depth == 1: - self.head = UnsqueezeHead() - else: - self.head = PixelToVoxelHead( - decoder_channels[-1], - out_channels, - out_stack_depth, - head_expansion_ratio, - pool=head_pool, - ) + self.head = PixelToVoxelHead( + decoder_channels[-1], + out_channels, + out_stack_depth, + head_expansion_ratio, + pool=head_pool, + ) self.out_stack_depth = out_stack_depth @property diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 97771365..51345906 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -18,7 +18,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.Unet21D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead +from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead def _init_weights(module: nn.Module) -> None: