diff --git a/cyto_dl/nn/vits/blocks/patchify.py b/cyto_dl/nn/vits/blocks/patchify.py index ee5a2c56..98d61592 100644 --- a/cyto_dl/nn/vits/blocks/patchify.py +++ b/cyto_dl/nn/vits/blocks/patchify.py @@ -6,7 +6,7 @@ from einops.layers.torch import Rearrange, Reduce from timm.models.layers import trunc_normal_ -from cyto_dl.nn.vits.utils import take_indexes +from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes def random_indexes(size: int, device): @@ -27,6 +27,7 @@ def __init__( context_pixels: List[int] = [0, 0, 0], input_channels: int = 1, tasks: Optional[List[str]] = [], + learnable_pos_embedding: bool = True, ): """ Parameters @@ -45,12 +46,16 @@ def __init__( Number of input channels tasks: List[str] List of tasks to encode + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() self.n_patches = np.asarray(n_patches) self.spatial_dims = spatial_dims - self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(n_patches), 1, emb_dim)) + self.pos_embedding = get_positional_embedding( + n_patches, emb_dim, learnable=learnable_pos_embedding, use_cls_token=False + ) context_pixels = context_pixels[:spatial_dims] weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int) @@ -112,7 +117,6 @@ def __init__( self._init_weight() def _init_weight(self): - trunc_normal_(self.pos_embedding, std=0.02) for task in self.task_embedding: trunc_normal_(self.task_embedding[task], std=0.02) diff --git a/cyto_dl/nn/vits/cross_mae.py b/cyto_dl/nn/vits/cross_mae.py index e8bde3f4..3981de6e 100644 --- a/cyto_dl/nn/vits/cross_mae.py +++ b/cyto_dl/nn/vits/cross_mae.py @@ -1,6 +1,5 @@ from typing import List, Optional -import numpy as np import torch import torch.nn as nn from einops import rearrange @@ -8,7 +7,7 @@ from timm.models.layers import trunc_normal_ from cyto_dl.nn.vits.blocks import CrossAttentionBlock -from cyto_dl.nn.vits.utils import take_indexes +from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes class CrossMAE_Decoder(torch.nn.Module): @@ -24,6 +23,7 @@ def __init__( emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, num_head: Optional[int] = 3, + learnable_pos_embedding: Optional[bool] = True, ) -> None: """ Parameters @@ -40,6 +40,8 @@ def __init__( Number of transformer layers num_head: int Number of heads in transformer + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() @@ -58,7 +60,10 @@ def __init__( self.projection = torch.nn.Linear(enc_dim, emb_dim) self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim)) + + self.pos_embedding = get_positional_embedding( + num_patches, emb_dim, learnable=learnable_pos_embedding + ) self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) self.num_patches = torch.as_tensor(num_patches) @@ -86,7 +91,6 @@ def __init__( def init_weight(self): trunc_normal_(self.mask_token, std=0.02) - trunc_normal_(self.pos_embedding, std=0.02) def forward(self, features, forward_indexes, backward_indexes): # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index c05a6b31..1617bf68 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -12,7 +12,7 @@ from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder -from cyto_dl.nn.vits.utils import take_indexes +from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes class MAE_Encoder(torch.nn.Module): @@ -107,6 +107,7 @@ def __init__( emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, num_head: Optional[int] = 3, + learnable_pos_embedding: Optional[bool] = True, ) -> None: """ Parameters @@ -123,12 +124,17 @@ def __init__( Number of transformer layers num_head: int Number of heads in transformer + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() self.projection_norm = nn.LayerNorm(emb_dim) self.projection = torch.nn.Linear(enc_dim, emb_dim) self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim)) + + self.pos_embedding = get_positional_embedding( + num_patches, emb_dim, learnable=learnable_pos_embedding + ) self.transformer = torch.nn.Sequential( *[Block(emb_dim, num_head) for _ in range(num_layer)] @@ -161,7 +167,6 @@ def __init__( def init_weight(self): trunc_normal_(self.mask_token, std=0.02) - trunc_normal_(self.pos_embedding, std=0.02) def forward(self, features, forward_indexes, backward_indexes): # project from encoder dimension to decoder dimension @@ -221,6 +226,7 @@ def __init__( context_pixels: Optional[List[int]] = [0, 0, 0], input_channels: Optional[int] = 1, features_only: Optional[bool] = False, + learnable_pos_embedding: Optional[bool] = True, ) -> None: """ Parameters @@ -251,6 +257,8 @@ def __init__( Number of input channels features_only: bool Only use encoder to extract features + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" @@ -291,6 +299,7 @@ def __init__( emb_dim=decoder_dim, num_layer=decoder_layer, num_head=decoder_head, + learnable_pos_embedding=learnable_pos_embedding, ) def forward(self, img): diff --git a/cyto_dl/nn/vits/utils.py b/cyto_dl/nn/vits/utils.py index 61263ccd..b918fb73 100644 --- a/cyto_dl/nn/vits/utils.py +++ b/cyto_dl/nn/vits/utils.py @@ -1,6 +1,42 @@ +from typing import Sequence + +import numpy as np import torch -from einops import repeat +from einops import rearrange, repeat +from positional_encodings.torch_encodings import ( + PositionalEncoding2D, + PositionalEncoding3D, +) +from timm.models.layers import trunc_normal_ def take_indexes(sequences, indexes): return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1])) + + +def get_positional_embedding( + num_patches: Sequence[int], emb_dim: int, use_cls_token: bool = True, learnable: bool = True +): + """Generate a positional embedding (with or without a cls token) for a given number of patches + and embedding dimension. + + Can be either learnable or fixed. + """ + if learnable: + pe = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + int(use_cls_token), 1, emb_dim)) + trunc_normal_(pe, std=0.02) + return pe + else: + test_tensor = torch.ones(1, *num_patches, emb_dim) + if len(num_patches) not in (2, 3): + raise ValueError("Only 2d and 3d positional encodings are supported") + if len(num_patches) == 2: + pe = PositionalEncoding2D(emb_dim)(test_tensor) + pe = rearrange(pe, "b y x c -> (y x) b c") + elif len(num_patches) == 3: + pe = PositionalEncoding3D(emb_dim)(test_tensor) + pe = rearrange(pe, "b z y x c -> (z y x) b c") + if use_cls_token: + cls_token = torch.zeros(1, 1, emb_dim) + pe = torch.cat([cls_token, pe], dim=0) + return torch.nn.Parameter(pe, requires_grad=False) diff --git a/pyproject.toml b/pyproject.toml index ba4716eb..192e3e84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ dependencies = [ "bioio-ome-tiff", "bioio-tifffile", "online-stats>=2023", + "positional-encodings>=6.0.3", ] requires-python = ">=3.9,<3.11"