Skip to content

Commit

Permalink
simplify patchify (#344)
Browse files Browse the repository at this point in the history
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Mar 6, 2024
1 parent 0c2b702 commit 1c11e5b
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 101 deletions.
21 changes: 6 additions & 15 deletions cyto_dl/nn/vits/cross_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@


def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))
return torch.gather(
sequences, 0, repeat(indexes.to(sequences.device), "t b -> t b c", c=sequences.shape[-1])
)


class CrossMAE_Decoder(torch.nn.Module):
"""Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masekd tokens only attend
"""Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend
to visible tokens."""

def __init__(
Expand Down Expand Up @@ -91,7 +93,7 @@ def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes, patch_size):
def forward(self, features, forward_indexes, backward_indexes):
T, B, C = features.shape
# we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now
features = self.projection_norm(self.projection(features))
Expand Down Expand Up @@ -142,18 +144,7 @@ def forward(self, features, forward_indexes, backward_indexes, patch_size):
dim=0,
)
patches = take_indexes(patches, backward_indexes[1:] - 1)

mask = torch.zeros_like(patches)
mask[T - 1 :] = 1
mask = take_indexes(mask, backward_indexes[1:] - 1)
# patches to image
img = self.patch2img(patches)
img = torch.nn.functional.interpolate(
img, tuple(torch.as_tensor(patch_size) * self.num_patches)
)

mask = self.patch2img(mask)
mask = torch.nn.functional.interpolate(
mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest"
)
return img, mask
return img
204 changes: 134 additions & 70 deletions cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from monai.networks.nets import Regressor
from timm.models.layers import trunc_normal_
from timm.models.vision_transformer import Block

Expand All @@ -22,61 +21,128 @@ def random_indexes(size: int):


def take_indexes(sequences, indexes):
return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1]))


def patch_shuffle(patches: torch.Tensor, ratio):
T, B, C = patches.shape
remain_T = int(T * (1 - ratio))

indexes = [random_indexes(T) for _ in range(B)]
forward_indexes = torch.as_tensor(
np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long
).to(patches.device)
backward_indexes = torch.as_tensor(
np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long
).to(patches.device)
# forward indexes : index in image -> shuffledpatch
# backward indexes : shuffled patch -> index in image
patches = take_indexes(patches, forward_indexes)
patches = patches[:remain_T]

return patches, forward_indexes, backward_indexes
return torch.gather(
sequences, 0, repeat(indexes.to(sequences.device), "t b -> t b c", c=sequences.shape[-1])
)


class Patchify(torch.nn.Module):
# based on https://github.com/google-research/big_vision/blob/main/big_vision/models/proj/flexi/vit.py
"""Class for flexibly turning image into sequence of patches.
"""Class for converting images to a masked sequence of patches with positional embeddings."""

Convolutional weights are resized to match the `base_patch_size`.
"""

def __init__(self, base_patch_size, emb_dim, n_patches, spatial_dims=3):
def __init__(
self,
patch_size: List[int],
emb_dim: int,
n_patches: List[int],
spatial_dims: int = 3,
context_pixels: List[int] = [0, 0, 0],
input_channels: int = 1,
):
"""
Parameters
----------
patch_size: List[int]
Size of each patch
emb_dim: int
Dimension of encoder
n_patches: List[int]
Number of patches in each spatial dimension
spatial_dims: int
Number of spatial dimensions
context_pixels: List[int]
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.
input_channels: int
Number of input channels
"""
super().__init__()
self.n_patches = np.asarray(n_patches)
self.weight = torch.nn.Parameter(torch.zeros(emb_dim, 1, *base_patch_size))
self.norm = torch.nn.LayerNorm([emb_dim, *n_patches[:spatial_dims]])
self.emb_dim = emb_dim
self.spatial_dims = spatial_dims
self.conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d

self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(n_patches), 1, emb_dim))

context_pixels = context_pixels[:spatial_dims]
weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int)

if spatial_dims == 3:
self.conv = nn.Conv3d(
in_channels=input_channels,
out_channels=emb_dim,
kernel_size=weight_size,
stride=patch_size,
padding=context_pixels,
)
self.img2token = Rearrange("b c z y x -> (z y x) b c")
self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x",
n_patch_z=n_patches[0],
n_patch_y=n_patches[1],
n_patch_x=n_patches[2],
)
elif spatial_dims == 2:
self.conv = nn.Conv2d(
in_channels=input_channels,
out_channels=emb_dim,
kernel_size=weight_size,
stride=patch_size,
padding=context_pixels,
)
self.img2token = Rearrange("b c y x -> (y x) b c")
self.patch2img = Rearrange(
"(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x",
n_patch_y=n_patches[0],
n_patch_x=n_patches[1],
)

def resample_weight(self, length):
return torch.nn.functional.interpolate(self.weight, size=length)
self._init_weight()

def forward(self, img):
patch_size = (
(np.asarray(img.shape[-self.spatial_dims :]) / self.n_patches).astype(int).tolist()
def _init_weight(self):
trunc_normal_(self.pos_embedding, std=0.02)

def get_mask(self, img, n_visible_patches, num_patches):
B = img.shape[0]

indexes = [random_indexes(num_patches) for _ in range(B)]
# forward indexes : index in image -> shuffledpatch
forward_indexes = torch.as_tensor(
np.stack([i[0] for i in indexes], axis=-1), dtype=torch.long
)
# backward indexes : shuffled patch -> index in image
backward_indexes = torch.as_tensor(
np.stack([i[1] for i in indexes], axis=-1), dtype=torch.long
)

mask = torch.zeros(num_patches, B, 1)
# visible patches are first
mask[:n_visible_patches] = 1
mask = take_indexes(mask, backward_indexes)
mask = self.patch2img(mask)
# one pixel per masked patch, interpolate to size of input image
mask = torch.nn.functional.interpolate(
mask, img.shape[-self.spatial_dims :], mode="nearest"
)
tokens = self.conv(img, weight=self.resample_weight(patch_size), stride=patch_size)
tokens = self.norm(tokens)
assert np.all(tokens.shape[-self.spatial_dims :] == self.n_patches)

return mask.to(img), forward_indexes, backward_indexes

def forward(self, img, mask_ratio):
# generate mask
num_patches = np.prod(self.n_patches)
n_visible_patches = int(num_patches * (1 - mask_ratio))
mask, forward_indexes, backward_indexes = self.get_mask(
img, n_visible_patches, num_patches
)
# generate patches
tokens = self.conv(img * mask)
tokens = self.img2token(tokens)
return tokens, patch_size
# add position embedding
tokens = tokens + self.pos_embedding
if mask_ratio > 0:
# extract visible patches
tokens = take_indexes(tokens, forward_indexes)[:n_visible_patches]

# mask is used above to mask out patches, we need to invert it for loss calculation
mask = (1 - mask).bool()

return tokens, mask, forward_indexes, backward_indexes


class MAE_Encoder(torch.nn.Module):
Expand All @@ -88,6 +154,8 @@ def __init__(
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 12,
num_head: Optional[int] = 3,
context_pixels: Optional[List[int]] = [0, 0, 0],
input_channels: Optional[int] = 1,
) -> None:
"""
Parameters
Expand All @@ -104,11 +172,16 @@ def __init__(
Number of transformer layers
num_head: int
Number of heads in transformer
context_pixels: List[int]
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.
input_channels: int
Number of input channels
"""
super().__init__()
self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches), 1, emb_dim))
self.patchify = Patchify(base_patch_size, emb_dim, num_patches, spatial_dims)
self.patchify = Patchify(
base_patch_size, emb_dim, num_patches, spatial_dims, context_pixels, input_channels
)

self.transformer = torch.nn.Sequential(
*[Block(emb_dim, num_head) for _ in range(num_layer)]
Expand All @@ -119,22 +192,15 @@ def __init__(

def init_weight(self):
trunc_normal_(self.cls_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, img, mask_ratio=0.75):
patches, patch_size = self.patchify(img)
patches = patches + self.pos_embedding

backward_indexes = None
if mask_ratio > 0:
patches, forward_indexes, backward_indexes = patch_shuffle(patches, mask_ratio)

patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio)
patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0)
patches = rearrange(patches, "t b c -> b t c")
features = self.layer_norm(self.transformer(patches))
features = rearrange(features, "b t c -> t b c")
if mask_ratio > 0:
return features, forward_indexes, backward_indexes, patch_size
return features, mask, forward_indexes, backward_indexes
return features


Expand All @@ -156,8 +222,10 @@ def __init__(
Number of patches in each dimension
base_patch_size: Tuple[int]
Size of each patch
enc_dim: int
Dimension of encoder
emb_dim: int
Dimension of embedding
Dimension of decoder
num_layer: int
Number of transformer layers
num_head: int
Expand Down Expand Up @@ -202,8 +270,7 @@ def init_weight(self):
trunc_normal_(self.mask_token, std=0.02)
trunc_normal_(self.pos_embedding, std=0.02)

def forward(self, features, forward_indexes, backward_indexes, patch_size):
T = features.shape[0]
def forward(self, features, forward_indexes, backward_indexes):
# project from encoder dimension to decoder dimension
features = self.projection_norm(self.projection(features))

Expand Down Expand Up @@ -233,20 +300,10 @@ def forward(self, features, forward_indexes, backward_indexes, patch_size):

# (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x)
patches = self.head_norm(self.head(features))
mask = torch.zeros_like(patches)
mask[T:] = 1
mask = take_indexes(mask, backward_indexes[1:] - 1)

# patches to image
img = self.patch2img(patches)
img = torch.nn.functional.interpolate(
img, tuple(torch.as_tensor(patch_size) * self.num_patches)
)

mask = self.patch2img(mask)
mask = torch.nn.functional.interpolate(
mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest"
)
return img, mask
return img


class MAE_ViT(torch.nn.Module):
Expand All @@ -263,6 +320,8 @@ def __init__(
decoder_dim: Optional[int] = 192,
mask_ratio: Optional[int] = 0.75,
use_crossmae: Optional[bool] = False,
context_pixels: Optional[List[int]] = [0, 0, 0],
input_channels: Optional[int] = 1,
) -> None:
"""
Parameters
Expand All @@ -285,6 +344,11 @@ def __init__(
Number of decoder heads
mask_ratio: float
Ratio of patches to mask out
use_crossmae: bool
Use CrossMAE-style decoder
context_pixels: List[int]
Number of extra pixels around each patch to include in convolutional embedding to encoder dimension.
input_channels: int
"""
super().__init__()
assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3"
Expand All @@ -308,6 +372,8 @@ def __init__(
emb_dim,
encoder_layer,
encoder_head,
context_pixels,
input_channels,
)

decoder_class = MAE_Decoder
Expand All @@ -324,8 +390,6 @@ def __init__(
)

def forward(self, img):
features, forward_indexes, backward_indexes, patch_size = self.encoder(
img, self.mask_ratio
)
predicted_img, mask = self.decoder(features, forward_indexes, backward_indexes, patch_size)
features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio)
predicted_img = self.decoder(features, forward_indexes, backward_indexes)
return predicted_img, mask
Loading

0 comments on commit 1c11e5b

Please sign in to comment.