Skip to content

Commit

Permalink
Configurable drop path rate in contrastive models (#131)
Browse files Browse the repository at this point in the history
* log instead of print

* configurable drop path rate

* fix docstring
  • Loading branch information
ziw-liu authored and edyoshikun committed Aug 15, 2024
1 parent b9bd1f6 commit fb2ec0f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
2 changes: 2 additions & 0 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,7 @@ def __init__(
stem_kernel_size: tuple[int, int, int] = (5, 4, 4),
embedding_len: int = 256,
predict: bool = False,
drop_path_rate: float = 0.2,
tracks_path: str = "data/tracks",
features_output_path: str = "",
projections_output_path: str = "",
Expand Down Expand Up @@ -615,6 +616,7 @@ def __init__(
stem_kernel_size=stem_kernel_size,
embedding_len=embedding_len,
predict=predict,
drop_path_rate=drop_path_rate,
)
self.example_input_array = torch.rand(
1, in_channels, in_stack_depth, *example_input_yx_shape
Expand Down
41 changes: 25 additions & 16 deletions viscy/representation/contrastive.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import logging

import timm
import torch.nn as nn
import torch.nn.functional as F

from viscy.unet.networks.unext2 import StemDepthtoChannels

_logger = logging.getLogger("lightning.pytorch")


class ContrastiveEncoder(nn.Module):
def __init__(
Expand All @@ -15,33 +19,38 @@ def __init__(
embedding_len: int = 256,
stem_stride: int = 2,
predict: bool = False,
drop_path_rate: float = 0.2,
):
"""ContrastiveEncoder network that uses
ConvNext and ResNet backbons from timm.
:param str backbone: Backbone architecture for the encoder,
defaults to "convnext_tiny"
:param int in_channels: Number of input channels, defaults to 2
:param int in_stack_depth: Number of input slices in z-stack, defaults to 12
:param tuple[int, int, int] stem_kernel_size: 3D kernel size for the stem.
Input stack depth must be divisible by the kernel depth,
defaults to (5, 3, 3)
:param int embedding_len: Length of the embedding vector, defaults to 256
:param int stem_stride: stride of the stem, defaults to 2
:param bool predict: prediction mode, defaults to False
:param float drop_path_rate: probability that residual connections
are dropped during training, defaults to 0.2
"""
super().__init__()

self.predict = predict
self.backbone = backbone

"""
ContrastiveEncoder network that uses ConvNext and ResNet backbons from timm.
Parameters:
- backbone (str): Backbone architecture for the encoder. Default is "convnext_tiny".
- in_channels (int): Number of input channels. Default is 2.
- in_stack_depth (int): Number of input slices in z-stack. Default is 15.
- stem_kernel_size (tuple[int, int, int]): 3D kernel size for the stem. Input stack depth must be divisible by the kernel depth. Default is (5, 3, 3).
- embedding_len (int): Length of the embedding. Default is 1000.
"""

encoder = timm.create_model(
backbone,
pretrained=True,
features_only=False,
drop_path_rate=0.2,
drop_path_rate=drop_path_rate,
num_classes=3 * embedding_len,
)

if "convnext" in backbone:
print("Using ConvNext backbone.")
_logger.debug(f"Using ConvNeXt backbone for {type(self).__name__}.")

in_channels_encoder = encoder.stem[0].out_channels

Expand All @@ -58,7 +67,7 @@ def __init__(
encoder.head.fc = nn.Identity()

elif "resnet" in backbone:
print("Using ResNet backbone.")
_logger.debug(f"Using ResNet backbone for {type(self).__name__}")
# Adapt stem and projection head of resnet here.
# replace the stem designed for RGB images with a stem designed to handle 3D multi-channel input.

Expand All @@ -73,7 +82,7 @@ def __init__(
encoder.fc = nn.Identity()

# Create a new stem that can handle 3D multi-channel input.
print("using stem kernel size", stem_kernel_size)
_logger.debug(f"Stem kernel size: {stem_kernel_size}")
self.stem = StemDepthtoChannels(
in_channels, in_stack_depth, in_channels_encoder, stem_kernel_size
)
Expand Down

0 comments on commit fb2ec0f

Please sign in to comment.