diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 77bcc1ed..ffc827ed 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -263,7 +263,7 @@ class HCSDataModule(LightningDataModule): by default 0.8 :param int batch_size: batch size, defaults to 16 :param int num_workers: number of data-loading workers, defaults to 8 - :param Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] architecture: U-Net architecture, + :param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture, defaults to "2.5D" :param tuple[int, int] yx_patch_size: patch size in (Y, X), defaults to (256, 256) @@ -288,7 +288,7 @@ def __init__( split_ratio: float = 0.8, batch_size: int = 16, num_workers: int = 8, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D", + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D", yx_patch_size: tuple[int, int] = (256, 256), normalizations: list[MapTransform] = [], augmentations: list[MapTransform] = [], @@ -301,7 +301,7 @@ def __init__( self.target_channel = _ensure_channel_list(target_channel) self.batch_size = batch_size self.num_workers = num_workers - self.target_2d = False if architecture in ["2.2D", "3D", "fcmae"] else True + self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True self.z_window_size = z_window_size self.split_ratio = split_ratio self.yx_patch_size = yx_patch_size diff --git a/viscy/light/engine.py b/viscy/light/engine.py index ac15c208..33da3552 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -29,8 +29,8 @@ from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d from viscy.unet.networks.fcmae import FullyConvolutionalMAE from viscy.unet.networks.Unet2D import Unet2d -from viscy.unet.networks.Unet22D import Unet22d from viscy.unet.networks.Unet25D import Unet25d +from viscy.unet.networks.unext2 import UNeXt2 try: from cellpose.models import CellposeModel @@ -40,7 +40,7 @@ _UNET_ARCHITECTURE = { "2D": Unet2d, - "2.2D": Unet22d, + "UNeXt2": UNeXt2, "2.5D": Unet25d, "fcmae": FullyConvolutionalMAE, } @@ -117,7 +117,7 @@ class VSUNet(LightningModule): def __init__( self, - architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"], + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"], model_config: dict = {}, loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, diff --git a/viscy/scripts/count_flops.py b/viscy/scripts/count_flops.py index 352da9c7..206d744f 100644 --- a/viscy/scripts/count_flops.py +++ b/viscy/scripts/count_flops.py @@ -6,7 +6,7 @@ # %% model = VSUNet( - architecture="2.2D", + architecture="UNeXt2", model_config={ "in_channels": 1, "out_channels": 2, diff --git a/viscy/scripts/network_diagram.py b/viscy/scripts/network_diagram.py index bcc1714f..c99f97c8 100644 --- a/viscy/scripts/network_diagram.py +++ b/viscy/scripts/network_diagram.py @@ -46,7 +46,7 @@ # %% # 3D->2D model = VSUNet( - architecture="2.2D", + architecture="UNext2", model_config={ "in_channels": 2, "out_channels": 3, @@ -61,7 +61,7 @@ model_graph = draw_graph( model, model.example_input_array, - graph_name="2.2D UNet", + graph_name="UNext2", roll=True, depth=3, ) @@ -69,9 +69,9 @@ model_graph.visual_graph # %% -# 2.1D UNet with upsampling in Z. +# 3D->3D model = VSUNet( - architecture="2.2D", + architecture="UNext2", model_config={ "in_channels": 1, "out_channels": 2, @@ -85,13 +85,12 @@ model_graph = draw_graph( model, model.example_input_array, - graph_name="2.2D UNet", + graph_name="UNext2", roll=True, depth=3, ) -graph22d = model_graph.visual_graph -graph22d +model_graph.visual_graph # %% If you want to save the graphs as SVG files: # model_graph.visual_graph.render(format="svg") diff --git a/viscy/scripts/visualize_features.py b/viscy/scripts/visualize_features.py index 1de5033d..c331bf46 100644 --- a/viscy/scripts/visualize_features.py +++ b/viscy/scripts/visualize_features.py @@ -38,7 +38,7 @@ # load model model = VSUNet.load_from_checkpoint( "model.ckpt", - architecture="2.2D", + architecture="UNeXt2", model_config={ "in_channels": 1, "out_channels": 2, diff --git a/viscy/unet/networks/fcmae.py b/viscy/unet/networks/fcmae.py index 6c2f6f45..d63b65a7 100644 --- a/viscy/unet/networks/fcmae.py +++ b/viscy/unet/networks/fcmae.py @@ -20,7 +20,7 @@ ) from torch import BoolTensor, Size, Tensor, nn -from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder +from viscy.unet.networks.unext2 import PixelToVoxelHead, UNeXt2Decoder def _init_weights(module: nn.Module) -> None: @@ -430,7 +430,7 @@ def __init__( decoder_channels[-1] = ( out_channels * in_stack_depth * stem_kernel_size[-1] ** 2 ) - self.decoder = Unet2dDecoder( + self.decoder = UNeXt2Decoder( decoder_channels, norm_name="instance", mode="pixelshuffle", diff --git a/viscy/unet/networks/Unet22D.py b/viscy/unet/networks/unext2.py similarity index 97% rename from viscy/unet/networks/Unet22D.py rename to viscy/unet/networks/unext2.py index a215ff0e..a695c06a 100644 --- a/viscy/unet/networks/Unet22D.py +++ b/viscy/unet/networks/unext2.py @@ -64,8 +64,8 @@ def _get_convnext_stage( return stage -class Conv22dStem(nn.Module): - """Stem for 2.1D networks.""" +class UNeXt2Stem(nn.Module): + """Stem for UNeXt2 networks.""" def __init__( self, @@ -91,7 +91,7 @@ def forward(self, x: Tensor): return x.reshape(b, c * d, h, w) -class Unet2dUpStage(nn.Module): +class UNeXt2UpStage(nn.Module): def __init__( self, in_channels: int, @@ -214,7 +214,7 @@ def forward(self, x: Tensor) -> Tensor: return x -class Unet2dDecoder(nn.Module): +class UNeXt2Decoder(nn.Module): def __init__( self, num_channels: list[int], @@ -228,7 +228,7 @@ def __init__( self.decoder_stages = nn.ModuleList([]) stages = len(num_channels) - 1 for i in range(stages): - stage = Unet2dUpStage( + stage = UNeXt2UpStage( in_channels=num_channels[i], skip_channels=num_channels[i] // 2, out_channels=num_channels[i + 1], @@ -249,7 +249,7 @@ def forward(self, features: Sequence[Tensor]) -> Tensor: return feat -class Unet22d(nn.Module): +class UNeXt2(nn.Module): def __init__( self, in_channels: int = 1, @@ -285,7 +285,7 @@ def __init__( # replace first convolution layer with a projection tokenizer multi_scale_encoder.stem_0 = nn.Identity() self.encoder_stages = multi_scale_encoder - self.stem = Conv22dStem( + self.stem = UNeXt2Stem( in_channels, num_channels[0], stem_kernel_size, in_stack_depth ) decoder_channels = num_channels @@ -293,7 +293,7 @@ def __init__( decoder_channels[-1] = ( (out_stack_depth + 2) * out_channels * 2**2 * head_expansion_ratio ) - self.decoder = Unet2dDecoder( + self.decoder = UNeXt2Decoder( decoder_channels, norm_name=decoder_norm_layer, mode=decoder_mode,