From aac5e9336c3fc14febd9de626c2e623c4afa7e05 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 5 Jun 2024 21:22:32 +0800 Subject: [PATCH 1/3] rename file --- viscy/light/engine.py | 2 +- viscy/unet/networks/fcmae.py | 2 +- viscy/unet/networks/{Unet22D.py => unext2.py} | 0 3 files changed, 2 insertions(+), 2 deletions(-) rename viscy/unet/networks/{Unet22D.py => unext2.py} (100%) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 08b85319..56fac136 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -29,7 +29,7 @@ from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d -from viscy.unet.networks.Unet22D import Unet22d +from viscy.unet.networks.unext2 import Unet22d from viscy.unet.networks.Unet25D import Unet25d try: diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 51345906..14d7d876 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -18,7 +18,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead +from viscy.unet.networks.unext2 import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead def _init_weights(module: nn.Module) -> None: diff --git a/viscy/unet/networks/Unet22D.py b/viscy/unet/networks/unext2.py similarity index 100% rename from viscy/unet/networks/Unet22D.py rename to viscy/unet/networks/unext2.py From a3bb561a3a34750f7cb8f4b521cdda4ba9f459ce Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Wed, 5 Jun 2024 21:32:09 +0800 Subject: [PATCH 2/3] rename the architecture --- viscy/data/hcs.py | 6 +++--- viscy/light/engine.py | 6 +++--- viscy/scripts/count_flops.py | 2 +- viscy/scripts/network_diagram.py | 13 ++++++------- viscy/scripts/visualize_features.py | 2 +- viscy/unet/networks/fcmae.py | 4 ++-- viscy/unet/networks/unext2.py | 16 ++++++++-------- 7 files changed, 24 insertions(+), 25 deletions(-) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index f33b6121..60a2dac2 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -265,7 +265,7 @@ class HCSDataModule(LightningDataModule): by default 0.8 :param int batch_size: batch size, defaults to 16 :param int num_workers: number of data-loading workers, defaults to 8 - :param Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] architecture: U-Net architecture, + :param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture, defaults to "2.5D" :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) @@ -290,7 +290,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], @@ -303,7 +303,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["2.2D", "3D", "fcmae"] else True + self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 56fac136..ecbb6753 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -29,8 +29,8 @@ from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d -from viscy.unet.networks.unext2 import Unet22d from viscy.unet.networks.Unet25D import Unet25d +from viscy.unet.networks.unext2 import UNeXt2 try: from cellpose.models import CellposeModel @@ -40,7 +40,7 @@ _UNET_ARCHITECTURE = { "2D": Unet2d, - "2.2D": Unet22d, + "UNeXt2": UNeXt2, "2.5D": Unet25d, "fcmae": FullyConvolutionalMAE, } @@ -117,7 +117,7 @@ class VSUNet(LightningModule): def __init__( self, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"], + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"], model_config: dict = {}, loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, diff --git a/viscy/scripts/count_flops.py b/viscy/scripts/count_flops.py index 352da9c7..206d744f 100644 --- a/viscy/scripts/count_flops.py +++ b/viscy/scripts/count_flops.py @@ -6,7 +6,7 @@ # %% model = VSUNet( - architecture="2.2D", + architecture="UNeXt2", model_config={ "in_channels": 1, "out_channels": 2, diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index dc436cdf..6c6a8026 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -46,7 +46,7 @@ # %% # 3D->2D model = VSUNet( - architecture="2.2D", + architecture="UNext2", model_config={ "in_channels": 2, "out_channels": 3, @@ -61,7 +61,7 @@ model_graph = draw_graph( model, model.example_input_array, - graph_name="2.2D UNet", + graph_name="UNext2", roll=True, depth=3, ) @@ -69,9 +69,9 @@ model_graph.visual_graph # %% -# 2.1D UNet with upsampling in Z. +# 3D->3D model = VSUNet( - architecture="2.2D", + architecture="UNext2", model_config={ "in_channels": 1, "out_channels": 2, @@ -85,12 +85,11 @@ model_graph = draw_graph( model, model.example_input_array, - graph_name="2.2D UNet", + graph_name="UNext2", roll=True, depth=3, ) -graph22d = model_graph.visual_graph -graph22d +model_graph.visual_graph # %% If you want to save the graphs as SVG files: # model_graph.visual_graph.render(format="svg") diff --git a/viscy/scripts/visualize_features.py b/viscy/scripts/visualize_features.py index 1de5033d..c331bf46 100644 --- a/viscy/scripts/visualize_features.py +++ b/viscy/scripts/visualize_features.py @@ -38,7 +38,7 @@ # load model model = VSUNet.load_from_checkpoint( "model.ckpt", - architecture="2.2D", + architecture="UNeXt2", model_config={ "in_channels": 1, "out_channels": 2, diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 14d7d876..954843b5 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -18,7 +18,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.unext2 import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead +from viscy.unet.networks.unext2 import PixelToVoxelHead, UNeXt2Decoder, UnsqueezeHead def _init_weights(module: nn.Module) -> None: @@ -390,7 +390,7 @@ def __init__( decoder_channels[-1] = ( (in_stack_depth + 2) * in_channels * 2**2 * head_expansion_ratio ) - self.decoder = Unet2dDecoder( + self.decoder = UNeXt2Decoder( decoder_channels, norm_name="instance", mode="pixelshuffle", diff --git a/viscy/unet/networks/unext2.py b/viscy/unet/networks/unext2.py index a215ff0e..a695c06a 100644 --- a/viscy/unet/networks/unext2.py +++ b/viscy/unet/networks/unext2.py @@ -64,8 +64,8 @@ def _get_convnext_stage( return stage -class Conv22dStem(nn.Module): - """Stem for 2.1D networks.""" +class UNeXt2Stem(nn.Module): + """Stem for UNeXt2 networks.""" def __init__( self, @@ -91,7 +91,7 @@ def forward(self, x: Tensor): return x.reshape(b, c * d, h, w) -class Unet2dUpStage(nn.Module): +class UNeXt2UpStage(nn.Module): def __init__( self, in_channels: int, @@ -214,7 +214,7 @@ def forward(self, x: Tensor) -> Tensor: return x -class Unet2dDecoder(nn.Module): +class UNeXt2Decoder(nn.Module): def __init__( self, num_channels: list[int], @@ -228,7 +228,7 @@ def __init__( self.decoder_stages = nn.ModuleList([]) stages = len(num_channels) - 1 for i in range(stages): - stage = Unet2dUpStage( + stage = UNeXt2UpStage( in_channels=num_channels[i], skip_channels=num_channels[i] // 2, out_channels=num_channels[i + 1], @@ -249,7 +249,7 @@ def forward(self, features: Sequence[Tensor]) -> Tensor: return feat -class Unet22d(nn.Module): +class UNeXt2(nn.Module): def __init__( self, in_channels: int = 1, @@ -285,7 +285,7 @@ def __init__( # replace first convolution layer with a projection tokenizer multi_scale_encoder.stem_0 = nn.Identity() self.encoder_stages = multi_scale_encoder - self.stem = Conv22dStem( + self.stem = UNeXt2Stem( in_channels, num_channels[0], stem_kernel_size, in_stack_depth ) decoder_channels = num_channels @@ -293,7 +293,7 @@ def __init__( decoder_channels[-1] = ( (out_stack_depth + 2) * out_channels * 2**2 * head_expansion_ratio ) - self.decoder = Unet2dDecoder( + self.decoder = UNeXt2Decoder( decoder_channels, norm_name=decoder_norm_layer, mode=decoder_mode, From 3fafff0f145e12f408def447214ce43abc1bc0de Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:12:10 -0700 Subject: [PATCH 3/3] fix merge --- viscy/unet/networks/fcmae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 752ae32c..d63b65a7 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -20,7 +20,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.unext2 import PixelToVoxelHead, UNeXt2Decoder, UnsqueezeHead +from viscy.unet.networks.unext2 import PixelToVoxelHead, UNeXt2Decoder def _init_weights(module: nn.Module) -> None: