-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add crossmae * multich segmentation projection * add cross attention block * add crossmae decoder * add crossmae to init * remove patchify * use crossmae as default * remove note --------- Co-authored-by: Benjamin Morris <[email protected]>
- Loading branch information
1 parent
4860c8b
commit 7536cd8
Showing
7 changed files
with
394 additions
and
61 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .cross_mae import CrossMAE_Decoder | ||
from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT | ||
from .seg import Seg_ViT, SupperresDecoder | ||
from .seg import Seg_ViT, SuperresDecoder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .cross_attention import CrossAttention, CrossAttentionBlock, CrossSelfBlock, Mlp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from timm.layers import DropPath | ||
from timm.models.vision_transformer import Block | ||
|
||
# from https://github.com/TonyLianLong/CrossMAE/blob/main/transformer_utils.py | ||
|
||
|
||
class Mlp(nn.Module): | ||
def __init__( | ||
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 | ||
): | ||
super().__init__() | ||
out_features = out_features or in_features | ||
hidden_features = hidden_features or in_features | ||
self.fc1 = nn.Linear(in_features, hidden_features) | ||
self.act = act_layer() | ||
self.fc2 = nn.Linear(hidden_features, out_features) | ||
self.drop = nn.Dropout(drop) | ||
|
||
def forward(self, x): | ||
x = self.fc1(x) | ||
x = self.act(x) | ||
x = self.drop(x) | ||
x = self.fc2(x) | ||
x = self.drop(x) | ||
return x | ||
|
||
|
||
class CrossAttention(nn.Module): | ||
def __init__( | ||
self, | ||
encoder_dim, | ||
decoder_dim, | ||
num_heads=8, | ||
qkv_bias=False, | ||
qk_scale=None, | ||
attn_drop=0.0, | ||
proj_drop=0.0, | ||
): | ||
super().__init__() | ||
self.num_heads = num_heads | ||
head_dim = decoder_dim // num_heads | ||
self.scale = qk_scale or head_dim**-0.5 | ||
self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) | ||
self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) | ||
self.attn_drop = attn_drop | ||
self.proj = nn.Linear(decoder_dim, decoder_dim) | ||
self.proj_drop = nn.Dropout(proj_drop) | ||
|
||
def forward(self, x, y): | ||
"""query from decoder (x), key and value from encoder (y)""" | ||
B, N, C = x.shape | ||
Ny = y.shape[1] | ||
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) | ||
kv = ( | ||
self.kv(y) | ||
.reshape(B, Ny, 2, self.num_heads, C // self.num_heads) | ||
.permute(2, 0, 3, 1, 4) | ||
) | ||
k, v = kv[0], kv[1] | ||
|
||
attn = F.scaled_dot_product_attention( | ||
q, | ||
k, | ||
v, | ||
dropout_p=self.attn_drop, | ||
) | ||
x = attn.transpose(1, 2).reshape(B, N, C) | ||
|
||
x = self.proj(x) | ||
x = self.proj_drop(x) | ||
return x | ||
|
||
|
||
class CrossAttentionBlock(nn.Module): | ||
def __init__( | ||
self, | ||
encoder_dim, | ||
decoder_dim, | ||
num_heads, | ||
mlp_ratio=4.0, | ||
qkv_bias=False, | ||
qk_scale=None, | ||
drop=0.0, | ||
attn_drop=0.0, | ||
drop_path=0.0, | ||
act_layer=nn.GELU, | ||
norm_layer=nn.LayerNorm, | ||
): | ||
super().__init__() | ||
self.norm1 = norm_layer(decoder_dim) | ||
self.cross_attn = CrossAttention( | ||
encoder_dim, | ||
decoder_dim, | ||
num_heads=num_heads, | ||
qkv_bias=qkv_bias, | ||
qk_scale=qk_scale, | ||
attn_drop=attn_drop, | ||
proj_drop=drop, | ||
) | ||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here | ||
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | ||
self.norm2 = norm_layer(decoder_dim) | ||
mlp_hidden_dim = int(decoder_dim * mlp_ratio) | ||
self.mlp = Mlp( | ||
in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop | ||
) | ||
|
||
def forward(self, x, y): | ||
""" | ||
x: decoder feature; y: encoder feature (after layernorm) | ||
""" | ||
x = x + self.drop_path(self.cross_attn(self.norm1(x), y)) | ||
x = x + self.drop_path(self.mlp(self.norm2(x))) | ||
return x | ||
|
||
|
||
class CrossSelfBlock(nn.Module): | ||
def __init__( | ||
self, | ||
emb_dim, | ||
num_heads, | ||
mlp_ratio=4.0, | ||
qkv_bias=False, | ||
qk_scale=None, | ||
drop=0.0, | ||
attn_drop=0.0, | ||
drop_path=0.0, | ||
act_layer=nn.GELU, | ||
norm_layer=nn.LayerNorm, | ||
): | ||
super().__init__() | ||
self.x_attn_block = CrossAttentionBlock( | ||
emb_dim, | ||
emb_dim, | ||
num_heads, | ||
mlp_ratio, | ||
qkv_bias, | ||
qk_scale, | ||
drop, | ||
attn_drop, | ||
drop_path, | ||
act_layer, | ||
norm_layer, | ||
) | ||
self.self_attn_block = Block(dim=emb_dim, num_heads=num_heads) | ||
|
||
def forward(self, x, y): | ||
""" | ||
x: decoder feature; y: encoder feature | ||
""" | ||
x = self.x_attn_block(x, y) | ||
x = self.self_attn_block(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
from einops import rearrange, repeat | ||
from einops.layers.torch import Rearrange | ||
from timm.models.layers import trunc_normal_ | ||
|
||
from cyto_dl.nn.vits.blocks import CrossAttentionBlock | ||
|
||
|
||
def take_indexes(sequences, indexes): | ||
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1])) | ||
|
||
|
||
class CrossMAE_Decoder(torch.nn.Module): | ||
"""Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masekd tokens only attend | ||
to visible tokens.""" | ||
|
||
def __init__( | ||
self, | ||
num_patches: List[int], | ||
spatial_dims: int = 3, | ||
base_patch_size: Optional[List[int]] = [4, 8, 8], | ||
enc_dim: Optional[int] = 768, | ||
emb_dim: Optional[int] = 192, | ||
num_layer: Optional[int] = 4, | ||
num_head: Optional[int] = 3, | ||
) -> None: | ||
""" | ||
Parameters | ||
---------- | ||
num_patches: List[int] | ||
Number of patches in each dimension | ||
base_patch_size: Tuple[int] | ||
Size of each patch | ||
enc_dim: int | ||
Dimension of encoder | ||
emb_dim: int | ||
Dimension of embedding | ||
num_layer: int | ||
Number of transformer layers | ||
num_head: int | ||
Number of heads in transformer | ||
""" | ||
super().__init__() | ||
|
||
self.transformer = torch.nn.ParameterList( | ||
[ | ||
CrossAttentionBlock( | ||
encoder_dim=emb_dim, | ||
decoder_dim=emb_dim, | ||
num_heads=num_head, | ||
) | ||
for _ in range(num_layer) | ||
] | ||
) | ||
self.decoder_norm = nn.LayerNorm(emb_dim) | ||
self.projection_norm = nn.LayerNorm(emb_dim) | ||
|
||
self.projection = torch.nn.Linear(enc_dim, emb_dim) | ||
self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) | ||
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim)) | ||
|
||
self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) | ||
self.num_patches = torch.as_tensor(num_patches) | ||
|
||
if spatial_dims == 3: | ||
self.patch2img = Rearrange( | ||
"(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", | ||
n_patch_z=num_patches[0], | ||
n_patch_y=num_patches[1], | ||
n_patch_x=num_patches[2], | ||
patch_size_z=base_patch_size[0], | ||
patch_size_y=base_patch_size[1], | ||
patch_size_x=base_patch_size[2], | ||
) | ||
elif spatial_dims == 2: | ||
self.patch2img = Rearrange( | ||
"(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", | ||
n_patch_y=num_patches[0], | ||
n_patch_x=num_patches[1], | ||
patch_size_y=base_patch_size[0], | ||
patch_size_x=base_patch_size[1], | ||
) | ||
|
||
self.init_weight() | ||
|
||
def init_weight(self): | ||
trunc_normal_(self.mask_token, std=0.02) | ||
trunc_normal_(self.pos_embedding, std=0.02) | ||
|
||
def forward(self, features, forward_indexes, backward_indexes, patch_size): | ||
T, B, C = features.shape | ||
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now | ||
features = self.projection_norm(self.projection(features)) | ||
|
||
# add cls token | ||
backward_indexes = torch.cat( | ||
[torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], | ||
dim=0, | ||
) | ||
forward_indexes = torch.cat( | ||
[torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1], | ||
dim=0, | ||
) | ||
# fill in masked regions | ||
features = torch.cat( | ||
[ | ||
features, | ||
self.mask_token.expand( | ||
backward_indexes.shape[0] - features.shape[0], features.shape[1], -1 | ||
), | ||
], | ||
dim=0, | ||
) | ||
|
||
# unshuffle to original positions for positional embedding so we can do cross attention during decoding | ||
features = take_indexes(features, backward_indexes) | ||
features = features + self.pos_embedding | ||
|
||
reshuffled = take_indexes(features, forward_indexes) | ||
features, masked = reshuffled[:T], reshuffled[T:] | ||
|
||
masked = rearrange(masked, "t b c -> b t c") | ||
features = rearrange(features, "t b c -> b t c") | ||
|
||
for transformer in self.transformer: | ||
masked = transformer(masked, features) | ||
|
||
# noneed to remove cls token, it's a part of the features vector | ||
masked = rearrange(masked, "b t c -> t b c") | ||
|
||
# (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) | ||
masked = self.decoder_norm(masked) | ||
patches = self.head(masked) | ||
|
||
# add back in visible/encoded tokens that we don't calculate loss on | ||
patches = torch.cat( | ||
[torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches], | ||
dim=0, | ||
) | ||
patches = take_indexes(patches, backward_indexes[1:] - 1) | ||
|
||
mask = torch.zeros_like(patches) | ||
mask[T - 1 :] = 1 | ||
mask = take_indexes(mask, backward_indexes[1:] - 1) | ||
# patches to image | ||
img = self.patch2img(patches) | ||
img = torch.nn.functional.interpolate( | ||
img, tuple(torch.as_tensor(patch_size) * self.num_patches) | ||
) | ||
|
||
mask = self.patch2img(mask) | ||
mask = torch.nn.functional.interpolate( | ||
mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest" | ||
) | ||
return img, mask |
Oops, something went wrong.