From fb2ec0f706ecf7498850ebea4eee223c8ba07d70 Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Thu, 8 Aug 2024 22:56:40 -0700 Subject: [PATCH] Configurable drop path rate in contrastive models (#131) * log instead of print * configurable drop path rate * fix docstring --- viscy/light/engine.py | 2 ++ viscy/representation/contrastive.py | 41 ++++++++++++++++++----------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 326639d6..6e658100 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -584,6 +584,7 @@ def __init__( stem_kernel_size: tuple[int, int, int] = (5, 4, 4), embedding_len: int = 256, predict: bool = False, + drop_path_rate: float = 0.2, tracks_path: str = "data/tracks", features_output_path: str = "", projections_output_path: str = "", @@ -615,6 +616,7 @@ def __init__( stem_kernel_size=stem_kernel_size, embedding_len=embedding_len, predict=predict, + drop_path_rate=drop_path_rate, ) self.example_input_array = torch.rand( 1, in_channels, in_stack_depth, *example_input_yx_shape diff --git a/viscy/representation/contrastive.py b/viscy/representation/contrastive.py index 1dc269ce..13be2b15 100644 --- a/viscy/representation/contrastive.py +++ b/viscy/representation/contrastive.py @@ -1,9 +1,13 @@ +import logging + import timm import torch.nn as nn import torch.nn.functional as F from viscy.unet.networks.unext2 import StemDepthtoChannels +_logger = logging.getLogger("lightning.pytorch") + class ContrastiveEncoder(nn.Module): def __init__( @@ -15,33 +19,38 @@ def __init__( embedding_len: int = 256, stem_stride: int = 2, predict: bool = False, + drop_path_rate: float = 0.2, ): + """ContrastiveEncoder network that uses + ConvNext and ResNet backbons from timm. + + :param str backbone: Backbone architecture for the encoder, + defaults to "convnext_tiny" + :param int in_channels: Number of input channels, defaults to 2 + :param int in_stack_depth: Number of input slices in z-stack, defaults to 12 + :param tuple[int, int, int] stem_kernel_size: 3D kernel size for the stem. + Input stack depth must be divisible by the kernel depth, + defaults to (5, 3, 3) + :param int embedding_len: Length of the embedding vector, defaults to 256 + :param int stem_stride: stride of the stem, defaults to 2 + :param bool predict: prediction mode, defaults to False + :param float drop_path_rate: probability that residual connections + are dropped during training, defaults to 0.2 + """ super().__init__() - self.predict = predict self.backbone = backbone - """ - ContrastiveEncoder network that uses ConvNext and ResNet backbons from timm. - - Parameters: - - backbone (str): Backbone architecture for the encoder. Default is "convnext_tiny". - - in_channels (int): Number of input channels. Default is 2. - - in_stack_depth (int): Number of input slices in z-stack. Default is 15. - - stem_kernel_size (tuple[int, int, int]): 3D kernel size for the stem. Input stack depth must be divisible by the kernel depth. Default is (5, 3, 3). - - embedding_len (int): Length of the embedding. Default is 1000. - """ - encoder = timm.create_model( backbone, pretrained=True, features_only=False, - drop_path_rate=0.2, + drop_path_rate=drop_path_rate, num_classes=3 * embedding_len, ) if "convnext" in backbone: - print("Using ConvNext backbone.") + _logger.debug(f"Using ConvNeXt backbone for {type(self).__name__}.") in_channels_encoder = encoder.stem[0].out_channels @@ -58,7 +67,7 @@ def __init__( encoder.head.fc = nn.Identity() elif "resnet" in backbone: - print("Using ResNet backbone.") + _logger.debug(f"Using ResNet backbone for {type(self).__name__}") # Adapt stem and projection head of resnet here. # replace the stem designed for RGB images with a stem designed to handle 3D multi-channel input. @@ -73,7 +82,7 @@ def __init__( encoder.fc = nn.Identity() # Create a new stem that can handle 3D multi-channel input. - print("using stem kernel size", stem_kernel_size) + _logger.debug(f"Stem kernel size: {stem_kernel_size}") self.stem = StemDepthtoChannels( in_channels, in_stack_depth, in_channels_encoder, stem_kernel_size )