Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add sincos pos embed #416

Merged
merged 3 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cyto_dl/nn/vits/blocks/patchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from einops.layers.torch import Rearrange, Reduce
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.utils import take_indexes
from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes


def random_indexes(size: int, device):
Expand All @@ -27,6 +27,7 @@ def __init__(
context_pixels: List[int] = [0, 0, 0],
input_channels: int = 1,
tasks: Optional[List[str]] = [],
learnable_pos_embedding: bool = True,
):
"""
Parameters
Expand All @@ -45,12 +46,16 @@ def __init__(
Number of input channels
tasks: List[str]
List of tasks to encode
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings
"""
super().__init__()
self.n_patches = np.asarray(n_patches)
self.spatial_dims = spatial_dims

self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(n_patches), 1, emb_dim))
self.pos_embedding = get_positional_embedding(
n_patches, emb_dim, learnable=learnable_pos_embedding, use_cls_token=False
)

context_pixels = context_pixels[:spatial_dims]
weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int)
Expand Down Expand Up @@ -112,7 +117,6 @@ def __init__(
self._init_weight()

def _init_weight(self):
trunc_normal_(self.pos_embedding, std=0.02)
for task in self.task_embedding:
trunc_normal_(self.task_embedding[task], std=0.02)

Expand Down
12 changes: 8 additions & 4 deletions cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.blocks import CrossAttentionBlock
from cyto_dl.nn.vits.utils import take_indexes
from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes


class CrossMAE_Decoder(torch.nn.Module):
Expand All @@ -24,6 +23,7 @@ def __init__(
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 4,
num_head: Optional[int] = 3,
learnable_pos_embedding: Optional[bool] = True,
) -> None:
"""
Parameters
Expand All @@ -40,6 +40,8 @@ def __init__(
Number of transformer layers
num_head: int
Number of heads in transformer
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used
"""
super().__init__()

Expand All @@ -58,7 +60,10 @@ def __init__(

self.projection = torch.nn.Linear(enc_dim, emb_dim)
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim))

self.pos_embedding = get_positional_embedding(
num_patches, emb_dim, learnable=learnable_pos_embedding
)

self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size)))
self.num_patches = torch.as_tensor(num_patches)
Expand Down Expand Up @@ -86,7 +91,6 @@ def __init__(

def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes):
# HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers
Expand Down
15 changes: 12 additions & 3 deletions cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify
from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder
from cyto_dl.nn.vits.utils import take_indexes
from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes


class MAE_Encoder(torch.nn.Module):
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 4,
num_head: Optional[int] = 3,
learnable_pos_embedding: Optional[bool] = True,
) -> None:
"""
Parameters
Expand All @@ -123,12 +124,17 @@ def __init__(
Number of transformer layers
num_head: int
Number of heads in transformer
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings
"""
super().__init__()
self.projection_norm = nn.LayerNorm(emb_dim)
self.projection = torch.nn.Linear(enc_dim, emb_dim)
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim))

self.pos_embedding = get_positional_embedding(
num_patches, emb_dim, learnable=learnable_pos_embedding
)

self.transformer = torch.nn.Sequential(
*[Block(emb_dim, num_head) for _ in range(num_layer)]
Expand Down Expand Up @@ -161,7 +167,6 @@ def __init__(

def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes):
# project from encoder dimension to decoder dimension
Expand Down Expand Up @@ -221,6 +226,7 @@ def __init__(
context_pixels: Optional[List[int]] = [0, 0, 0],
input_channels: Optional[int] = 1,
features_only: Optional[bool] = False,
learnable_pos_embedding: Optional[bool] = True,
) -> None:
"""
Parameters
Expand Down Expand Up @@ -251,6 +257,8 @@ def __init__(
Number of input channels
features_only: bool
Only use encoder to extract features
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings
"""
super().__init__()
assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3"
Expand Down Expand Up @@ -291,6 +299,7 @@ def __init__(
emb_dim=decoder_dim,
num_layer=decoder_layer,
num_head=decoder_head,
learnable_pos_embedding=learnable_pos_embedding,
)

def forward(self, img):
Expand Down
33 changes: 32 additions & 1 deletion cyto_dl/nn/vits/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
from typing import Sequence

import numpy as np
import torch
from einops import repeat
from einops import rearrange, repeat
from positional_encodings.torch_encodings import (
PositionalEncoding2D,
PositionalEncoding3D,
)
from timm.models.layers import trunc_normal_


def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))


def get_positional_embedding(
num_patches: Sequence[int], emb_dim: int, use_cls_token: bool = True, learnable: bool = True
):
if learnable:
pe = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + int(use_cls_token), 1, emb_dim))
trunc_normal_(pe, std=0.02)
return pe
else:
test_tensor = torch.ones(1, *num_patches, emb_dim)
if len(num_patches) not in (2, 3):
raise ValueError("Only 2d and 3d positional encodings are supported")
if len(num_patches) == 2:
pe = PositionalEncoding2D(emb_dim)(test_tensor)
pe = rearrange(pe, "b y x c -> (y x) b c")
elif len(num_patches) == 3:
pe = PositionalEncoding3D(emb_dim)(test_tensor)
pe = rearrange(pe, "b z y x c -> (z y x) b c")
if use_cls_token:
cls_token = torch.zeros(1, 1, emb_dim)
pe = torch.cat([cls_token, pe], dim=0)
return torch.nn.Parameter(pe, requires_grad=False)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"bioio-ome-tiff",
"bioio-tifffile",
"online-stats>=2023",
"positional-encodings>=6.0.3",
]
requires-python = ">=3.9,<3.11"

Expand Down
Loading