Skip to content

Commit

Permalink
Feature/add sincos pos embed (#416)
Browse files Browse the repository at this point in the history
* allow learnable and fixed positional embeddings

* add positional-encodings

* describe when to use fixed embeddings

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Aug 14, 2024
1 parent af537b3 commit 4b94064
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 11 deletions.
10 changes: 7 additions & 3 deletions cyto_dl/nn/vits/blocks/patchify.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from einops.layers.torch import Rearrange, Reduce
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.utils import take_indexes
from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes


def random_indexes(size: int, device):
Expand All @@ -27,6 +27,7 @@ def __init__(
context_pixels: List[int] = [0, 0, 0],
input_channels: int = 1,
tasks: Optional[List[str]] = [],
learnable_pos_embedding: bool = True,
):
"""
Parameters
Expand All @@ -45,12 +46,16 @@ def __init__(
Number of input channels
tasks: List[str]
List of tasks to encode
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.
"""
super().__init__()
self.n_patches = np.asarray(n_patches)
self.spatial_dims = spatial_dims

self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(n_patches), 1, emb_dim))
self.pos_embedding = get_positional_embedding(
n_patches, emb_dim, learnable=learnable_pos_embedding, use_cls_token=False
)

context_pixels = context_pixels[:spatial_dims]
weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int)
Expand Down Expand Up @@ -112,7 +117,6 @@ def __init__(
self._init_weight()

def _init_weight(self):
trunc_normal_(self.pos_embedding, std=0.02)
for task in self.task_embedding:
trunc_normal_(self.task_embedding[task], std=0.02)

Expand Down
12 changes: 8 additions & 4 deletions cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.blocks import CrossAttentionBlock
from cyto_dl.nn.vits.utils import take_indexes
from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes


class CrossMAE_Decoder(torch.nn.Module):
Expand All @@ -24,6 +23,7 @@ def __init__(
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 4,
num_head: Optional[int] = 3,
learnable_pos_embedding: Optional[bool] = True,
) -> None:
"""
Parameters
Expand All @@ -40,6 +40,8 @@ def __init__(
Number of transformer layers
num_head: int
Number of heads in transformer
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images.
"""
super().__init__()

Expand All @@ -58,7 +60,10 @@ def __init__(

self.projection = torch.nn.Linear(enc_dim, emb_dim)
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim))

self.pos_embedding = get_positional_embedding(
num_patches, emb_dim, learnable=learnable_pos_embedding
)

self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size)))
self.num_patches = torch.as_tensor(num_patches)
Expand Down Expand Up @@ -86,7 +91,6 @@ def __init__(

def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes):
# HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers
Expand Down
15 changes: 12 additions & 3 deletions cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify
from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder
from cyto_dl.nn.vits.utils import take_indexes
from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes


class MAE_Encoder(torch.nn.Module):
Expand Down Expand Up @@ -107,6 +107,7 @@ def __init__(
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 4,
num_head: Optional[int] = 3,
learnable_pos_embedding: Optional[bool] = True,
) -> None:
"""
Parameters
Expand All @@ -123,12 +124,17 @@ def __init__(
Number of transformer layers
num_head: int
Number of heads in transformer
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.
"""
super().__init__()
self.projection_norm = nn.LayerNorm(emb_dim)
self.projection = torch.nn.Linear(enc_dim, emb_dim)
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim))

self.pos_embedding = get_positional_embedding(
num_patches, emb_dim, learnable=learnable_pos_embedding
)

self.transformer = torch.nn.Sequential(
*[Block(emb_dim, num_head) for _ in range(num_layer)]
Expand Down Expand Up @@ -161,7 +167,6 @@ def __init__(

def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes):
# project from encoder dimension to decoder dimension
Expand Down Expand Up @@ -221,6 +226,7 @@ def __init__(
context_pixels: Optional[List[int]] = [0, 0, 0],
input_channels: Optional[int] = 1,
features_only: Optional[bool] = False,
learnable_pos_embedding: Optional[bool] = True,
) -> None:
"""
Parameters
Expand Down Expand Up @@ -251,6 +257,8 @@ def __init__(
Number of input channels
features_only: bool
Only use encoder to extract features
learnable_pos_embedding: bool
If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images.
"""
super().__init__()
assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3"
Expand Down Expand Up @@ -291,6 +299,7 @@ def __init__(
emb_dim=decoder_dim,
num_layer=decoder_layer,
num_head=decoder_head,
learnable_pos_embedding=learnable_pos_embedding,
)

def forward(self, img):
Expand Down
38 changes: 37 additions & 1 deletion cyto_dl/nn/vits/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,42 @@
from typing import Sequence

import numpy as np
import torch
from einops import repeat
from einops import rearrange, repeat
from positional_encodings.torch_encodings import (
PositionalEncoding2D,
PositionalEncoding3D,
)
from timm.models.layers import trunc_normal_


def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))


def get_positional_embedding(
num_patches: Sequence[int], emb_dim: int, use_cls_token: bool = True, learnable: bool = True
):
"""Generate a positional embedding (with or without a cls token) for a given number of patches
and embedding dimension.
Can be either learnable or fixed.
"""
if learnable:
pe = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + int(use_cls_token), 1, emb_dim))
trunc_normal_(pe, std=0.02)
return pe
else:
test_tensor = torch.ones(1, *num_patches, emb_dim)
if len(num_patches) not in (2, 3):
raise ValueError("Only 2d and 3d positional encodings are supported")
if len(num_patches) == 2:
pe = PositionalEncoding2D(emb_dim)(test_tensor)
pe = rearrange(pe, "b y x c -> (y x) b c")
elif len(num_patches) == 3:
pe = PositionalEncoding3D(emb_dim)(test_tensor)
pe = rearrange(pe, "b z y x c -> (z y x) b c")
if use_cls_token:
cls_token = torch.zeros(1, 1, emb_dim)
pe = torch.cat([cls_token, pe], dim=0)
return torch.nn.Parameter(pe, requires_grad=False)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"bioio-ome-tiff",
"bioio-tifffile",
"online-stats>=2023",
"positional-encodings>=6.0.3",
]
requires-python = ">=3.9,<3.11"

Expand Down

0 comments on commit 4b94064

Please sign in to comment.