Skip to content

Commit

Permalink
draft projection head per Update the projection head (normalization a…
Browse files Browse the repository at this point in the history
…nd size). #139
  • Loading branch information
mattersoflight authored and ziw-liu committed Aug 31, 2024
1 parent 6e7d61f commit 5f76d61
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions viscy/representation/contrastive.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(
in_channels: int = 2,
in_stack_depth: int = 12,
stem_kernel_size: tuple[int, int, int] = (5, 4, 4),
embedding_len: int = 256,
embedding_len: int = 1024,
stem_stride: int = 2,
predict: bool = False,
drop_path_rate: float = 0.2,
Expand All @@ -31,7 +31,7 @@ def __init__(
:param tuple[int, int, int] stem_kernel_size: 3D kernel size for the stem.
Input stack depth must be divisible by the kernel depth,
defaults to (5, 3, 3)
:param int embedding_len: Length of the embedding vector, defaults to 256
:param int embedding_len: Length of the embedding vector, defaults to 1024
:param int stem_stride: stride of the stem, defaults to 2
:param bool predict: prediction mode, defaults to False
:param float drop_path_rate: probability that residual connections
Expand All @@ -46,7 +46,7 @@ def __init__(
pretrained=True,
features_only=False,
drop_path_rate=drop_path_rate,
num_classes=3 * embedding_len,
num_classes=embedding_len,
)

if "convnext" in backbone:
Expand All @@ -59,9 +59,11 @@ def __init__(

# Save projection head separately and erase the projection head contained within the encoder.
projection = nn.Sequential(
nn.Linear(encoder.head.fc.in_features, 3 * embedding_len),
nn.Linear(encoder.head.fc.in_features, embedding_len),
nn.BatchNorm1d(embedding_len),
nn.ReLU(inplace=True),
nn.Linear(3 * embedding_len, embedding_len),
nn.BatchNorm1d(embedding_len),
nn.Linear(embedding_len, embedding_len / 8),
)

encoder.head.fc = nn.Identity()
Expand All @@ -75,9 +77,11 @@ def __init__(
encoder.conv1 = nn.Identity()

projection = nn.Sequential(
nn.Linear(encoder.fc.in_features, 3 * embedding_len),
nn.Linear(encoder.fc.in_features, embedding_len),
nn.BatchNorm1d(embedding_len),
nn.ReLU(inplace=True),
nn.Linear(3 * embedding_len, embedding_len),
nn.BatchNorm1d(embedding_len),
nn.Linear(embedding_len, embedding_len / 8),
)
encoder.fc = nn.Identity()

Expand Down

0 comments on commit 5f76d61

Please sign in to comment.