Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VIT Decoder updates #339

Merged
merged 8 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions configs/model/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ backbone:
encoder_layer: 2
encoder_head: 1
decoder_layer: 1
mask_ratio: 0.75

use_crossmae: True
task_heads: ${kv_to_dict:${model._aux._tasks}}

optimizer:
Expand Down
3 changes: 2 additions & 1 deletion cyto_dl/nn/vits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .cross_mae import CrossMAE_Decoder
from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT
from .seg import Seg_ViT, SupperresDecoder
from .seg import Seg_ViT, SuperresDecoder
1 change: 1 addition & 0 deletions cyto_dl/nn/vits/blocks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cross_attention import CrossAttention, CrossAttentionBlock, CrossSelfBlock, Mlp
155 changes: 155 additions & 0 deletions cyto_dl/nn/vits/blocks/cross_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath
from timm.models.vision_transformer import Block

# from https://github.com/TonyLianLong/CrossMAE/blob/main/transformer_utils.py


class Mlp(nn.Module):
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x


class CrossAttention(nn.Module):
def __init__(
self,
encoder_dim,
decoder_dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = decoder_dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias)
self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias)
self.attn_drop = attn_drop
self.proj = nn.Linear(decoder_dim, decoder_dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, x, y):
"""query from decoder (x), key and value from encoder (y)"""
B, N, C = x.shape
Ny = y.shape[1]
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = (
self.kv(y)
.reshape(B, Ny, 2, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
k, v = kv[0], kv[1]

attn = F.scaled_dot_product_attention(
q,
k,
v,
dropout_p=self.attn_drop,
)
x = attn.transpose(1, 2).reshape(B, N, C)

x = self.proj(x)
x = self.proj_drop(x)
return x


class CrossAttentionBlock(nn.Module):
def __init__(
self,
encoder_dim,
decoder_dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.norm1 = norm_layer(decoder_dim)
self.cross_attn = CrossAttention(
encoder_dim,
decoder_dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(decoder_dim)
mlp_hidden_dim = int(decoder_dim * mlp_ratio)
self.mlp = Mlp(
in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
)

def forward(self, x, y):
"""
x: decoder feature; y: encoder feature (after layernorm)
"""
x = x + self.drop_path(self.cross_attn(self.norm1(x), y))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x


class CrossSelfBlock(nn.Module):
def __init__(
self,
emb_dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__()
self.x_attn_block = CrossAttentionBlock(
emb_dim,
emb_dim,
num_heads,
mlp_ratio,
qkv_bias,
qk_scale,
drop,
attn_drop,
drop_path,
act_layer,
norm_layer,
)
self.self_attn_block = Block(dim=emb_dim, num_heads=num_heads)

def forward(self, x, y):
"""
x: decoder feature; y: encoder feature
"""
x = self.x_attn_block(x, y)
x = self.self_attn_block(x)
return x
159 changes: 159 additions & 0 deletions cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
from typing import List, Optional

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from timm.models.layers import trunc_normal_

from cyto_dl.nn.vits.blocks import CrossAttentionBlock


def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))


class CrossMAE_Decoder(torch.nn.Module):
"""Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masekd tokens only attend
to visible tokens."""

def __init__(
self,
num_patches: List[int],
spatial_dims: int = 3,
base_patch_size: Optional[List[int]] = [4, 8, 8],
enc_dim: Optional[int] = 768,
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 4,
num_head: Optional[int] = 3,
) -> None:
"""
Parameters
----------
num_patches: List[int]
Number of patches in each dimension
base_patch_size: Tuple[int]
Size of each patch
enc_dim: int
Dimension of encoder
emb_dim: int
Dimension of embedding
num_layer: int
Number of transformer layers
num_head: int
Number of heads in transformer
"""
super().__init__()

self.transformer = torch.nn.ParameterList(
[
CrossAttentionBlock(
encoder_dim=emb_dim,
decoder_dim=emb_dim,
num_heads=num_head,
)
for _ in range(num_layer)
]
)
self.decoder_norm = nn.LayerNorm(emb_dim)
self.projection_norm = nn.LayerNorm(emb_dim)

self.projection = torch.nn.Linear(enc_dim, emb_dim)
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim))

self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size)))
self.num_patches = torch.as_tensor(num_patches)

if spatial_dims == 3:
self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
n_patch_z=num_patches[0],
n_patch_y=num_patches[1],
n_patch_x=num_patches[2],
patch_size_z=base_patch_size[0],
patch_size_y=base_patch_size[1],
patch_size_x=base_patch_size[2],
)
elif spatial_dims == 2:
self.patch2img = Rearrange(
"(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
n_patch_y=num_patches[0],
n_patch_x=num_patches[1],
patch_size_y=base_patch_size[0],
patch_size_x=base_patch_size[1],
)

self.init_weight()

def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes, patch_size):
T, B, C = features.shape
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now
features = self.projection_norm(self.projection(features))

# add cls token
backward_indexes = torch.cat(
[torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1],
dim=0,
)
forward_indexes = torch.cat(
[torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1],
dim=0,
)
# fill in masked regions
features = torch.cat(
[
features,
self.mask_token.expand(
backward_indexes.shape[0] - features.shape[0], features.shape[1], -1
),
],
dim=0,
)

# unshuffle to original positions for positional embedding so we can do cross attention during decoding
features = take_indexes(features, backward_indexes)
features = features + self.pos_embedding

reshuffled = take_indexes(features, forward_indexes)
features, masked = reshuffled[:T], reshuffled[T:]

masked = rearrange(masked, "t b c -> b t c")
features = rearrange(features, "t b c -> b t c")

for transformer in self.transformer:
masked = transformer(masked, features)

# noneed to remove cls token, it's a part of the features vector
masked = rearrange(masked, "b t c -> t b c")

# (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x)
masked = self.decoder_norm(masked)
patches = self.head(masked)

# add back in visible/encoded tokens that we don't calculate loss on
patches = torch.cat(
[torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches],
dim=0,
)
patches = take_indexes(patches, backward_indexes[1:] - 1)

mask = torch.zeros_like(patches)
mask[T - 1 :] = 1
mask = take_indexes(mask, backward_indexes[1:] - 1)
# patches to image
img = self.patch2img(patches)
img = torch.nn.functional.interpolate(
img, tuple(torch.as_tensor(patch_size) * self.num_patches)
)

mask = self.patch2img(mask)
mask = torch.nn.functional.interpolate(
mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest"
)
return img, mask
Loading
Loading