Skip to content

Commit

Permalink
VIT Decoder updates (#339)
Browse files Browse the repository at this point in the history
* add crossmae

* multich segmentation projection

* add cross attention block

* add crossmae decoder

* add crossmae to init

* remove patchify

* use crossmae as default

* remove note

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Feb 29, 2024
1 parent 4860c8b commit 7536cd8
Show file tree
Hide file tree
Showing 7 changed files with 394 additions and 61 deletions.
3 changes: 1 addition & 2 deletions configs/model/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ backbone:
encoder_layer: 2
encoder_head: 1
decoder_layer: 1
mask_ratio: 0.75

use_crossmae: True
task_heads: ${kv_to_dict:${model._aux._tasks}}

optimizer:
Expand Down
3 changes: 2 additions & 1 deletion cyto_dl/nn/vits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .cross_mae import CrossMAE_Decoder
from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT
from .seg import Seg_ViT, SupperresDecoder
from .seg import Seg_ViT, SuperresDecoder
1 change: 1 addition & 0 deletions cyto_dl/nn/vits/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cross_attention import CrossAttention, CrossAttentionBlock, CrossSelfBlock, Mlp
155 changes: 155 additions & 0 deletions cyto_dl/nn/vits/blocks/cross_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath
from timm.models.vision_transformer import Block

# from https://github.com/TonyLianLong/CrossMAE/blob/main/transformer_utils.py


class Mlp(nn.Module):
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


class CrossAttention(nn.Module):
def __init__(
self,
encoder_dim,
decoder_dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = decoder_dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias)
self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(decoder_dim, decoder_dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x, y):
"""query from decoder (x), key and value from encoder (y)"""
B, N, C = x.shape
Ny = y.shape[1]
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = (
self.kv(y)
.reshape(B, Ny, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]

attn = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop,
)
x = attn.transpose(1, 2).reshape(B, N, C)

x = self.proj(x)
x = self.proj_drop(x)
return x


class CrossAttentionBlock(nn.Module):
def __init__(
self,
encoder_dim,
decoder_dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(decoder_dim)
self.cross_attn = CrossAttention(
encoder_dim,
decoder_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(decoder_dim)
mlp_hidden_dim = int(decoder_dim * mlp_ratio)
self.mlp = Mlp(
in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)

def forward(self, x, y):
"""
x: decoder feature; y: encoder feature (after layernorm)
"""
x = x + self.drop_path(self.cross_attn(self.norm1(x), y))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x


class CrossSelfBlock(nn.Module):
def __init__(
self,
emb_dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.x_attn_block = CrossAttentionBlock(
emb_dim,
emb_dim,
num_heads,
mlp_ratio,
qkv_bias,
qk_scale,
drop,
attn_drop,
drop_path,
act_layer,
norm_layer,
)
self.self_attn_block = Block(dim=emb_dim, num_heads=num_heads)

def forward(self, x, y):
"""
x: decoder feature; y: encoder feature
"""
x = self.x_attn_block(x, y)
x = self.self_attn_block(x)
return x
159 changes: 159 additions & 0 deletions cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.blocks import CrossAttentionBlock


def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))


class CrossMAE_Decoder(torch.nn.Module):
"""Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masekd tokens only attend
to visible tokens."""

def __init__(
self,
num_patches: List[int],
spatial_dims: int = 3,
base_patch_size: Optional[List[int]] = [4, 8, 8],
enc_dim: Optional[int] = 768,
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 4,
num_head: Optional[int] = 3,
) -> None:
"""
Parameters
----------
num_patches: List[int]
Number of patches in each dimension
base_patch_size: Tuple[int]
Size of each patch
enc_dim: int
Dimension of encoder
emb_dim: int
Dimension of embedding
num_layer: int
Number of transformer layers
num_head: int
Number of heads in transformer
"""
super().__init__()

self.transformer = torch.nn.ParameterList(
[
CrossAttentionBlock(
encoder_dim=emb_dim,
decoder_dim=emb_dim,
num_heads=num_head,
)
for _ in range(num_layer)
]
)
self.decoder_norm = nn.LayerNorm(emb_dim)
self.projection_norm = nn.LayerNorm(emb_dim)

self.projection = torch.nn.Linear(enc_dim, emb_dim)
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim))

self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size)))
self.num_patches = torch.as_tensor(num_patches)

if spatial_dims == 3:
self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
n_patch_z=num_patches[0],
n_patch_y=num_patches[1],
n_patch_x=num_patches[2],
patch_size_z=base_patch_size[0],
patch_size_y=base_patch_size[1],
patch_size_x=base_patch_size[2],
)
elif spatial_dims == 2:
self.patch2img = Rearrange(
"(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
n_patch_y=num_patches[0],
n_patch_x=num_patches[1],
patch_size_y=base_patch_size[0],
patch_size_x=base_patch_size[1],
)

self.init_weight()

def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes, patch_size):
T, B, C = features.shape
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now
features = self.projection_norm(self.projection(features))

# add cls token
backward_indexes = torch.cat(
[torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
dim=0,
)
forward_indexes = torch.cat(
[torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1],
dim=0,
)
# fill in masked regions
features = torch.cat(
[
features,
self.mask_token.expand(
backward_indexes.shape[0] - features.shape[0], features.shape[1], -1
),
],
dim=0,
)

# unshuffle to original positions for positional embedding so we can do cross attention during decoding
features = take_indexes(features, backward_indexes)
features = features + self.pos_embedding

reshuffled = take_indexes(features, forward_indexes)
features, masked = reshuffled[:T], reshuffled[T:]

masked = rearrange(masked, "t b c -> b t c")
features = rearrange(features, "t b c -> b t c")

for transformer in self.transformer:
masked = transformer(masked, features)

# noneed to remove cls token, it's a part of the features vector
masked = rearrange(masked, "b t c -> t b c")

# (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x)
masked = self.decoder_norm(masked)
patches = self.head(masked)

# add back in visible/encoded tokens that we don't calculate loss on
patches = torch.cat(
[torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches],
dim=0,
)
patches = take_indexes(patches, backward_indexes[1:] - 1)

mask = torch.zeros_like(patches)
mask[T - 1 :] = 1
mask = take_indexes(mask, backward_indexes[1:] - 1)
# patches to image
img = self.patch2img(patches)
img = torch.nn.functional.interpolate(
img, tuple(torch.as_tensor(patch_size) * self.num_patches)
)

mask = self.patch2img(mask)
mask = torch.nn.functional.interpolate(
mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest"
)
return img, mask
Loading

0 comments on commit 7536cd8

Please sign in to comment.