From 6ed231d26ef9e92a8e184ea7bcd4263c111fec28 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Thu, 2 May 2024 10:02:27 -0700 Subject: [PATCH 01/27] =?UTF-8?q?Bump=20version:=200.1.5=20=E2=86=92=200.1?= =?UTF-8?q?.6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- cyto_dl/__init__.py | 2 +- pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 00d5d5893..fc8cae8e7 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.1.5 +current_version = 0.1.6 tag = True commit = True diff --git a/cyto_dl/__init__.py b/cyto_dl/__init__.py index fcec3b830..4bac3e158 100644 --- a/cyto_dl/__init__.py +++ b/cyto_dl/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.1.5" +__version__ = "0.1.6" # silence bio packages warnings diff --git a/pyproject.toml b/pyproject.toml index 8862e84d0..7456c7177 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "pdm.pep517.api" [project] name = "cyto-dl" -version = "0.1.5" +version = "0.1.6" description = """\ Collection of representation learning models, techniques, callbacks, utils, \ used to create latent variable models of cell shape, morphology and \ From 2c0c275faba09033ffd02d663732d8fe7d3142f0 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 20 May 2024 09:55:51 -0700 Subject: [PATCH 02/27] add hiera --- .../nn/vits/blocks/masked_unit_attention.py | 109 ++++++ cyto_dl/nn/vits/blocks/patchify_hiera.py | 109 ++++++ cyto_dl/nn/vits/hiera_mae.py | 329 ++++++++++++++++++ 3 files changed, 547 insertions(+) create mode 100644 cyto_dl/nn/vits/blocks/masked_unit_attention.py create mode 100644 cyto_dl/nn/vits/blocks/patchify_hiera.py create mode 100644 cyto_dl/nn/vits/hiera_mae.py diff --git a/cyto_dl/nn/vits/blocks/masked_unit_attention.py b/cyto_dl/nn/vits/blocks/masked_unit_attention.py new file mode 100644 index 000000000..768cc7bbd --- /dev/null +++ b/cyto_dl/nn/vits/blocks/masked_unit_attention.py @@ -0,0 +1,109 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import reduce, rearrange + +from timm.models.layers import DropPath, Mlp +from typing import List +from einops.layers.torch import Reduce + + +class MaskUnitAttention(torch.nn.Module): + def __init__( + self, + dim, + dim_out, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + q_stride= [1,1,1], + patches_per_mask_unit=[2,12,12], + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim_out // num_heads + self.scale = qk_scale or self.head_dim**-0.5 + self.qkv = nn.Linear(dim, dim_out*3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim_out, dim_out) + self.proj_drop = nn.Dropout(proj_drop) + self.dim_out= dim_out + self.q_stride=np.array(q_stride) + self.pooled_patches_per_mask_unit = (np.array(patches_per_mask_unit)/self.q_stride).astype(int) + + def forward(self, x): + # project and split into q,k,v embeddings + qkv = rearrange(self.qkv(x), 'batch num_mask_units tokens_per_mask_unit (head_dim num_heads qkv) -> qkv batch num_mask_units num_heads tokens_per_mask_unit head_dim', head_dim = self.head_dim, qkv=3, num_heads =self.num_heads) + + q, k, v = qkv[0], qkv[1], qkv[2] + + if np.any(self.q_stride>1): + # within a mask unit, tokens are spatially ordered + # perform spatial 2x2x2 max pooling over tokens + q = reduce(q, 'b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c', reduction='max', q_stride_z=self.q_stride[0], q_stride_y = self.q_stride[1], q_stride_x = self.q_stride[2] ,n_patches_z = self.pooled_patches_per_mask_unit[0], n_patches_y= self.pooled_patches_per_mask_unit[1], n_patches_x=self.pooled_patches_per_mask_unit[2]) + + attn = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop, + ) + # combine heads into single channel dimension + x = rearrange(attn, 'b mask_units n_heads t c -> b mask_units t (n_heads c)',n_heads = self.num_heads) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +class HieraBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + heads: int, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + act_layer: nn.Module = nn.GELU, + q_stride: List[int]= [1,1,1], + patches_per_mask_unit:List[int]=[2,12,12], + ): + super().__init__() + + self.dim = dim + self.dim_out = dim_out + self.q_stride = q_stride + + self.norm1 = norm_layer(dim) + + do_pool = np.any(np.array(q_stride)>1) or dim != dim_out + + self.attn = MaskUnitAttention(dim, dim_out, num_heads=heads, q_stride=q_stride, patches_per_mask_unit=patches_per_mask_unit) + + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) + + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + + # max pooling by q stride within a mask unit + skip_connection_pooling = Reduce('b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c', reduction='mean', q_stride_z=self.q_stride[0], q_stride_y = self.q_stride[1], q_stride_x = self.q_stride[2] ,n_patches_z = self.attn.pooled_patches_per_mask_unit[0], n_patches_y= self.attn.pooled_patches_per_mask_unit[1], n_patches_x=self.attn. pooled_patches_per_mask_unit[2]) + + self.proj = torch.nn.Sequential(skip_connection_pooling, nn.Linear(dim, dim_out)) if do_pool else torch.nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + ''' + x: batch x mask units x tokens x emb_dim + ''' + # Attention + Q Pooling + x_norm = self.norm1(x) + # change dimension and subsample within mask unit for skip connection + x = self.proj(x_norm) + + x = x + self.drop_path(self.attn(x_norm)) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + \ No newline at end of file diff --git a/cyto_dl/nn/vits/blocks/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify_hiera.py new file mode 100644 index 000000000..9bf4d888a --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify_hiera.py @@ -0,0 +1,109 @@ +from typing import List + +import numpy as np +import torch +import torch.nn as nn +from einops.layers.torch import Rearrange +from einops import repeat + +from cyto_dl.nn.vits.utils import take_indexes + +def random_indexes(size: int, device): + forward_indexes = torch.randperm(size, device=device, dtype=torch.long) + backward_indexes = torch.argsort(forward_indexes) + return forward_indexes, backward_indexes + +def take_indexes_mask(sequences, indexes): + ''' + sequences: batch x mask units x patches x emb_dim + indexes: mask_units x batch + ''' + # always gather across tokens dimension + return torch.gather( + sequences, 1, repeat(indexes.to(sequences.device), "mu b -> b mu p c", b= sequences.shape[0], c=sequences.shape[-1], mu = sequences.shape[1], p=sequences.shape[2]) + ) + + +class PatchifyHiera(torch.nn.Module): + """Class for converting images to a masked sequence of patches with positional embeddings.""" + + def __init__( + self, + patch_size: List[int], + n_patches: List[int], + mask_ratio: float = 0.8, + num_mask_units: List[int] = [8,8,8], + emb_dim: int = 64, + spatial_dims: int= 3, + context_pixels: List[int] = [0, 0, 0], + ): + super().__init__() + self.spatial_dims = spatial_dims + self.mask_ratio = mask_ratio + self.total_n_mask_units = np.prod(num_mask_units) + patches_per_mask_unit = np.prod(n_patches) // self.total_n_mask_units + self.pos_embedding = torch.nn.Parameter(torch.zeros(1, self.total_n_mask_units, patches_per_mask_unit, emb_dim)) + + self.num_mask_units = num_mask_units + self.num_selected_mask_units = int(self.total_n_mask_units*(1-mask_ratio)) + + # mu -> mask unit + self.mask2img = Rearrange("(n_mu_z n_mu_y n_mu_x) b c -> b c n_mu_z n_mu_y n_mu_x ", n_mu_z=num_mask_units[0], n_mu_y=num_mask_units[1], n_mu_x=num_mask_units[2]) + + self.img2mask_units = Rearrange('b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ', n_mu_z = num_mask_units[0], n_mu_y= num_mask_units[1], n_mu_x=num_mask_units[2]) + + + context_pixels = context_pixels[:spatial_dims] + weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int) + self.conv = nn.Conv3d( + in_channels=1, + out_channels=emb_dim, + kernel_size=weight_size, + stride=patch_size, + padding=context_pixels, + ) + + def get_mask(self, img): + B = img.shape[0] + indexes = [random_indexes(self.total_n_mask_units, device=img.device) for _ in range(B)] + # forward indexes : index in image -> shuffledpatch + forward_indexes = torch.stack([i[0] for i in indexes], axis=-1) + # backward indexes : shuffled patch -> index in image + backward_indexes = torch.stack([i[1] for i in indexes], axis=-1) + + mask = torch.zeros(self.total_n_mask_units, B, 1, device=img.device, dtype=torch.uint8) + # visible patches are first + mask[:self.num_selected_mask_units] = 1 + mask = take_indexes(mask, backward_indexes) + mask = self.mask2img(mask) + # one pixel per masked patch, interpolate to size of input image + mask = torch.nn.functional.interpolate( + mask, img.shape[-self.spatial_dims :], mode="nearest" + ) + return mask, forward_indexes, backward_indexes + + def forward(self, img): + """" + takes in BCZYX image + returns B x num_selected_mask_units x patches_per_mask_unit x emb_dim + """ + mask = torch.ones_like(img) + forward_indexes, backward_indexes = None, None + if self.mask_ratio > 0: + mask, forward_indexes, backward_indexes = self.get_mask(img) + tokens = self.conv(img * mask) + # break into batch x mask units x patches permask unit x emb_dim + tokens = self.img2mask_units(tokens) + + tokens = tokens + self.pos_embedding + if self.mask_ratio > 0: + tokens = take_indexes_mask(tokens, forward_indexes)[:, :self.num_selected_mask_units] + mask = (1 - mask).bool() + return tokens, mask, forward_indexes, backward_indexes + + + + + + + diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py new file mode 100644 index 000000000..1e594a026 --- /dev/null +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -0,0 +1,329 @@ +# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 + +from typing import List, Optional, Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.vision_transformer import Block +from monai.networks.blocks import UnetOutBlock, UnetResBlock, UpSample + + +from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock + +from cyto_dl.nn.vits.blocks.patchify_hiera import PatchifyHiera +from cyto_dl.nn.vits.mae import MAE_Decoder +from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder + + + +class SpatialMerger(nn.Module): + def __init__(self, downsample_factor, in_dim, out_dim): + super().__init__() + self.downsample_factor = downsample_factor + conv = nn.Conv3d( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=downsample_factor, + stride=downsample_factor, + padding=0, + bias=False, + ) + + tokens2img = Rearrange( + "b n_mu (z y x) c -> (b n_mu) c z y x", z=downsample_factor[0], y=downsample_factor[1], x=downsample_factor[2] + ) + self.model = nn.Sequential( + tokens2img, + conv + ) + + def forward(self, x): + b, n_mu, _, _ = x.shape + x = self.model(x) + x = rearrange(x, "(b n_mu) c z y x -> b n_mu (z y x) c", b=b, n_mu=n_mu) + return x + +class HieraEncoder(torch.nn.Module): + def __init__( + self, + num_patches: List[int], + num_mask_units: List[int], + architecture: List[Dict], + emb_dim: int = 64, + spatial_dims: int = 3, + patch_size: List[int] = (16, 16, 16), + mask_ratio: Optional[float] = 0.75, + context_pixels: Optional[List[int]] = [0, 0, 0], + ) -> None: + """ + Parameters + ---------- + num_patches: List[int] + Number of patches in each dimension + spatial_dims: int + Number of spatial dimensions + base_patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + weight_intermediates: bool + Whether to output linear combination of intermediate layers as final output like CrossMAE + """ + super().__init__() + # TODO decide how to deal with class token. Do we add it as an extra patch token per mask unit? Leaving out for now... + self.patchify = PatchifyHiera( + patch_size, num_patches, mask_ratio, num_mask_units, emb_dim, spatial_dims, context_pixels + ) + + patches_per_mask_unit = np.array(num_patches) // np.array(num_mask_units) + self.final_dim = emb_dim * (2**len(architecture)) + + self.save_block_idxs = [] + self.spatial_mergers = torch.nn.ParameterDict({}) + transformer = [] + num_blocks = 0 + for stage_num, stage in enumerate(architecture): + # use mask unit attention until first layer that uses self attention + if stage.get('self_attention', False): + break + print(f"Stage: {stage_num}") + for block in range(stage["repeat"]): + is_last = block == stage["repeat"] - 1 + # do spatial pooling within mask unit on last block of stage + q_stride = stage['q_stride'] if is_last else [1] * spatial_dims + + # double embedding dimension in last block of stage + dim_in = emb_dim * (2**stage_num) + dim_out = dim_in if not is_last else dim_in * 2 + print(f"\tBlock {block}:\t\tdim_in: {dim_in}, dim_out: {dim_out}, num_heads: {stage['num_heads']}, q_stride: {q_stride}, patches_per_mask_unit: {patches_per_mask_unit}") + transformer.append( + HieraBlock( + dim=dim_in, + dim_out=dim_out, + heads=stage['num_heads'], + q_stride = q_stride, + patches_per_mask_unit = patches_per_mask_unit, + ) + ) + if is_last: + # save the block before the spatial pooling unless it's the final stage + save_block = num_blocks -1 if stage_num < len(architecture) - 1 else num_blocks + self.save_block_idxs.append(save_block) + + # create a spatial merger for combining tokens pre-downsampling, last stage doesn't need merging since it has expected num channels, spatial shape + self.spatial_mergers[f'block_{save_block}'] = SpatialMerger(patches_per_mask_unit, dim_in, self.final_dim) if stage_num < len(architecture) - 1 else torch.nn.Identity() + + # at end of each layer, patches per mask unit is reduced as we pool spatially + patches_per_mask_unit = patches_per_mask_unit // np.array(stage['q_stride']) + num_blocks += 1 + self.mask_unit_transformer = torch.nn.Sequential(*transformer) + + self.self_attention_transformer = torch.nn.Sequential( + *[Block(self.final_dim, stage['num_heads']) for _ in range(stage['repeat'])] + ) + + self.layer_norm = torch.nn.LayerNorm(self.final_dim) + + def forward(self, img): + patches, mask, forward_indexes, backward_indexes = self.patchify(img) + + # mask unit attention + mask_unit_embeddings = 0.0 + for i, block in enumerate(self.mask_unit_transformer): + patches = block(patches) + if i in self.save_block_idxs: + mask_unit_embeddings += self.spatial_mergers[f'block_{i}'](patches) + + # combine mask units and tokens for full self attention transformer + mask_unit_embeddings = rearrange(mask_unit_embeddings, "b n_mu t c -> b (n_mu t) c") + mask_unit_embeddings = self.self_attention_transformer(mask_unit_embeddings) + mask_unit_embeddings= self.layer_norm(mask_unit_embeddings) + + return mask_unit_embeddings, mask, forward_indexes, backward_indexes + + +class HieraMAE(torch.nn.Module): + def __init__( + self, + architecture, + spatial_dims: int = 3, + num_patches: Optional[List[int]] = [2, 32, 32], + num_mask_units: Optional[List[int]] = [2, 12, 12], + patch_size: Optional[List[int]] = [16, 16, 16], + emb_dim: Optional[int] = 64, + decoder_layer: Optional[int] = 4, + decoder_head: Optional[int] = 8, + decoder_dim: Optional[int] = 192, + mask_ratio: Optional[int] = 0.75, + context_pixels: Optional[List[int]] = [0, 0, 0], + use_crossmae: Optional[bool] = False, + ) -> None: + """ + Parameters + ---------- + """ + super().__init__() + assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" + + if isinstance(num_patches, int): + num_patches = [num_patches] * spatial_dims + if isinstance(patch_size, int): + patch_size = [patch_size] * spatial_dims + + assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" + assert ( + len(patch_size) == spatial_dims + ), "patch_size must be of length spatial_dims" + + self.mask_ratio = mask_ratio + + self.encoder = HieraEncoder( + num_patches=num_patches, + num_mask_units=num_mask_units, + architecture=architecture, + emb_dim=emb_dim, + spatial_dims=spatial_dims, + patch_size=patch_size, + mask_ratio=mask_ratio, + context_pixels=context_pixels, + ) + # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size + mask_unit_size = (np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units) + + decoder_class = MAE_Decoder + if use_crossmae: + decoder_class = CrossMAE_Decoder + + self.decoder = decoder_class( + num_patches=num_mask_units, + spatial_dims=spatial_dims, + base_patch_size=mask_unit_size.astype(int), + enc_dim=self.encoder.final_dim, + emb_dim=decoder_dim, + num_layer=decoder_layer, + num_head=decoder_head, + has_cls_token=False + ) + + def forward(self, img): + features, mask, forward_indexes, backward_indexes = self.encoder(img) + features = rearrange(features, "b t c -> t b c") + predicted_img = self.decoder(features, forward_indexes, backward_indexes) + return predicted_img, mask + + +class HieraSeg(torch.nn.Module): + def __init__( + self, + encoder_ckpt, + architecture, + spatial_dims: int = 3, + n_out_channels: int = 6, + num_patches: Optional[List[int]] = [2, 32, 32], + num_mask_units: Optional[List[int]] = [2, 12, 12], + patch_size: Optional[List[int]] = [16, 16, 16], + emb_dim: Optional[int] = 64, + context_pixels: Optional[List[int]] = [0, 0, 0], + ) -> None: + """ + Parameters + ---------- + """ + super().__init__() + assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" + + if isinstance(num_patches, int): + num_patches = [num_patches] * spatial_dims + if isinstance(patch_size, int): + patch_size = [patch_size] * spatial_dims + + assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" + assert ( + len(patch_size) == spatial_dims + ), "patch_size must be of length spatial_dims" + + + self.encoder = HieraEncoder( + num_patches=num_patches, + num_mask_units=num_mask_units, + architecture=architecture, + emb_dim=emb_dim, + spatial_dims=spatial_dims, + patch_size=patch_size, + mask_ratio=0, + context_pixels=context_pixels, + ) + model = torch.load(encoder_ckpt, map_location="cuda:0") + enc_state_dict = { + k.replace("backbone.encoder.", ""): v + for k, v in model["state_dict"].items() + if "encoder" in k and "intermediate" not in k + } + self.encoder.load_state_dict(enc_state_dict, strict=False) + for name, param in self.encoder.named_parameters(): + # allow different weighting of internal activations for finetuning + param.requires_grad = False + + # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size + mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) + + project_dim = np.prod(mask_unit_size)*16 + head = torch.nn.Linear(self.encoder.final_dim, project_dim) + norm = torch.nn.LayerNorm(project_dim) + patch2img = Rearrange( + "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_z=num_mask_units[0], + n_patch_y=num_mask_units[1], + n_patch_x=num_mask_units[2], + patch_size_z=mask_unit_size[0], + patch_size_y=mask_unit_size[1], + patch_size_x=mask_unit_size[2], + ) + self.patch2img = torch.nn.Sequential(head, norm, patch2img) + + self.upsample = torch.nn.Sequential( + *[ + UpSample( + spatial_dims=spatial_dims, + in_channels=16, + out_channels=16, + scale_factor=[2.6134, 2.5005, 2.5005], + mode="nontrainable", + interp_mode="trilinear", + ), + UnetResBlock( + spatial_dims=spatial_dims, + in_channels=16, + out_channels=16, + stride=1, + kernel_size=3, + norm_name="INSTANCE", + dropout=0, + ), + UnetOutBlock( + spatial_dims=spatial_dims, + in_channels=16, + out_channels=n_out_channels, + dropout=0, + ), + ] + ) + + def forward(self, img): + features, _, _, _ = self.encoder(img) + features = rearrange(features, "b t c -> t b c") + predicted_img = self.patch2img(features) + predicted_img = self.upsample(predicted_img) + return predicted_img \ No newline at end of file From 50646aec759da0bf300852f9a1d1afe68f68f7e1 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Wed, 22 May 2024 13:59:56 -0700 Subject: [PATCH 03/27] start of mask2former --- cyto_dl/nn/vits/hiera_mae.py | 244 +++++++++++++++++++++++++++++++++-- 1 file changed, 231 insertions(+), 13 deletions(-) diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py index 1e594a026..5766a620f 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -58,6 +58,7 @@ def __init__( patch_size: List[int] = (16, 16, 16), mask_ratio: Optional[float] = 0.75, context_pixels: Optional[List[int]] = [0, 0, 0], + save_layers: Optional[bool] = True, ) -> None: """ Parameters @@ -82,6 +83,7 @@ def __init__( Whether to output linear combination of intermediate layers as final output like CrossMAE """ super().__init__() + self.save_layers= save_layers # TODO decide how to deal with class token. Do we add it as an extra patch token per mask unit? Leaving out for now... self.patchify = PatchifyHiera( patch_size, num_patches, mask_ratio, num_mask_units, emb_dim, spatial_dims, context_pixels @@ -141,17 +143,20 @@ def forward(self, img): # mask unit attention mask_unit_embeddings = 0.0 + save_layers = [] for i, block in enumerate(self.mask_unit_transformer): patches = block(patches) if i in self.save_block_idxs: mask_unit_embeddings += self.spatial_mergers[f'block_{i}'](patches) + if self.save_layers: + save_layers.append(patches) # combine mask units and tokens for full self attention transformer mask_unit_embeddings = rearrange(mask_unit_embeddings, "b n_mu t c -> b (n_mu t) c") mask_unit_embeddings = self.self_attention_transformer(mask_unit_embeddings) mask_unit_embeddings= self.layer_norm(mask_unit_embeddings) - return mask_unit_embeddings, mask, forward_indexes, backward_indexes + return mask_unit_embeddings, mask, forward_indexes, backward_indexes, save_layers class HieraMAE(torch.nn.Module): @@ -218,7 +223,7 @@ def __init__( ) def forward(self, img): - features, mask, forward_indexes, backward_indexes = self.encoder(img) + features, mask, forward_indexes, backward_indexes, save_layers = self.encoder(img) features = rearrange(features, "b t c -> t b c") predicted_img = self.decoder(features, forward_indexes, backward_indexes) return predicted_img, mask @@ -264,17 +269,18 @@ def __init__( patch_size=patch_size, mask_ratio=0, context_pixels=context_pixels, + save_layers=True ) - model = torch.load(encoder_ckpt, map_location="cuda:0") - enc_state_dict = { - k.replace("backbone.encoder.", ""): v - for k, v in model["state_dict"].items() - if "encoder" in k and "intermediate" not in k - } - self.encoder.load_state_dict(enc_state_dict, strict=False) - for name, param in self.encoder.named_parameters(): - # allow different weighting of internal activations for finetuning - param.requires_grad = False + # model = torch.load(encoder_ckpt, map_location="cuda:0") + # enc_state_dict = { + # k.replace("backbone.encoder.", ""): v + # for k, v in model["state_dict"].items() + # if "encoder" in k and "intermediate" not in k + # } + # self.encoder.load_state_dict(enc_state_dict, strict=False) + # for name, param in self.encoder.named_parameters(): + # # allow different weighting of internal activations for finetuning + # param.requires_grad = False # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) @@ -322,8 +328,220 @@ def __init__( ) def forward(self, img): - features, _, _, _ = self.encoder(img) + breakpoint() + features, _, _, _, save_layers = self.encoder(img) features = rearrange(features, "b t c -> t b c") predicted_img = self.patch2img(features) predicted_img = self.upsample(predicted_img) + return predicted_img + + + + +class Mlp(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CrossAttention(nn.Module): + def __init__( + self, + encoder_dim, + decoder_dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + super().__init__() + self.num_heads = num_heads + head_dim = decoder_dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) + self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(decoder_dim, decoder_dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, mask): + B, N, C = x.shape + Ny = y.shape[1] + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + kv = ( + self.kv(y) + .reshape(B, Ny, 2, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + k, v = kv[0], kv[1] + + attn = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop, + attn_mask=mask>0.5 if mask is not None else None, + ) + x = attn.transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x +from timm.models.vision_transformer import Attention + +class Mask2FormerBlock(nn.Module): + def __init__( + self, + encoder_dim, + decoder_dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.norm1 = norm_layer(decoder_dim) + + self.self_attn_block = Attention( + dim=decoder_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.cross_attn = CrossAttention( + encoder_dim, + decoder_dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.norm2 = norm_layer(decoder_dim) + mlp_hidden_dim = int(decoder_dim * mlp_ratio) + self.mlp = Mlp( + in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + def forward(self, x, y, mask): + """ + x: query features, y: image features, mask: previous mask prediction + """ + x = self.norm1(x + self.cross_attn(x, y, mask)) + x = x + self.self_attn_block(x) + x = x + self.mlp(self.norm2(x)) + return x + +class HieraMask2Former(torch.nn.Module): + def __init__( + self, + encoder_ckpt, + architecture, + spatial_dims: int = 3, + num_queries: int = 50, + num_patches: Optional[List[int]] = [2, 32, 32], + num_mask_units: Optional[List[int]] = [2, 12, 12], + patch_size: Optional[List[int]] = [16, 16, 16], + emb_dim: Optional[int] = 64, + context_pixels: Optional[List[int]] = [0, 0, 0], + ) -> None: + """ + Parameters + ---------- + """ + super().__init__() + assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" + + if isinstance(num_patches, int): + num_patches = [num_patches] * spatial_dims + if isinstance(patch_size, int): + patch_size = [patch_size] * spatial_dims + + assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" + assert ( + len(patch_size) == spatial_dims + ), "patch_size must be of length spatial_dims" + self.num_mask_units = num_mask_units + + self.encoder = HieraEncoder( + num_patches=num_patches, + num_mask_units=num_mask_units, + architecture=architecture, + emb_dim=emb_dim, + spatial_dims=spatial_dims, + patch_size=patch_size, + mask_ratio=0, + context_pixels=context_pixels, + ) + model = torch.load(encoder_ckpt, map_location="cuda:0") + enc_state_dict = { + k.replace("backbone.encoder.", ""): v + for k, v in model["state_dict"].items() + if "encoder" in k and "intermediate" not in k + } + self.encoder.load_state_dict(enc_state_dict, strict=False) + for name, param in self.encoder.named_parameters(): + param.requires_grad = False + + # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size + mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) + + self.instance_queries = torch.nn.Parameter(torch.zeros(1, num_queries, self.encoder.final_dim)) + self.instance_queries_pos_emb = torch.nn.Parameter(torch.zeros(1, num_queries, self.encoder.final_dim)) + + q_strides = [np.array(stage['q_stride']) for stage in architecture if stage.get('q_stride', False)] + patches_per_mask_unit = [np.array(num_patches) // np.array(num_mask_units)] + for qs in q_strides: + patches_per_mask_unit.append(patches_per_mask_unit[-1] // qs) + patches_per_mask_unit.reverse() + patches_per_mask_unit[0] = np.array([1,1,1]) + self.patches_per_mask_unit = patches_per_mask_unit + + # TODO each block should have a different embedding dimension + # self.transformer = torch.nn.ModuleList([Mask2FormerBlock() for _ in range(len(patches_per_mask_unit))]) + + + + def forward(self, img): + breakpoint() + #features are b x t x c + features, _, _, _, save_layers = self.encoder(img) + save_layers.append(features.unsqueeze(2)) + save_layers.reverse() + # start with lowest resolution + mask = None + for layer, ppmu in zip(save_layers, self.patches_per_mask_unit): + # TODO add positional embedding and scale embedding to image features + # do we even need to rearrange to an image here if we're just doing cross attention against it? + img_features = rearrange(layer, 'b (n_mu_z n_mu_y n_mu_x) (patches_per_mu_z patches_per_mu_y patches_per_mu_x) c -> b c (n_mu_z patches_per_mu_z) (n_mu_y patches_per_mu_y) (n_mu_x patches_per_mu_x)', n_mu_z=self.num_mask_units[0], n_mu_y=self.num_mask_units[1], n_mu_x=self.num_mask_units[2], patches_per_mu_z=ppmu[0], patches_per_mu_y=ppmu[1], patches_per_mu_x=ppmu[2]) + + + # mask = self.transformer(img_features, mask, self.instance_queries, self.instance_queries_pos_emb) + + return predicted_img \ No newline at end of file From 22ddadcb7e52e1b4560edfe3dfef8e296adc7671 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Fri, 24 May 2024 12:50:22 -0700 Subject: [PATCH 04/27] first take at transfomer --- cyto_dl/nn/vits/hiera_mae.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py index 5766a620f..b8c039b22 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -84,7 +84,6 @@ def __init__( """ super().__init__() self.save_layers= save_layers - # TODO decide how to deal with class token. Do we add it as an extra patch token per mask unit? Leaving out for now... self.patchify = PatchifyHiera( patch_size, num_patches, mask_ratio, num_mask_units, emb_dim, spatial_dims, context_pixels ) @@ -93,6 +92,7 @@ def __init__( self.final_dim = emb_dim * (2**len(architecture)) self.save_block_idxs = [] + self.save_block_dims = [] self.spatial_mergers = torch.nn.ParameterDict({}) transformer = [] num_blocks = 0 @@ -123,6 +123,7 @@ def __init__( # save the block before the spatial pooling unless it's the final stage save_block = num_blocks -1 if stage_num < len(architecture) - 1 else num_blocks self.save_block_idxs.append(save_block) + self.save_block_dims.append(dim_out) # create a spatial merger for combining tokens pre-downsampling, last stage doesn't need merging since it has expected num channels, spatial shape self.spatial_mergers[f'block_{save_block}'] = SpatialMerger(patches_per_mask_unit, dim_in, self.final_dim) if stage_num < len(architecture) - 1 else torch.nn.Identity() @@ -411,6 +412,7 @@ def __init__( encoder_dim, decoder_dim, num_heads, + num_patches, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, @@ -423,6 +425,9 @@ def __init__( super().__init__() self.norm1 = norm_layer(decoder_dim) + # TODO add positional embedding and scale embedding to image features + self.scale_positional_embedding = nn.Parameter(torch.zeros(1, num_patches, encoder_dim)) + self.self_attn_block = Attention( dim=decoder_dim, num_heads=num_heads, @@ -441,7 +446,6 @@ def __init__( attn_drop=attn_drop, proj_drop=drop, ) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.norm2 = norm_layer(decoder_dim) mlp_hidden_dim = int(decoder_dim * mlp_ratio) self.mlp = Mlp( @@ -523,7 +527,7 @@ def __init__( self.patches_per_mask_unit = patches_per_mask_unit # TODO each block should have a different embedding dimension - # self.transformer = torch.nn.ModuleList([Mask2FormerBlock() for _ in range(len(patches_per_mask_unit))]) + self.transformer = torch.nn.ModuleList([Mask2FormerBlock(encoder_dim = self.encoder.save_block_dims[i], decoder_dim = 128, num_neads = 4, num_patches = patches_per_mask_unit[i] * num_mask_units) for i in range(len(patches_per_mask_unit))]) @@ -534,14 +538,21 @@ def forward(self, img): save_layers.append(features.unsqueeze(2)) save_layers.reverse() # start with lowest resolution + # first mask should be prediction from query features alone mask = None for layer, ppmu in zip(save_layers, self.patches_per_mask_unit): - # TODO add positional embedding and scale embedding to image features - # do we even need to rearrange to an image here if we're just doing cross attention against it? - img_features = rearrange(layer, 'b (n_mu_z n_mu_y n_mu_x) (patches_per_mu_z patches_per_mu_y patches_per_mu_x) c -> b c (n_mu_z patches_per_mu_z) (n_mu_y patches_per_mu_y) (n_mu_x patches_per_mu_x)', n_mu_z=self.num_mask_units[0], n_mu_y=self.num_mask_units[1], n_mu_x=self.num_mask_units[2], patches_per_mu_z=ppmu[0], patches_per_mu_y=ppmu[1], patches_per_mu_x=ppmu[2]) + layer = rearrange(layer, 'b n_mu mu_dims c -> b (n_mu mu_dims) c') + layer = self.transformer(layer, mask, self.instance_queries, self.instance_queries_pos_emb) - # mask = self.transformer(img_features, mask, self.instance_queries, self.instance_queries_pos_emb) + # cross attention provides one mask per query + # self attention refines mask + # repeat for each block + + # rearrange to mask TODO make this account for havingn_queries masks + img_features = rearrange(layer, 'b (n_mu_z n_mu_y n_mu_x) (patches_per_mu_z patches_per_mu_y patches_per_mu_x) c -> b c (n_mu_z patches_per_mu_z) (n_mu_y patches_per_mu_y) (n_mu_x patches_per_mu_x)', n_mu_z=self.num_mask_units[0], n_mu_y=self.num_mask_units[1], n_mu_x=self.num_mask_units[2], patches_per_mu_z=ppmu[0], patches_per_mu_y=ppmu[1], patches_per_mu_x=ppmu[2]) + # upsample to next resolution + mask = F.interpolate(mask, scale_factor=ppmu, mode='nearest') return predicted_img \ No newline at end of file From 693c514890227f43621a20b3900266b65634e623 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Thu, 30 May 2024 11:40:38 -0700 Subject: [PATCH 05/27] fix dimensionality, now updating instance queries instead of mask --- cyto_dl/nn/vits/hiera_mae.py | 50 ++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 17 deletions(-) diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py index b8c039b22..349ca01c5 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -18,6 +18,7 @@ from cyto_dl.nn.vits.mae import MAE_Decoder from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder +import torch.nn.functional as F class SpatialMerger(nn.Module): @@ -132,6 +133,8 @@ def __init__( patches_per_mask_unit = patches_per_mask_unit // np.array(stage['q_stride']) num_blocks += 1 self.mask_unit_transformer = torch.nn.Sequential(*transformer) + self.save_block_dims.append(self.final_dim) + self.save_block_dims.reverse() self.self_attention_transformer = torch.nn.Sequential( *[Block(self.final_dim, stage['num_heads']) for _ in range(stage['repeat'])] @@ -375,19 +378,22 @@ def __init__( self.num_heads = num_heads head_dim = decoder_dim // num_heads self.scale = qk_scale or head_dim**-0.5 - self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) + self.q = nn.Linear(encoder_dim, decoder_dim, bias=qkv_bias) self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(decoder_dim, decoder_dim) self.proj_drop = nn.Dropout(proj_drop) + self.decoder_dim = decoder_dim + def forward(self, x, y, mask): + """ x: queries y: values, mask: mask for attention""" B, N, C = x.shape Ny = y.shape[1] - q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + q = self.q(x).reshape(B, N, self.num_heads, self.decoder_dim // self.num_heads).permute(0, 2, 1, 3) kv = ( self.kv(y) - .reshape(B, Ny, 2, self.num_heads, C // self.num_heads) + .reshape(B, Ny, 2, self.num_heads, self.decoder_dim // self.num_heads) .permute(2, 0, 3, 1, 4) ) k, v = kv[0], kv[1] @@ -399,7 +405,7 @@ def forward(self, x, y, mask): dropout_p=self.attn_drop, attn_mask=mask>0.5 if mask is not None else None, ) - x = attn.transpose(1, 2).reshape(B, N, C) + x = attn.transpose(1, 2).reshape(B, N, self.decoder_dim) x = self.proj(x) x = self.proj_drop(x) @@ -426,13 +432,12 @@ def __init__( self.norm1 = norm_layer(decoder_dim) # TODO add positional embedding and scale embedding to image features - self.scale_positional_embedding = nn.Parameter(torch.zeros(1, num_patches, encoder_dim)) + self.scale_positional_embedding = nn.Parameter(torch.zeros(1, np.prod(num_patches), encoder_dim)) self.self_attn_block = Attention( dim=decoder_dim, num_heads=num_heads, qkv_bias=qkv_bias, - qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) @@ -452,14 +457,15 @@ def __init__( in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop ) - def forward(self, x, y, mask): + def forward(self, instance_queries, image_feats, mask): """ x: query features, y: image features, mask: previous mask prediction """ - x = self.norm1(x + self.cross_attn(x, y, mask)) - x = x + self.self_attn_block(x) - x = x + self.mlp(self.norm2(x)) - return x + breakpoint() + instance_queries = self.norm1(instance_queries + self.cross_attn(instance_queries, image_feats, mask)) + instance_queries = instance_queries + self.self_attn_block(instance_queries) + instance_queries = instance_queries + self.mlp(self.norm2(instance_queries)) + return instance_queries class HieraMask2Former(torch.nn.Module): def __init__( @@ -515,7 +521,13 @@ def __init__( # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) - self.instance_queries = torch.nn.Parameter(torch.zeros(1, num_queries, self.encoder.final_dim)) + self.instance_queries = torch.nn.Parameter(torch.zeros(1, num_queries, emb_dim)) + #unclear if we need a separate positional embedding for instance queries? + #s. Object query features are only used as the initial + # input to the Transformer decoder and are updated through + # decoder layers; whereas query positional embeddings are + # added to query features in every Transformer decoder layer + # when computing the attention weights. self.instance_queries_pos_emb = torch.nn.Parameter(torch.zeros(1, num_queries, self.encoder.final_dim)) q_strides = [np.array(stage['q_stride']) for stage in architecture if stage.get('q_stride', False)] @@ -527,23 +539,25 @@ def __init__( self.patches_per_mask_unit = patches_per_mask_unit # TODO each block should have a different embedding dimension - self.transformer = torch.nn.ModuleList([Mask2FormerBlock(encoder_dim = self.encoder.save_block_dims[i], decoder_dim = 128, num_neads = 4, num_patches = patches_per_mask_unit[i] * num_mask_units) for i in range(len(patches_per_mask_unit))]) + self.transformer = torch.nn.ModuleList([Mask2FormerBlock(encoder_dim = self.encoder.save_block_dims[i], decoder_dim = emb_dim, num_heads = 4, num_patches = np.prod(patches_per_mask_unit[i] * num_mask_units)) for i in range(len(patches_per_mask_unit))]) def forward(self, img): - breakpoint() #features are b x t x c features, _, _, _, save_layers = self.encoder(img) save_layers.append(features.unsqueeze(2)) save_layers.reverse() # start with lowest resolution # first mask should be prediction from query features alone - mask = None - for layer, ppmu in zip(save_layers, self.patches_per_mask_unit): + mask = None# should create a mask here using raw instance queries + instance_queries = self.instance_queries + breakpoint() + + for i, (layer, ppmu) in enumerate(zip(save_layers, self.patches_per_mask_unit)): layer = rearrange(layer, 'b n_mu mu_dims c -> b (n_mu mu_dims) c') - layer = self.transformer(layer, mask, self.instance_queries, self.instance_queries_pos_emb) + instance_queries = self.transformer[i](instance_queries, layer, mask) # cross attention provides one mask per query # self attention refines mask @@ -555,4 +569,6 @@ def forward(self, img): # upsample to next resolution mask = F.interpolate(mask, scale_factor=ppmu, mode='nearest') + # loss is calculated on each intermediate mask as well + return predicted_img \ No newline at end of file From 913416a9a4f957326f9d4da879270db95e578cc3 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Fri, 31 May 2024 09:40:40 -0700 Subject: [PATCH 06/27] give instance queries own dim --- cyto_dl/nn/vits/hiera_mae.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py index 349ca01c5..47468ebb3 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -378,7 +378,7 @@ def __init__( self.num_heads = num_heads head_dim = decoder_dim // num_heads self.scale = qk_scale or head_dim**-0.5 - self.q = nn.Linear(encoder_dim, decoder_dim, bias=qkv_bias) + self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(decoder_dim, decoder_dim) @@ -479,6 +479,7 @@ def __init__( patch_size: Optional[List[int]] = [16, 16, 16], emb_dim: Optional[int] = 64, context_pixels: Optional[List[int]] = [0, 0, 0], + mask2former_dim: Optional[int] = 128, ) -> None: """ Parameters @@ -521,14 +522,14 @@ def __init__( # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) - self.instance_queries = torch.nn.Parameter(torch.zeros(1, num_queries, emb_dim)) + self.instance_queries = torch.nn.Parameter(torch.zeros(1, num_queries, mask2former_dim)) #unclear if we need a separate positional embedding for instance queries? #s. Object query features are only used as the initial # input to the Transformer decoder and are updated through # decoder layers; whereas query positional embeddings are # added to query features in every Transformer decoder layer # when computing the attention weights. - self.instance_queries_pos_emb = torch.nn.Parameter(torch.zeros(1, num_queries, self.encoder.final_dim)) + self.instance_queries_pos_emb = torch.nn.Parameter(torch.zeros(1, num_queries, mask2former_dim)) q_strides = [np.array(stage['q_stride']) for stage in architecture if stage.get('q_stride', False)] patches_per_mask_unit = [np.array(num_patches) // np.array(num_mask_units)] @@ -539,7 +540,7 @@ def __init__( self.patches_per_mask_unit = patches_per_mask_unit # TODO each block should have a different embedding dimension - self.transformer = torch.nn.ModuleList([Mask2FormerBlock(encoder_dim = self.encoder.save_block_dims[i], decoder_dim = emb_dim, num_heads = 4, num_patches = np.prod(patches_per_mask_unit[i] * num_mask_units)) for i in range(len(patches_per_mask_unit))]) + self.transformer = torch.nn.ModuleList([Mask2FormerBlock(encoder_dim = self.encoder.save_block_dims[i], decoder_dim = mask2former_dim, num_heads = 4, num_patches = np.prod(patches_per_mask_unit[i] * num_mask_units)) for i in range(len(patches_per_mask_unit))]) @@ -569,6 +570,8 @@ def forward(self, img): # upsample to next resolution mask = F.interpolate(mask, scale_factor=ppmu, mode='nearest') + # output is instance_queries * output_features + # loss is calculated on each intermediate mask as well return predicted_img \ No newline at end of file From a561aa4c42933dfd08801cdd995449c79df2aa77 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 3 Jun 2024 12:24:18 -0700 Subject: [PATCH 07/27] add mask creation --- cyto_dl/nn/vits/hiera_mae.py | 60 +++++++++++++++++++++++------------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py index 47468ebb3..ba96abdaa 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -19,6 +19,7 @@ from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder import torch.nn.functional as F +from einops import repeat class SpatialMerger(nn.Module): @@ -124,7 +125,7 @@ def __init__( # save the block before the spatial pooling unless it's the final stage save_block = num_blocks -1 if stage_num < len(architecture) - 1 else num_blocks self.save_block_idxs.append(save_block) - self.save_block_dims.append(dim_out) + self.save_block_dims.append(dim_in) # create a spatial merger for combining tokens pre-downsampling, last stage doesn't need merging since it has expected num channels, spatial shape self.spatial_mergers[f'block_{save_block}'] = SpatialMerger(patches_per_mask_unit, dim_in, self.final_dim) if stage_num < len(architecture) - 1 else torch.nn.Identity() @@ -332,7 +333,6 @@ def __init__( ) def forward(self, img): - breakpoint() features, _, _, _, save_layers = self.encoder(img) features = rearrange(features, "b t c -> t b c") predicted_img = self.patch2img(features) @@ -403,7 +403,7 @@ def forward(self, x, y, mask): k, v, dropout_p=self.attn_drop, - attn_mask=mask>0.5 if mask is not None else None, + attn_mask=mask>0 if mask is not None else None, ) x = attn.transpose(1, 2).reshape(B, N, self.decoder_dim) @@ -413,6 +413,11 @@ def forward(self, x, y, mask): from timm.models.vision_transformer import Attention class Mask2FormerBlock(nn.Module): + """ + cross attention provides one mask per query + self attention refines mask + repeat for each block + """ def __init__( self, encoder_dim, @@ -431,7 +436,6 @@ def __init__( super().__init__() self.norm1 = norm_layer(decoder_dim) - # TODO add positional embedding and scale embedding to image features self.scale_positional_embedding = nn.Parameter(torch.zeros(1, np.prod(num_patches), encoder_dim)) self.self_attn_block = Attention( @@ -461,7 +465,8 @@ def forward(self, instance_queries, image_feats, mask): """ x: query features, y: image features, mask: previous mask prediction """ - breakpoint() + image_feats = image_feats + self.scale_positional_embedding + instance_queries = self.norm1(instance_queries + self.cross_attn(instance_queries, image_feats, mask)) instance_queries = instance_queries + self.self_attn_block(instance_queries) instance_queries = instance_queries + self.mlp(self.norm2(instance_queries)) @@ -518,12 +523,10 @@ def __init__( self.encoder.load_state_dict(enc_state_dict, strict=False) for name, param in self.encoder.named_parameters(): param.requires_grad = False - # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) self.instance_queries = torch.nn.Parameter(torch.zeros(1, num_queries, mask2former_dim)) - #unclear if we need a separate positional embedding for instance queries? #s. Object query features are only used as the initial # input to the Transformer decoder and are updated through # decoder layers; whereas query positional embeddings are @@ -539,9 +542,19 @@ def __init__( patches_per_mask_unit[0] = np.array([1,1,1]) self.patches_per_mask_unit = patches_per_mask_unit - # TODO each block should have a different embedding dimension self.transformer = torch.nn.ModuleList([Mask2FormerBlock(encoder_dim = self.encoder.save_block_dims[i], decoder_dim = mask2former_dim, num_heads = 4, num_patches = np.prod(patches_per_mask_unit[i] * num_mask_units)) for i in range(len(patches_per_mask_unit))]) + self.image_feature_projector = torch.nn.ModuleList([torch.nn.Linear(self.encoder.save_block_dims[i], mask2former_dim) for i in range(len(patches_per_mask_unit))]) + + + def get_mask(self,i, image_features, instance_queries): + # this is a token mask, not a spatial mask + # 1.change channels for image feature to match mask_dim + # 2. multiply query features by mask + image_features= self.image_feature_projector[i](image_features) + mask = torch.matmul(instance_queries, image_features.transpose(1,2)) + return mask + def forward(self, img): @@ -554,24 +567,27 @@ def forward(self, img): mask = None# should create a mask here using raw instance queries instance_queries = self.instance_queries breakpoint() - + # TODO check ppmu is in the correct order for i, (layer, ppmu) in enumerate(zip(save_layers, self.patches_per_mask_unit)): + print(f"Layer {i}: {layer.shape}") layer = rearrange(layer, 'b n_mu mu_dims c -> b (n_mu mu_dims) c') - instance_queries = self.transformer[i](instance_queries, layer, mask) - - # cross attention provides one mask per query - # self attention refines mask - # repeat for each block - - # rearrange to mask TODO make this account for havingn_queries masks - img_features = rearrange(layer, 'b (n_mu_z n_mu_y n_mu_x) (patches_per_mu_z patches_per_mu_y patches_per_mu_x) c -> b c (n_mu_z patches_per_mu_z) (n_mu_y patches_per_mu_y) (n_mu_x patches_per_mu_x)', n_mu_z=self.num_mask_units[0], n_mu_y=self.num_mask_units[1], n_mu_x=self.num_mask_units[2], patches_per_mu_z=ppmu[0], patches_per_mu_y=ppmu[1], patches_per_mu_x=ppmu[2]) + upsample = [1,1,1] + if mask is not None: + # upsample mask to match current resolution + upsample = ppmu//self.patches_per_mask_unit[i-1] + mask = repeat(mask, 'b n_queries n_mu_z n_mu_y n_mu_x -> b n_queries (n_mu_z mu_z) (n_mu_y mu_y) (n_mu_x mu_x)', mu_z=upsample[0], mu_y=upsample[1], mu_x=upsample[2]) + # convert mask from image to tokens + mask = rearrange(mask,'b n_queries n_mu_z n_mu_y n_mu_x -> b n_queries (n_mu_z n_mu_y n_mu_x)') + instance_queries = self.transformer[i](instance_queries + self.instance_queries_pos_emb, layer, mask) + print(f"Instance queries {i}: {instance_queries.shape}") # upsample to next resolution - mask = F.interpolate(mask, scale_factor=ppmu, mode='nearest') - - # output is instance_queries * output_features - # loss is calculated on each intermediate mask as well + mask = self.get_mask(i, layer, instance_queries) + mask = rearrange(mask, 'b n_queries (n_mu_z mu_z n_mu_y mu_y n_mu_x mu_x) -> b n_queries (n_mu_z mu_z) (n_mu_y mu_y) (n_mu_x mu_x)', n_mu_z = self.num_mask_units[0], n_mu_y = self.num_mask_units[1], n_mu_x = self.num_mask_units[1], mu_z = upsample[0], mu_y=upsample[1], mu_x = upsample[2]) + + print(f"Mask {i}: {mask.shape}") + print() - return predicted_img \ No newline at end of file + return mask \ No newline at end of file From 9f6591c16993b7cc6efe2f06cad456add891565b Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Wed, 7 Aug 2024 11:44:48 -0700 Subject: [PATCH 08/27] remove experimental code --- .../nn/vits/blocks/masked_unit_attention.py | 90 +++- cyto_dl/nn/vits/blocks/patchify.py | 8 +- cyto_dl/nn/vits/blocks/patchify_hiera.py | 87 +-- cyto_dl/nn/vits/hiera_mae.py | 502 ++++-------------- cyto_dl/nn/vits/utils.py | 6 + 5 files changed, 222 insertions(+), 471 deletions(-) diff --git a/cyto_dl/nn/vits/blocks/masked_unit_attention.py b/cyto_dl/nn/vits/blocks/masked_unit_attention.py index 768cc7bbd..92a76ceec 100644 --- a/cyto_dl/nn/vits/blocks/masked_unit_attention.py +++ b/cyto_dl/nn/vits/blocks/masked_unit_attention.py @@ -1,12 +1,13 @@ +# inspired by https://github.com/facebookresearch/hiera/tree/main +from typing import List + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np -from einops import reduce, rearrange - -from timm.models.layers import DropPath, Mlp -from typing import List +from einops import rearrange, reduce from einops.layers.torch import Reduce +from timm.models.layers import DropPath, Mlp class MaskUnitAttention(torch.nn.Module): @@ -19,31 +20,49 @@ def __init__( qk_scale=None, attn_drop=0.0, proj_drop=0.0, - q_stride= [1,1,1], - patches_per_mask_unit=[2,12,12], + q_stride=[1, 1, 1], + patches_per_mask_unit=[2, 12, 12], ): super().__init__() self.num_heads = num_heads self.head_dim = dim_out // num_heads self.scale = qk_scale or self.head_dim**-0.5 - self.qkv = nn.Linear(dim, dim_out*3, bias=qkv_bias) + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(dim_out, dim_out) self.proj_drop = nn.Dropout(proj_drop) - self.dim_out= dim_out - self.q_stride=np.array(q_stride) - self.pooled_patches_per_mask_unit = (np.array(patches_per_mask_unit)/self.q_stride).astype(int) + self.dim_out = dim_out + self.q_stride = np.array(q_stride) + self.pooled_patches_per_mask_unit = ( + np.array(patches_per_mask_unit) / self.q_stride + ).astype(int) def forward(self, x): # project and split into q,k,v embeddings - qkv = rearrange(self.qkv(x), 'batch num_mask_units tokens_per_mask_unit (head_dim num_heads qkv) -> qkv batch num_mask_units num_heads tokens_per_mask_unit head_dim', head_dim = self.head_dim, qkv=3, num_heads =self.num_heads) + qkv = rearrange( + self.qkv(x), + "batch num_mask_units tokens_per_mask_unit (head_dim num_heads qkv) -> qkv batch num_mask_units num_heads tokens_per_mask_unit head_dim", + head_dim=self.head_dim, + qkv=3, + num_heads=self.num_heads, + ) q, k, v = qkv[0], qkv[1], qkv[2] - if np.any(self.q_stride>1): + if np.any(self.q_stride > 1): # within a mask unit, tokens are spatially ordered # perform spatial 2x2x2 max pooling over tokens - q = reduce(q, 'b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c', reduction='max', q_stride_z=self.q_stride[0], q_stride_y = self.q_stride[1], q_stride_x = self.q_stride[2] ,n_patches_z = self.pooled_patches_per_mask_unit[0], n_patches_y= self.pooled_patches_per_mask_unit[1], n_patches_x=self.pooled_patches_per_mask_unit[2]) + q = reduce( + q, + "b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c", + reduction="max", + q_stride_z=self.q_stride[0], + q_stride_y=self.q_stride[1], + q_stride_x=self.q_stride[2], + n_patches_z=self.pooled_patches_per_mask_unit[0], + n_patches_y=self.pooled_patches_per_mask_unit[1], + n_patches_x=self.pooled_patches_per_mask_unit[2], + ) attn = F.scaled_dot_product_attention( q, @@ -52,12 +71,15 @@ def forward(self, x): dropout_p=self.attn_drop, ) # combine heads into single channel dimension - x = rearrange(attn, 'b mask_units n_heads t c -> b mask_units t (n_heads c)',n_heads = self.num_heads) + x = rearrange( + attn, "b mask_units n_heads t c -> b mask_units t (n_heads c)", n_heads=self.num_heads + ) x = self.proj(x) x = self.proj_drop(x) return x + class HieraBlock(nn.Module): def __init__( self, @@ -68,8 +90,8 @@ def __init__( drop_path: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, act_layer: nn.Module = nn.GELU, - q_stride: List[int]= [1,1,1], - patches_per_mask_unit:List[int]=[2,12,12], + q_stride: List[int] = [1, 1, 1], + patches_per_mask_unit: List[int] = [2, 12, 12], ): super().__init__() @@ -79,9 +101,15 @@ def __init__( self.norm1 = norm_layer(dim) - do_pool = np.any(np.array(q_stride)>1) or dim != dim_out + do_pool = np.any(np.array(q_stride) > 1) or dim != dim_out - self.attn = MaskUnitAttention(dim, dim_out, num_heads=heads, q_stride=q_stride, patches_per_mask_unit=patches_per_mask_unit) + self.attn = MaskUnitAttention( + dim, + dim_out, + num_heads=heads, + q_stride=q_stride, + patches_per_mask_unit=patches_per_mask_unit, + ) self.norm2 = norm_layer(dim_out) self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) @@ -89,21 +117,33 @@ def __init__( self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() # max pooling by q stride within a mask unit - skip_connection_pooling = Reduce('b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c', reduction='mean', q_stride_z=self.q_stride[0], q_stride_y = self.q_stride[1], q_stride_x = self.q_stride[2] ,n_patches_z = self.attn.pooled_patches_per_mask_unit[0], n_patches_y= self.attn.pooled_patches_per_mask_unit[1], n_patches_x=self.attn. pooled_patches_per_mask_unit[2]) + skip_connection_pooling = Reduce( + "b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c", + reduction="mean", + q_stride_z=self.q_stride[0], + q_stride_y=self.q_stride[1], + q_stride_x=self.q_stride[2], + n_patches_z=self.attn.pooled_patches_per_mask_unit[0], + n_patches_y=self.attn.pooled_patches_per_mask_unit[1], + n_patches_x=self.attn.pooled_patches_per_mask_unit[2], + ) - self.proj = torch.nn.Sequential(skip_connection_pooling, nn.Linear(dim, dim_out)) if do_pool else torch.nn.Identity() + self.proj = ( + torch.nn.Sequential(skip_connection_pooling, nn.Linear(dim, dim_out)) + if do_pool + else torch.nn.Identity() + ) def forward(self, x: torch.Tensor) -> torch.Tensor: - ''' + """ x: batch x mask units x tokens x emb_dim - ''' + """ # Attention + Q Pooling x_norm = self.norm1(x) # change dimension and subsample within mask unit for skip connection x = self.proj(x_norm) - + x = x + self.drop_path(self.attn(x_norm)) # MLP x = x + self.drop_path(self.mlp(self.norm2(x))) return x - \ No newline at end of file diff --git a/cyto_dl/nn/vits/blocks/patchify.py b/cyto_dl/nn/vits/blocks/patchify.py index ee5a2c563..d029e152f 100644 --- a/cyto_dl/nn/vits/blocks/patchify.py +++ b/cyto_dl/nn/vits/blocks/patchify.py @@ -6,13 +6,7 @@ from einops.layers.torch import Rearrange, Reduce from timm.models.layers import trunc_normal_ -from cyto_dl.nn.vits.utils import take_indexes - - -def random_indexes(size: int, device): - forward_indexes = torch.randperm(size, device=device, dtype=torch.long) - backward_indexes = torch.argsort(forward_indexes) - return forward_indexes, backward_indexes +from cyto_dl.nn.vits.utils import random_indexes, take_indexes class Patchify(torch.nn.Module): diff --git a/cyto_dl/nn/vits/blocks/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify_hiera.py index 9bf4d888a..3acc3bcb7 100644 --- a/cyto_dl/nn/vits/blocks/patchify_hiera.py +++ b/cyto_dl/nn/vits/blocks/patchify_hiera.py @@ -3,24 +3,29 @@ import numpy as np import torch import torch.nn as nn -from einops.layers.torch import Rearrange from einops import repeat +from einops.layers.torch import Rearrange -from cyto_dl.nn.vits.utils import take_indexes +from cyto_dl.nn.vits.utils import random_indexes, take_indexes -def random_indexes(size: int, device): - forward_indexes = torch.randperm(size, device=device, dtype=torch.long) - backward_indexes = torch.argsort(forward_indexes) - return forward_indexes, backward_indexes def take_indexes_mask(sequences, indexes): - ''' + """ sequences: batch x mask units x patches x emb_dim indexes: mask_units x batch - ''' + """ # always gather across tokens dimension return torch.gather( - sequences, 1, repeat(indexes.to(sequences.device), "mu b -> b mu p c", b= sequences.shape[0], c=sequences.shape[-1], mu = sequences.shape[1], p=sequences.shape[2]) + sequences, + 1, + repeat( + indexes.to(sequences.device), + "mu b -> b mu p c", + b=sequences.shape[0], + c=sequences.shape[-1], + mu=sequences.shape[1], + p=sequences.shape[2], + ), ) @@ -28,30 +33,59 @@ class PatchifyHiera(torch.nn.Module): """Class for converting images to a masked sequence of patches with positional embeddings.""" def __init__( - self, + self, patch_size: List[int], n_patches: List[int], - mask_ratio: float = 0.8, - num_mask_units: List[int] = [8,8,8], + mask_ratio: float = 0.8, + num_mask_units: List[int] = [8, 8, 8], emb_dim: int = 64, - spatial_dims: int= 3, + spatial_dims: int = 3, context_pixels: List[int] = [0, 0, 0], ): + """ + Parameters + ---------- + patch_size: List[int] + Size of each patch in pix (ZYX order for 3D, YX order for 2D) + n_patches: List[int] + Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) + mask_ratio: float + Fraction of mask units to remove + num_mask_units: List[int] + Number of mask units in each spatial dimension (Z, Y, X) + emb_dim: int + Dimension of encoder + spatial_dims: int + Number of spatial dimensions + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension + """ super().__init__() self.spatial_dims = spatial_dims self.mask_ratio = mask_ratio self.total_n_mask_units = np.prod(num_mask_units) patches_per_mask_unit = np.prod(n_patches) // self.total_n_mask_units - self.pos_embedding = torch.nn.Parameter(torch.zeros(1, self.total_n_mask_units, patches_per_mask_unit, emb_dim)) + self.pos_embedding = torch.nn.Parameter( + torch.zeros(1, self.total_n_mask_units, patches_per_mask_unit, emb_dim) + ) self.num_mask_units = num_mask_units - self.num_selected_mask_units = int(self.total_n_mask_units*(1-mask_ratio)) + self.num_selected_mask_units = int(self.total_n_mask_units * (1 - mask_ratio)) # mu -> mask unit - self.mask2img = Rearrange("(n_mu_z n_mu_y n_mu_x) b c -> b c n_mu_z n_mu_y n_mu_x ", n_mu_z=num_mask_units[0], n_mu_y=num_mask_units[1], n_mu_x=num_mask_units[2]) - - self.img2mask_units = Rearrange('b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ', n_mu_z = num_mask_units[0], n_mu_y= num_mask_units[1], n_mu_x=num_mask_units[2]) + self.mask2img = Rearrange( + "(n_mu_z n_mu_y n_mu_x) b c -> b c n_mu_z n_mu_y n_mu_x ", + n_mu_z=num_mask_units[0], + n_mu_y=num_mask_units[1], + n_mu_x=num_mask_units[2], + ) + self.img2mask_units = Rearrange( + "b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ", + n_mu_z=num_mask_units[0], + n_mu_y=num_mask_units[1], + n_mu_x=num_mask_units[2], + ) context_pixels = context_pixels[:spatial_dims] weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int) @@ -73,7 +107,7 @@ def get_mask(self, img): mask = torch.zeros(self.total_n_mask_units, B, 1, device=img.device, dtype=torch.uint8) # visible patches are first - mask[:self.num_selected_mask_units] = 1 + mask[: self.num_selected_mask_units] = 1 mask = take_indexes(mask, backward_indexes) mask = self.mask2img(mask) # one pixel per masked patch, interpolate to size of input image @@ -83,10 +117,8 @@ def get_mask(self, img): return mask, forward_indexes, backward_indexes def forward(self, img): - """" - takes in BCZYX image - returns B x num_selected_mask_units x patches_per_mask_unit x emb_dim - """ + """" takes in BCZYX image returns B x num_selected_mask_units x patches_per_mask_unit x + emb_dim.""" mask = torch.ones_like(img) forward_indexes, backward_indexes = None, None if self.mask_ratio > 0: @@ -97,13 +129,6 @@ def forward(self, img): tokens = tokens + self.pos_embedding if self.mask_ratio > 0: - tokens = take_indexes_mask(tokens, forward_indexes)[:, :self.num_selected_mask_units] + tokens = take_indexes_mask(tokens, forward_indexes)[:, : self.num_selected_mask_units] mask = (1 - mask).bool() return tokens, mask, forward_indexes, backward_indexes - - - - - - - diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py index ba96abdaa..b9807f2e9 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -1,6 +1,6 @@ -# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 +# inspired by https://github.com/facebookresearch/hiera -from typing import List, Optional, Dict +from typing import Dict, List, Optional import numpy as np import torch @@ -9,17 +9,11 @@ from einops import rearrange from einops.layers.torch import Rearrange from timm.models.vision_transformer import Block -from monai.networks.blocks import UnetOutBlock, UnetResBlock, UpSample - from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock - from cyto_dl.nn.vits.blocks.patchify_hiera import PatchifyHiera -from cyto_dl.nn.vits.mae import MAE_Decoder from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder - -import torch.nn.functional as F -from einops import repeat +from cyto_dl.nn.vits.mae import MAE_Decoder class SpatialMerger(nn.Module): @@ -36,19 +30,20 @@ def __init__(self, downsample_factor, in_dim, out_dim): ) tokens2img = Rearrange( - "b n_mu (z y x) c -> (b n_mu) c z y x", z=downsample_factor[0], y=downsample_factor[1], x=downsample_factor[2] + "b n_mu (z y x) c -> (b n_mu) c z y x", + z=downsample_factor[0], + y=downsample_factor[1], + x=downsample_factor[2], ) - self.model = nn.Sequential( - tokens2img, - conv - ) - + self.model = nn.Sequential(tokens2img, conv) + def forward(self, x): b, n_mu, _, _ = x.shape x = self.model(x) x = rearrange(x, "(b n_mu) c z y x -> b n_mu (z y x) c", b=b, n_mu=n_mu) return x + class HieraEncoder(torch.nn.Module): def __init__( self, @@ -67,31 +62,45 @@ def __init__( ---------- num_patches: List[int] Number of patches in each dimension + num_mask_units: List[int] + Number of mask units in each dimension + architecture: List[Dict] + List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: + - repeat: int + Number of times to repeat the block + - num_heads: int + Number of heads in the multihead attention + - q_stride: List[int] + Stride for the query in each spatial dimension + - self_attention: bool + Whether to use self attention or mask unit attention + emb_dim: int + Dimension of embedding spatial_dims: int Number of spatial dimensions - base_patch_size: List[int] + patch_size: List[int] Size of each patch - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer + mask_ratio: float + Fraction of mask units to remove context_pixels: List[int] Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. - input_channels: int - Number of input channels - weight_intermediates: bool - Whether to output linear combination of intermediate layers as final output like CrossMAE + save_layers: bool + Whether to save the intermediate layer outputs """ super().__init__() - self.save_layers= save_layers + self.save_layers = save_layers self.patchify = PatchifyHiera( - patch_size, num_patches, mask_ratio, num_mask_units, emb_dim, spatial_dims, context_pixels + patch_size, + num_patches, + mask_ratio, + num_mask_units, + emb_dim, + spatial_dims, + context_pixels, ) patches_per_mask_unit = np.array(num_patches) // np.array(num_mask_units) - self.final_dim = emb_dim * (2**len(architecture)) + self.final_dim = emb_dim * (2 ** len(architecture)) self.save_block_idxs = [] self.save_block_dims = [] @@ -100,46 +109,54 @@ def __init__( num_blocks = 0 for stage_num, stage in enumerate(architecture): # use mask unit attention until first layer that uses self attention - if stage.get('self_attention', False): + if stage.get("self_attention", False): break print(f"Stage: {stage_num}") for block in range(stage["repeat"]): is_last = block == stage["repeat"] - 1 # do spatial pooling within mask unit on last block of stage - q_stride = stage['q_stride'] if is_last else [1] * spatial_dims + q_stride = stage["q_stride"] if is_last else [1] * spatial_dims # double embedding dimension in last block of stage dim_in = emb_dim * (2**stage_num) dim_out = dim_in if not is_last else dim_in * 2 - print(f"\tBlock {block}:\t\tdim_in: {dim_in}, dim_out: {dim_out}, num_heads: {stage['num_heads']}, q_stride: {q_stride}, patches_per_mask_unit: {patches_per_mask_unit}") + print( + f"\tBlock {block}:\t\tdim_in: {dim_in}, dim_out: {dim_out}, num_heads: {stage['num_heads']}, q_stride: {q_stride}, patches_per_mask_unit: {patches_per_mask_unit}" + ) transformer.append( HieraBlock( dim=dim_in, dim_out=dim_out, - heads=stage['num_heads'], - q_stride = q_stride, - patches_per_mask_unit = patches_per_mask_unit, + heads=stage["num_heads"], + q_stride=q_stride, + patches_per_mask_unit=patches_per_mask_unit, ) ) if is_last: # save the block before the spatial pooling unless it's the final stage - save_block = num_blocks -1 if stage_num < len(architecture) - 1 else num_blocks + save_block = ( + num_blocks - 1 if stage_num < len(architecture) - 1 else num_blocks + ) self.save_block_idxs.append(save_block) self.save_block_dims.append(dim_in) - + # create a spatial merger for combining tokens pre-downsampling, last stage doesn't need merging since it has expected num channels, spatial shape - self.spatial_mergers[f'block_{save_block}'] = SpatialMerger(patches_per_mask_unit, dim_in, self.final_dim) if stage_num < len(architecture) - 1 else torch.nn.Identity() + self.spatial_mergers[f"block_{save_block}"] = ( + SpatialMerger(patches_per_mask_unit, dim_in, self.final_dim) + if stage_num < len(architecture) - 1 + else torch.nn.Identity() + ) # at end of each layer, patches per mask unit is reduced as we pool spatially - patches_per_mask_unit = patches_per_mask_unit // np.array(stage['q_stride']) + patches_per_mask_unit = patches_per_mask_unit // np.array(stage["q_stride"]) num_blocks += 1 self.mask_unit_transformer = torch.nn.Sequential(*transformer) self.save_block_dims.append(self.final_dim) self.save_block_dims.reverse() self.self_attention_transformer = torch.nn.Sequential( - *[Block(self.final_dim, stage['num_heads']) for _ in range(stage['repeat'])] - ) + *[Block(self.final_dim, stage["num_heads"]) for _ in range(stage["repeat"])] + ) self.layer_norm = torch.nn.LayerNorm(self.final_dim) @@ -152,14 +169,14 @@ def forward(self, img): for i, block in enumerate(self.mask_unit_transformer): patches = block(patches) if i in self.save_block_idxs: - mask_unit_embeddings += self.spatial_mergers[f'block_{i}'](patches) + mask_unit_embeddings += self.spatial_mergers[f"block_{i}"](patches) if self.save_layers: save_layers.append(patches) # combine mask units and tokens for full self attention transformer mask_unit_embeddings = rearrange(mask_unit_embeddings, "b n_mu t c -> b (n_mu t) c") mask_unit_embeddings = self.self_attention_transformer(mask_unit_embeddings) - mask_unit_embeddings= self.layer_norm(mask_unit_embeddings) + mask_unit_embeddings = self.layer_norm(mask_unit_embeddings) return mask_unit_embeddings, mask, forward_indexes, backward_indexes, save_layers @@ -167,7 +184,7 @@ def forward(self, img): class HieraMAE(torch.nn.Module): def __init__( self, - architecture, + architecture: List[Dict], spatial_dims: int = 3, num_patches: Optional[List[int]] = [2, 32, 32], num_mask_units: Optional[List[int]] = [2, 12, 12], @@ -183,6 +200,36 @@ def __init__( """ Parameters ---------- + architecture: List[Dict] + List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: + - repeat: int + Number of times to repeat the block + - num_heads: int + Number of heads in the multihead attention + - q_stride: List[int] + Stride for the query in each spatial dimension + - self_attention: bool + Whether to use self attention or mask unit attention + spatial_dims: int + Number of spatial dimensions + num_patches: List[int] + Number of patches in each dimension + num_mask_units: List[int] + Number of mask units in each dimension + patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of embedding + decoder_layer: int + Number of layers in the decoder + decoder_head: int + Number of heads in the decoder + decoder_dim: int + Dimension of the decoder + mask_ratio: float + Fraction of mask units to remove + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. """ super().__init__() assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" @@ -193,9 +240,7 @@ def __init__( patch_size = [patch_size] * spatial_dims assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" - assert ( - len(patch_size) == spatial_dims - ), "patch_size must be of length spatial_dims" + assert len(patch_size) == spatial_dims, "patch_size must be of length spatial_dims" self.mask_ratio = mask_ratio @@ -210,7 +255,7 @@ def __init__( context_pixels=context_pixels, ) # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size - mask_unit_size = (np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units) + mask_unit_size = (np.array(num_patches) * np.array(patch_size)) / np.array(num_mask_units) decoder_class = MAE_Decoder if use_crossmae: @@ -224,7 +269,7 @@ def __init__( emb_dim=decoder_dim, num_layer=decoder_layer, num_head=decoder_head, - has_cls_token=False + has_cls_token=False, ) def forward(self, img): @@ -232,362 +277,3 @@ def forward(self, img): features = rearrange(features, "b t c -> t b c") predicted_img = self.decoder(features, forward_indexes, backward_indexes) return predicted_img, mask - - -class HieraSeg(torch.nn.Module): - def __init__( - self, - encoder_ckpt, - architecture, - spatial_dims: int = 3, - n_out_channels: int = 6, - num_patches: Optional[List[int]] = [2, 32, 32], - num_mask_units: Optional[List[int]] = [2, 12, 12], - patch_size: Optional[List[int]] = [16, 16, 16], - emb_dim: Optional[int] = 64, - context_pixels: Optional[List[int]] = [0, 0, 0], - ) -> None: - """ - Parameters - ---------- - """ - super().__init__() - assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" - - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(patch_size, int): - patch_size = [patch_size] * spatial_dims - - assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" - assert ( - len(patch_size) == spatial_dims - ), "patch_size must be of length spatial_dims" - - - self.encoder = HieraEncoder( - num_patches=num_patches, - num_mask_units=num_mask_units, - architecture=architecture, - emb_dim=emb_dim, - spatial_dims=spatial_dims, - patch_size=patch_size, - mask_ratio=0, - context_pixels=context_pixels, - save_layers=True - ) - # model = torch.load(encoder_ckpt, map_location="cuda:0") - # enc_state_dict = { - # k.replace("backbone.encoder.", ""): v - # for k, v in model["state_dict"].items() - # if "encoder" in k and "intermediate" not in k - # } - # self.encoder.load_state_dict(enc_state_dict, strict=False) - # for name, param in self.encoder.named_parameters(): - # # allow different weighting of internal activations for finetuning - # param.requires_grad = False - - # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size - mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) - - project_dim = np.prod(mask_unit_size)*16 - head = torch.nn.Linear(self.encoder.final_dim, project_dim) - norm = torch.nn.LayerNorm(project_dim) - patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_z=num_mask_units[0], - n_patch_y=num_mask_units[1], - n_patch_x=num_mask_units[2], - patch_size_z=mask_unit_size[0], - patch_size_y=mask_unit_size[1], - patch_size_x=mask_unit_size[2], - ) - self.patch2img = torch.nn.Sequential(head, norm, patch2img) - - self.upsample = torch.nn.Sequential( - *[ - UpSample( - spatial_dims=spatial_dims, - in_channels=16, - out_channels=16, - scale_factor=[2.6134, 2.5005, 2.5005], - mode="nontrainable", - interp_mode="trilinear", - ), - UnetResBlock( - spatial_dims=spatial_dims, - in_channels=16, - out_channels=16, - stride=1, - kernel_size=3, - norm_name="INSTANCE", - dropout=0, - ), - UnetOutBlock( - spatial_dims=spatial_dims, - in_channels=16, - out_channels=n_out_channels, - dropout=0, - ), - ] - ) - - def forward(self, img): - features, _, _, _, save_layers = self.encoder(img) - features = rearrange(features, "b t c -> t b c") - predicted_img = self.patch2img(features) - predicted_img = self.upsample(predicted_img) - return predicted_img - - - - -class Mlp(nn.Module): - def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x - - -class CrossAttention(nn.Module): - def __init__( - self, - encoder_dim, - decoder_dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): - super().__init__() - self.num_heads = num_heads - head_dim = decoder_dim // num_heads - self.scale = qk_scale or head_dim**-0.5 - self.q = nn.Linear(decoder_dim, decoder_dim, bias=qkv_bias) - self.kv = nn.Linear(encoder_dim, decoder_dim * 2, bias=qkv_bias) - self.attn_drop = attn_drop - self.proj = nn.Linear(decoder_dim, decoder_dim) - self.proj_drop = nn.Dropout(proj_drop) - - self.decoder_dim = decoder_dim - - def forward(self, x, y, mask): - """ x: queries y: values, mask: mask for attention""" - B, N, C = x.shape - Ny = y.shape[1] - q = self.q(x).reshape(B, N, self.num_heads, self.decoder_dim // self.num_heads).permute(0, 2, 1, 3) - kv = ( - self.kv(y) - .reshape(B, Ny, 2, self.num_heads, self.decoder_dim // self.num_heads) - .permute(2, 0, 3, 1, 4) - ) - k, v = kv[0], kv[1] - - attn = F.scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.attn_drop, - attn_mask=mask>0 if mask is not None else None, - ) - x = attn.transpose(1, 2).reshape(B, N, self.decoder_dim) - - x = self.proj(x) - x = self.proj_drop(x) - return x -from timm.models.vision_transformer import Attention - -class Mask2FormerBlock(nn.Module): - """ - cross attention provides one mask per query - self attention refines mask - repeat for each block - """ - def __init__( - self, - encoder_dim, - decoder_dim, - num_heads, - num_patches, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - ): - super().__init__() - self.norm1 = norm_layer(decoder_dim) - - self.scale_positional_embedding = nn.Parameter(torch.zeros(1, np.prod(num_patches), encoder_dim)) - - self.self_attn_block = Attention( - dim=decoder_dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - attn_drop=attn_drop, - proj_drop=drop, - ) - - self.cross_attn = CrossAttention( - encoder_dim, - decoder_dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - ) - self.norm2 = norm_layer(decoder_dim) - mlp_hidden_dim = int(decoder_dim * mlp_ratio) - self.mlp = Mlp( - in_features=decoder_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop - ) - - def forward(self, instance_queries, image_feats, mask): - """ - x: query features, y: image features, mask: previous mask prediction - """ - image_feats = image_feats + self.scale_positional_embedding - - instance_queries = self.norm1(instance_queries + self.cross_attn(instance_queries, image_feats, mask)) - instance_queries = instance_queries + self.self_attn_block(instance_queries) - instance_queries = instance_queries + self.mlp(self.norm2(instance_queries)) - return instance_queries - -class HieraMask2Former(torch.nn.Module): - def __init__( - self, - encoder_ckpt, - architecture, - spatial_dims: int = 3, - num_queries: int = 50, - num_patches: Optional[List[int]] = [2, 32, 32], - num_mask_units: Optional[List[int]] = [2, 12, 12], - patch_size: Optional[List[int]] = [16, 16, 16], - emb_dim: Optional[int] = 64, - context_pixels: Optional[List[int]] = [0, 0, 0], - mask2former_dim: Optional[int] = 128, - ) -> None: - """ - Parameters - ---------- - """ - super().__init__() - assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" - - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(patch_size, int): - patch_size = [patch_size] * spatial_dims - - assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" - assert ( - len(patch_size) == spatial_dims - ), "patch_size must be of length spatial_dims" - self.num_mask_units = num_mask_units - - self.encoder = HieraEncoder( - num_patches=num_patches, - num_mask_units=num_mask_units, - architecture=architecture, - emb_dim=emb_dim, - spatial_dims=spatial_dims, - patch_size=patch_size, - mask_ratio=0, - context_pixels=context_pixels, - ) - model = torch.load(encoder_ckpt, map_location="cuda:0") - enc_state_dict = { - k.replace("backbone.encoder.", ""): v - for k, v in model["state_dict"].items() - if "encoder" in k and "intermediate" not in k - } - self.encoder.load_state_dict(enc_state_dict, strict=False) - for name, param in self.encoder.named_parameters(): - param.requires_grad = False - # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size - mask_unit_size = ((np.array(num_patches) * np.array(patch_size))/np.array(num_mask_units)).astype(int) - - self.instance_queries = torch.nn.Parameter(torch.zeros(1, num_queries, mask2former_dim)) - #s. Object query features are only used as the initial - # input to the Transformer decoder and are updated through - # decoder layers; whereas query positional embeddings are - # added to query features in every Transformer decoder layer - # when computing the attention weights. - self.instance_queries_pos_emb = torch.nn.Parameter(torch.zeros(1, num_queries, mask2former_dim)) - - q_strides = [np.array(stage['q_stride']) for stage in architecture if stage.get('q_stride', False)] - patches_per_mask_unit = [np.array(num_patches) // np.array(num_mask_units)] - for qs in q_strides: - patches_per_mask_unit.append(patches_per_mask_unit[-1] // qs) - patches_per_mask_unit.reverse() - patches_per_mask_unit[0] = np.array([1,1,1]) - self.patches_per_mask_unit = patches_per_mask_unit - - self.transformer = torch.nn.ModuleList([Mask2FormerBlock(encoder_dim = self.encoder.save_block_dims[i], decoder_dim = mask2former_dim, num_heads = 4, num_patches = np.prod(patches_per_mask_unit[i] * num_mask_units)) for i in range(len(patches_per_mask_unit))]) - - self.image_feature_projector = torch.nn.ModuleList([torch.nn.Linear(self.encoder.save_block_dims[i], mask2former_dim) for i in range(len(patches_per_mask_unit))]) - - - def get_mask(self,i, image_features, instance_queries): - # this is a token mask, not a spatial mask - # 1.change channels for image feature to match mask_dim - # 2. multiply query features by mask - image_features= self.image_feature_projector[i](image_features) - mask = torch.matmul(instance_queries, image_features.transpose(1,2)) - return mask - - - - def forward(self, img): - #features are b x t x c - features, _, _, _, save_layers = self.encoder(img) - save_layers.append(features.unsqueeze(2)) - save_layers.reverse() - # start with lowest resolution - # first mask should be prediction from query features alone - mask = None# should create a mask here using raw instance queries - instance_queries = self.instance_queries - breakpoint() - # TODO check ppmu is in the correct order - for i, (layer, ppmu) in enumerate(zip(save_layers, self.patches_per_mask_unit)): - print(f"Layer {i}: {layer.shape}") - layer = rearrange(layer, 'b n_mu mu_dims c -> b (n_mu mu_dims) c') - - upsample = [1,1,1] - if mask is not None: - # upsample mask to match current resolution - upsample = ppmu//self.patches_per_mask_unit[i-1] - mask = repeat(mask, 'b n_queries n_mu_z n_mu_y n_mu_x -> b n_queries (n_mu_z mu_z) (n_mu_y mu_y) (n_mu_x mu_x)', mu_z=upsample[0], mu_y=upsample[1], mu_x=upsample[2]) - # convert mask from image to tokens - mask = rearrange(mask,'b n_queries n_mu_z n_mu_y n_mu_x -> b n_queries (n_mu_z n_mu_y n_mu_x)') - - instance_queries = self.transformer[i](instance_queries + self.instance_queries_pos_emb, layer, mask) - print(f"Instance queries {i}: {instance_queries.shape}") - # upsample to next resolution - # loss is calculated on each intermediate mask as well - mask = self.get_mask(i, layer, instance_queries) - mask = rearrange(mask, 'b n_queries (n_mu_z mu_z n_mu_y mu_y n_mu_x mu_x) -> b n_queries (n_mu_z mu_z) (n_mu_y mu_y) (n_mu_x mu_x)', n_mu_z = self.num_mask_units[0], n_mu_y = self.num_mask_units[1], n_mu_x = self.num_mask_units[1], mu_z = upsample[0], mu_y=upsample[1], mu_x = upsample[2]) - - print(f"Mask {i}: {mask.shape}") - print() - - return mask \ No newline at end of file diff --git a/cyto_dl/nn/vits/utils.py b/cyto_dl/nn/vits/utils.py index 61263ccd0..a5595a1a7 100644 --- a/cyto_dl/nn/vits/utils.py +++ b/cyto_dl/nn/vits/utils.py @@ -4,3 +4,9 @@ def take_indexes(sequences, indexes): return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1])) + + +def random_indexes(size: int, device): + forward_indexes = torch.randperm(size, device=device, dtype=torch.long) + backward_indexes = torch.argsort(forward_indexes) + return forward_indexes, backward_indexes From a773f3b2364aefa207a4eb69dde1ffc798aa4ae0 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Wed, 7 Aug 2024 16:32:15 -0700 Subject: [PATCH 09/27] update to base patchify --- cyto_dl/nn/vits/blocks/patchify/__init__.py | 2 + .../patchify_base.py} | 96 +++++++++---- .../nn/vits/blocks/patchify/patchify_hiera.py | 100 +++++++++++++ cyto_dl/nn/vits/blocks/patchify_hiera.py | 134 ------------------ 4 files changed, 168 insertions(+), 164 deletions(-) create mode 100644 cyto_dl/nn/vits/blocks/patchify/__init__.py rename cyto_dl/nn/vits/blocks/{patchify.py => patchify/patchify_base.py} (71%) create mode 100644 cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py delete mode 100644 cyto_dl/nn/vits/blocks/patchify_hiera.py diff --git a/cyto_dl/nn/vits/blocks/patchify/__init__.py b/cyto_dl/nn/vits/blocks/patchify/__init__.py new file mode 100644 index 000000000..0d52db981 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify/__init__.py @@ -0,0 +1,2 @@ +from .patchify_base import Patchify +from .patchify_hiera import PatchifyHiera \ No newline at end of file diff --git a/cyto_dl/nn/vits/blocks/patchify.py b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py similarity index 71% rename from cyto_dl/nn/vits/blocks/patchify.py rename to cyto_dl/nn/vits/blocks/patchify/patchify_base.py index d029e152f..c9eef0e73 100644 --- a/cyto_dl/nn/vits/blocks/patchify.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py @@ -42,23 +42,58 @@ def __init__( """ super().__init__() self.n_patches = np.asarray(n_patches) + + if spatial_dims not in (2, 3): + raise ValueError("Only 2D and 3D images are supported") self.spatial_dims = spatial_dims self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(n_patches), 1, emb_dim)) - context_pixels = context_pixels[:spatial_dims] + self.patch2img = self.create_patch2img(n_patches, patch_size) + self.img2token = self.create_img2token() + self.conv = self.create_conv(input_channels, emb_dim, patch_size, context_pixels) + + self.task_embedding = torch.nn.ParameterDict( + {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} + ) + self._init_weight() + + def create_conv(self, input_channels, emb_dim, patch_size, context_pixels): + context_pixels = context_pixels[:self.spatial_dims] weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int) - if spatial_dims == 3: - self.conv = nn.Conv3d( + if self.spatial_dims == 3: + return nn.Conv3d( in_channels=input_channels, out_channels=emb_dim, kernel_size=weight_size, stride=patch_size, padding=context_pixels, ) - self.img2token = Rearrange("b c z y x -> (z y x) b c") - self.patch2img = torch.nn.Sequential( + elif self.spatial_dims == 2: + return nn.Conv2d( + in_channels=input_channels, + out_channels=emb_dim, + kernel_size=weight_size, + stride=patch_size, + padding=context_pixels, + ) + + def create_img2token(self): + """ + Rearranges the image tensor to a sequence of patches + """ + if self.spatial_dims == 3: + return Rearrange("b c z y x -> (z y x) b c") + elif self.spatial_dims == 2: + return Rearrange("b c y x -> (y x) b c") + + def create_patch2img(self, n_patches, patch_size): + """ + Converts boolean array of whether to keep index of each patch to an image-shaped mask of same size as input image + """ + if self.spatial_dims == 3: + return torch.nn.Sequential( *[ Rearrange( "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", @@ -75,17 +110,8 @@ def __init__( ), ] ) - - elif spatial_dims == 2: - self.conv = nn.Conv2d( - in_channels=input_channels, - out_channels=emb_dim, - kernel_size=weight_size, - stride=patch_size, - padding=context_pixels, - ) - self.img2token = Rearrange("b c y x -> (y x) b c") - self.patch2img = torch.nn.Sequential( + elif self.spatial_dims == 2: + return torch.nn.Sequential( *[ Rearrange( "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", @@ -100,17 +126,20 @@ def __init__( ), ] ) - self.task_embedding = torch.nn.ParameterDict( - {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} - ) - self._init_weight() + def _init_weight(self): trunc_normal_(self.pos_embedding, std=0.02) for task in self.task_embedding: trunc_normal_(self.task_embedding[task], std=0.02) - def get_mask(self, img, n_visible_patches, num_patches): + def get_mask_args(self, mask_ratio): + num_patches = np.prod(self.n_patches) + n_visible_patches = int(num_patches * (1 - mask_ratio)) + return n_visible_patches, num_patches + + def get_mask(self, img, mask_ratio, n_visible_patches, num_patches): + B = img.shape[0] indexes = [random_indexes(num_patches, img.device) for _ in range(B)] @@ -127,23 +156,30 @@ def get_mask(self, img, n_visible_patches, num_patches): mask = self.patch2img(mask) return mask, forward_indexes, backward_indexes - + + def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): + return take_indexes(tokens, forward_indexes)[:n_visible_patches] + def forward(self, img, mask_ratio, task=None): # generate mask - num_patches = np.prod(self.n_patches) - n_visible_patches = int(num_patches * (1 - mask_ratio)) - mask, forward_indexes, backward_indexes = self.get_mask( - img, n_visible_patches, num_patches - ) + mask = torch.ones_like(img) + forward_indexes, backward_indexes = None, None + if mask_ratio > 0: + n_visible_patches, num_patches = self.get_mask_args(mask_ratio) + mask, forward_indexes, backward_indexes = self.get_mask(img, mask_ratio, n_visible_patches, num_patches) + # generate patches tokens = self.conv(img * mask) tokens = self.img2token(tokens) + # add position embedding tokens = tokens + self.pos_embedding + + # extract visible patches if mask_ratio > 0: - # extract visible patches - tokens = take_indexes(tokens, forward_indexes)[:n_visible_patches] - + tokens = self.extract_visible_tokens(tokens, forward_indexes, n_visible_patches) + + # add task embedding if task in self.task_embedding: tokens = tokens + self.task_embedding[task] diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py new file mode 100644 index 000000000..201315b27 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py @@ -0,0 +1,100 @@ +from typing import List, Optional + +import numpy as np +import torch +from einops import repeat +from einops.layers.torch import Rearrange + +from cyto_dl.nn.vits.blocks.patchify.patchify_base import Patchify + + +def take_indexes_mask(sequences, indexes): + """ + sequences: batch x mask units x patches x emb_dim + indexes: mask_units x batch + """ + # always gather across tokens dimension + return torch.gather( + sequences, + 1, + repeat( + indexes, + "mu b -> b mu p c", + b=sequences.shape[0], + c=sequences.shape[-1], + mu=sequences.shape[1], + p=sequences.shape[2], + ), + ) + + +class PatchifyHiera(Patchify): + """Class for converting images to a masked sequence of patches with positional embeddings.""" + + def __init__( + self, + patch_size: List[int], + emb_dim: int = 64, + n_patches: List[int], + spatial_dims: int = 3, + context_pixels: List[int] = [0, 0, 0], + input_channels: int = 1, + tasks: Optional[List[str]] = [], + mask_units_per_dim: List[int] = [8, 8, 8], + ): + """ + patch_size: List[int] + Size of each patch in pix (ZYX order for 3D, YX order for 2D) + emb_dim: int + Dimension of encoder + n_patches: List[int] + Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) + spatial_dims: int + Number of spatial dimensions + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + tasks: List[str] + List of tasks to encode + mask_units_per_dim: List[int] + Number of mask units in each spatial dimension (ZYX order for 3D, YX order for 2D) + """ + super().__init__(patch_size, emb_dim, n_patches, spatial_dims, context_pixels, input_channels, tasks) + + self.total_n_mask_units = np.prod(mask_units_per_dim) + patches_per_mask_unit = n_patches // self.total_n_mask_units + self.pos_embedding = torch.nn.Parameter( + torch.zeros(1, self.total_n_mask_units, np.prod(patches_per_mask_unit), emb_dim) + ) + + # redefine this to work with mask units instead of patches + self.img2token = self.create_img2token(mask_units_per_dim) + + mask_unit_size_pix = patches_per_mask_unit * patch_size + self.patch2img = self.create_patch2img(mask_units_per_dim, mask_unit_size_pix) + + + def create_img2token(self, mask_units_per_dim): + if self.spatial_dims == 3: + return Rearrange( + "b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ", + n_mu_z=mask_units_per_dim[0], + n_mu_y=mask_units_per_dim[1], + n_mu_x=mask_units_per_dim[2], + ) + elif self.spatial_dims == 2: + return Rearrange( + "b c (n_mu_y y) (n_mu_x x) -> b (n_mu_y n_mu_x) (y x) c ", + n_mu_y=mask_units_per_dim[1], + n_mu_x=mask_units_per_dim[2], + ) + + # in hiera, the level of masking is at the mask unit, not the patch level + def get_mask_args(self, mask_ratio): + n_visible_patches = int(total_n_mask_units * (1 - mask_ratio)) + return n_visible_patches, self.total_n_mask_units + + def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): + return take_indexes_mask(tokens, forward_indexes)[:, :n_visible_patches] + diff --git a/cyto_dl/nn/vits/blocks/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify_hiera.py deleted file mode 100644 index 3acc3bcb7..000000000 --- a/cyto_dl/nn/vits/blocks/patchify_hiera.py +++ /dev/null @@ -1,134 +0,0 @@ -from typing import List - -import numpy as np -import torch -import torch.nn as nn -from einops import repeat -from einops.layers.torch import Rearrange - -from cyto_dl.nn.vits.utils import random_indexes, take_indexes - - -def take_indexes_mask(sequences, indexes): - """ - sequences: batch x mask units x patches x emb_dim - indexes: mask_units x batch - """ - # always gather across tokens dimension - return torch.gather( - sequences, - 1, - repeat( - indexes.to(sequences.device), - "mu b -> b mu p c", - b=sequences.shape[0], - c=sequences.shape[-1], - mu=sequences.shape[1], - p=sequences.shape[2], - ), - ) - - -class PatchifyHiera(torch.nn.Module): - """Class for converting images to a masked sequence of patches with positional embeddings.""" - - def __init__( - self, - patch_size: List[int], - n_patches: List[int], - mask_ratio: float = 0.8, - num_mask_units: List[int] = [8, 8, 8], - emb_dim: int = 64, - spatial_dims: int = 3, - context_pixels: List[int] = [0, 0, 0], - ): - """ - Parameters - ---------- - patch_size: List[int] - Size of each patch in pix (ZYX order for 3D, YX order for 2D) - n_patches: List[int] - Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) - mask_ratio: float - Fraction of mask units to remove - num_mask_units: List[int] - Number of mask units in each spatial dimension (Z, Y, X) - emb_dim: int - Dimension of encoder - spatial_dims: int - Number of spatial dimensions - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension - """ - super().__init__() - self.spatial_dims = spatial_dims - self.mask_ratio = mask_ratio - self.total_n_mask_units = np.prod(num_mask_units) - patches_per_mask_unit = np.prod(n_patches) // self.total_n_mask_units - self.pos_embedding = torch.nn.Parameter( - torch.zeros(1, self.total_n_mask_units, patches_per_mask_unit, emb_dim) - ) - - self.num_mask_units = num_mask_units - self.num_selected_mask_units = int(self.total_n_mask_units * (1 - mask_ratio)) - - # mu -> mask unit - self.mask2img = Rearrange( - "(n_mu_z n_mu_y n_mu_x) b c -> b c n_mu_z n_mu_y n_mu_x ", - n_mu_z=num_mask_units[0], - n_mu_y=num_mask_units[1], - n_mu_x=num_mask_units[2], - ) - - self.img2mask_units = Rearrange( - "b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ", - n_mu_z=num_mask_units[0], - n_mu_y=num_mask_units[1], - n_mu_x=num_mask_units[2], - ) - - context_pixels = context_pixels[:spatial_dims] - weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int) - self.conv = nn.Conv3d( - in_channels=1, - out_channels=emb_dim, - kernel_size=weight_size, - stride=patch_size, - padding=context_pixels, - ) - - def get_mask(self, img): - B = img.shape[0] - indexes = [random_indexes(self.total_n_mask_units, device=img.device) for _ in range(B)] - # forward indexes : index in image -> shuffledpatch - forward_indexes = torch.stack([i[0] for i in indexes], axis=-1) - # backward indexes : shuffled patch -> index in image - backward_indexes = torch.stack([i[1] for i in indexes], axis=-1) - - mask = torch.zeros(self.total_n_mask_units, B, 1, device=img.device, dtype=torch.uint8) - # visible patches are first - mask[: self.num_selected_mask_units] = 1 - mask = take_indexes(mask, backward_indexes) - mask = self.mask2img(mask) - # one pixel per masked patch, interpolate to size of input image - mask = torch.nn.functional.interpolate( - mask, img.shape[-self.spatial_dims :], mode="nearest" - ) - return mask, forward_indexes, backward_indexes - - def forward(self, img): - """" takes in BCZYX image returns B x num_selected_mask_units x patches_per_mask_unit x - emb_dim.""" - mask = torch.ones_like(img) - forward_indexes, backward_indexes = None, None - if self.mask_ratio > 0: - mask, forward_indexes, backward_indexes = self.get_mask(img) - tokens = self.conv(img * mask) - # break into batch x mask units x patches permask unit x emb_dim - tokens = self.img2mask_units(tokens) - - tokens = tokens + self.pos_embedding - if self.mask_ratio > 0: - tokens = take_indexes_mask(tokens, forward_indexes)[:, : self.num_selected_mask_units] - mask = (1 - mask).bool() - return tokens, mask, forward_indexes, backward_indexes From 23f4b943af4c3e672686419439ee38834ada30f2 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Thu, 15 Aug 2024 16:58:59 -0700 Subject: [PATCH 10/27] wip --- .../nn/vits/blocks/patchify/patchify_base.py | 6 +- .../nn/vits/blocks/patchify/patchify_hiera.py | 17 ++-- cyto_dl/nn/vits/cross_mae.py | 79 +++---------------- cyto_dl/nn/vits/hiera_mae.py | 8 +- cyto_dl/nn/vits/mae.py | 54 ++++++++----- 5 files changed, 64 insertions(+), 100 deletions(-) diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py index 36da975a4..97ee25bdf 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py @@ -54,7 +54,6 @@ def __init__( ) self.patch2img = self.create_patch2img(n_patches, patch_size) - self.img2token = self.create_img2token() self.conv = self.create_conv(input_channels, emb_dim, patch_size, context_pixels) self.task_embedding = torch.nn.ParameterDict( @@ -141,8 +140,7 @@ def get_mask_args(self, mask_ratio): n_visible_patches = int(num_patches * (1 - mask_ratio)) return n_visible_patches, num_patches - def get_mask(self, img, mask_ratio, n_visible_patches, num_patches): - + def get_mask(self, img, n_visible_patches, num_patches): B = img.shape[0] indexes = [random_indexes(num_patches, img.device) for _ in range(B)] @@ -169,7 +167,7 @@ def forward(self, img, mask_ratio, task=None): forward_indexes, backward_indexes = None, None if mask_ratio > 0: n_visible_patches, num_patches = self.get_mask_args(mask_ratio) - mask, forward_indexes, backward_indexes = self.get_mask(img, mask_ratio, n_visible_patches, num_patches) + mask, forward_indexes, backward_indexes = self.get_mask(img, n_visible_patches, num_patches) # generate patches tokens = self.conv(img * mask) diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py index 201315b27..d61905608 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py @@ -4,6 +4,7 @@ import torch from einops import repeat from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_ from cyto_dl.nn.vits.blocks.patchify.patchify_base import Patchify @@ -34,8 +35,8 @@ class PatchifyHiera(Patchify): def __init__( self, patch_size: List[int], - emb_dim: int = 64, n_patches: List[int], + emb_dim: int = 64, spatial_dims: int = 3, context_pixels: List[int] = [0, 0, 0], input_channels: int = 1, @@ -45,10 +46,10 @@ def __init__( """ patch_size: List[int] Size of each patch in pix (ZYX order for 3D, YX order for 2D) - emb_dim: int - Dimension of encoder n_patches: List[int] Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) + emb_dim: int + Dimension of encoder spatial_dims: int Number of spatial dimensions context_pixels: List[int] @@ -74,8 +75,14 @@ def __init__( mask_unit_size_pix = patches_per_mask_unit * patch_size self.patch2img = self.create_patch2img(mask_units_per_dim, mask_unit_size_pix) + self._init_weight() + + def _init_weight(self): + trunc_normal_(self.pos_embedding, std=0.02) - def create_img2token(self, mask_units_per_dim): + def create_img2token(self, mask_units_per_dim=None): + if mask_units_per_dim is None: + return if self.spatial_dims == 3: return Rearrange( "b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ", @@ -92,7 +99,7 @@ def create_img2token(self, mask_units_per_dim): # in hiera, the level of masking is at the mask unit, not the patch level def get_mask_args(self, mask_ratio): - n_visible_patches = int(total_n_mask_units * (1 - mask_ratio)) + n_visible_patches = int(self.total_n_mask_units * (1 - mask_ratio)) return n_visible_patches, self.total_n_mask_units def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): diff --git a/cyto_dl/nn/vits/cross_mae.py b/cyto_dl/nn/vits/cross_mae.py index 3981de6e7..3dc7e68f2 100644 --- a/cyto_dl/nn/vits/cross_mae.py +++ b/cyto_dl/nn/vits/cross_mae.py @@ -6,11 +6,12 @@ from einops.layers.torch import Rearrange from timm.models.layers import trunc_normal_ +from cyto_dl.nn.vits.mae import MAE_Decoder from cyto_dl.nn.vits.blocks import CrossAttentionBlock from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes -class CrossMAE_Decoder(torch.nn.Module): +class CrossMAE_Decoder(MAE_decoder): """Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend to visible tokens.""" @@ -23,6 +24,7 @@ def __init__( emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, num_head: Optional[int] = 3, + has_cls_token: Optional[bool] = True, learnable_pos_embedding: Optional[bool] = True, ) -> None: """ @@ -40,10 +42,12 @@ def __init__( Number of transformer layers num_head: int Number of heads in transformer + has_cls_token: bool + Whether encoder features have a cls token learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. """ - super().__init__() + super().__init__(num_patches, spatial_dims, base_patch_size, enc_dim, emb_dim, num_layer, num_head, has_cls_token, learnable_pos_embedding) self.transformer = torch.nn.ParameterList( [ @@ -55,42 +59,6 @@ def __init__( for _ in range(num_layer) ] ) - self.decoder_norm = nn.LayerNorm(emb_dim) - self.projection_norm = nn.LayerNorm(emb_dim) - - self.projection = torch.nn.Linear(enc_dim, emb_dim) - self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - - self.pos_embedding = get_positional_embedding( - num_patches, emb_dim, learnable=learnable_pos_embedding - ) - - self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) - self.num_patches = torch.as_tensor(num_patches) - - if spatial_dims == 3: - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_z=num_patches[0], - n_patch_y=num_patches[1], - n_patch_x=num_patches[2], - patch_size_z=base_patch_size[0], - patch_size_y=base_patch_size[1], - patch_size_x=base_patch_size[2], - ) - elif spatial_dims == 2: - self.patch2img = Rearrange( - "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_y=num_patches[0], - n_patch_x=num_patches[1], - patch_size_y=base_patch_size[0], - patch_size_x=base_patch_size[1], - ) - - self.init_weight() - - def init_weight(self): - trunc_normal_(self.mask_token, std=0.02) def forward(self, features, forward_indexes, backward_indexes): # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers @@ -100,35 +68,10 @@ def forward(self, features, forward_indexes, backward_indexes): # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now features = self.projection_norm(self.projection(features)) - # add cls token - backward_indexes = torch.cat( - [ - torch.zeros( - 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long - ), - backward_indexes + 1, - ], - dim=0, - ) - forward_indexes = torch.cat( - [ - torch.zeros( - 1, forward_indexes.shape[1], device=forward_indexes.device, dtype=torch.long - ), - forward_indexes + 1, - ], - dim=0, - ) - # fill in masked regions - features = torch.cat( - [ - features, - self.mask_token.expand( - backward_indexes.shape[0] - features.shape[0], features.shape[1], -1 - ), - ], - dim=0, - ) + backward_indexes = self.adjust_indices_for_cls(backward_indexes) + forward_indexes = self.adjust_indices_for_cls(forward_indexes) + + features = self.add_mask_tokens(features, backward_indexes) # unshuffle to original positions for positional embedding so we can do cross attention during decoding features = take_indexes(features, backward_indexes) @@ -164,7 +107,7 @@ def forward(self, features, forward_indexes, backward_indexes): ], dim=0, ) - patches = take_indexes(patches, backward_indexes[1:] - 1) + patches = take_indexes(patches, backward_indexes[1:] - 1) if self.has_cls_token else take_indexes(patches, backward_indexes) # patches to image img = self.patch2img(patches) return img diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/hiera_mae.py index b9807f2e9..3c8c0001d 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/hiera_mae.py @@ -11,7 +11,7 @@ from timm.models.vision_transformer import Block from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock -from cyto_dl.nn.vits.blocks.patchify_hiera import PatchifyHiera +from cyto_dl.nn.vits.blocks.patchify import PatchifyHiera from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder from cyto_dl.nn.vits.mae import MAE_Decoder @@ -88,15 +88,15 @@ def __init__( Whether to save the intermediate layer outputs """ super().__init__() + self.mask_ratio = mask_ratio self.save_layers = save_layers self.patchify = PatchifyHiera( patch_size, num_patches, - mask_ratio, - num_mask_units, emb_dim, spatial_dims, context_pixels, + mask_units_per_dim=num_mask_units, ) patches_per_mask_unit = np.array(num_patches) // np.array(num_mask_units) @@ -161,7 +161,7 @@ def __init__( self.layer_norm = torch.nn.LayerNorm(self.final_dim) def forward(self, img): - patches, mask, forward_indexes, backward_indexes = self.patchify(img) + patches, mask, forward_indexes, backward_indexes = self.patchify(img, self.mask_ratio) # mask unit attention mask_unit_embeddings = 0.0 diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index 1617bf687..92d94eb56 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -2,7 +2,6 @@ from typing import List, Optional -import numpy as np import torch import torch.nn as nn from einops import rearrange @@ -107,6 +106,7 @@ def __init__( emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, num_head: Optional[int] = 3, + has_cls_token: Optional[bool] = False, learnable_pos_embedding: Optional[bool] = True, ) -> None: """ @@ -124,10 +124,14 @@ def __init__( Number of transformer layers num_head: int Number of heads in transformer + has_cls_token: bool + Whether encoder features have a cls token learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() + self.has_cls_token = has_cls_token + self.projection_norm = nn.LayerNorm(emb_dim) self.projection = torch.nn.Linear(enc_dim, emb_dim) self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) @@ -140,7 +144,7 @@ def __init__( *[Block(emb_dim, num_head) for _ in range(num_layer)] ) out_dim = torch.prod(torch.as_tensor(base_patch_size)).item() - self.head_norm = nn.LayerNorm(out_dim) + self.decoder_norm = nn.LayerNorm(emb_dim) self.head = torch.nn.Linear(emb_dim, out_dim) self.num_patches = torch.as_tensor(num_patches) @@ -168,21 +172,22 @@ def __init__( def init_weight(self): trunc_normal_(self.mask_token, std=0.02) - def forward(self, features, forward_indexes, backward_indexes): - # project from encoder dimension to decoder dimension - features = self.projection_norm(self.projection(features)) - - backward_indexes = torch.cat( - [ - torch.zeros( - 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long - ), - backward_indexes + 1, - ], - dim=0, - ) - # fill in masked regions - features = torch.cat( + def adjust_indices_for_cls(self, indexes): + if self.has_cls_token: + return torch.cat( + [ + torch.zeros( + 1, indexes.shape[1], device=indexes.device, dtype=torch.long + ), + indexes + 1, + ], + dim=0, + ) + return indexes + + def add_mask_tokens(self, features, backward_indexes): + # fill in deleted masked regions with mask token + return torch.cat( [ features, self.mask_token.expand( @@ -191,6 +196,15 @@ def forward(self, features, forward_indexes, backward_indexes): ], dim=0, ) + + def forward(self, features, forward_indexes, backward_indexes): + # project from encoder dimension to decoder dimension + features = self.projection_norm(self.projection(features)) + + backward_indexes = self.adjust_indices_for_cls(backward_indexes) + + features = self.add_mask_tokens(features, backward_indexes) + # unshuffle to original positions features = take_indexes(features, backward_indexes) features = features + self.pos_embedding @@ -199,10 +213,12 @@ def forward(self, features, forward_indexes, backward_indexes): features = rearrange(features, "t b c -> b t c") features = self.transformer(features) features = rearrange(features, "b t c -> t b c") - features = features[1:] # remove global feature + + if self.has_cls_token: + features = features[1:] # remove global feature # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) - patches = self.head_norm(self.head(features)) + patches = self.head(self.decoder_norm(features)) # patches to image img = self.patch2img(patches) From 357166f47c155279a1956e316ed054868c858a4f Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Fri, 16 Aug 2024 16:16:24 -0700 Subject: [PATCH 11/27] update configs --- configs/experiment/im2im/hiera.yaml | 55 +++++++++++++++ configs/model/im2im/hiera.yaml | 69 +++++++++++++++++++ configs/model/im2im/mae.yaml | 6 +- .../model/im2im/vit_segmentation_decoder.yaml | 4 +- 4 files changed, 129 insertions(+), 5 deletions(-) create mode 100644 configs/experiment/im2im/hiera.yaml create mode 100644 configs/model/im2im/hiera.yaml diff --git a/configs/experiment/im2im/hiera.yaml b/configs/experiment/im2im/hiera.yaml new file mode 100644 index 000000000..2363a947f --- /dev/null +++ b/configs/experiment/im2im/hiera.yaml @@ -0,0 +1,55 @@ +# @package _global_ +# to execute this experiment run: +# python train.py experiment=example +defaults: + - override /data: im2im/mae.yaml + - override /model: im2im/hiera.yaml + - override /callbacks: default.yaml + - override /trainer: gpu.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["dev"] +seed: 12345 + +experiment_name: YOUR_EXP_NAME +run_name: YOUR_RUN_NAME + +# only source_col is needed for masked autoencoder +source_col: raw +spatial_dims: 3 +raw_im_channels: 1 + +trainer: + max_epochs: 100 + gradient_clip_val: 10 + +data: + path: ${paths.data_dir}/example_experiment_data/segmentation + cache_dir: ${paths.data_dir}/example_experiment_data/cache + batch_size: 128 + num_workers: 8 + subsample: + train: 10000 + _aux: + # 2D + # patch_shape: [16, 16] + # 3D + patch_shape: [16, 16, 16] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] diff --git a/configs/model/im2im/hiera.yaml b/configs/model/im2im/hiera.yaml new file mode 100644 index 000000000..901bbe02c --- /dev/null +++ b/configs/model/im2im/hiera.yaml @@ -0,0 +1,69 @@ +_target_: cyto_dl.models.im2im.MultiTaskIm2Im + +save_images_every_n_epochs: 1 +save_dir: ${paths.output_dir} + +x_key: ${source_col} + +backbone: + _target_: cyto_dl.nn.vits.mae.HieraMAE + spatial_dims: 3 + patch_size: [2, 2, 2] # patch_size* num_patches should be your patch shape + num_patches: [8, 8, 8] # patch_size * num_patches = img_shape + num_mask_units: [4, 4, 4] #img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit + emb_dim: 2 + architecture: + # mask_unit_attention blocks - attention is only done within a mask unit and not across mask units + # the total amount of q_stride across the architecture must be less than the number of patches per mask unit + - repeat: 1 + q_stride: [1,1,1] + num_heads: 1 + - repeat: 1 + q_stride: [2,2,2] + num_heads: 4 + # self attention transformer - attention is done across all patches, irrespective of which mask unit they're in + - repeat: 2 + num_heads: 8 + self_attention: True + decoder_layer: 1 + decoder_dim: 16 + mask_ratio: 0.66666666666 + context_pixels: [4,4,4] + use_crossmae: True + +task_heads: ${kv_to_dict:${model._aux._tasks}} + +optimizer: + generator: + _partial_: True + _target_: torch.optim.AdamW + weight_decay: 0.05 + +lr_scheduler: + generator: + _partial_: True + _target_: torch.optim.lr_scheduler.OneCycleLR + max_lr: 0.0001 + epochs: ${trainer.max_epochs} + steps_per_epoch: 1 + pct_start: 0.1 + +inference_args: + sw_batch_size: 1 + roi_size: ${data._aux.patch_shape} + overlap: 0 + progress: True + mode: "gaussian" + +_aux: + _tasks: + - - ${source_col} + - _target_: cyto_dl.nn.head.mae_head.MAEHead + loss: + postprocess: + input: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 + prediction: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 diff --git a/configs/model/im2im/mae.yaml b/configs/model/im2im/mae.yaml index 615f8c3c4..b1172f655 100644 --- a/configs/model/im2im/mae.yaml +++ b/configs/model/im2im/mae.yaml @@ -6,10 +6,10 @@ save_dir: ${paths.output_dir} x_key: ${source_col} backbone: - _target_: cyto_dl.nn.vits.MAE_ViT + _target_: cyto_dl.nn.vits.MAE spatial_dims: ${spatial_dims} - # base_patch_size* num_patches should be your patch shape - base_patch_size: 2 + # patch_size* num_patches should be your patch shape + patch_size: 2 num_patches: 8 emb_dim: 16 encoder_layer: 2 diff --git a/configs/model/im2im/vit_segmentation_decoder.yaml b/configs/model/im2im/vit_segmentation_decoder.yaml index 6c1ad137b..a0c72bd69 100644 --- a/configs/model/im2im/vit_segmentation_decoder.yaml +++ b/configs/model/im2im/vit_segmentation_decoder.yaml @@ -8,8 +8,8 @@ x_key: ${source_col} backbone: _target_: cyto_dl.nn.vits.Seg_ViT spatial_dims: ${spatial_dims} - # base_patch_size* num_patches should be your patch shape - base_patch_size: 2 + # patch_size* num_patches should be your patch shape + patch_size: 2 num_patches: 8 emb_dim: 16 encoder_layer: 2 From a12d8cd09d6e5a5f33e851fea2949198d58e9961 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Fri, 16 Aug 2024 16:16:50 -0700 Subject: [PATCH 12/27] update patchify --- cyto_dl/nn/vits/__init__.py | 5 +- cyto_dl/nn/vits/blocks/patchify/__init__.py | 4 +- cyto_dl/nn/vits/blocks/patchify/patchify.py | 49 +++++++ .../nn/vits/blocks/patchify/patchify_base.py | 74 +++++------ .../nn/vits/blocks/patchify/patchify_conv.py | 122 ++++++++++++++++++ .../nn/vits/blocks/patchify/patchify_hiera.py | 49 ++++--- 6 files changed, 246 insertions(+), 57 deletions(-) create mode 100644 cyto_dl/nn/vits/blocks/patchify/patchify.py create mode 100644 cyto_dl/nn/vits/blocks/patchify/patchify_conv.py diff --git a/cyto_dl/nn/vits/__init__.py b/cyto_dl/nn/vits/__init__.py index ee5076c52..5207f6f9d 100644 --- a/cyto_dl/nn/vits/__init__.py +++ b/cyto_dl/nn/vits/__init__.py @@ -1,4 +1,5 @@ -from .cross_mae import CrossMAE_Decoder -from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT +from .decoder import CrossMAE_Decoder, MAE_Decoder +from .encoder import HieraEncoder, MAE_Encoder +from .mae import MAE, HieraMAE from .seg import Seg_ViT, SuperresDecoder from .utils import take_indexes diff --git a/cyto_dl/nn/vits/blocks/patchify/__init__.py b/cyto_dl/nn/vits/blocks/patchify/__init__.py index 0d52db981..02b79746c 100644 --- a/cyto_dl/nn/vits/blocks/patchify/__init__.py +++ b/cyto_dl/nn/vits/blocks/patchify/__init__.py @@ -1,2 +1,2 @@ -from .patchify_base import Patchify -from .patchify_hiera import PatchifyHiera \ No newline at end of file +from .patchify import Patchify +from .patchify_hiera import PatchifyHiera diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify.py b/cyto_dl/nn/vits/blocks/patchify/patchify.py new file mode 100644 index 000000000..3f6c21ac6 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify/patchify.py @@ -0,0 +1,49 @@ +from cyto_dl.nn.vits.blocks.patchify.patchify_base import PatchifyBase +from typing import List, Optional +from einops.layers.torch import Rearrange +import numpy as np +from cyto_dl.nn.vits.utils import take_indexes + + +class Patchify(PatchifyBase): + """Class for converting images to a masked sequence of patches with positional embeddings.""" + def __init__( + self, + patch_size: List[int], + emb_dim: int, + n_patches: List[int], + spatial_dims: int = 3, + context_pixels: List[int] = [0, 0, 0], + input_channels: int = 1, + tasks: Optional[List[str]] = [], + learnable_pos_embedding: bool = True, + ): + super().__init__( + patch_size=patch_size, + emb_dim=emb_dim, + n_patches=n_patches, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, + tasks=tasks, + learnable_pos_embedding=learnable_pos_embedding, + ) + + @property + def img2token(self): + return self.create_img2token() + + def get_mask_args(self, mask_ratio): + num_patches = np.prod(self.n_patches) + n_visible_patches = int(num_patches * (1 - mask_ratio)) + return n_visible_patches, num_patches + + def create_img2token(self): + """Rearranges the image tensor to a sequence of patches.""" + if self.spatial_dims == 3: + return Rearrange("b c z y x -> (z y x) b c") + elif self.spatial_dims == 2: + return Rearrange("b c y x -> (y x) b c") + + def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): + return take_indexes(tokens, forward_indexes)[:n_visible_patches] \ No newline at end of file diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py index 97ee25bdf..8d2b99ff9 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import List, Optional import numpy as np @@ -6,9 +7,10 @@ from einops.layers.torch import Rearrange, Reduce from timm.models.layers import trunc_normal_ -from cyto_dl.nn.vits.utils import random_indexes, take_indexes, get_positional_embedding +from cyto_dl.nn.vits.utils import get_positional_embedding, random_indexes, take_indexes -class Patchify(torch.nn.Module): + +class PatchifyBase(torch.nn.Module, ABC): """Class for converting images to a masked sequence of patches with positional embeddings.""" def __init__( @@ -43,17 +45,17 @@ def __init__( If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() - self.n_patches = np.asarray(n_patches) if spatial_dims not in (2, 3): raise ValueError("Only 2D and 3D images are supported") self.spatial_dims = spatial_dims + self.n_patches = np.asarray(n_patches) self.pos_embedding = get_positional_embedding( n_patches, emb_dim, learnable=learnable_pos_embedding, use_cls_token=False ) - self.patch2img = self.create_patch2img(n_patches, patch_size) + self.patch2img = self.create_patch2img(n_patches, patch_size) self.conv = self.create_conv(input_channels, emb_dim, patch_size, context_pixels) self.task_embedding = torch.nn.ParameterDict( @@ -61,8 +63,25 @@ def __init__( ) self._init_weight() + def _init_weight(self): + for task in self.task_embedding: + trunc_normal_(self.task_embedding[task], std=0.02) + + @property + @abstractmethod + def img2token(self): + pass + + @abstractmethod + def get_mask_args(self): + pass + + @abstractmethod + def extract_visible_tokens(self): + pass + def create_conv(self, input_channels, emb_dim, patch_size, context_pixels): - context_pixels = context_pixels[:self.spatial_dims] + context_pixels = context_pixels[: self.spatial_dims] weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int) if self.spatial_dims == 3: @@ -82,28 +101,20 @@ def create_conv(self, input_channels, emb_dim, patch_size, context_pixels): padding=context_pixels, ) - def create_img2token(self): - """ - Rearranges the image tensor to a sequence of patches - """ - if self.spatial_dims == 3: - return Rearrange("b c z y x -> (z y x) b c") - elif self.spatial_dims == 2: - return Rearrange("b c y x -> (y x) b c") - def create_patch2img(self, n_patches, patch_size): - """ - Converts boolean array of whether to keep index of each patch to an image-shaped mask of same size as input image - """ + """Converts boolean array of whether to keep index of each patch to an image-shaped mask of + same size as input image.""" if self.spatial_dims == 3: return torch.nn.Sequential( *[ + # rearrange tokens to image Rearrange( "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", n_patch_z=n_patches[0], n_patch_y=n_patches[1], n_patch_x=n_patches[2], ), + # nearest neighbor resize image to match input image size Reduce( "b c n_patch_z n_patch_y n_patch_x -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", reduction="repeat", @@ -122,7 +133,7 @@ def create_patch2img(self, n_patches, patch_size): n_patch_x=n_patches[1], ), Reduce( - "b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + "b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", reduction="repeat", patch_size_y=patch_size[0], patch_size_x=patch_size[1], @@ -130,16 +141,6 @@ def create_patch2img(self, n_patches, patch_size): ] ) - - def _init_weight(self): - for task in self.task_embedding: - trunc_normal_(self.task_embedding[task], std=0.02) - - def get_mask_args(self, mask_ratio): - num_patches = np.prod(self.n_patches) - n_visible_patches = int(num_patches * (1 - mask_ratio)) - return n_visible_patches, num_patches - def get_mask(self, img, n_visible_patches, num_patches): B = img.shape[0] @@ -157,29 +158,28 @@ def get_mask(self, img, n_visible_patches, num_patches): mask = self.patch2img(mask) return mask, forward_indexes, backward_indexes - - def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): - return take_indexes(tokens, forward_indexes)[:n_visible_patches] - + def forward(self, img, mask_ratio, task=None): # generate mask mask = torch.ones_like(img) forward_indexes, backward_indexes = None, None if mask_ratio > 0: n_visible_patches, num_patches = self.get_mask_args(mask_ratio) - mask, forward_indexes, backward_indexes = self.get_mask(img, n_visible_patches, num_patches) + mask, forward_indexes, backward_indexes = self.get_mask( + img, n_visible_patches, num_patches + ) # generate patches tokens = self.conv(img * mask) tokens = self.img2token(tokens) - + # add position embedding tokens = tokens + self.pos_embedding - - # extract visible patches + + # extract visible patches if mask_ratio > 0: tokens = self.extract_visible_tokens(tokens, forward_indexes, n_visible_patches) - + # add task embedding if task in self.task_embedding: tokens = tokens + self.task_embedding[task] diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_conv.py b/cyto_dl/nn/vits/blocks/patchify/patchify_conv.py new file mode 100644 index 000000000..685f65c62 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_conv.py @@ -0,0 +1,122 @@ +from monai.networks.nets import Regressor + + +class PatchifyConv(torch.nn.Module): + """Class for converting images to a masked sequence of patches with positional embeddings.""" + + def __init__( + self, + patch_size: List[int], + emb_dim: int, + n_patches: List[int], + spatial_dims: int = 3, + input_channels: int = 1, + tasks: Optional[List[str]] = [], + ): + """ + Parameters + ---------- + patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of encoder + n_patches: List[int] + Number of patches in each spatial dimension + spatial_dims: int + Number of spatial dimensions + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + tasks: List[str] + List of tasks to encode + """ + super().__init__() + self.n_patches = np.asarray(n_patches) + self.spatial_dims = spatial_dims + + self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(n_patches), 1, emb_dim)) + + self.conv = Regressor( + in_shape=patch_size, + out_shape=emb_dim, + channels=[16, 64, 256, 512], + strides=[2, 2, 2, 1], + kernel_size=3, + ) + + if spatial_dims == 3: + self.img2token = Rearrange("b c z y x -> (z y x) b c") + self.patch2img = Rearrange( + "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", + n_patch_z=n_patches[0], + n_patch_y=n_patches[1], + n_patch_x=n_patches[2], + ) + elif spatial_dims == 2: + self.img2token = Rearrange("b c y x -> (y x) b c") + self.patch2img = Rearrange( + "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", + n_patch_y=n_patches[0], + n_patch_x=n_patches[1], + ) + + self.task_embedding = torch.nn.ParameterDict( + {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} + ) + self._init_weight() + + def _init_weight(self): + trunc_normal_(self.pos_embedding, std=0.02) + for task in self.task_embedding: + trunc_normal_(self.task_embedding[task], std=0.02) + + def get_mask(self, img, n_visible_patches, num_patches): + B = img.shape[0] + indexes = [random_indexes(num_patches, device=img.device) for _ in range(B)] + # forward indexes : index in image -> shuffledpatch + forward_indexes = torch.stack([i[0] for i in indexes], axis=-1) + # backward indexes : shuffled patch -> index in image + backward_indexes = torch.stack([i[1] for i in indexes], axis=-1) + + mask = torch.zeros(num_patches, B, 1, device=img.device, dtype=torch.uint8) + # visible patches are first + mask[:n_visible_patches] = 1 + mask = take_indexes(mask, backward_indexes) + mask = self.patch2img(mask) + # one pixel per masked patch, interpolate to size of input image + mask = torch.nn.functional.interpolate( + mask, img.shape[-self.spatial_dims :], mode="nearest" + ) + return mask, forward_indexes, backward_indexes + + def forward(self, img, mask_ratio=0.75, n_visible_patches=None, task=None): + mask = torch.ones_like(img) + forward_indexes, backward_indexes = None, None + if mask_ratio > 0: + # generate mask + num_patches = np.prod(self.n_patches) + n_visible_patches = n_visible_patches or int(num_patches * (1 - mask_ratio)) + mask, forward_indexes, backward_indexes = self.get_mask( + img, n_visible_patches, num_patches + ) + # generate patches + tokens = self.conv(img * mask) + tokens = self.img2token(tokens) + + pos_emb = rearrange(self.pos_embedding, "t b c -> b c t") + pos_emb = torch.nn.functional.interpolate(pos_emb, tokens.shape[0], mode="linear") + pos_emb = rearrange(pos_emb, "b c t -> t b c") + + # add position embedding + tokens = tokens + pos_emb + if mask_ratio > 0: + # extract visible patches + tokens = take_indexes(tokens, forward_indexes)[:n_visible_patches] + + if task in self.task_embedding: + tokens = tokens + self.task_embedding[task] + + # mask is used above to mask out patches, we need to invert it for loss calculation + mask = (1 - mask).bool() + return tokens, mask, forward_indexes, backward_indexes diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py index d61905608..0d0a5e467 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py @@ -6,7 +6,7 @@ from einops.layers.torch import Rearrange from timm.models.layers import trunc_normal_ -from cyto_dl.nn.vits.blocks.patchify.patchify_base import Patchify +from cyto_dl.nn.vits.blocks.patchify.patchify_base import PatchifyBase def take_indexes_mask(sequences, indexes): @@ -29,7 +29,7 @@ def take_indexes_mask(sequences, indexes): ) -class PatchifyHiera(Patchify): +class PatchifyHiera(PatchifyBase): """Class for converting images to a masked sequence of patches with positional embeddings.""" def __init__( @@ -61,28 +61,51 @@ def __init__( mask_units_per_dim: List[int] Number of mask units in each spatial dimension (ZYX order for 3D, YX order for 2D) """ - super().__init__(patch_size, emb_dim, n_patches, spatial_dims, context_pixels, input_channels, tasks) + super().__init__( + patch_size, + emb_dim, + n_patches, + spatial_dims, + context_pixels, + input_channels, + tasks, + True, + ) self.total_n_mask_units = np.prod(mask_units_per_dim) - patches_per_mask_unit = n_patches // self.total_n_mask_units + # img shape = patch_size * n_patches + # mask_unit_size = img shape / mask_units_per_dim + mask_unit_size_pix = ( + (np.array(patch_size) * np.array(n_patches)) / np.array(mask_units_per_dim) + ).astype(int) + + patches_per_mask_unit = mask_unit_size_pix // patch_size self.pos_embedding = torch.nn.Parameter( torch.zeros(1, self.total_n_mask_units, np.prod(patches_per_mask_unit), emb_dim) ) - # redefine this to work with mask units instead of patches - self.img2token = self.create_img2token(mask_units_per_dim) + self.mask_units_per_dim = mask_units_per_dim - mask_unit_size_pix = patches_per_mask_unit * patch_size self.patch2img = self.create_patch2img(mask_units_per_dim, mask_unit_size_pix) self._init_weight() def _init_weight(self): trunc_normal_(self.pos_embedding, std=0.02) + for task in self.task_embedding: + trunc_normal_(self.task_embedding[task], std=0.02) + + @property + def img2token(self): + # redefine this to work with mask units instead of patches + return self.create_img2token(self.mask_units_per_dim) + + # in hiera, the level of masking is at the mask unit, not the patch level + def get_mask_args(self, mask_ratio): + n_visible_patches = int(self.total_n_mask_units * (1 - mask_ratio)) + return n_visible_patches, self.total_n_mask_units - def create_img2token(self, mask_units_per_dim=None): - if mask_units_per_dim is None: - return + def create_img2token(self, mask_units_per_dim): if self.spatial_dims == 3: return Rearrange( "b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ", @@ -97,11 +120,5 @@ def create_img2token(self, mask_units_per_dim=None): n_mu_x=mask_units_per_dim[2], ) - # in hiera, the level of masking is at the mask unit, not the patch level - def get_mask_args(self, mask_ratio): - n_visible_patches = int(self.total_n_mask_units * (1 - mask_ratio)) - return n_visible_patches, self.total_n_mask_units - def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): return take_indexes_mask(tokens, forward_indexes)[:, :n_visible_patches] - From f652138aa924d6f7adc4b9594c98417eff63c9c1 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Fri, 16 Aug 2024 16:18:43 -0700 Subject: [PATCH 13/27] rearrange encoder/decoder/mae --- cyto_dl/nn/vits/cross_mae.py | 113 ------ cyto_dl/nn/vits/decoder.py | 250 ++++++++++++ cyto_dl/nn/vits/{hiera_mae.py => encoder.py} | 209 +++++----- cyto_dl/nn/vits/mae.py | 402 ++++++++----------- 4 files changed, 515 insertions(+), 459 deletions(-) delete mode 100644 cyto_dl/nn/vits/cross_mae.py create mode 100644 cyto_dl/nn/vits/decoder.py rename cyto_dl/nn/vits/{hiera_mae.py => encoder.py} (66%) diff --git a/cyto_dl/nn/vits/cross_mae.py b/cyto_dl/nn/vits/cross_mae.py deleted file mode 100644 index 3dc7e68f2..000000000 --- a/cyto_dl/nn/vits/cross_mae.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import List, Optional - -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_ - -from cyto_dl.nn.vits.mae import MAE_Decoder -from cyto_dl.nn.vits.blocks import CrossAttentionBlock -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes - - -class CrossMAE_Decoder(MAE_decoder): - """Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend - to visible tokens.""" - - def __init__( - self, - num_patches: List[int], - spatial_dims: int = 3, - base_patch_size: Optional[List[int]] = [4, 8, 8], - enc_dim: Optional[int] = 768, - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 4, - num_head: Optional[int] = 3, - has_cls_token: Optional[bool] = True, - learnable_pos_embedding: Optional[bool] = True, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - base_patch_size: Tuple[int] - Size of each patch - enc_dim: int - Dimension of encoder - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - has_cls_token: bool - Whether encoder features have a cls token - learnable_pos_embedding: bool - If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. - """ - super().__init__(num_patches, spatial_dims, base_patch_size, enc_dim, emb_dim, num_layer, num_head, has_cls_token, learnable_pos_embedding) - - self.transformer = torch.nn.ParameterList( - [ - CrossAttentionBlock( - encoder_dim=emb_dim, - decoder_dim=emb_dim, - num_heads=num_head, - ) - for _ in range(num_layer) - ] - ) - - def forward(self, features, forward_indexes, backward_indexes): - # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers - # features can be n t b c (if intermediate feature weighter used) or t b c if not - features = features[0] if len(features.shape) == 4 else features - T, B, C = features.shape - # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now - features = self.projection_norm(self.projection(features)) - - backward_indexes = self.adjust_indices_for_cls(backward_indexes) - forward_indexes = self.adjust_indices_for_cls(forward_indexes) - - features = self.add_mask_tokens(features, backward_indexes) - - # unshuffle to original positions for positional embedding so we can do cross attention during decoding - features = take_indexes(features, backward_indexes) - features = features + self.pos_embedding - - # reshuffle to shuffled positions for cross attention - features = take_indexes(features, forward_indexes) - features, masked = features[:T], features[T:] - - masked = rearrange(masked, "t b c -> b t c") - features = rearrange(features, "t b c -> b t c") - - for transformer in self.transformer: - masked = transformer(masked, features) - - # noneed to remove cls token, it's a part of the features vector - masked = rearrange(masked, "b t c -> t b c") - - # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) - masked = self.decoder_norm(masked) - patches = self.head(masked) - - # add back in visible/encoded tokens that we don't calculate loss on - patches = torch.cat( - [ - torch.zeros( - (T - 1, B, patches.shape[-1]), - requires_grad=False, - device=patches.device, - dtype=patches.dtype, - ), - patches, - ], - dim=0, - ) - patches = take_indexes(patches, backward_indexes[1:] - 1) if self.has_cls_token else take_indexes(patches, backward_indexes) - # patches to image - img = self.patch2img(patches) - return img diff --git a/cyto_dl/nn/vits/decoder.py b/cyto_dl/nn/vits/decoder.py new file mode 100644 index 000000000..bf98de76c --- /dev/null +++ b/cyto_dl/nn/vits/decoder.py @@ -0,0 +1,250 @@ +# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 + +from typing import List, Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_ +from timm.models.vision_transformer import Block + +from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes + +from typing import List, Optional + +import torch +from einops import rearrange + +from cyto_dl.nn.vits.blocks import CrossAttentionBlock +from cyto_dl.nn.vits.utils import take_indexes + + +class MAE_Decoder(torch.nn.Module): + def __init__( + self, + num_patches: List[int], + spatial_dims: int = 3, + patch_size: Optional[List[int]] = [4, 8, 8], + enc_dim: Optional[int] = 768, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 4, + num_head: Optional[int] = 3, + has_cls_token: Optional[bool] = False, + learnable_pos_embedding: Optional[bool] = True, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int] + Number of patches in each dimension + patch_size: Tuple[int] + Size of each patch + enc_dim: int + Dimension of encoder + emb_dim: int + Dimension of decoder + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + has_cls_token: bool + Whether encoder features have a cls token + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. + """ + super().__init__() + self.has_cls_token = has_cls_token + + self.projection_norm = nn.LayerNorm(emb_dim) + self.projection = torch.nn.Linear(enc_dim, emb_dim) + self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) + + self.pos_embedding = get_positional_embedding( + num_patches, emb_dim, use_cls_token=has_cls_token, learnable=learnable_pos_embedding + ) + + self.transformer = torch.nn.Sequential( + *[Block(emb_dim, num_head) for _ in range(num_layer)] + ) + out_dim = torch.prod(torch.as_tensor(patch_size)).item() + self.decoder_norm = nn.LayerNorm(emb_dim) + self.head = torch.nn.Linear(emb_dim, out_dim) + self.num_patches = torch.as_tensor(num_patches) + + if spatial_dims == 3: + self.patch2img = Rearrange( + "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_z=num_patches[0], + n_patch_y=num_patches[1], + n_patch_x=num_patches[2], + patch_size_z=patch_size[0], + patch_size_y=patch_size[1], + patch_size_x=patch_size[2], + ) + elif spatial_dims == 2: + self.patch2img = Rearrange( + "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_y=num_patches[0], + n_patch_x=num_patches[1], + patch_size_y=patch_size[0], + patch_size_x=patch_size[1], + ) + + self.init_weight() + + def init_weight(self): + trunc_normal_(self.mask_token, std=0.02) + + def adjust_indices_for_cls(self, indexes): + if self.has_cls_token: + # add all zeros to indices - this keeps the class token as the first index in the + # tensor. We also have to add 1 to all the indices to account for the zeros we added + return torch.cat( + [ + torch.zeros( + 1, indexes.shape[1], device=indexes.device, dtype=torch.long + ), + indexes + 1, + ], + dim=0, + ) + return indexes + + def add_mask_tokens(self, features, backward_indexes): + # fill in deleted masked regions with mask token + num_visible_tokens, B, _ = features.shape + # total number of tokens - number of visible tokens + num_mask_tokens = backward_indexes.shape[0] - num_visible_tokens + mask_tokens = repeat(self.mask_token, "1 1 c -> t b c", t=num_mask_tokens, b=B) + return torch.cat([features, mask_tokens],dim=0) + + def forward(self, features, forward_indexes, backward_indexes): + # project from encoder dimension to decoder dimension + features = self.projection_norm(self.projection(features)) + + backward_indexes = self.adjust_indices_for_cls(backward_indexes) + + features = self.add_mask_tokens(features, backward_indexes) + + # unshuffle to original positions + features = take_indexes(features, backward_indexes) + features = features + self.pos_embedding + + # decode + features = rearrange(features, "t b c -> b t c") + features = self.transformer(features) + features = rearrange(features, "b t c -> t b c") + + if self.has_cls_token: + features = features[1:] # remove global feature + + # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) + patches = self.head(self.decoder_norm(features)) + + # patches to image + img = self.patch2img(patches) + return img + + +class CrossMAE_Decoder(MAE_Decoder): + """Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend + to visible tokens.""" + + def __init__( + self, + num_patches: List[int], + spatial_dims: int = 3, + patch_size: Optional[List[int]] = [4, 8, 8], + enc_dim: Optional[int] = 768, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 4, + num_head: Optional[int] = 3, + has_cls_token: Optional[bool] = True, + learnable_pos_embedding: Optional[bool] = True, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int] + Number of patches in each dimension + patch_size: Tuple[int] + Size of each patch + enc_dim: int + Dimension of encoder + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + has_cls_token: bool + Whether encoder features have a cls token + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. + """ + super().__init__(num_patches, spatial_dims, patch_size, enc_dim, emb_dim, num_layer, num_head, has_cls_token, learnable_pos_embedding) + + self.transformer = torch.nn.ParameterList( + [ + CrossAttentionBlock( + encoder_dim=emb_dim, + decoder_dim=emb_dim, + num_heads=num_head, + ) + for _ in range(num_layer) + ] + ) + + def forward(self, features, forward_indexes, backward_indexes): + # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 1 layer + # features can be n t b c (if intermediate feature weighter used) or t b c if not + features = features[0] if len(features.shape) == 4 else features + T, B, C = features.shape + # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now + features = self.projection_norm(self.projection(features)) + + backward_indexes = self.adjust_indices_for_cls(backward_indexes) + forward_indexes = self.adjust_indices_for_cls(forward_indexes) + + features = self.add_mask_tokens(features, backward_indexes) + + # unshuffle to original positions for positional embedding so we can do cross attention during decoding + features = take_indexes(features, backward_indexes) + features = features + self.pos_embedding + + # reshuffle to shuffled positions for cross attention + features = take_indexes(features, forward_indexes) + features, masked = features[:T], features[T:] + + masked = rearrange(masked, "t b c -> b t c") + features = rearrange(features, "t b c -> b t c") + + for transformer in self.transformer: + masked = transformer(masked, features) + + # noneed to remove cls token, it's a part of the features vector + masked = rearrange(masked, "b t c -> t b c") + + # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) + masked = self.decoder_norm(masked) + patches = self.head(masked) + + # add back in visible/encoded tokens that we don't calculate loss on + patches = torch.cat( + [ + torch.zeros( + # T-1 accounts for cls token + (T - self.has_cls_token, B, patches.shape[-1]), + requires_grad=False, + device=patches.device, + dtype=patches.dtype, + ), + patches, + ], + dim=0, + ) + patches = take_indexes(patches, backward_indexes[1:] - 1) if self.has_cls_token else take_indexes(patches, backward_indexes) + # patches to image + img = self.patch2img(patches) + return img diff --git a/cyto_dl/nn/vits/hiera_mae.py b/cyto_dl/nn/vits/encoder.py similarity index 66% rename from cyto_dl/nn/vits/hiera_mae.py rename to cyto_dl/nn/vits/encoder.py index 3c8c0001d..77e05cc7b 100644 --- a/cyto_dl/nn/vits/hiera_mae.py +++ b/cyto_dl/nn/vits/encoder.py @@ -1,3 +1,4 @@ +# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 # inspired by https://github.com/facebookresearch/hiera from typing import Dict, List, Optional @@ -8,15 +9,100 @@ import torch.nn.functional from einops import rearrange from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_ from timm.models.vision_transformer import Block from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock from cyto_dl.nn.vits.blocks.patchify import PatchifyHiera -from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder -from cyto_dl.nn.vits.mae import MAE_Decoder +from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify + + +class MAE_Encoder(torch.nn.Module): + def __init__( + self, + num_patches: List[int], + spatial_dims: int = 3, + patch_size: List[int] = (16, 16, 16), + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 12, + num_head: Optional[int] = 3, + context_pixels: Optional[List[int]] = [0, 0, 0], + input_channels: Optional[int] = 1, + n_intermediate_weights: Optional[int] = -1, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int] + Number of patches in each dimension + spatial_dims: int + Number of spatial dimensions + patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + weight_intermediates: bool + Whether to output linear combination of intermediate layers as final output like CrossMAE + """ + super().__init__() + self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) + self.patchify = Patchify( + patch_size, emb_dim, num_patches, spatial_dims, context_pixels, input_channels + ) + weight_intermediates = n_intermediate_weights > 0 + if weight_intermediates: + self.transformer = torch.nn.ModuleList( + [Block(emb_dim, num_head) for _ in range(num_layer)] + ) + else: + self.transformer = torch.nn.Sequential( + *[Block(emb_dim, num_head) for _ in range(num_layer)] + ) + + self.layer_norm = torch.nn.LayerNorm(emb_dim) + + self.intermediate_weighter = ( + IntermediateWeigher(num_layer, emb_dim, n_intermediate_weights) + if weight_intermediates + else None + ) + self.init_weight() + + def init_weight(self): + trunc_normal_(self.cls_token, std=0.02) + + def forward(self, img, mask_ratio=0.75): + patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio) + patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) + patches = rearrange(patches, "t b c -> b t c") + + if self.intermediate_weighter is not None: + intermediates = [patches] + for block in self.transformer: + patches = block(patches) + intermediates.append(patches) + features = self.layer_norm(self.intermediate_weighter(intermediates)) + features = rearrange(features, "n b t c -> n t b c") + else: + features = self.layer_norm(self.transformer(patches)) + features = rearrange(features, "b t c -> t b c") + if mask_ratio > 0: + return features, mask, forward_indexes, backward_indexes + return features class SpatialMerger(nn.Module): + """ + Class for converting multi-resolution Hiera features to the same (lowest) spatial resolution via convolution + """ def __init__(self, downsample_factor, in_dim, out_dim): super().__init__() self.downsample_factor = downsample_factor @@ -53,9 +139,8 @@ def __init__( emb_dim: int = 64, spatial_dims: int = 3, patch_size: List[int] = (16, 16, 16), - mask_ratio: Optional[float] = 0.75, context_pixels: Optional[List[int]] = [0, 0, 0], - save_layers: Optional[bool] = True, + save_layers: Optional[bool] = False, ) -> None: """ Parameters @@ -80,15 +165,12 @@ def __init__( Number of spatial dimensions patch_size: List[int] Size of each patch - mask_ratio: float - Fraction of mask units to remove context_pixels: List[int] Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. save_layers: bool Whether to save the intermediate layer outputs """ super().__init__() - self.mask_ratio = mask_ratio self.save_layers = save_layers self.patchify = PatchifyHiera( patch_size, @@ -100,6 +182,15 @@ def __init__( ) patches_per_mask_unit = np.array(num_patches) // np.array(num_mask_units) + + total_downsampling_per_axis = np.prod([block.get("q_stride", [1] * spatial_dims) for block in architecture],axis=0) + + assert np.all(patches_per_mask_unit - total_downsampling_per_axis >= 0), f"Number of mask units must be greater than the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride or increase the number of patches per mask unit." + assert np.all( + patches_per_mask_unit % total_downsampling_per_axis == 0 + ), f"Number of mask units must be divisible by the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride" + + self.final_dim = emb_dim * (2 ** len(architecture)) self.save_block_idxs = [] @@ -160,8 +251,8 @@ def __init__( self.layer_norm = torch.nn.LayerNorm(self.final_dim) - def forward(self, img): - patches, mask, forward_indexes, backward_indexes = self.patchify(img, self.mask_ratio) + def forward(self, img, mask_ratio): + patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio) # mask unit attention mask_unit_embeddings = 0.0 @@ -177,103 +268,7 @@ def forward(self, img): mask_unit_embeddings = rearrange(mask_unit_embeddings, "b n_mu t c -> b (n_mu t) c") mask_unit_embeddings = self.self_attention_transformer(mask_unit_embeddings) mask_unit_embeddings = self.layer_norm(mask_unit_embeddings) + mask_unit_embeddings = rearrange(mask_unit_embeddings, 'b t c -> t b c') - return mask_unit_embeddings, mask, forward_indexes, backward_indexes, save_layers - - -class HieraMAE(torch.nn.Module): - def __init__( - self, - architecture: List[Dict], - spatial_dims: int = 3, - num_patches: Optional[List[int]] = [2, 32, 32], - num_mask_units: Optional[List[int]] = [2, 12, 12], - patch_size: Optional[List[int]] = [16, 16, 16], - emb_dim: Optional[int] = 64, - decoder_layer: Optional[int] = 4, - decoder_head: Optional[int] = 8, - decoder_dim: Optional[int] = 192, - mask_ratio: Optional[int] = 0.75, - context_pixels: Optional[List[int]] = [0, 0, 0], - use_crossmae: Optional[bool] = False, - ) -> None: - """ - Parameters - ---------- - architecture: List[Dict] - List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: - - repeat: int - Number of times to repeat the block - - num_heads: int - Number of heads in the multihead attention - - q_stride: List[int] - Stride for the query in each spatial dimension - - self_attention: bool - Whether to use self attention or mask unit attention - spatial_dims: int - Number of spatial dimensions - num_patches: List[int] - Number of patches in each dimension - num_mask_units: List[int] - Number of mask units in each dimension - patch_size: List[int] - Size of each patch - emb_dim: int - Dimension of embedding - decoder_layer: int - Number of layers in the decoder - decoder_head: int - Number of heads in the decoder - decoder_dim: int - Dimension of the decoder - mask_ratio: float - Fraction of mask units to remove - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. - """ - super().__init__() - assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" - - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(patch_size, int): - patch_size = [patch_size] * spatial_dims - - assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" - assert len(patch_size) == spatial_dims, "patch_size must be of length spatial_dims" - - self.mask_ratio = mask_ratio - - self.encoder = HieraEncoder( - num_patches=num_patches, - num_mask_units=num_mask_units, - architecture=architecture, - emb_dim=emb_dim, - spatial_dims=spatial_dims, - patch_size=patch_size, - mask_ratio=mask_ratio, - context_pixels=context_pixels, - ) - # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size - mask_unit_size = (np.array(num_patches) * np.array(patch_size)) / np.array(num_mask_units) - - decoder_class = MAE_Decoder - if use_crossmae: - decoder_class = CrossMAE_Decoder - - self.decoder = decoder_class( - num_patches=num_mask_units, - spatial_dims=spatial_dims, - base_patch_size=mask_unit_size.astype(int), - enc_dim=self.encoder.final_dim, - emb_dim=decoder_dim, - num_layer=decoder_layer, - num_head=decoder_head, - has_cls_token=False, - ) + return mask_unit_embeddings, mask, forward_indexes, backward_indexes #, save_layers - def forward(self, img): - features, mask, forward_indexes, backward_indexes, save_layers = self.encoder(img) - features = rearrange(features, "b t c -> t b c") - predicted_img = self.decoder(features, forward_indexes, backward_indexes) - return predicted_img, mask diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index 92d94eb56..56953acbf 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -1,236 +1,65 @@ # modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 -from typing import List, Optional +from abc import ABC, abstractmethod +from typing import Dict, List, Optional +import numpy as np import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_ -from timm.models.vision_transformer import Block -from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify -from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes +from cyto_dl.nn.vits.decoder import CrossMAE_Decoder, MAE_Decoder +from cyto_dl.nn.vits.encoder import HieraEncoder, MAE_Encoder -class MAE_Encoder(torch.nn.Module): - def __init__( - self, - num_patches: List[int], - spatial_dims: int = 3, - base_patch_size: List[int] = (16, 16, 16), - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 12, - num_head: Optional[int] = 3, - context_pixels: Optional[List[int]] = [0, 0, 0], - input_channels: Optional[int] = 1, - n_intermediate_weights: Optional[int] = -1, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - spatial_dims: int - Number of spatial dimensions - base_patch_size: List[int] - Size of each patch - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. - input_channels: int - Number of input channels - weight_intermediates: bool - Whether to output linear combination of intermediate layers as final output like CrossMAE - """ +class MAE_Base(torch.nn.Module, ABC): + def __init__(self, spatial_dims, num_patches, patch_size, mask_ratio, features_only): super().__init__() - self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - self.patchify = Patchify( - base_patch_size, emb_dim, num_patches, spatial_dims, context_pixels, input_channels - ) - weight_intermediates = n_intermediate_weights > 0 - if weight_intermediates: - self.transformer = torch.nn.ModuleList( - [Block(emb_dim, num_head) for _ in range(num_layer)] - ) - else: - self.transformer = torch.nn.Sequential( - *[Block(emb_dim, num_head) for _ in range(num_layer)] - ) - - self.layer_norm = torch.nn.LayerNorm(emb_dim) - - self.intermediate_weighter = ( - IntermediateWeigher(num_layer, emb_dim, n_intermediate_weights) - if weight_intermediates - else None - ) - self.init_weight() - - def init_weight(self): - trunc_normal_(self.cls_token, std=0.02) - - def forward(self, img, mask_ratio=0.75): - patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio) - patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) - patches = rearrange(patches, "t b c -> b t c") - - if self.intermediate_weighter is not None: - intermediates = [patches] - for block in self.transformer: - patches = block(patches) - intermediates.append(patches) - features = self.layer_norm(self.intermediate_weighter(intermediates)) - features = rearrange(features, "n b t c -> n t b c") - else: - features = self.layer_norm(self.transformer(patches)) - features = rearrange(features, "b t c -> t b c") - if mask_ratio > 0: - return features, mask, forward_indexes, backward_indexes - return features - - -class MAE_Decoder(torch.nn.Module): - def __init__( - self, - num_patches: List[int], - spatial_dims: int = 3, - base_patch_size: Optional[List[int]] = [4, 8, 8], - enc_dim: Optional[int] = 768, - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 4, - num_head: Optional[int] = 3, - has_cls_token: Optional[bool] = False, - learnable_pos_embedding: Optional[bool] = True, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - base_patch_size: Tuple[int] - Size of each patch - enc_dim: int - Dimension of encoder - emb_dim: int - Dimension of decoder - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - has_cls_token: bool - Whether encoder features have a cls token - learnable_pos_embedding: bool - If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. - """ - super().__init__() - self.has_cls_token = has_cls_token - - self.projection_norm = nn.LayerNorm(emb_dim) - self.projection = torch.nn.Linear(enc_dim, emb_dim) - self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - - self.pos_embedding = get_positional_embedding( - num_patches, emb_dim, learnable=learnable_pos_embedding - ) - - self.transformer = torch.nn.Sequential( - *[Block(emb_dim, num_head) for _ in range(num_layer)] - ) - out_dim = torch.prod(torch.as_tensor(base_patch_size)).item() - self.decoder_norm = nn.LayerNorm(emb_dim) - self.head = torch.nn.Linear(emb_dim, out_dim) - self.num_patches = torch.as_tensor(num_patches) - - if spatial_dims == 3: - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_z=num_patches[0], - n_patch_y=num_patches[1], - n_patch_x=num_patches[2], - patch_size_z=base_patch_size[0], - patch_size_y=base_patch_size[1], - patch_size_x=base_patch_size[2], - ) - elif spatial_dims == 2: - self.patch2img = Rearrange( - "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_y=num_patches[0], - n_patch_x=num_patches[1], - patch_size_y=base_patch_size[0], - patch_size_x=base_patch_size[1], - ) - - self.init_weight() - - def init_weight(self): - trunc_normal_(self.mask_token, std=0.02) + assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" - def adjust_indices_for_cls(self, indexes): - if self.has_cls_token: - return torch.cat( - [ - torch.zeros( - 1, indexes.shape[1], device=indexes.device, dtype=torch.long - ), - indexes + 1, - ], - dim=0, - ) - return indexes - - def add_mask_tokens(self, features, backward_indexes): - # fill in deleted masked regions with mask token - return torch.cat( - [ - features, - self.mask_token.expand( - backward_indexes.shape[0] - features.shape[0], features.shape[1], -1 - ), - ], - dim=0, - ) + if isinstance(num_patches, int): + num_patches = [num_patches] * spatial_dims + if isinstance(patch_size, int): + patch_size = [patch_size] * spatial_dims - def forward(self, features, forward_indexes, backward_indexes): - # project from encoder dimension to decoder dimension - features = self.projection_norm(self.projection(features)) + assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" + assert len(patch_size) == spatial_dims, "patch_size must be of length spatial_dims" - backward_indexes = self.adjust_indices_for_cls(backward_indexes) + self.spatial_dims = spatial_dims + self.num_patches = num_patches + self.patch_size = patch_size + self.mask_ratio = mask_ratio + self.features_only = features_only - features = self.add_mask_tokens(features, backward_indexes) + # encoder and decoder must be defined in subclasses + @property + @abstractmethod + def encoder(self): + pass - # unshuffle to original positions - features = take_indexes(features, backward_indexes) - features = features + self.pos_embedding + @property + @abstractmethod + def decoder(self): + pass - # decode - features = rearrange(features, "t b c -> b t c") - features = self.transformer(features) - features = rearrange(features, "b t c -> t b c") - - if self.has_cls_token: - features = features[1:] # remove global feature + def init_encoder(self): + raise NotImplementedError - # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) - patches = self.head(self.decoder_norm(features)) + def init_decoder(self): + raise NotImplementedError - # patches to image - img = self.patch2img(patches) - return img + def forward(self, img): + features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio) + if self.features_only: + return features + predicted_img = self.decoder(features, forward_indexes, backward_indexes) + return predicted_img, mask -class MAE_ViT(torch.nn.Module): +class MAE(MAE_Base): def __init__( self, spatial_dims: int = 3, num_patches: Optional[List[int]] = [2, 32, 32], - base_patch_size: Optional[List[int]] = [16, 16, 16], + patch_size: Optional[List[int]] = [16, 16, 16], emb_dim: Optional[int] = 768, encoder_layer: Optional[int] = 12, encoder_head: Optional[int] = 8, @@ -251,7 +80,7 @@ def __init__( Number of spatial dimensions num_patches: List[int] Number of patches in each dimension (ZYX order) - base_patch_size: List[int] + patch_size: List[int] Size of each patch (ZYX order) emb_dim: int Dimension of encoder embedding @@ -276,26 +105,18 @@ def __init__( learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ - super().__init__() - assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" - - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(base_patch_size, int): - base_patch_size = [base_patch_size] * spatial_dims - - assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" - assert ( - len(base_patch_size) == spatial_dims - ), "base_patch_size must be of length spatial_dims" - - self.mask_ratio = mask_ratio - self.features_only = features_only + super().__init__( + spatial_dims=spatial_dims, + num_patches=num_patches, + patch_size=patch_size, + mask_ratio=mask_ratio, + features_only=features_only, + ) - self.encoder = MAE_Encoder( - num_patches, + self._encoder = MAE_Encoder( + self.num_patches, spatial_dims, - base_patch_size, + self.patch_size, emb_dim, encoder_layer, encoder_head, @@ -307,10 +128,10 @@ def __init__( decoder_class = MAE_Decoder if use_crossmae: decoder_class = CrossMAE_Decoder - self.decoder = decoder_class( - num_patches=num_patches, + self._decoder = decoder_class( + num_patches=self.num_patches, spatial_dims=spatial_dims, - base_patch_size=base_patch_size, + patch_size=self.patch_size, enc_dim=emb_dim, emb_dim=decoder_dim, num_layer=decoder_layer, @@ -318,9 +139,112 @@ def __init__( learnable_pos_embedding=learnable_pos_embedding, ) - def forward(self, img): - features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio) - if self.features_only: - return features - predicted_img = self.decoder(features, forward_indexes, backward_indexes) - return predicted_img, mask + @property + def encoder(self): + return self._encoder + + @property + def decoder(self): + return self._decoder + + +class HieraMAE(MAE_Base): + def __init__( + self, + architecture: List[Dict], + spatial_dims: int = 3, + num_patches: Optional[List[int]] = [2, 32, 32], + num_mask_units: Optional[List[int]] = [2, 12, 12], + patch_size: Optional[List[int]] = [16, 16, 16], + emb_dim: Optional[int] = 64, + decoder_layer: Optional[int] = 4, + decoder_head: Optional[int] = 8, + decoder_dim: Optional[int] = 192, + mask_ratio: Optional[int] = 0.75, + use_crossmae: Optional[bool] = False, + context_pixels: Optional[List[int]] = [0, 0, 0], + input_channels: Optional[int] = 1, + features_only: Optional[bool] = False, + ) -> None: + """ + Parameters + ---------- + architecture: List[Dict] + List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: + - repeat: int + Number of times to repeat the block + - num_heads: int + Number of heads in the multihead attention + - q_stride: List[int] + Stride for the query in each spatial dimension + - self_attention: bool + Whether to use self attention or mask unit attention + spatial_dims: int + Number of spatial dimensions + num_patches: List[int] + Number of patches in each dimension + num_mask_units: List[int] + Number of mask units in each dimension + patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of embedding + decoder_layer: int + Number of layers in the decoder + decoder_head: int + Number of heads in the decoder + decoder_dim: int + Dimension of the decoder + mask_ratio: float + Fraction of mask units to remove + use_crossmae: bool + Use CrossMAE-style decoder instead of MAE decoder + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + features_only: bool + Only use encoder to extract features + """ + super().__init__( + spatial_dims=spatial_dims, + num_patches=num_patches, + patch_size=patch_size, + mask_ratio=mask_ratio, + features_only=features_only, + ) + + self._encoder = HieraEncoder( + num_patches=self.num_patches, + num_mask_units=num_mask_units, + architecture=architecture, + emb_dim=emb_dim, + spatial_dims=spatial_dims, + patch_size=self.patch_size, + context_pixels=context_pixels, + ) + # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size + mask_unit_size = (np.array(num_patches) * np.array(patch_size)) / np.array(num_mask_units) + + decoder_class = MAE_Decoder + if use_crossmae: + decoder_class = CrossMAE_Decoder + + self._decoder = decoder_class( + num_patches=num_mask_units, + spatial_dims=spatial_dims, + patch_size=mask_unit_size.astype(int), + enc_dim=self.encoder.final_dim, + emb_dim=decoder_dim, + num_layer=decoder_layer, + num_head=decoder_head, + has_cls_token=False, + ) + + @property + def encoder(self): + return self._encoder + + @property + def decoder(self): + return self._decoder From 7ee08e3c3b67c7ae6b3f8bf81faef813032d73b9 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 13:24:07 -0700 Subject: [PATCH 14/27] add 2d hiera --- .../nn/vits/blocks/patchify/patchify_hiera.py | 23 ++-- cyto_dl/nn/vits/encoder.py | 103 ++++++++++++------ cyto_dl/nn/vits/mae.py | 51 +++++---- cyto_dl/nn/vits/utils.py | 8 ++ 4 files changed, 116 insertions(+), 69 deletions(-) diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py index 0d0a5e467..8712de088 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py @@ -62,19 +62,18 @@ def __init__( Number of mask units in each spatial dimension (ZYX order for 3D, YX order for 2D) """ super().__init__( - patch_size, - emb_dim, - n_patches, - spatial_dims, - context_pixels, - input_channels, - tasks, - True, + patch_size=patch_size, + emb_dim=emb_dim, + n_patches=n_patches, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, + tasks=tasks, + learnable_pos_embedding=True, ) self.total_n_mask_units = np.prod(mask_units_per_dim) - # img shape = patch_size * n_patches - # mask_unit_size = img shape / mask_units_per_dim + # mask_unit_size is the img shape / mask_units_per_dim, img_shape is size per patch * n_patches mask_unit_size_pix = ( (np.array(patch_size) * np.array(n_patches)) / np.array(mask_units_per_dim) ).astype(int) @@ -116,8 +115,8 @@ def create_img2token(self, mask_units_per_dim): elif self.spatial_dims == 2: return Rearrange( "b c (n_mu_y y) (n_mu_x x) -> b (n_mu_y n_mu_x) (y x) c ", - n_mu_y=mask_units_per_dim[1], - n_mu_x=mask_units_per_dim[2], + n_mu_y=mask_units_per_dim[0], + n_mu_x=mask_units_per_dim[1], ) def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): diff --git a/cyto_dl/nn/vits/encoder.py b/cyto_dl/nn/vits/encoder.py index 77e05cc7b..11653a4eb 100644 --- a/cyto_dl/nn/vits/encoder.py +++ b/cyto_dl/nn/vits/encoder.py @@ -1,7 +1,7 @@ # modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 # inspired by https://github.com/facebookresearch/hiera -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np import torch @@ -12,9 +12,10 @@ from timm.models.layers import trunc_normal_ from timm.models.vision_transformer import Block +from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock from cyto_dl.nn.vits.blocks.patchify import PatchifyHiera -from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify +from cyto_dl.nn.vits.utils import validate_spatial_dims class MAE_Encoder(torch.nn.Module): @@ -49,10 +50,14 @@ def __init__( Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. input_channels: int Number of input channels - weight_intermediates: bool - Whether to output linear combination of intermediate layers as final output like CrossMAE + n_intermediate_weights: bool + Whether to use intermediate weights for weighted sum of intermediate layers """ super().__init__() + num_patches, patch_size, context_pixels = validate_spatial_dims( + spatial_dims, [num_patches, patch_size, context_pixels] + ) + self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.patchify = Patchify( patch_size, emb_dim, num_patches, spatial_dims, context_pixels, input_channels @@ -100,13 +105,18 @@ def forward(self, img, mask_ratio=0.75): class SpatialMerger(nn.Module): - """ - Class for converting multi-resolution Hiera features to the same (lowest) spatial resolution via convolution - """ - def __init__(self, downsample_factor, in_dim, out_dim): + """Class for converting multi-resolution Hiera features to the same (lowest) spatial resolution + via convolution.""" + + def __init__( + self, downsample_factor: List[int], in_dim: int, out_dim: int, spatial_dims: int = 3 + ): super().__init__() - self.downsample_factor = downsample_factor - conv = nn.Conv3d( + downsample_factor = validate_spatial_dims(spatial_dims, [downsample_factor])[0] + + self.spatial_dims = spatial_dims + conv_fn = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + conv = conv_fn( in_channels=in_dim, out_channels=out_dim, kernel_size=downsample_factor, @@ -114,27 +124,36 @@ def __init__(self, downsample_factor, in_dim, out_dim): padding=0, bias=False, ) - - tokens2img = Rearrange( - "b n_mu (z y x) c -> (b n_mu) c z y x", - z=downsample_factor[0], - y=downsample_factor[1], - x=downsample_factor[2], - ) + if spatial_dims == 3: + tokens2img = Rearrange( + "b n_mu (z y x) c -> (b n_mu) c z y x", + z=downsample_factor[0], + y=downsample_factor[1], + x=downsample_factor[2], + ) + else: + tokens2img = Rearrange( + "b n_mu (y x) c -> (b n_mu) c y x", + y=downsample_factor[0], + x=downsample_factor[1], + ) self.model = nn.Sequential(tokens2img, conv) def forward(self, x): b, n_mu, _, _ = x.shape x = self.model(x) - x = rearrange(x, "(b n_mu) c z y x -> b n_mu (z y x) c", b=b, n_mu=n_mu) + if self.spatial_dims == 3: + x = rearrange(x, "(b n_mu) c z y x -> b n_mu (z y x) c", b=b, n_mu=n_mu) + else: + x = rearrange(x, "(b n_mu) c y x -> b n_mu (y x) c", b=b, n_mu=n_mu) return x class HieraEncoder(torch.nn.Module): def __init__( self, - num_patches: List[int], - num_mask_units: List[int], + num_patches: Union[int, List[int]], + num_mask_units: Union[int, List[int]], architecture: List[Dict], emb_dim: int = 64, spatial_dims: int = 3, @@ -145,17 +164,17 @@ def __init__( """ Parameters ---------- - num_patches: List[int] - Number of patches in each dimension - num_mask_units: List[int] - Number of mask units in each dimension + num_patches: int, List[int] + Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same. + num_mask_units: int, List[int] + Number of mask units in each dimension. If a single int is provided, the number of mask units in each dimension will be the same. architecture: List[Dict] List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: - repeat: int Number of times to repeat the block - num_heads: int Number of heads in the multihead attention - - q_stride: List[int] + - q_stride: int, List[int] Stride for the query in each spatial dimension - self_attention: bool Whether to use self attention or mask unit attention @@ -171,6 +190,16 @@ def __init__( Whether to save the intermediate layer outputs """ super().__init__() + num_patches, num_mask_units, patch_size, context_pixels = validate_spatial_dims( + spatial_dims, [num_patches, num_mask_units, patch_size, context_pixels] + ) + # make sure q stride shape matches spatial dims + for i in range(len(architecture)): + if "q_stride" in architecture[i]: + architecture[i]["q_stride"] = validate_spatial_dims( + spatial_dims, [architecture[i]["q_stride"]] + )[0] + self.save_layers = save_layers self.patchify = PatchifyHiera( patch_size, @@ -183,14 +212,17 @@ def __init__( patches_per_mask_unit = np.array(num_patches) // np.array(num_mask_units) - total_downsampling_per_axis = np.prod([block.get("q_stride", [1] * spatial_dims) for block in architecture],axis=0) - - assert np.all(patches_per_mask_unit - total_downsampling_per_axis >= 0), f"Number of mask units must be greater than the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride or increase the number of patches per mask unit." + total_downsampling_per_axis = np.prod( + [block.get("q_stride", [1] * spatial_dims) for block in architecture], axis=0 + ) + + assert np.all( + patches_per_mask_unit - total_downsampling_per_axis >= 0 + ), f"Number of mask units must be greater than the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride or increase the number of patches per mask unit." assert np.all( patches_per_mask_unit % total_downsampling_per_axis == 0 ), f"Number of mask units must be divisible by the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride" - self.final_dim = emb_dim * (2 ** len(architecture)) self.save_block_idxs = [] @@ -219,6 +251,7 @@ def __init__( dim=dim_in, dim_out=dim_out, heads=stage["num_heads"], + spatial_dims=spatial_dims, q_stride=q_stride, patches_per_mask_unit=patches_per_mask_unit, ) @@ -233,7 +266,12 @@ def __init__( # create a spatial merger for combining tokens pre-downsampling, last stage doesn't need merging since it has expected num channels, spatial shape self.spatial_mergers[f"block_{save_block}"] = ( - SpatialMerger(patches_per_mask_unit, dim_in, self.final_dim) + SpatialMerger( + patches_per_mask_unit, + dim_in, + self.final_dim, + spatial_dims=spatial_dims, + ) if stage_num < len(architecture) - 1 else torch.nn.Identity() ) @@ -268,7 +306,6 @@ def forward(self, img, mask_ratio): mask_unit_embeddings = rearrange(mask_unit_embeddings, "b n_mu t c -> b (n_mu t) c") mask_unit_embeddings = self.self_attention_transformer(mask_unit_embeddings) mask_unit_embeddings = self.layer_norm(mask_unit_embeddings) - mask_unit_embeddings = rearrange(mask_unit_embeddings, 'b t c -> t b c') - - return mask_unit_embeddings, mask, forward_indexes, backward_indexes #, save_layers + mask_unit_embeddings = rearrange(mask_unit_embeddings, "b t c -> t b c") + return mask_unit_embeddings, mask, forward_indexes, backward_indexes # , save_layers diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index 56953acbf..b7d601760 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -1,33 +1,31 @@ # modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import numpy as np import torch from cyto_dl.nn.vits.decoder import CrossMAE_Decoder, MAE_Decoder from cyto_dl.nn.vits.encoder import HieraEncoder, MAE_Encoder +from cyto_dl.nn.vits.utils import validate_spatial_dims class MAE_Base(torch.nn.Module, ABC): - def __init__(self, spatial_dims, num_patches, patch_size, mask_ratio, features_only): + def __init__( + self, spatial_dims, num_patches, patch_size, mask_ratio, features_only, context_pixels + ): super().__init__() - assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" - - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(patch_size, int): - patch_size = [patch_size] * spatial_dims - - assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" - assert len(patch_size) == spatial_dims, "patch_size must be of length spatial_dims" + num_patches, patch_size, context_pixels = validate_spatial_dims( + spatial_dims, [num_patches, patch_size, context_pixels] + ) self.spatial_dims = spatial_dims self.num_patches = num_patches self.patch_size = patch_size self.mask_ratio = mask_ratio self.features_only = features_only + self.context_pixels = context_pixels # encoder and decoder must be defined in subclasses @property @@ -111,6 +109,7 @@ def __init__( patch_size=patch_size, mask_ratio=mask_ratio, features_only=features_only, + context_pixels=context_pixels, ) self._encoder = MAE_Encoder( @@ -120,7 +119,7 @@ def __init__( emb_dim, encoder_layer, encoder_head, - context_pixels, + self.context_pixels, input_channels, n_intermediate_weights=-1 if not use_crossmae else decoder_layer, ) @@ -153,9 +152,9 @@ def __init__( self, architecture: List[Dict], spatial_dims: int = 3, - num_patches: Optional[List[int]] = [2, 32, 32], - num_mask_units: Optional[List[int]] = [2, 12, 12], - patch_size: Optional[List[int]] = [16, 16, 16], + num_patches: Optional[Union[int, List[int]]] = [2, 32, 32], + num_mask_units: Optional[Union[int, List[int]]] = [2, 12, 12], + patch_size: Optional[Union[int, List[int]]] = [16, 16, 16], emb_dim: Optional[int] = 64, decoder_layer: Optional[int] = 4, decoder_head: Optional[int] = 8, @@ -181,12 +180,12 @@ def __init__( Whether to use self attention or mask unit attention spatial_dims: int Number of spatial dimensions - num_patches: List[int] - Number of patches in each dimension - num_mask_units: List[int] - Number of mask units in each dimension - patch_size: List[int] - Size of each patch + num_patches: int, List[int] + Number of patches in each dimension (Z)YX order. If int, the same number of patches is used in each dimension. + num_mask_units: int, List[int] + Number of mask units in each dimension (Z)YX order. If int, the same number of mask units is used in each dimension. + patch_size: int, List[int] + Size of each patch (Z)YX order. If int, the same patch size is used in each dimension. emb_dim: int Dimension of embedding decoder_layer: int @@ -212,19 +211,23 @@ def __init__( patch_size=patch_size, mask_ratio=mask_ratio, features_only=features_only, + context_pixels=context_pixels, ) + num_mask_units = validate_spatial_dims(self.spatial_dims, [num_mask_units])[0] self._encoder = HieraEncoder( num_patches=self.num_patches, num_mask_units=num_mask_units, architecture=architecture, emb_dim=emb_dim, - spatial_dims=spatial_dims, + spatial_dims=self.spatial_dims, patch_size=self.patch_size, - context_pixels=context_pixels, + context_pixels=self.context_pixels, ) # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size - mask_unit_size = (np.array(num_patches) * np.array(patch_size)) / np.array(num_mask_units) + mask_unit_size = (np.array(self.num_patches) * np.array(self.patch_size)) / np.array( + num_mask_units + ) decoder_class = MAE_Decoder if use_crossmae: diff --git a/cyto_dl/nn/vits/utils.py b/cyto_dl/nn/vits/utils.py index d6a8b246d..caeb20237 100644 --- a/cyto_dl/nn/vits/utils.py +++ b/cyto_dl/nn/vits/utils.py @@ -3,6 +3,7 @@ import numpy as np import torch from einops import rearrange, repeat +from monai.utils.misc import ensure_tuple_rep from positional_encodings.torch_encodings import ( PositionalEncoding2D, PositionalEncoding3D, @@ -13,11 +14,13 @@ def take_indexes(sequences, indexes): return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1])) + def random_indexes(size: int, device): forward_indexes = torch.randperm(size, device=device, dtype=torch.long) backward_indexes = torch.argsort(forward_indexes) return forward_indexes, backward_indexes + def get_positional_embedding( num_patches: Sequence[int], emb_dim: int, use_cls_token: bool = True, learnable: bool = True ): @@ -44,3 +47,8 @@ def get_positional_embedding( cls_token = torch.zeros(1, 1, emb_dim) pe = torch.cat([cls_token, pe], dim=0) return torch.nn.Parameter(pe, requires_grad=False) + + +def validate_spatial_dims(spatial_dims, tuples): + assert spatial_dims in (2, 3), "spatial_dims must be 2 or 3" + return [ensure_tuple_rep(t, spatial_dims) for t in tuples] From 71b80969fd9e352c6a78084ddf8f3105a8d67cd7 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 13:24:39 -0700 Subject: [PATCH 15/27] 2d masked unit attention --- .../nn/vits/blocks/masked_unit_attention.py | 127 +++++++++++++++--- 1 file changed, 105 insertions(+), 22 deletions(-) diff --git a/cyto_dl/nn/vits/blocks/masked_unit_attention.py b/cyto_dl/nn/vits/blocks/masked_unit_attention.py index 92a76ceec..9b61d18e3 100644 --- a/cyto_dl/nn/vits/blocks/masked_unit_attention.py +++ b/cyto_dl/nn/vits/blocks/masked_unit_attention.py @@ -9,12 +9,15 @@ from einops.layers.torch import Reduce from timm.models.layers import DropPath, Mlp +from cyto_dl.nn.vits.utils import validate_spatial_dims + class MaskUnitAttention(torch.nn.Module): def __init__( self, dim, dim_out, + spatial_dims: int = 3, num_heads=8, qkv_bias=False, qk_scale=None, @@ -23,7 +26,36 @@ def __init__( q_stride=[1, 1, 1], patches_per_mask_unit=[2, 12, 12], ): + """ + Parameters + ---------- + dim : int + Dimension of the input features. + dim_out : int + Dimension of the output features. + spatial_dims : int, optional + Number of spatial dimensions, by default 3. + num_heads : int, optional + Number of attention heads, by default 8. + qkv_bias : bool, optional + If True, add a learnable bias to query, key, value, by default False. + qk_scale : float, optional + Override default qk scale of head_dim ** -0.5 if set, by default None. + attn_drop : float, optional + Dropout rate for attention, by default 0.0. + proj_drop : float, optional + Dropout rate for projection, by default 0.0. + q_stride : List[int], optional + Stride for query, by default [1, 1, 1]. + patches_per_mask_unit : List[int], optional + Number of patches per mask unit, by default [2, 12, 12]. + """ super().__init__() + q_stride, patches_per_mask_unit = validate_spatial_dims( + spatial_dims, [q_stride, patches_per_mask_unit] + ) + + self.spatial_dims = spatial_dims self.num_heads = num_heads self.head_dim = dim_out // num_heads self.scale = qk_scale or self.head_dim**-0.5 @@ -41,7 +73,7 @@ def forward(self, x): # project and split into q,k,v embeddings qkv = rearrange( self.qkv(x), - "batch num_mask_units tokens_per_mask_unit (head_dim num_heads qkv) -> qkv batch num_mask_units num_heads tokens_per_mask_unit head_dim", + "batch num_mask_units tokens_per_mask_unit (head_dim num_heads qkv) -> qkv batch num_mask_units num_heads tokens_per_mask_unit head_dim", head_dim=self.head_dim, qkv=3, num_heads=self.num_heads, @@ -52,17 +84,28 @@ def forward(self, x): if np.any(self.q_stride > 1): # within a mask unit, tokens are spatially ordered # perform spatial 2x2x2 max pooling over tokens - q = reduce( - q, - "b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c", - reduction="max", - q_stride_z=self.q_stride[0], - q_stride_y=self.q_stride[1], - q_stride_x=self.q_stride[2], - n_patches_z=self.pooled_patches_per_mask_unit[0], - n_patches_y=self.pooled_patches_per_mask_unit[1], - n_patches_x=self.pooled_patches_per_mask_unit[2], - ) + if self.spatial_dims == 3: + q = reduce( + q, + "b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c", + reduction="max", + q_stride_z=self.q_stride[0], + q_stride_y=self.q_stride[1], + q_stride_x=self.q_stride[2], + n_patches_z=self.pooled_patches_per_mask_unit[0], + n_patches_y=self.pooled_patches_per_mask_unit[1], + n_patches_x=self.pooled_patches_per_mask_unit[2], + ) + elif self.spatial_dims == 2: + q = reduce( + q, + "b n h (n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_y n_patches_x) c", + reduction="max", + q_stride_y=self.q_stride[0], + q_stride_x=self.q_stride[1], + n_patches_y=self.pooled_patches_per_mask_unit[0], + n_patches_x=self.pooled_patches_per_mask_unit[1], + ) attn = F.scaled_dot_product_attention( q, @@ -86,6 +129,7 @@ def __init__( dim: int, dim_out: int, heads: int, + spatial_dims: int = 3, mlp_ratio: float = 4.0, drop_path: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, @@ -93,8 +137,36 @@ def __init__( q_stride: List[int] = [1, 1, 1], patches_per_mask_unit: List[int] = [2, 12, 12], ): + """ + Parameters + ---------- + dim : int + Dimension of the input features. + dim_out : int + Dimension of the output features. + spatial_dims : int, optional + Number of spatial dimensions, by default 3. + num_heads : int, optional + Number of attention heads, by default 8. + qkv_bias : bool, optional + If True, add a learnable bias to query, key, value, by default False. + qk_scale : float, optional + Override default qk scale of head_dim ** -0.5 if set, by default None. + attn_drop : float, optional + Dropout rate for attention, by default 0.0. + proj_drop : float, optional + Dropout rate for projection, by default 0.0. + q_stride : List[int], optional + Stride for query, by default [1, 1, 1]. + patches_per_mask_unit : List[int], optional + Number of patches per mask unit, by default [2, 12, 12]. + """ super().__init__() + patches_per_mask_unit, q_stride = validate_spatial_dims( + spatial_dims, [patches_per_mask_unit, q_stride] + ) + self.spatial_dims = spatial_dims self.dim = dim self.dim_out = dim_out self.q_stride = q_stride @@ -106,6 +178,7 @@ def __init__( self.attn = MaskUnitAttention( dim, dim_out, + spatial_dims=spatial_dims, num_heads=heads, q_stride=q_stride, patches_per_mask_unit=patches_per_mask_unit, @@ -117,16 +190,26 @@ def __init__( self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() # max pooling by q stride within a mask unit - skip_connection_pooling = Reduce( - "b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c", - reduction="mean", - q_stride_z=self.q_stride[0], - q_stride_y=self.q_stride[1], - q_stride_x=self.q_stride[2], - n_patches_z=self.attn.pooled_patches_per_mask_unit[0], - n_patches_y=self.attn.pooled_patches_per_mask_unit[1], - n_patches_x=self.attn.pooled_patches_per_mask_unit[2], - ) + if self.spatial_dims == 3: + skip_connection_pooling = Reduce( + "b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c", + reduction="mean", + q_stride_z=self.q_stride[0], + q_stride_y=self.q_stride[1], + q_stride_x=self.q_stride[2], + n_patches_z=self.attn.pooled_patches_per_mask_unit[0], + n_patches_y=self.attn.pooled_patches_per_mask_unit[1], + n_patches_x=self.attn.pooled_patches_per_mask_unit[2], + ) + elif self.spatial_dims == 2: + skip_connection_pooling = Reduce( + "b n (n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_y n_patches_x) c", + reduction="mean", + q_stride_y=self.q_stride[0], + q_stride_x=self.q_stride[1], + n_patches_y=self.attn.pooled_patches_per_mask_unit[0], + n_patches_x=self.attn.pooled_patches_per_mask_unit[1], + ) self.proj = ( torch.nn.Sequential(skip_connection_pooling, nn.Linear(dim, dim_out)) From 038d57382fc813ef060b99915764eef6a14bb35b Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 13:25:17 -0700 Subject: [PATCH 16/27] precommit --- cyto_dl/nn/vits/blocks/patchify/patchify.py | 15 ++++--- .../nn/vits/blocks/patchify/patchify_base.py | 1 - cyto_dl/nn/vits/decoder.py | 39 +++++++++++-------- 3 files changed, 31 insertions(+), 24 deletions(-) diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify.py b/cyto_dl/nn/vits/blocks/patchify/patchify.py index 3f6c21ac6..4da4e788e 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify.py @@ -1,12 +1,15 @@ -from cyto_dl.nn.vits.blocks.patchify.patchify_base import PatchifyBase from typing import List, Optional -from einops.layers.torch import Rearrange + import numpy as np +from einops.layers.torch import Rearrange + +from cyto_dl.nn.vits.blocks.patchify.patchify_base import PatchifyBase from cyto_dl.nn.vits.utils import take_indexes class Patchify(PatchifyBase): """Class for converting images to a masked sequence of patches with positional embeddings.""" + def __init__( self, patch_size: List[int], @@ -28,11 +31,11 @@ def __init__( tasks=tasks, learnable_pos_embedding=learnable_pos_embedding, ) - + @property def img2token(self): return self.create_img2token() - + def get_mask_args(self, mask_ratio): num_patches = np.prod(self.n_patches) n_visible_patches = int(num_patches * (1 - mask_ratio)) @@ -44,6 +47,6 @@ def create_img2token(self): return Rearrange("b c z y x -> (z y x) b c") elif self.spatial_dims == 2: return Rearrange("b c y x -> (y x) b c") - + def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): - return take_indexes(tokens, forward_indexes)[:n_visible_patches] \ No newline at end of file + return take_indexes(tokens, forward_indexes)[:n_visible_patches] diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py index 8d2b99ff9..f2b34024d 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py @@ -168,7 +168,6 @@ def forward(self, img, mask_ratio, task=None): mask, forward_indexes, backward_indexes = self.get_mask( img, n_visible_patches, num_patches ) - # generate patches tokens = self.conv(img * mask) tokens = self.img2token(tokens) diff --git a/cyto_dl/nn/vits/decoder.py b/cyto_dl/nn/vits/decoder.py index bf98de76c..eb7b5081c 100644 --- a/cyto_dl/nn/vits/decoder.py +++ b/cyto_dl/nn/vits/decoder.py @@ -9,15 +9,8 @@ from timm.models.layers import trunc_normal_ from timm.models.vision_transformer import Block -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes - -from typing import List, Optional - -import torch -from einops import rearrange - from cyto_dl.nn.vits.blocks import CrossAttentionBlock -from cyto_dl.nn.vits.utils import take_indexes +from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes class MAE_Decoder(torch.nn.Module): @@ -98,26 +91,24 @@ def init_weight(self): def adjust_indices_for_cls(self, indexes): if self.has_cls_token: - # add all zeros to indices - this keeps the class token as the first index in the + # add all zeros to indices - this keeps the class token as the first index in the # tensor. We also have to add 1 to all the indices to account for the zeros we added return torch.cat( [ - torch.zeros( - 1, indexes.shape[1], device=indexes.device, dtype=torch.long - ), + torch.zeros(1, indexes.shape[1], device=indexes.device, dtype=torch.long), indexes + 1, ], dim=0, ) return indexes - + def add_mask_tokens(self, features, backward_indexes): # fill in deleted masked regions with mask token num_visible_tokens, B, _ = features.shape # total number of tokens - number of visible tokens num_mask_tokens = backward_indexes.shape[0] - num_visible_tokens mask_tokens = repeat(self.mask_token, "1 1 c -> t b c", t=num_mask_tokens, b=B) - return torch.cat([features, mask_tokens],dim=0) + return torch.cat([features, mask_tokens], dim=0) def forward(self, features, forward_indexes, backward_indexes): # project from encoder dimension to decoder dimension @@ -135,7 +126,7 @@ def forward(self, features, forward_indexes, backward_indexes): features = rearrange(features, "t b c -> b t c") features = self.transformer(features) features = rearrange(features, "b t c -> t b c") - + if self.has_cls_token: features = features[1:] # remove global feature @@ -183,7 +174,17 @@ def __init__( learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. """ - super().__init__(num_patches, spatial_dims, patch_size, enc_dim, emb_dim, num_layer, num_head, has_cls_token, learnable_pos_embedding) + super().__init__( + num_patches, + spatial_dims, + patch_size, + enc_dim, + emb_dim, + num_layer, + num_head, + has_cls_token, + learnable_pos_embedding, + ) self.transformer = torch.nn.ParameterList( [ @@ -244,7 +245,11 @@ def forward(self, features, forward_indexes, backward_indexes): ], dim=0, ) - patches = take_indexes(patches, backward_indexes[1:] - 1) if self.has_cls_token else take_indexes(patches, backward_indexes) + patches = ( + take_indexes(patches, backward_indexes[1:] - 1) + if self.has_cls_token + else take_indexes(patches, backward_indexes) + ) # patches to image img = self.patch2img(patches) return img From 5098732aeb8e47edf106423aa9c3d70ae3c123c2 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 14:27:49 -0700 Subject: [PATCH 17/27] update configs --- configs/data/im2im/mae.yaml | 19 ++++++++----------- configs/experiment/im2im/hiera.yaml | 5 +---- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/configs/data/im2im/mae.yaml b/configs/data/im2im/mae.yaml index a310e15ba..6cd6a946e 100644 --- a/configs/data/im2im/mae.yaml +++ b/configs/data/im2im/mae.yaml @@ -32,12 +32,12 @@ transforms: - _target_: monai.transforms.NormalizeIntensityd keys: ${source_col} channel_wise: True - - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + - _target_: monai.transforms.RandSpatialCropSamplesd keys: - ${source_col} - patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 - scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + roi_size: ${data._aux.patch_shape} + num_samples: 1 + random_size: False test: _target_: monai.transforms.Compose @@ -104,14 +104,11 @@ transforms: - _target_: monai.transforms.NormalizeIntensityd keys: ${source_col} channel_wise: True - - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + - _target_: monai.transforms.RandSpatialCropSamplesd keys: - ${source_col} - patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 - scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + roi_size: ${data._aux.patch_shape} + num_samples: 1 + random_size: False _aux: - _scales_dict: - - - ${source_col} - - [1] diff --git a/configs/experiment/im2im/hiera.yaml b/configs/experiment/im2im/hiera.yaml index 2363a947f..f1351e492 100644 --- a/configs/experiment/im2im/hiera.yaml +++ b/configs/experiment/im2im/hiera.yaml @@ -29,10 +29,7 @@ trainer: data: path: ${paths.data_dir}/example_experiment_data/segmentation cache_dir: ${paths.data_dir}/example_experiment_data/cache - batch_size: 128 - num_workers: 8 - subsample: - train: 10000 + batch_size: 1 _aux: # 2D # patch_shape: [16, 16] From 453f0e5012ea5395591dff6845efeb7c7e9e08e4 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 14:28:10 -0700 Subject: [PATCH 18/27] update hiera model config --- configs/model/im2im/hiera.yaml | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/configs/model/im2im/hiera.yaml b/configs/model/im2im/hiera.yaml index 901bbe02c..06105ec4c 100644 --- a/configs/model/im2im/hiera.yaml +++ b/configs/model/im2im/hiera.yaml @@ -7,28 +7,28 @@ x_key: ${source_col} backbone: _target_: cyto_dl.nn.vits.mae.HieraMAE - spatial_dims: 3 - patch_size: [2, 2, 2] # patch_size* num_patches should be your patch shape - num_patches: [8, 8, 8] # patch_size * num_patches = img_shape - num_mask_units: [4, 4, 4] #img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit - emb_dim: 2 + spatial_dims: ${spatial_dims} + patch_size: 2 # patch_size* num_patches should be your patch shape + num_patches: 8 # patch_size * num_patches = img_shape + num_mask_units: 4 #img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit + emb_dim: 4 architecture: # mask_unit_attention blocks - attention is only done within a mask unit and not across mask units # the total amount of q_stride across the architecture must be less than the number of patches per mask unit - repeat: 1 - q_stride: [1,1,1] + q_stride: 2 num_heads: 1 - repeat: 1 - q_stride: [2,2,2] - num_heads: 4 + q_stride: 1 + num_heads: 2 # self attention transformer - attention is done across all patches, irrespective of which mask unit they're in - repeat: 2 - num_heads: 8 + num_heads: 4 self_attention: True decoder_layer: 1 decoder_dim: 16 mask_ratio: 0.66666666666 - context_pixels: [4,4,4] + context_pixels: 3 use_crossmae: True task_heads: ${kv_to_dict:${model._aux._tasks}} From ec07e39187411aa5234dba6195bfda05b8a51874 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 14:29:11 -0700 Subject: [PATCH 19/27] update deafults --- .../nn/vits/blocks/patchify/patchify_hiera.py | 21 ++++++++++--------- cyto_dl/nn/vits/encoder.py | 17 +++++++++------ cyto_dl/nn/vits/mae.py | 15 ++++++------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py index 8712de088..9c8166860 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py @@ -2,7 +2,7 @@ import numpy as np import torch -from einops import repeat +from einops import rearrange, repeat from einops.layers.torch import Rearrange from timm.models.layers import trunc_normal_ @@ -79,27 +79,28 @@ def __init__( ).astype(int) patches_per_mask_unit = mask_unit_size_pix // patch_size + + # rearrange patch embeddings to mask units self.pos_embedding = torch.nn.Parameter( - torch.zeros(1, self.total_n_mask_units, np.prod(patches_per_mask_unit), emb_dim) + rearrange( + self.pos_embedding, + "(ppmu total_n_mu) 1 emb_dim -> 1 total_n_mu ppmu emb_dim", + total_n_mu=self.total_n_mask_units, + ppmu=patches_per_mask_unit.prod(), + emb_dim=emb_dim, + ) ) self.mask_units_per_dim = mask_units_per_dim self.patch2img = self.create_patch2img(mask_units_per_dim, mask_unit_size_pix) - self._init_weight() - - def _init_weight(self): - trunc_normal_(self.pos_embedding, std=0.02) - for task in self.task_embedding: - trunc_normal_(self.task_embedding[task], std=0.02) - @property def img2token(self): # redefine this to work with mask units instead of patches return self.create_img2token(self.mask_units_per_dim) - # in hiera, the level of masking is at the mask unit, not the patch level + # in hiera, the masking is done at the mask unit, not the patch level def get_mask_args(self, mask_ratio): n_visible_patches = int(self.total_n_mask_units * (1 - mask_ratio)) return n_visible_patches, self.total_n_mask_units diff --git a/cyto_dl/nn/vits/encoder.py b/cyto_dl/nn/vits/encoder.py index 11653a4eb..56ce707a2 100644 --- a/cyto_dl/nn/vits/encoder.py +++ b/cyto_dl/nn/vits/encoder.py @@ -159,6 +159,7 @@ def __init__( spatial_dims: int = 3, patch_size: List[int] = (16, 16, 16), context_pixels: Optional[List[int]] = [0, 0, 0], + input_channels: Optional[int] = 1, save_layers: Optional[bool] = False, ) -> None: """ @@ -186,6 +187,8 @@ def __init__( Size of each patch context_pixels: List[int] Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels save_layers: bool Whether to save the intermediate layer outputs """ @@ -202,11 +205,12 @@ def __init__( self.save_layers = save_layers self.patchify = PatchifyHiera( - patch_size, - num_patches, - emb_dim, - spatial_dims, - context_pixels, + patch_size=patch_size, + n_patches=num_patches, + emb_dim=emb_dim, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, mask_units_per_dim=num_mask_units, ) @@ -223,6 +227,7 @@ def __init__( patches_per_mask_unit % total_downsampling_per_axis == 0 ), f"Number of mask units must be divisible by the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride" + # number of output features doubles in each masked unit attention block, stays constant during self attention blocks self.final_dim = emb_dim * (2 ** len(architecture)) self.save_block_idxs = [] @@ -276,7 +281,7 @@ def __init__( else torch.nn.Identity() ) - # at end of each layer, patches per mask unit is reduced as we pool spatially + # at end of each layer, patches per mask unit is reduced as we pool spatially within mask units patches_per_mask_unit = patches_per_mask_unit // np.array(stage["q_stride"]) num_blocks += 1 self.mask_unit_transformer = torch.nn.Sequential(*transformer) diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index b7d601760..ab792e2ff 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -56,8 +56,8 @@ class MAE(MAE_Base): def __init__( self, spatial_dims: int = 3, - num_patches: Optional[List[int]] = [2, 32, 32], - patch_size: Optional[List[int]] = [16, 16, 16], + num_patches: Optional[List[int]] = 16, + patch_size: Optional[List[int]] = 4, emb_dim: Optional[int] = 768, encoder_layer: Optional[int] = 12, encoder_head: Optional[int] = 8, @@ -66,7 +66,7 @@ def __init__( decoder_dim: Optional[int] = 192, mask_ratio: Optional[int] = 0.75, use_crossmae: Optional[bool] = False, - context_pixels: Optional[List[int]] = [0, 0, 0], + context_pixels: Optional[List[int]] = 0, input_channels: Optional[int] = 1, features_only: Optional[bool] = False, learnable_pos_embedding: Optional[bool] = True, @@ -152,16 +152,16 @@ def __init__( self, architecture: List[Dict], spatial_dims: int = 3, - num_patches: Optional[Union[int, List[int]]] = [2, 32, 32], - num_mask_units: Optional[Union[int, List[int]]] = [2, 12, 12], - patch_size: Optional[Union[int, List[int]]] = [16, 16, 16], + num_patches: Optional[Union[int, List[int]]] = 16, + num_mask_units: Optional[Union[int, List[int]]] = 8, + patch_size: Optional[Union[int, List[int]]] = 4, emb_dim: Optional[int] = 64, decoder_layer: Optional[int] = 4, decoder_head: Optional[int] = 8, decoder_dim: Optional[int] = 192, mask_ratio: Optional[int] = 0.75, use_crossmae: Optional[bool] = False, - context_pixels: Optional[List[int]] = [0, 0, 0], + context_pixels: Optional[List[int]] = 0, input_channels: Optional[int] = 1, features_only: Optional[bool] = False, ) -> None: @@ -223,6 +223,7 @@ def __init__( spatial_dims=self.spatial_dims, patch_size=self.patch_size, context_pixels=self.context_pixels, + input_channels=input_channels, ) # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size mask_unit_size = (np.array(self.num_patches) * np.array(self.patch_size)) / np.array( From 128baa51db860497783a3663bb74d123d2b17753 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 14:29:23 -0700 Subject: [PATCH 20/27] update tests --- tests/conftest.py | 11 ++++++++++- tests/utils.py | 2 +- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 5094ae856..ed987812d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,16 @@ OmegaConf.register_new_resolver("eval", eval) # Experiment configs to test -experiment_types = ["mae", "ijepa", "iwm", "segmentation", "labelfree", "gan", "instance_seg"] +experiment_types = [ + "hiera", + "mae", + "ijepa", + "iwm", + "segmentation", + "labelfree", + "gan", + "instance_seg", +] @pytest.fixture(scope="package", params=experiment_types) diff --git a/tests/utils.py b/tests/utils.py index 61778c230..99d097067 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,4 +9,4 @@ def resolve_readonly(cfg): def skip_test(test_name): """Skip pretraining models for testing.""" - return bool(np.any([x in test_name for x in ("mae", "ijepa", "iwm")])) + return bool(np.any([x in test_name for x in ("mae", "ijepa", "iwm", "hiera")])) From 3c97933c9ef05747cf952cbda01aa18c7199315a Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 14:33:42 -0700 Subject: [PATCH 21/27] delete patchify_conv --- .../nn/vits/blocks/patchify/patchify_conv.py | 122 ------------------ 1 file changed, 122 deletions(-) delete mode 100644 cyto_dl/nn/vits/blocks/patchify/patchify_conv.py diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_conv.py b/cyto_dl/nn/vits/blocks/patchify/patchify_conv.py deleted file mode 100644 index 685f65c62..000000000 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_conv.py +++ /dev/null @@ -1,122 +0,0 @@ -from monai.networks.nets import Regressor - - -class PatchifyConv(torch.nn.Module): - """Class for converting images to a masked sequence of patches with positional embeddings.""" - - def __init__( - self, - patch_size: List[int], - emb_dim: int, - n_patches: List[int], - spatial_dims: int = 3, - input_channels: int = 1, - tasks: Optional[List[str]] = [], - ): - """ - Parameters - ---------- - patch_size: List[int] - Size of each patch - emb_dim: int - Dimension of encoder - n_patches: List[int] - Number of patches in each spatial dimension - spatial_dims: int - Number of spatial dimensions - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. - input_channels: int - Number of input channels - tasks: List[str] - List of tasks to encode - """ - super().__init__() - self.n_patches = np.asarray(n_patches) - self.spatial_dims = spatial_dims - - self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(n_patches), 1, emb_dim)) - - self.conv = Regressor( - in_shape=patch_size, - out_shape=emb_dim, - channels=[16, 64, 256, 512], - strides=[2, 2, 2, 1], - kernel_size=3, - ) - - if spatial_dims == 3: - self.img2token = Rearrange("b c z y x -> (z y x) b c") - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", - n_patch_z=n_patches[0], - n_patch_y=n_patches[1], - n_patch_x=n_patches[2], - ) - elif spatial_dims == 2: - self.img2token = Rearrange("b c y x -> (y x) b c") - self.patch2img = Rearrange( - "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", - n_patch_y=n_patches[0], - n_patch_x=n_patches[1], - ) - - self.task_embedding = torch.nn.ParameterDict( - {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} - ) - self._init_weight() - - def _init_weight(self): - trunc_normal_(self.pos_embedding, std=0.02) - for task in self.task_embedding: - trunc_normal_(self.task_embedding[task], std=0.02) - - def get_mask(self, img, n_visible_patches, num_patches): - B = img.shape[0] - indexes = [random_indexes(num_patches, device=img.device) for _ in range(B)] - # forward indexes : index in image -> shuffledpatch - forward_indexes = torch.stack([i[0] for i in indexes], axis=-1) - # backward indexes : shuffled patch -> index in image - backward_indexes = torch.stack([i[1] for i in indexes], axis=-1) - - mask = torch.zeros(num_patches, B, 1, device=img.device, dtype=torch.uint8) - # visible patches are first - mask[:n_visible_patches] = 1 - mask = take_indexes(mask, backward_indexes) - mask = self.patch2img(mask) - # one pixel per masked patch, interpolate to size of input image - mask = torch.nn.functional.interpolate( - mask, img.shape[-self.spatial_dims :], mode="nearest" - ) - return mask, forward_indexes, backward_indexes - - def forward(self, img, mask_ratio=0.75, n_visible_patches=None, task=None): - mask = torch.ones_like(img) - forward_indexes, backward_indexes = None, None - if mask_ratio > 0: - # generate mask - num_patches = np.prod(self.n_patches) - n_visible_patches = n_visible_patches or int(num_patches * (1 - mask_ratio)) - mask, forward_indexes, backward_indexes = self.get_mask( - img, n_visible_patches, num_patches - ) - # generate patches - tokens = self.conv(img * mask) - tokens = self.img2token(tokens) - - pos_emb = rearrange(self.pos_embedding, "t b c -> b c t") - pos_emb = torch.nn.functional.interpolate(pos_emb, tokens.shape[0], mode="linear") - pos_emb = rearrange(pos_emb, "b c t -> t b c") - - # add position embedding - tokens = tokens + pos_emb - if mask_ratio > 0: - # extract visible patches - tokens = take_indexes(tokens, forward_indexes)[:n_visible_patches] - - if task in self.task_embedding: - tokens = tokens + self.task_embedding[task] - - # mask is used above to mask out patches, we need to invert it for loss calculation - mask = (1 - mask).bool() - return tokens, mask, forward_indexes, backward_indexes From 1d4a76d040ac73a958bb3984b089d899de12d246 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Mon, 19 Aug 2024 16:23:17 -0700 Subject: [PATCH 22/27] fix jepa tests --- configs/data/im2im/ijepa.yaml | 3 + configs/data/im2im/iwm.yaml | 2 + configs/experiment/im2im/ijepa.yaml | 7 - configs/experiment/im2im/iwm.yaml | 7 - configs/model/im2im/ijepa.yaml | 9 +- configs/model/im2im/iwm.yaml | 9 +- .../nn/vits/blocks/patchify/patchify_base.py | 2 +- cyto_dl/nn/vits/decoder.py | 32 +-- cyto_dl/nn/vits/encoder.py | 82 ++++++- cyto_dl/nn/vits/jepa.py | 205 ------------------ 10 files changed, 109 insertions(+), 249 deletions(-) delete mode 100644 cyto_dl/nn/vits/jepa.py diff --git a/configs/data/im2im/ijepa.yaml b/configs/data/im2im/ijepa.yaml index 4c1131daa..add94c074 100644 --- a/configs/data/im2im/ijepa.yaml +++ b/configs/data/im2im/ijepa.yaml @@ -39,6 +39,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} test: _target_: monai.transforms.Compose @@ -69,6 +70,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} predict: _target_: monai.transforms.Compose @@ -127,6 +129,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} _aux: _scales_dict: diff --git a/configs/data/im2im/iwm.yaml b/configs/data/im2im/iwm.yaml index 526920ba2..9042f6652 100644 --- a/configs/data/im2im/iwm.yaml +++ b/configs/data/im2im/iwm.yaml @@ -58,6 +58,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} test: _target_: monai.transforms.Compose @@ -176,5 +177,6 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} _aux: diff --git a/configs/experiment/im2im/ijepa.yaml b/configs/experiment/im2im/ijepa.yaml index df5529d36..d0af7cd6f 100644 --- a/configs/experiment/im2im/ijepa.yaml +++ b/configs/experiment/im2im/ijepa.yaml @@ -37,10 +37,3 @@ data: # patch_shape: [16, 16] # 3D patch_shape: [16, 16, 16] - -model: - _aux: - # 3D - num_patches: [8, 8, 8] - # 2d - # num_patches: [8, 8] diff --git a/configs/experiment/im2im/iwm.yaml b/configs/experiment/im2im/iwm.yaml index e097933b0..3cee1f777 100644 --- a/configs/experiment/im2im/iwm.yaml +++ b/configs/experiment/im2im/iwm.yaml @@ -38,13 +38,6 @@ data: # 3D patch_shape: [16, 16, 16] -model: - _aux: - # 3D - num_patches: [8, 8, 8] - # 2d - # num_patches: [8, 8] - callbacks: prediction_saver: _target_: cyto_dl.callbacks.csv_saver.CSVSaver diff --git a/configs/model/im2im/ijepa.yaml b/configs/model/im2im/ijepa.yaml index 1e11a57d3..40ea24201 100644 --- a/configs/model/im2im/ijepa.yaml +++ b/configs/model/im2im/ijepa.yaml @@ -5,15 +5,16 @@ max_epochs: ${trainer.max_epochs} save_dir: ${paths.output_dir} encoder: - _target_: cyto_dl.nn.vits.jepa.JEPAEncoder - patch_size: 2 + _target_: cyto_dl.nn.vits.encoder.JEPAEncoder + patch_size: 2 # patch_size * num_patches should equl data._aux.patch_shape num_patches: ${model._aux.num_patches} emb_dim: 16 num_layer: 2 num_head: 1 + spatial_dims: ${spatial_dims} predictor: - _target_: cyto_dl.nn.vits.jepa.JEPAPredictor + _target_: cyto_dl.nn.vits.predictor.JEPAPredictor num_patches: ${model._aux.num_patches} input_dim: ${model.encoder.emb_dim} emb_dim: 8 @@ -34,4 +35,4 @@ lr_scheduler: pct_start: 0.1 _aux: - num_patches: + num_patches: 8 diff --git a/configs/model/im2im/iwm.yaml b/configs/model/im2im/iwm.yaml index 21357ed3e..66ccb7f68 100644 --- a/configs/model/im2im/iwm.yaml +++ b/configs/model/im2im/iwm.yaml @@ -8,15 +8,16 @@ max_epochs: ${trainer.max_epochs} save_dir: ${paths.output_dir} encoder: - _target_: cyto_dl.nn.vits.jepa.JEPAEncoder - patch_size: 2 + _target_: cyto_dl.nn.vits.encoder.JEPAEncoder + patch_size: 2 # patch_size * num_patches should be the same as data._aux.patch_shape num_patches: ${model._aux.num_patches} emb_dim: 16 num_layer: 1 num_head: 1 + spatial_dims: ${spatial_dims} predictor: - _target_: cyto_dl.nn.vits.jepa.IWMPredictor + _target_: cyto_dl.nn.vits.predictor.IWMPredictor domains: [SEC61B] num_patches: ${model._aux.num_patches} input_dim: ${model.encoder.emb_dim} @@ -38,4 +39,4 @@ lr_scheduler: pct_start: 0.1 _aux: - num_patches: + num_patches: 8 diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py index f2b34024d..1b2573084 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_base.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py @@ -161,7 +161,7 @@ def get_mask(self, img, n_visible_patches, num_patches): def forward(self, img, mask_ratio, task=None): # generate mask - mask = torch.ones_like(img) + mask = torch.ones_like(img).bool() forward_indexes, backward_indexes = None, None if mask_ratio > 0: n_visible_patches, num_patches = self.get_mask_args(mask_ratio) diff --git a/cyto_dl/nn/vits/decoder.py b/cyto_dl/nn/vits/decoder.py index eb7b5081c..fe296121b 100644 --- a/cyto_dl/nn/vits/decoder.py +++ b/cyto_dl/nn/vits/decoder.py @@ -1,6 +1,6 @@ # modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 -from typing import List, Optional +from typing import List, Optional, Union import torch import torch.nn as nn @@ -10,15 +10,19 @@ from timm.models.vision_transformer import Block from cyto_dl.nn.vits.blocks import CrossAttentionBlock -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes +from cyto_dl.nn.vits.utils import ( + get_positional_embedding, + take_indexes, + validate_spatial_dims, +) class MAE_Decoder(torch.nn.Module): def __init__( self, - num_patches: List[int], + num_patches: Union[int, List[int]], spatial_dims: int = 3, - patch_size: Optional[List[int]] = [4, 8, 8], + patch_size: Optional[Union[int, List[int]]] = 4, enc_dim: Optional[int] = 768, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, @@ -29,10 +33,10 @@ def __init__( """ Parameters ---------- - num_patches: List[int] - Number of patches in each dimension - patch_size: Tuple[int] - Size of each patch + num_patches: List[int], int + Number of patches in each dimension. If int, the same number of patches is used for all dimensions. + patch_size: Tuple[int], int + Size of each patch. If int, the same patch size is used for all dimensions. enc_dim: int Dimension of encoder emb_dim: int @@ -47,6 +51,8 @@ def __init__( If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() + num_patches, patch_size = validate_spatial_dims(spatial_dims, [num_patches, patch_size]) + self.has_cls_token = has_cls_token self.projection_norm = nn.LayerNorm(emb_dim) @@ -144,9 +150,9 @@ class CrossMAE_Decoder(MAE_Decoder): def __init__( self, - num_patches: List[int], + num_patches: Union[int, List[int]], spatial_dims: int = 3, - patch_size: Optional[List[int]] = [4, 8, 8], + patch_size: Optional[Union[int, List[int]]] = 4, enc_dim: Optional[int] = 768, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, @@ -157,10 +163,10 @@ def __init__( """ Parameters ---------- - num_patches: List[int] - Number of patches in each dimension + num_patches: List[int], int + Number of patches in each dimension. If int, the same number of patches is used for all dimensions. patch_size: Tuple[int] - Size of each patch + Size of each patch in each dimension. If int, the same patch size is used for all dimensions. enc_dim: int Dimension of encoder emb_dim: int diff --git a/cyto_dl/nn/vits/encoder.py b/cyto_dl/nn/vits/encoder.py index 56ce707a2..570b704ba 100644 --- a/cyto_dl/nn/vits/encoder.py +++ b/cyto_dl/nn/vits/encoder.py @@ -23,19 +23,19 @@ def __init__( self, num_patches: List[int], spatial_dims: int = 3, - patch_size: List[int] = (16, 16, 16), + patch_size: Union[int, List[int]] = 4, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 12, num_head: Optional[int] = 3, - context_pixels: Optional[List[int]] = [0, 0, 0], + context_pixels: Optional[Union[int, List[int]]] = 0, input_channels: Optional[int] = 1, n_intermediate_weights: Optional[int] = -1, ) -> None: """ Parameters ---------- - num_patches: List[int] - Number of patches in each dimension + num_patches: List[int], int + Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same. spatial_dims: int Number of spatial dimensions patch_size: List[int] @@ -46,8 +46,8 @@ def __init__( Number of transformer layers num_head: int Number of heads in transformer - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + context_pixels: List[int], int + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same. input_channels: int Number of input channels n_intermediate_weights: bool @@ -104,6 +104,72 @@ def forward(self, img, mask_ratio=0.75): return features +class JEPAEncoder(torch.nn.Module): + def __init__( + self, + num_patches: Union[int, List[int]], + spatial_dims: int = 3, + patch_size: Union[int, List[int]] = 4, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 12, + num_head: Optional[int] = 3, + context_pixels: Optional[Union[int, List[int]]] = 0, + input_channels: Optional[int] = 1, + learnable_pos_embedding: Optional[bool] = True, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int], int + Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same. + spatial_dims: int + Number of spatial dimensions + patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + context_pixels: List[int], int + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same. + input_channels: int + Number of input channels + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. + """ + super().__init__() + num_patches, patch_size, context_pixels = validate_spatial_dims( + spatial_dims, [num_patches, patch_size, context_pixels] + ) + + self.patchify = Patchify( + patch_size=patch_size, + emb_dim=emb_dim, + n_patches=num_patches, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, + learnable_pos_embedding=learnable_pos_embedding, + ) + + self.transformer = torch.nn.Sequential( + *[Block(emb_dim, num_head) for _ in range(num_layer)] + ) + + self.layer_norm = torch.nn.LayerNorm(emb_dim) + + def forward(self, img, patchify=True): + if patchify: + patches, _, _, _ = self.patchify(img, mask_ratio=0) + else: + patches = img + patches = rearrange(patches, "t b c -> b t c") + features = self.layer_norm(self.transformer(patches)) + return features + + class SpatialMerger(nn.Module): """Class for converting multi-resolution Hiera features to the same (lowest) spatial resolution via convolution.""" @@ -157,8 +223,8 @@ def __init__( architecture: List[Dict], emb_dim: int = 64, spatial_dims: int = 3, - patch_size: List[int] = (16, 16, 16), - context_pixels: Optional[List[int]] = [0, 0, 0], + patch_size: Union[int, List[int]] = 4, + context_pixels: Optional[Union[int, List[int]]] = 0, input_channels: Optional[int] = 1, save_layers: Optional[bool] = False, ) -> None: diff --git a/cyto_dl/nn/vits/jepa.py b/cyto_dl/nn/vits/jepa.py deleted file mode 100644 index 6b45303d2..000000000 --- a/cyto_dl/nn/vits/jepa.py +++ /dev/null @@ -1,205 +0,0 @@ -from typing import List, Optional, Union - -import torch -import torch.nn.functional -from einops import rearrange -from timm.models.layers import trunc_normal_ -from timm.models.vision_transformer import Block - -from cyto_dl.nn.vits.blocks import CrossAttentionBlock, Patchify -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes - - -class JEPAEncoder(torch.nn.Module): - def __init__( - self, - num_patches: Union[int, List[int]], - spatial_dims: int = 3, - patch_size: Union[int, List[int]] = (16, 16, 16), - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 12, - num_head: Optional[int] = 3, - context_pixels: Optional[List[int]] = [0, 0, 0], - input_channels: Optional[int] = 1, - learnable_pos_embedding: Optional[bool] = True, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - spatial_dims: int - Number of spatial dimensions - patch_size: List[int] - Size of each patch - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. - input_channels: int - Number of input channels - learnable_pos_embedding: bool - If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. - """ - super().__init__() - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(patch_size, int): - patch_size = [patch_size] * spatial_dims - self.patchify = Patchify( - patch_size, - emb_dim, - num_patches, - spatial_dims, - context_pixels, - input_channels, - learnable_pos_embedding=learnable_pos_embedding, - ) - - self.transformer = torch.nn.Sequential( - *[Block(emb_dim, num_head) for _ in range(num_layer)] - ) - - self.layer_norm = torch.nn.LayerNorm(emb_dim) - - def forward(self, img, patchify=True): - if patchify: - patches, _, _, _ = self.patchify(img, mask_ratio=0) - else: - patches = img - patches = rearrange(patches, "t b c -> b t c") - features = self.layer_norm(self.transformer(patches)) - return features - - -class JEPAPredictor(torch.nn.Module): - """Class for predicting target features from context embedding.""" - - def __init__( - self, - num_patches: List[int], - input_dim: Optional[int] = 192, - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 12, - num_head: Optional[int] = 3, - learnable_pos_embedding: Optional[bool] = True, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - learnable_pos_embedding: bool - If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. - """ - super().__init__() - self.transformer = torch.nn.ParameterList( - [ - CrossAttentionBlock( - encoder_dim=emb_dim, - decoder_dim=emb_dim, - num_heads=num_head, - ) - for _ in range(num_layer) - ] - ) - - self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - self.pos_embedding = get_positional_embedding( - num_patches, emb_dim, use_cls_token=False, learnable=learnable_pos_embedding - ) - - self.predictor_embed = torch.nn.Linear(input_dim, emb_dim) - - self.projector_embed = torch.nn.Linear(emb_dim, input_dim) - self.norm = torch.nn.LayerNorm(emb_dim) - self.init_weight() - - def init_weight(self): - trunc_normal_(self.mask_token, std=0.02) - trunc_normal_(self.pos_embedding, std=0.02) - - def predict_target_features(self, context_emb, target_masks): - t, b = target_masks.shape - # add masked positional embedding to mask tokens - mask = self.mask_token.expand(t, b, -1) - pe = self.pos_embedding.expand(-1, b, -1) - pe = take_indexes(pe, target_masks) - mask = mask + pe - mask = rearrange(mask, "t b c -> b t c") - - # cross attention from mask tokens to context embedding - for transformer in self.transformer: - mask = transformer(mask, context_emb) - - # norm and project back to input dimension - mask = self.projector_embed(self.norm(mask)) - return mask - - def forward(self, context_emb, target_masks): - # map context embedding to predictor dimension - context_emb = self.predictor_embed(context_emb) - target_features = self.predict_target_features(context_emb, target_masks) - return target_features - - -class IWMPredictor(JEPAPredictor): - """Specialized JEPA predictor that can conditionally predict between different domains (e.g. - predict from brightfield to multiple fluorescent tags)""" - - def __init__( - self, - domains: List[str], - num_patches: List[int], - input_dim: Optional[int] = 192, - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 12, - num_head: Optional[int] = 3, - ) -> None: - """ - Parameters - ---------- - domains: List[str] - List of names of target domains - num_patches: List[int] - Number of patches in each dimension - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - """ - super().__init__(num_patches, input_dim, emb_dim, num_layer, num_head) - - self.domain_embeddings = torch.nn.ParameterDict( - {d: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for d in domains} - ) - self.context_mixer = torch.nn.Linear(2 * emb_dim, emb_dim, 1) - - def forward(self, context_emb, target_masks, target_domain): - _, b = target_masks.shape - if len(target_domain) == 1: - target_domain = target_domain * b - # map context embedding to predictor dimension - context_emb = self.predictor_embed(context_emb) - - # add target domain information via concatenation + token mixing - target_domain_embedding = torch.cat( - [self.domain_embeddings[td] for td in target_domain] - ).repeat(1, context_emb.shape[1], 1) - context_emb = torch.cat([context_emb, target_domain_embedding], dim=-1) - context_emb = self.context_mixer(context_emb) - - target_features = self.predict_target_features(context_emb, target_masks) - return target_features From 663e52c8dfca5be9a0e7993e432e9efd14894617 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Tue, 20 Aug 2024 10:53:45 -0700 Subject: [PATCH 23/27] update mask transform --- cyto_dl/image/transforms/generate_jepa_masks.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/cyto_dl/image/transforms/generate_jepa_masks.py b/cyto_dl/image/transforms/generate_jepa_masks.py index 0a4dae69e..445587a25 100644 --- a/cyto_dl/image/transforms/generate_jepa_masks.py +++ b/cyto_dl/image/transforms/generate_jepa_masks.py @@ -5,6 +5,8 @@ from monai.transforms import RandomizableTransform from skimage.segmentation import find_boundaries +from cyto_dl.nn.vits.utils import validate_spatial_dims + class JEPAMaskGenerator(RandomizableTransform): """Transform for generating Block-contiguous masks for JEPA training. @@ -15,6 +17,7 @@ class JEPAMaskGenerator(RandomizableTransform): def __init__( self, + spatial_dims: int, mask_size: int = 12, block_aspect_ratio: Tuple[float] = (0.5, 1.5), num_patches: Tuple[float] = (6, 24, 24), @@ -23,6 +26,8 @@ def __init__( """ Parameters ---------- + spatial_dims : int + The number of spatial dimensions of the image (2 or 3) mask_size : int, optional The size of the blocks used to generate mask. Block dimensions are determined by the mask size and an aspect ratio sampled from the range `block_aspect_ratio` block_aspect_ratio : Tuple[float], optional @@ -32,7 +37,9 @@ def __init__( mask_ratio : float, optional The proportion of the image to be masked """ - assert mask_ratio < 1, "mask_ratio must be less than 1" + assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1" + + num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0] assert mask_size * max(block_aspect_ratio) < min( num_patches[-2:] ), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches" @@ -46,7 +53,7 @@ def __init__( self.mask = np.zeros(num_patches) self.edge_mask = np.ones(num_patches) - self.spatial_dims = len(num_patches) + self.spatial_dims = spatial_dims # create a mask that identified pixels on the edge of the image if self.spatial_dims == 3: self.edge_mask[1:-1, 1:-1, 1:-1] = 0 From 30a0526c9136da1aab99466865f044cfe3a361fe Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Tue, 20 Aug 2024 11:53:44 -0700 Subject: [PATCH 24/27] add predictor --- cyto_dl/nn/vits/predictor.py | 157 +++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 cyto_dl/nn/vits/predictor.py diff --git a/cyto_dl/nn/vits/predictor.py b/cyto_dl/nn/vits/predictor.py new file mode 100644 index 000000000..70f227586 --- /dev/null +++ b/cyto_dl/nn/vits/predictor.py @@ -0,0 +1,157 @@ +from typing import List, Optional, Union + +import torch +import torch.nn.functional +from einops import rearrange +from timm.models.layers import trunc_normal_ + +from cyto_dl.nn.vits.blocks import CrossAttentionBlock +from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes, validate_spatial_dims + + +class JEPAPredictor(torch.nn.Module): + """Class for predicting target features from context embedding.""" + + def __init__( + self, + num_patches: Union[int, List[int]], + spatial_dims: int = 3, + input_dim: Optional[int] = 192, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 12, + num_head: Optional[int] = 3, + learnable_pos_embedding: Optional[bool] = True, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int], int + Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions + spatial_dims: int + Number of spatial dimensions + input_dim: int + Dimension of input + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. + """ + super().__init__() + self.transformer = torch.nn.ParameterList( + [ + CrossAttentionBlock( + encoder_dim=emb_dim, + decoder_dim=emb_dim, + num_heads=num_head, + ) + for _ in range(num_layer) + ] + ) + + num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0] + + self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) + self.pos_embedding = get_positional_embedding( + num_patches, emb_dim, use_cls_token=False, learnable=learnable_pos_embedding + ) + + self.predictor_embed = torch.nn.Linear(input_dim, emb_dim) + + self.projector_embed = torch.nn.Linear(emb_dim, input_dim) + self.norm = torch.nn.LayerNorm(emb_dim) + self.init_weight() + + def init_weight(self): + trunc_normal_(self.mask_token, std=0.02) + trunc_normal_(self.pos_embedding, std=0.02) + + def predict_target_features(self, context_emb, target_masks): + t, b = target_masks.shape + # add masked positional embedding to mask tokens + mask = self.mask_token.expand(t, b, -1) + pe = self.pos_embedding.expand(-1, b, -1) + pe = take_indexes(pe, target_masks) + mask = mask + pe + mask = rearrange(mask, "t b c -> b t c") + + # cross attention from mask tokens to context embedding + for transformer in self.transformer: + mask = transformer(mask, context_emb) + + # norm and project back to input dimension + mask = self.projector_embed(self.norm(mask)) + return mask + + def forward(self, context_emb, target_masks): + # map context embedding to predictor dimension + context_emb = self.predictor_embed(context_emb) + target_features = self.predict_target_features(context_emb, target_masks) + return target_features + + +class IWMPredictor(JEPAPredictor): + """Specialized JEPA predictor that can conditionally predict between different domains (e.g. + predict from brightfield to multiple fluorescent tags)""" + + def __init__( + self, + domains: List[str], + num_patches: Union[int, List[int]], + spatial_dims: int = 3, + input_dim: Optional[int] = 192, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 12, + num_head: Optional[int] = 3, + ) -> None: + """ + Parameters + ---------- + domains: List[str] + List of names of target domains + num_patches: List[int] + Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions + spatial_dims: int + Number of spatial dimensions + spatial_dims: int + Number of spatial dimensions + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + """ + super().__init__( + num_patches=num_patches, + spatial_dims=spatial_dims, + input_dim=input_dim, + emb_dim=emb_dim, + num_layer=num_layer, + num_head=num_head + ) + + self.domain_embeddings = torch.nn.ParameterDict( + {d: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for d in domains} + ) + self.context_mixer = torch.nn.Linear(2 * emb_dim, emb_dim, 1) + + def forward(self, context_emb, target_masks, target_domain): + _, b = target_masks.shape + if len(target_domain) == 1: + target_domain = target_domain * b + # map context embedding to predictor dimension + context_emb = self.predictor_embed(context_emb) + + # add target domain information via concatenation + token mixing + target_domain_embedding = torch.cat( + [self.domain_embeddings[td] for td in target_domain] + ).repeat(1, context_emb.shape[1], 1) + context_emb = torch.cat([context_emb, target_domain_embedding], dim=-1) + context_emb = self.context_mixer(context_emb) + + target_features = self.predict_target_features(context_emb, target_masks) + return target_features From e6b58dfe397057c85f1b4f09ec820d7667fa7bb7 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Tue, 20 Aug 2024 11:58:41 -0700 Subject: [PATCH 25/27] precommit --- cyto_dl/nn/vits/predictor.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cyto_dl/nn/vits/predictor.py b/cyto_dl/nn/vits/predictor.py index 70f227586..476901b72 100644 --- a/cyto_dl/nn/vits/predictor.py +++ b/cyto_dl/nn/vits/predictor.py @@ -6,7 +6,11 @@ from timm.models.layers import trunc_normal_ from cyto_dl.nn.vits.blocks import CrossAttentionBlock -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes, validate_spatial_dims +from cyto_dl.nn.vits.utils import ( + get_positional_embedding, + take_indexes, + validate_spatial_dims, +) class JEPAPredictor(torch.nn.Module): @@ -126,12 +130,12 @@ def __init__( Number of heads in transformer """ super().__init__( - num_patches=num_patches, + num_patches=num_patches, spatial_dims=spatial_dims, input_dim=input_dim, emb_dim=emb_dim, num_layer=num_layer, - num_head=num_head + num_head=num_head, ) self.domain_embeddings = torch.nn.ParameterDict( From 84717b71fd93e9f63f7abd6a716082fccf270e4a Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Thu, 22 Aug 2024 15:51:50 -0700 Subject: [PATCH 26/27] update with ritviks comments --- configs/model/im2im/hiera.yaml | 9 +++-- .../nn/vits/blocks/masked_unit_attention.py | 37 +++++++++---------- .../nn/vits/blocks/patchify/patchify_hiera.py | 3 +- cyto_dl/nn/vits/encoder.py | 13 ++++--- cyto_dl/nn/vits/utils.py | 6 ++- 5 files changed, 36 insertions(+), 32 deletions(-) diff --git a/configs/model/im2im/hiera.yaml b/configs/model/im2im/hiera.yaml index 06105ec4c..88a10b127 100644 --- a/configs/model/im2im/hiera.yaml +++ b/configs/model/im2im/hiera.yaml @@ -8,15 +8,16 @@ x_key: ${source_col} backbone: _target_: cyto_dl.nn.vits.mae.HieraMAE spatial_dims: ${spatial_dims} - patch_size: 2 # patch_size* num_patches should be your patch shape + patch_size: 2 # patch_size * num_patches should be your image shape (data._aux.patch_shape) num_patches: 8 # patch_size * num_patches = img_shape - num_mask_units: 4 #img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit + num_mask_units: 4 #Mask units are used for local attention. img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit emb_dim: 4 + # NOTE: this is a very small model for testing - for best performance, the downsampling ratios, embedding dimension, number of layers and number of heads should be adjusted to your data architecture: # mask_unit_attention blocks - attention is only done within a mask unit and not across mask units # the total amount of q_stride across the architecture must be less than the number of patches per mask unit - - repeat: 1 - q_stride: 2 + - repeat: 1 # number of times to repeat this block + q_stride: 2 # size of downsampling within a mask unit num_heads: 1 - repeat: 1 q_stride: 1 diff --git a/cyto_dl/nn/vits/blocks/masked_unit_attention.py b/cyto_dl/nn/vits/blocks/masked_unit_attention.py index 9b61d18e3..ad1a02a99 100644 --- a/cyto_dl/nn/vits/blocks/masked_unit_attention.py +++ b/cyto_dl/nn/vits/blocks/masked_unit_attention.py @@ -9,7 +9,7 @@ from einops.layers.torch import Reduce from timm.models.layers import DropPath, Mlp -from cyto_dl.nn.vits.utils import validate_spatial_dims +from cyto_dl.nn.vits.utils import match_tuple_dimensions class MaskUnitAttention(torch.nn.Module): @@ -20,7 +20,6 @@ def __init__( spatial_dims: int = 3, num_heads=8, qkv_bias=False, - qk_scale=None, attn_drop=0.0, proj_drop=0.0, q_stride=[1, 1, 1], @@ -39,8 +38,6 @@ def __init__( Number of attention heads, by default 8. qkv_bias : bool, optional If True, add a learnable bias to query, key, value, by default False. - qk_scale : float, optional - Override default qk scale of head_dim ** -0.5 if set, by default None. attn_drop : float, optional Dropout rate for attention, by default 0.0. proj_drop : float, optional @@ -51,14 +48,13 @@ def __init__( Number of patches per mask unit, by default [2, 12, 12]. """ super().__init__() - q_stride, patches_per_mask_unit = validate_spatial_dims( + q_stride, patches_per_mask_unit = match_tuple_dimensions( spatial_dims, [q_stride, patches_per_mask_unit] ) self.spatial_dims = spatial_dims self.num_heads = num_heads self.head_dim = dim_out // num_heads - self.scale = qk_scale or self.head_dim**-0.5 self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) self.attn_drop = attn_drop self.proj = nn.Linear(dim_out, dim_out) @@ -87,7 +83,7 @@ def forward(self, x): if self.spatial_dims == 3: q = reduce( q, - "b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c", + "batch num_mask_units num_heads (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> batch num_mask_units num_heads (n_patches_z n_patches_y n_patches_x) c", reduction="max", q_stride_z=self.q_stride[0], q_stride_y=self.q_stride[1], @@ -99,7 +95,7 @@ def forward(self, x): elif self.spatial_dims == 2: q = reduce( q, - "b n h (n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_y n_patches_x) c", + "batch num_mask_units num_heads (n_patches_y q_stride_y n_patches_x q_stride_x) c ->batch num_mask_units num_heads (n_patches_y n_patches_x) c", reduction="max", q_stride_y=self.q_stride[0], q_stride_x=self.q_stride[1], @@ -144,25 +140,25 @@ def __init__( Dimension of the input features. dim_out : int Dimension of the output features. + heads : int + Number of attention heads. spatial_dims : int, optional Number of spatial dimensions, by default 3. - num_heads : int, optional - Number of attention heads, by default 8. - qkv_bias : bool, optional - If True, add a learnable bias to query, key, value, by default False. - qk_scale : float, optional - Override default qk scale of head_dim ** -0.5 if set, by default None. - attn_drop : float, optional - Dropout rate for attention, by default 0.0. - proj_drop : float, optional - Dropout rate for projection, by default 0.0. + mlp_ratio : float, optional + Ratio of MLP hidden dim to embedding dim, by default 4.0. + drop_path : float, optional + Dropout rate for the path, by default 0.0. + norm_layer : nn.Module, optional + Normalization layer, by default nn.LayerNorm. + act_layer : nn.Module, optional + Activation layer for the MLP, by default nn.GELU. q_stride : List[int], optional Stride for query, by default [1, 1, 1]. patches_per_mask_unit : List[int], optional Number of patches per mask unit, by default [2, 12, 12]. """ super().__init__() - patches_per_mask_unit, q_stride = validate_spatial_dims( + patches_per_mask_unit, q_stride = match_tuple_dimensions( spatial_dims, [patches_per_mask_unit, q_stride] ) @@ -189,7 +185,7 @@ def __init__( self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() - # max pooling by q stride within a mask unit + # mean pooling by q stride within a mask unit if self.spatial_dims == 3: skip_connection_pooling = Reduce( "b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c", @@ -223,6 +219,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """ # Attention + Q Pooling x_norm = self.norm1(x) + # change dimension and subsample within mask unit for skip connection x = self.proj(x_norm) diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py index 9c8166860..b9b7471e2 100644 --- a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py @@ -30,7 +30,8 @@ def take_indexes_mask(sequences, indexes): class PatchifyHiera(PatchifyBase): - """Class for converting images to a masked sequence of patches with positional embeddings.""" + """Class for converting images to a sequence of patches with positional embeddings, masked at + the level of mask units (groups of patches specified by mask_units_per_dim).""" def __init__( self, diff --git a/cyto_dl/nn/vits/encoder.py b/cyto_dl/nn/vits/encoder.py index 570b704ba..1c579c9d8 100644 --- a/cyto_dl/nn/vits/encoder.py +++ b/cyto_dl/nn/vits/encoder.py @@ -15,7 +15,7 @@ from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock from cyto_dl.nn.vits.blocks.patchify import PatchifyHiera -from cyto_dl.nn.vits.utils import validate_spatial_dims +from cyto_dl.nn.vits.utils import match_tuple_dimensions class MAE_Encoder(torch.nn.Module): @@ -54,7 +54,7 @@ def __init__( Whether to use intermediate weights for weighted sum of intermediate layers """ super().__init__() - num_patches, patch_size, context_pixels = validate_spatial_dims( + num_patches, patch_size, context_pixels = match_tuple_dimensions( spatial_dims, [num_patches, patch_size, context_pixels] ) @@ -140,7 +140,7 @@ def __init__( If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() - num_patches, patch_size, context_pixels = validate_spatial_dims( + num_patches, patch_size, context_pixels = match_tuple_dimensions( spatial_dims, [num_patches, patch_size, context_pixels] ) @@ -178,7 +178,7 @@ def __init__( self, downsample_factor: List[int], in_dim: int, out_dim: int, spatial_dims: int = 3 ): super().__init__() - downsample_factor = validate_spatial_dims(spatial_dims, [downsample_factor])[0] + downsample_factor = match_tuple_dimensions(spatial_dims, [downsample_factor])[0] self.spatial_dims = spatial_dims conv_fn = nn.Conv3d if spatial_dims == 3 else nn.Conv2d @@ -245,6 +245,7 @@ def __init__( Stride for the query in each spatial dimension - self_attention: bool Whether to use self attention or mask unit attention + On the last repeat of each non-self-attention block, the embedding dimension is doubled and spatial pooling with `q_stride` is performed within each mask unit. For example, a block with a embed_dim=4, q_stride=2, and repeat=2, the first repeat just does mask unit attention, while the second will produce an 8-dimensional output that has been spatially pooled. emb_dim: int Dimension of embedding spatial_dims: int @@ -259,13 +260,13 @@ def __init__( Whether to save the intermediate layer outputs """ super().__init__() - num_patches, num_mask_units, patch_size, context_pixels = validate_spatial_dims( + num_patches, num_mask_units, patch_size, context_pixels = match_tuple_dimensions( spatial_dims, [num_patches, num_mask_units, patch_size, context_pixels] ) # make sure q stride shape matches spatial dims for i in range(len(architecture)): if "q_stride" in architecture[i]: - architecture[i]["q_stride"] = validate_spatial_dims( + architecture[i]["q_stride"] = match_tuple_dimensions( spatial_dims, [architecture[i]["q_stride"]] )[0] diff --git a/cyto_dl/nn/vits/utils.py b/cyto_dl/nn/vits/utils.py index caeb20237..aa96557fe 100644 --- a/cyto_dl/nn/vits/utils.py +++ b/cyto_dl/nn/vits/utils.py @@ -49,6 +49,10 @@ def get_positional_embedding( return torch.nn.Parameter(pe, requires_grad=False) -def validate_spatial_dims(spatial_dims, tuples): +def match_tuple_dimensions(spatial_dims, tuples): + """Ensure that each element in a list of tuples has the same length as spatial_dims. + + If a single element, the element is repeated to match the spatial_dims. + """ assert spatial_dims in (2, 3), "spatial_dims must be 2 or 3" return [ensure_tuple_rep(t, spatial_dims) for t in tuples] From 316a93a7e540d62ba1660ad8b8b857bf9e5f97d5 Mon Sep 17 00:00:00 2001 From: Benjamin Morris Date: Fri, 23 Aug 2024 11:58:36 -0700 Subject: [PATCH 27/27] replace all function names --- cyto_dl/image/transforms/generate_jepa_masks.py | 4 ++-- cyto_dl/nn/vits/decoder.py | 4 ++-- cyto_dl/nn/vits/mae.py | 6 +++--- cyto_dl/nn/vits/predictor.py | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cyto_dl/image/transforms/generate_jepa_masks.py b/cyto_dl/image/transforms/generate_jepa_masks.py index 445587a25..da8201fba 100644 --- a/cyto_dl/image/transforms/generate_jepa_masks.py +++ b/cyto_dl/image/transforms/generate_jepa_masks.py @@ -5,7 +5,7 @@ from monai.transforms import RandomizableTransform from skimage.segmentation import find_boundaries -from cyto_dl.nn.vits.utils import validate_spatial_dims +from cyto_dl.nn.vits.utils import match_tuple_dimensions class JEPAMaskGenerator(RandomizableTransform): @@ -39,7 +39,7 @@ def __init__( """ assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1" - num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0] + num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0] assert mask_size * max(block_aspect_ratio) < min( num_patches[-2:] ), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches" diff --git a/cyto_dl/nn/vits/decoder.py b/cyto_dl/nn/vits/decoder.py index fe296121b..0ef96cda0 100644 --- a/cyto_dl/nn/vits/decoder.py +++ b/cyto_dl/nn/vits/decoder.py @@ -12,8 +12,8 @@ from cyto_dl.nn.vits.blocks import CrossAttentionBlock from cyto_dl.nn.vits.utils import ( get_positional_embedding, + match_tuple_dimensions, take_indexes, - validate_spatial_dims, ) @@ -51,7 +51,7 @@ def __init__( If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() - num_patches, patch_size = validate_spatial_dims(spatial_dims, [num_patches, patch_size]) + num_patches, patch_size = match_tuple_dimensions(spatial_dims, [num_patches, patch_size]) self.has_cls_token = has_cls_token diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index ab792e2ff..17d3e4729 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -8,7 +8,7 @@ from cyto_dl.nn.vits.decoder import CrossMAE_Decoder, MAE_Decoder from cyto_dl.nn.vits.encoder import HieraEncoder, MAE_Encoder -from cyto_dl.nn.vits.utils import validate_spatial_dims +from cyto_dl.nn.vits.utils import match_tuple_dimensions class MAE_Base(torch.nn.Module, ABC): @@ -16,7 +16,7 @@ def __init__( self, spatial_dims, num_patches, patch_size, mask_ratio, features_only, context_pixels ): super().__init__() - num_patches, patch_size, context_pixels = validate_spatial_dims( + num_patches, patch_size, context_pixels = match_tuple_dimensions( spatial_dims, [num_patches, patch_size, context_pixels] ) @@ -213,7 +213,7 @@ def __init__( features_only=features_only, context_pixels=context_pixels, ) - num_mask_units = validate_spatial_dims(self.spatial_dims, [num_mask_units])[0] + num_mask_units = match_tuple_dimensions(self.spatial_dims, [num_mask_units])[0] self._encoder = HieraEncoder( num_patches=self.num_patches, diff --git a/cyto_dl/nn/vits/predictor.py b/cyto_dl/nn/vits/predictor.py index 476901b72..8eb5978fb 100644 --- a/cyto_dl/nn/vits/predictor.py +++ b/cyto_dl/nn/vits/predictor.py @@ -8,8 +8,8 @@ from cyto_dl.nn.vits.blocks import CrossAttentionBlock from cyto_dl.nn.vits.utils import ( get_positional_embedding, + match_tuple_dimensions, take_indexes, - validate_spatial_dims, ) @@ -56,7 +56,7 @@ def __init__( ] ) - num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0] + num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0] self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = get_positional_embedding(