Skip to content

Commit

Permalink
Rename UNeXt2 (#84)
Browse files Browse the repository at this point in the history
* rename file

* rename the architecture

* fix merge
  • Loading branch information
ziw-liu authored and edyoshikun committed Jun 18, 2024
1 parent 75db771 commit 6282f0d
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 25 deletions.
6 changes: 3 additions & 3 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class HCSDataModule(LightningDataModule):
by default 0.8
:param int batch_size: batch size, defaults to 16
:param int num_workers: number of data-loading workers, defaults to 8
:param Literal["2D", "2.1D", "2.2D", "2.5D", "3D"] architecture: U-Net architecture,
:param Literal["2D", "UNeXt2", "2.5D", "3D"] architecture: U-Net architecture,
defaults to "2.5D"
:param tuple[int, int] yx_patch_size: patch size in (Y, X),
defaults to (256, 256)
Expand All @@ -288,7 +288,7 @@ def __init__(
split_ratio: float = 0.8,
batch_size: int = 16,
num_workers: int = 8,
architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"] = "2.5D",
yx_patch_size: tuple[int, int] = (256, 256),
normalizations: list[MapTransform] = [],
augmentations: list[MapTransform] = [],
Expand All @@ -302,7 +302,7 @@ def __init__(
self.target_channel = _ensure_channel_list(target_channel)
self.batch_size = batch_size
self.num_workers = num_workers
self.target_2d = False if architecture in ["2.2D", "3D", "fcmae"] else True
self.target_2d = False if architecture in ["UNeXt2", "3D", "fcmae"] else True
self.z_window_size = z_window_size
self.split_ratio = split_ratio
self.yx_patch_size = yx_patch_size
Expand Down
6 changes: 3 additions & 3 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d
from viscy.unet.networks.fcmae import FullyConvolutionalMAE
from viscy.unet.networks.Unet2D import Unet2d
from viscy.unet.networks.Unet22D import Unet22d
from viscy.unet.networks.Unet25D import Unet25d
from viscy.unet.networks.unext2 import UNeXt2


try:
Expand All @@ -41,7 +41,7 @@

_UNET_ARCHITECTURE = {
"2D": Unet2d,
"2.2D": Unet22d,
"UNeXt2": UNeXt2,
"2.5D": Unet25d,
"fcmae": FullyConvolutionalMAE,
}
Expand Down Expand Up @@ -122,7 +122,7 @@ class VSUNet(LightningModule):

def __init__(
self,
architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"],
architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"],
model_config: dict = {},
loss_function: Union[nn.Module, MixedLoss] = None,
lr: float = 1e-3,
Expand Down
2 changes: 1 addition & 1 deletion viscy/scripts/count_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# %%
model = VSUNet(
architecture="2.2D",
architecture="UNeXt2",
model_config={
"in_channels": 1,
"out_channels": 2,
Expand Down
13 changes: 6 additions & 7 deletions viscy/scripts/network_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
# %%
# 3D->2D
model = VSUNet(
architecture="2.2D",
architecture="UNext2",
model_config={
"in_channels": 2,
"out_channels": 3,
Expand All @@ -61,17 +61,17 @@
model_graph = draw_graph(
model,
model.example_input_array,
graph_name="2.2D UNet",
graph_name="UNext2",
roll=True,
depth=3,
)

model_graph.visual_graph

# %%
# 2.1D UNet with upsampling in Z.
# 3D->3D
model = VSUNet(
architecture="2.2D",
architecture="UNext2",
model_config={
"in_channels": 1,
"out_channels": 2,
Expand All @@ -85,13 +85,12 @@
model_graph = draw_graph(
model,
model.example_input_array,
graph_name="2.2D UNet",
graph_name="UNext2",
roll=True,
depth=3,
)

graph22d = model_graph.visual_graph
graph22d
model_graph.visual_graph
# %% If you want to save the graphs as SVG files:
# model_graph.visual_graph.render(format="svg")

Expand Down
2 changes: 1 addition & 1 deletion viscy/scripts/visualize_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# load model
model = VSUNet.load_from_checkpoint(
"model.ckpt",
architecture="2.2D",
architecture="UNeXt2",
model_config={
"in_channels": 1,
"out_channels": 2,
Expand Down
4 changes: 2 additions & 2 deletions viscy/unet/networks/fcmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from torch import BoolTensor, Size, Tensor, nn

from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder
from viscy.unet.networks.unext2 import PixelToVoxelHead, UNeXt2Decoder


def _init_weights(module: nn.Module) -> None:
Expand Down Expand Up @@ -430,7 +430,7 @@ def __init__(
decoder_channels[-1] = (
out_channels * in_stack_depth * stem_kernel_size[-1] ** 2
)
self.decoder = Unet2dDecoder(
self.decoder = UNeXt2Decoder(
decoder_channels,
norm_name="instance",
mode="pixelshuffle",
Expand Down
16 changes: 8 additions & 8 deletions viscy/unet/networks/Unet22D.py → viscy/unet/networks/unext2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def _get_convnext_stage(
return stage


class Conv22dStem(nn.Module):
"""Stem for 2.1D networks."""
class UNeXt2Stem(nn.Module):
"""Stem for UNeXt2 networks."""

def __init__(
self,
Expand All @@ -91,7 +91,7 @@ def forward(self, x: Tensor):
return x.reshape(b, c * d, h, w)


class Unet2dUpStage(nn.Module):
class UNeXt2UpStage(nn.Module):
def __init__(
self,
in_channels: int,
Expand Down Expand Up @@ -214,7 +214,7 @@ def forward(self, x: Tensor) -> Tensor:
return x


class Unet2dDecoder(nn.Module):
class UNeXt2Decoder(nn.Module):
def __init__(
self,
num_channels: list[int],
Expand All @@ -228,7 +228,7 @@ def __init__(
self.decoder_stages = nn.ModuleList([])
stages = len(num_channels) - 1
for i in range(stages):
stage = Unet2dUpStage(
stage = UNeXt2UpStage(
in_channels=num_channels[i],
skip_channels=num_channels[i] // 2,
out_channels=num_channels[i + 1],
Expand All @@ -249,7 +249,7 @@ def forward(self, features: Sequence[Tensor]) -> Tensor:
return feat


class Unet22d(nn.Module):
class UNeXt2(nn.Module):
def __init__(
self,
in_channels: int = 1,
Expand Down Expand Up @@ -285,15 +285,15 @@ def __init__(
# replace first convolution layer with a projection tokenizer
multi_scale_encoder.stem_0 = nn.Identity()
self.encoder_stages = multi_scale_encoder
self.stem = Conv22dStem(
self.stem = UNeXt2Stem(
in_channels, num_channels[0], stem_kernel_size, in_stack_depth
)
decoder_channels = num_channels
decoder_channels.reverse()
decoder_channels[-1] = (
(out_stack_depth + 2) * out_channels * 2**2 * head_expansion_ratio
)
self.decoder = Unet2dDecoder(
self.decoder = UNeXt2Decoder(
decoder_channels,
norm_name=decoder_norm_layer,
mode=decoder_mode,
Expand Down

0 comments on commit 6282f0d

Please sign in to comment.