Skip to content

Commit

Permalink
Fix 3D to 2D prediction with UNeXt2 model (#80)
Browse files Browse the repository at this point in the history
* shuffle head

* use shuffle head for 2.1D

* use the same head for 2d and 2d output

* update script

* rename 2.1D to 2.2D

* fix reference

* remove unused module
  • Loading branch information
ziw-liu authored and edyoshikun committed Jun 12, 2024
1 parent f954554 commit eeb8ce2
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 41 deletions.
8 changes: 2 additions & 6 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from viscy.evaluation.evaluation_metrics import mean_average_precision, ms_ssim_25d
from viscy.unet.networks.fcmae import FullyConvolutionalMAE
from viscy.unet.networks.Unet2D import Unet2d
from viscy.unet.networks.Unet21D import Unet21d
from viscy.unet.networks.Unet22D import Unet22d
from viscy.unet.networks.Unet25D import Unet25d

try:
Expand All @@ -40,9 +40,7 @@

_UNET_ARCHITECTURE = {
"2D": Unet2d,
"2.1D": Unet21d,
# same class with out_stack_depth > 1
"2.2D": Unet21d,
"2.2D": Unet22d,
"2.5D": Unet25d,
"fcmae": FullyConvolutionalMAE,
}
Expand Down Expand Up @@ -139,8 +137,6 @@ def __init__(
raise ValueError(
f"Architecture {architecture} not in {_UNET_ARCHITECTURE.keys()}"
)
if architecture == "2.2D":
model_config["out_stack_depth"] = model_config["in_stack_depth"]
self.model = net_class(**model_config)
# TODO: handle num_outputs in metrics
# self.out_channels = self.model.terminal_block.out_filters
Expand Down
17 changes: 9 additions & 8 deletions viscy/scripts/network_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,29 +44,30 @@
graph25d

# %%
# 2.1D UNet without upsampling in Z.
# 3D->2D
model = VSUNet(
architecture="2.1D",
architecture="2.2D",
model_config={
"in_channels": 2,
"out_channels": 1,
"in_stack_depth": 9,
"out_channels": 3,
"in_stack_depth": 5,
"out_stack_depth": 1,
"backbone": "convnextv2_tiny",
"stem_kernel_size": (3, 1, 1),
"decoder_mode": "pixelshuffle",
"stem_kernel_size": (5, 4, 4),
},
)

model_graph = draw_graph(
model,
model.example_input_array,
graph_name="2.1D UNet",
graph_name="2.2D UNet",
roll=True,
depth=3,
)

graph21d = model_graph.visual_graph
graph21d
model_graph.visual_graph

# %%
# 2.1D UNet with upsampling in Z.
model = VSUNet(
Expand Down
39 changes: 13 additions & 26 deletions viscy/unet/networks/Unet21D.py → viscy/unet/networks/Unet22D.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _get_convnext_stage(
return stage


class Conv21dStem(nn.Module):
class Conv22dStem(nn.Module):
"""Stem for 2.1D networks."""

def __init__(
Expand Down Expand Up @@ -249,13 +249,13 @@ def forward(self, features: Sequence[Tensor]) -> Tensor:
return feat


class Unet21d(nn.Module):
class Unet22d(nn.Module):
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
in_stack_depth: int = 5,
out_stack_depth: int = 1,
out_stack_depth: int = None,
backbone: str = "convnextv2_tiny",
pretrained: bool = False,
stem_kernel_size: tuple[int, int, int] = (5, 4, 4),
Expand All @@ -273,18 +273,8 @@ def __init__(
f"Input stack depth {in_stack_depth} is not divisible "
f"by stem kernel depth {stem_kernel_size[0]}."
)
if not (in_stack_depth == out_stack_depth or out_stack_depth == 1):
raise ValueError(
"`out_stack_depth` must be either 1 or "
f"the same as `input_stack_depth` ({in_stack_depth}), "
f"but got {out_stack_depth}."
)
if not (in_stack_depth == out_stack_depth or out_stack_depth == 1):
raise ValueError(
"`out_stack_depth` must be either 1 or "
f"the same as `input_stack_depth` ({in_stack_depth}), "
f"but got {out_stack_depth}."
)
if out_stack_depth is None:
out_stack_depth = in_stack_depth
multi_scale_encoder = timm.create_model(
backbone,
pretrained=pretrained,
Expand All @@ -295,7 +285,7 @@ def __init__(
# replace first convolution layer with a projection tokenizer
multi_scale_encoder.stem_0 = nn.Identity()
self.encoder_stages = multi_scale_encoder
self.stem = Conv21dStem(
self.stem = Conv22dStem(
in_channels, num_channels[0], stem_kernel_size, in_stack_depth
)
decoder_channels = num_channels
Expand All @@ -311,16 +301,13 @@ def __init__(
strides=[2] * (len(num_channels) - 1) + [stem_kernel_size[-1]],
upsample_pre_conv="default" if decoder_upsample_pre_conv else None,
)
if out_stack_depth == 1:
self.head = UnsqueezeHead()
else:
self.head = PixelToVoxelHead(
decoder_channels[-1],
out_channels,
out_stack_depth,
head_expansion_ratio,
pool=head_pool,
)
self.head = PixelToVoxelHead(
decoder_channels[-1],
out_channels,
out_stack_depth,
head_expansion_ratio,
pool=head_pool,
)
self.out_stack_depth = out_stack_depth

@property
Expand Down
2 changes: 1 addition & 1 deletion viscy/unet/networks/fcmae.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from torch import BoolTensor, Size, Tensor, nn

from viscy.unet.networks.Unet21D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead
from viscy.unet.networks.Unet22D import PixelToVoxelHead, Unet2dDecoder, UnsqueezeHead


def _init_weights(module: nn.Module) -> None:
Expand Down

0 comments on commit eeb8ce2

Please sign in to comment.