diff --git a/configs/model/im2im/mae.yaml b/configs/model/im2im/mae.yaml index f03ce6742..ab1fcac39 100644 --- a/configs/model/im2im/mae.yaml +++ b/configs/model/im2im/mae.yaml @@ -15,8 +15,7 @@ backbone: encoder_layer: 2 encoder_head: 1 decoder_layer: 1 - mask_ratio: 0.75 - + use_crossmae: True task_heads: ${kv_to_dict:${model._aux._tasks}} optimizer: diff --git a/cyto_dl/nn/vits/__init__.py b/cyto_dl/nn/vits/__init__.py index 37eb4dcc8..1b0690a2d 100644 --- a/cyto_dl/nn/vits/__init__.py +++ b/cyto_dl/nn/vits/__init__.py @@ -1,2 +1,3 @@ +from .cross_mae import CrossMAE_Decoder from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT -from .seg import Seg_ViT, SupperresDecoder +from .seg import Seg_ViT, SuperresDecoder diff --git a/cyto_dl/nn/vits/blocks/__init__.py b/cyto_dl/nn/vits/blocks/__init__.py new file mode 100644 index 000000000..632981b50 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/__init__.py @@ -0,0 +1 @@ +from .cross_attention import CrossAttention, CrossAttentionBlock, CrossSelfBlock, Mlp diff --git a/cyto_dl/nn/vits/blocks/cross_attention.py b/cyto_dl/nn/vits/blocks/cross_attention.py new file mode 100644 index 000000000..a9dc47d38 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/cross_attention.py @@ -0,0 +1,155 @@ +import torch.nn as nn +import torch.nn.functional as F +from timm.layers import DropPath +from timm.models.vision_transformer import Block + +# from https://github.com/TonyLianLong/CrossMAE/blob/main/transformer_utils.py + + +class Mlp(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + encoder_dim, + decoder_dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = decoder_dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) + self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(decoder_dim, decoder_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y): + """query from decoder (x), key and value from encoder (y)""" + B, N, C = x.shape + Ny = y.shape[1] + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = ( + self.kv(y) + .reshape(B, Ny, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop, + ) + x = attn.transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class CrossAttentionBlock(nn.Module): + def __init__( + self, + encoder_dim, + decoder_dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(decoder_dim) + self.cross_attn = CrossAttention( + encoder_dim, + decoder_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(decoder_dim) + mlp_hidden_dim = int(decoder_dim * mlp_ratio) + self.mlp = Mlp( + in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + def forward(self, x, y): + """ + x: decoder feature; y: encoder feature (after layernorm) + """ + x = x + self.drop_path(self.cross_attn(self.norm1(x), y)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class CrossSelfBlock(nn.Module): + def __init__( + self, + emb_dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.x_attn_block = CrossAttentionBlock( + emb_dim, + emb_dim, + num_heads, + mlp_ratio, + qkv_bias, + qk_scale, + drop, + attn_drop, + drop_path, + act_layer, + norm_layer, + ) + self.self_attn_block = Block(dim=emb_dim, num_heads=num_heads) + + def forward(self, x, y): + """ + x: decoder feature; y: encoder feature + """ + x = self.x_attn_block(x, y) + x = self.self_attn_block(x) + return x diff --git a/cyto_dl/nn/vits/cross_mae.py b/cyto_dl/nn/vits/cross_mae.py new file mode 100644 index 000000000..05f7e744a --- /dev/null +++ b/cyto_dl/nn/vits/cross_mae.py @@ -0,0 +1,159 @@ +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_ + +from cyto_dl.nn.vits.blocks import CrossAttentionBlock + + +def take_indexes(sequences, indexes): + return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1])) + + +class CrossMAE_Decoder(torch.nn.Module): + """Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masekd tokens only attend + to visible tokens.""" + + def __init__( + self, + num_patches: List[int], + spatial_dims: int = 3, + base_patch_size: Optional[List[int]] = [4, 8, 8], + enc_dim: Optional[int] = 768, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 4, + num_head: Optional[int] = 3, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int] + Number of patches in each dimension + base_patch_size: Tuple[int] + Size of each patch + enc_dim: int + Dimension of encoder + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + """ + super().__init__() + + self.transformer = torch.nn.ParameterList( + [ + CrossAttentionBlock( + encoder_dim=emb_dim, + decoder_dim=emb_dim, + num_heads=num_head, + ) + for _ in range(num_layer) + ] + ) + self.decoder_norm = nn.LayerNorm(emb_dim) + self.projection_norm = nn.LayerNorm(emb_dim) + + self.projection = torch.nn.Linear(enc_dim, emb_dim) + self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) + self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim)) + + self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) + self.num_patches = torch.as_tensor(num_patches) + + if spatial_dims == 3: + self.patch2img = Rearrange( + "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_z=num_patches[0], + n_patch_y=num_patches[1], + n_patch_x=num_patches[2], + patch_size_z=base_patch_size[0], + patch_size_y=base_patch_size[1], + patch_size_x=base_patch_size[2], + ) + elif spatial_dims == 2: + self.patch2img = Rearrange( + "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_y=num_patches[0], + n_patch_x=num_patches[1], + patch_size_y=base_patch_size[0], + patch_size_x=base_patch_size[1], + ) + + self.init_weight() + + def init_weight(self): + trunc_normal_(self.mask_token, std=0.02) + trunc_normal_(self.pos_embedding, std=0.02) + + def forward(self, features, forward_indexes, backward_indexes, patch_size): + T, B, C = features.shape + # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now + features = self.projection_norm(self.projection(features)) + + # add cls token + backward_indexes = torch.cat( + [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], + dim=0, + ) + forward_indexes = torch.cat( + [torch.zeros(1, forward_indexes.shape[1]).to(forward_indexes), forward_indexes + 1], + dim=0, + ) + # fill in masked regions + features = torch.cat( + [ + features, + self.mask_token.expand( + backward_indexes.shape[0] - features.shape[0], features.shape[1], -1 + ), + ], + dim=0, + ) + + # unshuffle to original positions for positional embedding so we can do cross attention during decoding + features = take_indexes(features, backward_indexes) + features = features + self.pos_embedding + + reshuffled = take_indexes(features, forward_indexes) + features, masked = reshuffled[:T], reshuffled[T:] + + masked = rearrange(masked, "t b c -> b t c") + features = rearrange(features, "t b c -> b t c") + + for transformer in self.transformer: + masked = transformer(masked, features) + + # noneed to remove cls token, it's a part of the features vector + masked = rearrange(masked, "b t c -> t b c") + + # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) + masked = self.decoder_norm(masked) + patches = self.head(masked) + + # add back in visible/encoded tokens that we don't calculate loss on + patches = torch.cat( + [torch.zeros((T - 1, B, patches.shape[-1]), requires_grad=False).to(patches), patches], + dim=0, + ) + patches = take_indexes(patches, backward_indexes[1:] - 1) + + mask = torch.zeros_like(patches) + mask[T - 1 :] = 1 + mask = take_indexes(mask, backward_indexes[1:] - 1) + # patches to image + img = self.patch2img(patches) + img = torch.nn.functional.interpolate( + img, tuple(torch.as_tensor(patch_size) * self.num_patches) + ) + + mask = self.patch2img(mask) + mask = torch.nn.functional.interpolate( + mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest" + ) + return img, mask diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index d565a6da3..95e4dff7b 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -4,11 +4,16 @@ import numpy as np import torch +import torch.nn as nn from einops import rearrange, repeat from einops.layers.torch import Rearrange +from monai.networks.nets import Regressor from timm.models.layers import trunc_normal_ from timm.models.vision_transformer import Block +from cyto_dl.nn.vits.blocks.attention_autoencoder import AttentionAutoencoder +from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder + def random_indexes(size: int): forward_indexes = np.arange(size) @@ -21,26 +26,23 @@ def take_indexes(sequences, indexes): return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1])) -class PatchShuffle(torch.nn.Module): - def __init__(self, ratio) -> None: - super().__init__() - self.ratio = ratio - - def forward(self, patches: torch.Tensor): - T, B, C = patches.shape - remain_T = int(T * (1 - self.ratio)) +def patch_shuffle(patches: torch.Tensor, ratio): + T, B, C = patches.shape + remain_T = int(T * (1 - ratio)) - indexes = [random_indexes(T) for _ in range(B)] - forward_indexes = torch.as_tensor( - np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long - ).to(patches.device) - backward_indexes = torch.as_tensor( - np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long - ).to(patches.device) - patches = take_indexes(patches, forward_indexes) - patches = patches[:remain_T] + indexes = [random_indexes(T) for _ in range(B)] + forward_indexes = torch.as_tensor( + np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long + ).to(patches.device) + backward_indexes = torch.as_tensor( + np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long + ).to(patches.device) + # forward indexes : index in image -> shuffledpatch + # backward indexes : shuffled patch -> index in image + patches = take_indexes(patches, forward_indexes) + patches = patches[:remain_T] - return patches, forward_indexes, backward_indexes + return patches, forward_indexes, backward_indexes class Patchify(torch.nn.Module): @@ -59,6 +61,11 @@ def __init__(self, base_patch_size, emb_dim, n_patches, spatial_dims=3): self.spatial_dims = spatial_dims self.conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d + if spatial_dims == 3: + self.img2token = Rearrange("b c z y x -> (z y x) b c") + elif spatial_dims == 2: + self.img2token = Rearrange("b c y x -> (y x) b c") + def resample_weight(self, length): return torch.nn.functional.interpolate(self.weight, size=length) @@ -69,6 +76,7 @@ def forward(self, img): tokens = self.conv(img, weight=self.resample_weight(patch_size), stride=patch_size) tokens = self.norm(tokens) assert np.all(tokens.shape[-self.spatial_dims :] == self.n_patches) + tokens = self.img2token(tokens) return tokens, patch_size @@ -81,7 +89,6 @@ def __init__( emb_dim: Optional[int] = 192, num_layer: Optional[int] = 12, num_head: Optional[int] = 3, - mask_ratio: Optional[int] = 0.75, ) -> None: """ Parameters @@ -98,13 +105,10 @@ def __init__( Number of transformer layers num_head: int Number of heads in transformer - mask_ratio: float - Ratio of patches to mask out """ super().__init__() self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches), 1, emb_dim)) - self.shuffle = PatchShuffle(mask_ratio) self.patchify = Patchify(base_patch_size, emb_dim, num_patches, spatial_dims) self.transformer = torch.nn.Sequential( @@ -112,32 +116,26 @@ def __init__( ) self.layer_norm = torch.nn.LayerNorm(emb_dim) - if spatial_dims == 3: - self.img2token = Rearrange("b c z y x -> (z y x) b c") - elif spatial_dims == 2: - self.img2token = Rearrange("b c y x -> (y x) b c") - self.init_weight() def init_weight(self): trunc_normal_(self.cls_token, std=0.02) trunc_normal_(self.pos_embedding, std=0.02) - def forward(self, img, do_mask=True): + def forward(self, img, mask_ratio=0.75): patches, patch_size = self.patchify(img) - patches = self.img2token(patches) patches = patches + self.pos_embedding backward_indexes = None - if do_mask: - patches, _, backward_indexes = self.shuffle(patches) + if mask_ratio > 0: + patches, forward_indexes, backward_indexes = patch_shuffle(patches, mask_ratio) patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) patches = rearrange(patches, "t b c -> b t c") features = self.layer_norm(self.transformer(patches)) features = rearrange(features, "b t c -> t b c") - if do_mask: - return features, backward_indexes, patch_size + if mask_ratio > 0: + return features, forward_indexes, backward_indexes, patch_size return features @@ -147,6 +145,7 @@ def __init__( num_patches: List[int], spatial_dims: int = 3, base_patch_size: Optional[List[int]] = [4, 8, 8], + enc_dim: Optional[int] = 768, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, num_head: Optional[int] = 3, @@ -166,14 +165,17 @@ def __init__( Number of heads in transformer """ super().__init__() + self.projection_norm = nn.LayerNorm(emb_dim) + self.projection = torch.nn.Linear(enc_dim, emb_dim) self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim)) self.transformer = torch.nn.Sequential( *[Block(emb_dim, num_head) for _ in range(num_layer)] ) - - self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) + out_dim = torch.prod(torch.as_tensor(base_patch_size)).item() + self.head_norm = nn.LayerNorm(out_dim) + self.head = torch.nn.Linear(emb_dim, out_dim) self.num_patches = torch.as_tensor(num_patches) if spatial_dims == 3: @@ -201,8 +203,11 @@ def init_weight(self): trunc_normal_(self.mask_token, std=0.02) trunc_normal_(self.pos_embedding, std=0.02) - def forward(self, features, backward_indexes, patch_size): + def forward(self, features, forward_indexes, backward_indexes, patch_size): T = features.shape[0] + # project from encoder dimension to decoder dimension + features = self.projection_norm(self.projection(features)) + backward_indexes = torch.cat( [torch.zeros(1, backward_indexes.shape[1]).to(backward_indexes), backward_indexes + 1], dim=0, @@ -228,7 +233,7 @@ def forward(self, features, backward_indexes, patch_size): features = features[1:] # remove global feature # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) - patches = self.head(features) + patches = self.head_norm(self.head(features)) mask = torch.zeros_like(patches) mask[T:] = 1 mask = take_indexes(mask, backward_indexes[1:] - 1) @@ -256,7 +261,9 @@ def __init__( encoder_head: Optional[int] = 8, decoder_layer: Optional[int] = 4, decoder_head: Optional[int] = 8, + decoder_dim: Optional[int] = 192, mask_ratio: Optional[int] = 0.75, + use_crossmae: Optional[bool] = False, ) -> None: """ Parameters @@ -293,6 +300,8 @@ def __init__( len(base_patch_size) == spatial_dims ), "base_patch_size must be of length spatial_dims" + self.mask_ratio = mask_ratio + self.encoder = MAE_Encoder( num_patches, spatial_dims, @@ -300,13 +309,24 @@ def __init__( emb_dim, encoder_layer, encoder_head, - mask_ratio, ) - self.decoder = MAE_Decoder( - num_patches, spatial_dims, base_patch_size, emb_dim, decoder_layer, decoder_head + + decoder_class = MAE_Decoder + if use_crossmae: + decoder_class = CrossMAE_Decoder + self.decoder = decoder_class( + num_patches=num_patches, + spatial_dims=spatial_dims, + base_patch_size=base_patch_size, + enc_dim=emb_dim, + emb_dim=decoder_dim, + num_layer=decoder_layer, + num_head=decoder_head, ) def forward(self, img): - features, backward_indexes, patch_size = self.encoder(img) - predicted_img, mask = self.decoder(features, backward_indexes, patch_size) + features, forward_indexes, backward_indexes, patch_size = self.encoder( + img, self.mask_ratio + ) + predicted_img, mask = self.decoder(features, forward_indexes, backward_indexes, patch_size) return predicted_img, mask diff --git a/cyto_dl/nn/vits/seg.py b/cyto_dl/nn/vits/seg.py index c3942d2d5..fc6b2ea28 100644 --- a/cyto_dl/nn/vits/seg.py +++ b/cyto_dl/nn/vits/seg.py @@ -1,5 +1,6 @@ from typing import List, Optional, Union +import numpy as np import torch from einops.layers.torch import Rearrange from monai.networks.blocks import UnetOutBlock, UnetResBlock, UpSample @@ -7,7 +8,7 @@ from cyto_dl.nn.vits.mae import MAE_Encoder -class SupperresDecoder(torch.nn.Module): +class SuperresDecoder(torch.nn.Module): def __init__( self, spatial_dims: int = 3, @@ -42,15 +43,11 @@ def __init__( super().__init__() self.lr_conv = [] - for i in range(num_layer): - if i == 0: - num_channels = 1 - else: - num_channels = 16 + for _ in range(num_layer): self.lr_conv.append( UnetResBlock( spatial_dims=spatial_dims, - in_channels=num_channels, + in_channels=n_decoder_filters, out_channels=n_decoder_filters, stride=1, kernel_size=3, @@ -66,7 +63,7 @@ def __init__( spatial_dims=spatial_dims, in_channels=n_decoder_filters, out_channels=n_decoder_filters, - scale_factor=upsample_factor, + scale_factor=np.array(upsample_factor), mode="nontrainable", interp_mode="trilinear", ), @@ -87,9 +84,10 @@ def __init__( ), ) - self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) + self.head = torch.nn.Linear( + emb_dim, torch.prod(torch.as_tensor(base_patch_size)) * n_decoder_filters + ) self.num_patches = torch.as_tensor(num_patches) - if spatial_dims == 3: self.patch2img = Rearrange( "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", @@ -99,6 +97,7 @@ def __init__( patch_size_z=base_patch_size[0], patch_size_y=base_patch_size[1], patch_size_x=base_patch_size[2], + c=n_decoder_filters, ) elif spatial_dims == 2: self.patch2img = Rearrange( @@ -107,13 +106,14 @@ def __init__( n_patch_x=num_patches[1], patch_size_y=base_patch_size[0], patch_size_x=base_patch_size[1], + c=n_decoder_filters, ) def forward(self, features): # remove global feature features = features[1:] - # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) + # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (c z y x) patches = self.head(features) # patches to image @@ -138,7 +138,6 @@ def __init__( decoder_layer: Optional[int] = 3, n_decoder_filters: Optional[int] = 16, out_channels: Optional[int] = 6, - mask_ratio: Optional[int] = 0.75, upsample_factor: Optional[List[int]] = [2.6134, 2.5005, 2.5005], encoder_ckpt: Optional[str] = None, freeze_encoder: Optional[bool] = True, @@ -190,22 +189,21 @@ def __init__( emb_dim=emb_dim, num_layer=encoder_layer, num_head=encoder_head, - mask_ratio=mask_ratio, ) if encoder_ckpt is not None: model = torch.load(encoder_ckpt) enc_state_dict = { k.replace("backbone.encoder.", ""): v + # k.replace("model.encoder.", ""): v for k, v in model["state_dict"].items() if "encoder" in k } - self.encoder.load_state_dict(enc_state_dict) if freeze_encoder: for param in self.encoder.parameters(): param.requires_grad = False - self.decoder = SupperresDecoder( + self.decoder = SuperresDecoder( spatial_dims, num_patches, base_patch_size, @@ -217,5 +215,5 @@ def __init__( ) def forward(self, img): - features = self.encoder(img, do_mask=False) + features = self.encoder(img, mask_ratio=0) return self.decoder(features)