From 48e7cb91579dd196284f4cc07e9ee4819878f932 Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Mon, 26 Aug 2024 10:42:46 -0700 Subject: [PATCH] Feature/hiera (#418) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Bump version: 0.1.5 → 0.1.6 * add hiera * start of mask2former * first take at transfomer * fix dimensionality, now updating instance queries instead of mask * give instance queries own dim * add mask creation * remove experimental code * update to base patchify * wip * update configs * update patchify * rearrange encoder/decoder/mae * add 2d hiera * 2d masked unit attention * precommit * update configs * update hiera model config * update deafults * update tests * delete patchify_conv * fix jepa tests * update mask transform * add predictor * precommit * update with ritviks comments * replace all function names --------- Co-authored-by: Benjamin Morris Co-authored-by: Benjamin Morris --- configs/data/im2im/ijepa.yaml | 3 + configs/data/im2im/iwm.yaml | 2 + configs/data/im2im/mae.yaml | 19 +- configs/experiment/im2im/hiera.yaml | 52 +++ configs/experiment/im2im/ijepa.yaml | 7 - configs/experiment/im2im/iwm.yaml | 7 - configs/model/im2im/hiera.yaml | 70 ++++ configs/model/im2im/ijepa.yaml | 9 +- configs/model/im2im/iwm.yaml | 9 +- configs/model/im2im/mae.yaml | 6 +- .../model/im2im/vit_segmentation_decoder.yaml | 4 +- .../image/transforms/generate_jepa_masks.py | 11 +- cyto_dl/nn/vits/__init__.py | 5 +- .../nn/vits/blocks/masked_unit_attention.py | 229 ++++++++++ cyto_dl/nn/vits/blocks/patchify/__init__.py | 2 + cyto_dl/nn/vits/blocks/patchify/patchify.py | 52 +++ .../patchify_base.py} | 108 +++-- .../nn/vits/blocks/patchify/patchify_hiera.py | 125 ++++++ cyto_dl/nn/vits/cross_mae.py | 170 -------- cyto_dl/nn/vits/decoder.py | 261 ++++++++++++ cyto_dl/nn/vits/encoder.py | 383 +++++++++++++++++ cyto_dl/nn/vits/mae.py | 394 ++++++++---------- cyto_dl/nn/vits/{jepa.py => predictor.py} | 106 ++--- cyto_dl/nn/vits/utils.py | 16 + tests/conftest.py | 11 +- tests/utils.py | 2 +- 26 files changed, 1508 insertions(+), 555 deletions(-) create mode 100644 configs/experiment/im2im/hiera.yaml create mode 100644 configs/model/im2im/hiera.yaml create mode 100644 cyto_dl/nn/vits/blocks/masked_unit_attention.py create mode 100644 cyto_dl/nn/vits/blocks/patchify/__init__.py create mode 100644 cyto_dl/nn/vits/blocks/patchify/patchify.py rename cyto_dl/nn/vits/blocks/{patchify.py => patchify/patchify_base.py} (72%) create mode 100644 cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py delete mode 100644 cyto_dl/nn/vits/cross_mae.py create mode 100644 cyto_dl/nn/vits/decoder.py create mode 100644 cyto_dl/nn/vits/encoder.py rename cyto_dl/nn/vits/{jepa.py => predictor.py} (64%) diff --git a/configs/data/im2im/ijepa.yaml b/configs/data/im2im/ijepa.yaml index 4c1131daa..add94c074 100644 --- a/configs/data/im2im/ijepa.yaml +++ b/configs/data/im2im/ijepa.yaml @@ -39,6 +39,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} test: _target_: monai.transforms.Compose @@ -69,6 +70,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} predict: _target_: monai.transforms.Compose @@ -127,6 +129,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} _aux: _scales_dict: diff --git a/configs/data/im2im/iwm.yaml b/configs/data/im2im/iwm.yaml index 526920ba2..9042f6652 100644 --- a/configs/data/im2im/iwm.yaml +++ b/configs/data/im2im/iwm.yaml @@ -58,6 +58,7 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} test: _target_: monai.transforms.Compose @@ -176,5 +177,6 @@ transforms: - _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator mask_size: 4 num_patches: ${model._aux.num_patches} + spatial_dims: ${spatial_dims} _aux: diff --git a/configs/data/im2im/mae.yaml b/configs/data/im2im/mae.yaml index a310e15ba..6cd6a946e 100644 --- a/configs/data/im2im/mae.yaml +++ b/configs/data/im2im/mae.yaml @@ -32,12 +32,12 @@ transforms: - _target_: monai.transforms.NormalizeIntensityd keys: ${source_col} channel_wise: True - - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + - _target_: monai.transforms.RandSpatialCropSamplesd keys: - ${source_col} - patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 - scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + roi_size: ${data._aux.patch_shape} + num_samples: 1 + random_size: False test: _target_: monai.transforms.Compose @@ -104,14 +104,11 @@ transforms: - _target_: monai.transforms.NormalizeIntensityd keys: ${source_col} channel_wise: True - - _target_: cyto_dl.image.transforms.RandomMultiScaleCropd + - _target_: monai.transforms.RandSpatialCropSamplesd keys: - ${source_col} - patch_shape: ${data._aux.patch_shape} - patch_per_image: 1 - scales_dict: ${kv_to_dict:${data._aux._scales_dict}} + roi_size: ${data._aux.patch_shape} + num_samples: 1 + random_size: False _aux: - _scales_dict: - - - ${source_col} - - [1] diff --git a/configs/experiment/im2im/hiera.yaml b/configs/experiment/im2im/hiera.yaml new file mode 100644 index 000000000..f1351e492 --- /dev/null +++ b/configs/experiment/im2im/hiera.yaml @@ -0,0 +1,52 @@ +# @package _global_ +# to execute this experiment run: +# python train.py experiment=example +defaults: + - override /data: im2im/mae.yaml + - override /model: im2im/hiera.yaml + - override /callbacks: default.yaml + - override /trainer: gpu.yaml + - override /logger: csv.yaml + +# all parameters below will be merged with parameters from default configurations set above +# this allows you to overwrite only specified parameters + +tags: ["dev"] +seed: 12345 + +experiment_name: YOUR_EXP_NAME +run_name: YOUR_RUN_NAME + +# only source_col is needed for masked autoencoder +source_col: raw +spatial_dims: 3 +raw_im_channels: 1 + +trainer: + max_epochs: 100 + gradient_clip_val: 10 + +data: + path: ${paths.data_dir}/example_experiment_data/segmentation + cache_dir: ${paths.data_dir}/example_experiment_data/cache + batch_size: 1 + _aux: + # 2D + # patch_shape: [16, 16] + # 3D + patch_shape: [16, 16, 16] + +callbacks: + # prediction + # saving: + # _target_: cyto_dl.callbacks.ImageSaver + # save_dir: ${paths.output_dir} + # save_every_n_epochs: ${model.save_images_every_n_epochs} + # stages: ["predict"] + # save_input: False + # training + saving: + _target_: cyto_dl.callbacks.ImageSaver + save_dir: ${paths.output_dir} + save_every_n_epochs: ${model.save_images_every_n_epochs} + stages: ["train", "test", "val"] diff --git a/configs/experiment/im2im/ijepa.yaml b/configs/experiment/im2im/ijepa.yaml index df5529d36..d0af7cd6f 100644 --- a/configs/experiment/im2im/ijepa.yaml +++ b/configs/experiment/im2im/ijepa.yaml @@ -37,10 +37,3 @@ data: # patch_shape: [16, 16] # 3D patch_shape: [16, 16, 16] - -model: - _aux: - # 3D - num_patches: [8, 8, 8] - # 2d - # num_patches: [8, 8] diff --git a/configs/experiment/im2im/iwm.yaml b/configs/experiment/im2im/iwm.yaml index e097933b0..3cee1f777 100644 --- a/configs/experiment/im2im/iwm.yaml +++ b/configs/experiment/im2im/iwm.yaml @@ -38,13 +38,6 @@ data: # 3D patch_shape: [16, 16, 16] -model: - _aux: - # 3D - num_patches: [8, 8, 8] - # 2d - # num_patches: [8, 8] - callbacks: prediction_saver: _target_: cyto_dl.callbacks.csv_saver.CSVSaver diff --git a/configs/model/im2im/hiera.yaml b/configs/model/im2im/hiera.yaml new file mode 100644 index 000000000..88a10b127 --- /dev/null +++ b/configs/model/im2im/hiera.yaml @@ -0,0 +1,70 @@ +_target_: cyto_dl.models.im2im.MultiTaskIm2Im + +save_images_every_n_epochs: 1 +save_dir: ${paths.output_dir} + +x_key: ${source_col} + +backbone: + _target_: cyto_dl.nn.vits.mae.HieraMAE + spatial_dims: ${spatial_dims} + patch_size: 2 # patch_size * num_patches should be your image shape (data._aux.patch_shape) + num_patches: 8 # patch_size * num_patches = img_shape + num_mask_units: 4 #Mask units are used for local attention. img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit + emb_dim: 4 + # NOTE: this is a very small model for testing - for best performance, the downsampling ratios, embedding dimension, number of layers and number of heads should be adjusted to your data + architecture: + # mask_unit_attention blocks - attention is only done within a mask unit and not across mask units + # the total amount of q_stride across the architecture must be less than the number of patches per mask unit + - repeat: 1 # number of times to repeat this block + q_stride: 2 # size of downsampling within a mask unit + num_heads: 1 + - repeat: 1 + q_stride: 1 + num_heads: 2 + # self attention transformer - attention is done across all patches, irrespective of which mask unit they're in + - repeat: 2 + num_heads: 4 + self_attention: True + decoder_layer: 1 + decoder_dim: 16 + mask_ratio: 0.66666666666 + context_pixels: 3 + use_crossmae: True + +task_heads: ${kv_to_dict:${model._aux._tasks}} + +optimizer: + generator: + _partial_: True + _target_: torch.optim.AdamW + weight_decay: 0.05 + +lr_scheduler: + generator: + _partial_: True + _target_: torch.optim.lr_scheduler.OneCycleLR + max_lr: 0.0001 + epochs: ${trainer.max_epochs} + steps_per_epoch: 1 + pct_start: 0.1 + +inference_args: + sw_batch_size: 1 + roi_size: ${data._aux.patch_shape} + overlap: 0 + progress: True + mode: "gaussian" + +_aux: + _tasks: + - - ${source_col} + - _target_: cyto_dl.nn.head.mae_head.MAEHead + loss: + postprocess: + input: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 + prediction: + _target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel + rescale_dtype: numpy.uint8 diff --git a/configs/model/im2im/ijepa.yaml b/configs/model/im2im/ijepa.yaml index 1e11a57d3..40ea24201 100644 --- a/configs/model/im2im/ijepa.yaml +++ b/configs/model/im2im/ijepa.yaml @@ -5,15 +5,16 @@ max_epochs: ${trainer.max_epochs} save_dir: ${paths.output_dir} encoder: - _target_: cyto_dl.nn.vits.jepa.JEPAEncoder - patch_size: 2 + _target_: cyto_dl.nn.vits.encoder.JEPAEncoder + patch_size: 2 # patch_size * num_patches should equl data._aux.patch_shape num_patches: ${model._aux.num_patches} emb_dim: 16 num_layer: 2 num_head: 1 + spatial_dims: ${spatial_dims} predictor: - _target_: cyto_dl.nn.vits.jepa.JEPAPredictor + _target_: cyto_dl.nn.vits.predictor.JEPAPredictor num_patches: ${model._aux.num_patches} input_dim: ${model.encoder.emb_dim} emb_dim: 8 @@ -34,4 +35,4 @@ lr_scheduler: pct_start: 0.1 _aux: - num_patches: + num_patches: 8 diff --git a/configs/model/im2im/iwm.yaml b/configs/model/im2im/iwm.yaml index 21357ed3e..66ccb7f68 100644 --- a/configs/model/im2im/iwm.yaml +++ b/configs/model/im2im/iwm.yaml @@ -8,15 +8,16 @@ max_epochs: ${trainer.max_epochs} save_dir: ${paths.output_dir} encoder: - _target_: cyto_dl.nn.vits.jepa.JEPAEncoder - patch_size: 2 + _target_: cyto_dl.nn.vits.encoder.JEPAEncoder + patch_size: 2 # patch_size * num_patches should be the same as data._aux.patch_shape num_patches: ${model._aux.num_patches} emb_dim: 16 num_layer: 1 num_head: 1 + spatial_dims: ${spatial_dims} predictor: - _target_: cyto_dl.nn.vits.jepa.IWMPredictor + _target_: cyto_dl.nn.vits.predictor.IWMPredictor domains: [SEC61B] num_patches: ${model._aux.num_patches} input_dim: ${model.encoder.emb_dim} @@ -38,4 +39,4 @@ lr_scheduler: pct_start: 0.1 _aux: - num_patches: + num_patches: 8 diff --git a/configs/model/im2im/mae.yaml b/configs/model/im2im/mae.yaml index 615f8c3c4..b1172f655 100644 --- a/configs/model/im2im/mae.yaml +++ b/configs/model/im2im/mae.yaml @@ -6,10 +6,10 @@ save_dir: ${paths.output_dir} x_key: ${source_col} backbone: - _target_: cyto_dl.nn.vits.MAE_ViT + _target_: cyto_dl.nn.vits.MAE spatial_dims: ${spatial_dims} - # base_patch_size* num_patches should be your patch shape - base_patch_size: 2 + # patch_size* num_patches should be your patch shape + patch_size: 2 num_patches: 8 emb_dim: 16 encoder_layer: 2 diff --git a/configs/model/im2im/vit_segmentation_decoder.yaml b/configs/model/im2im/vit_segmentation_decoder.yaml index 6c1ad137b..a0c72bd69 100644 --- a/configs/model/im2im/vit_segmentation_decoder.yaml +++ b/configs/model/im2im/vit_segmentation_decoder.yaml @@ -8,8 +8,8 @@ x_key: ${source_col} backbone: _target_: cyto_dl.nn.vits.Seg_ViT spatial_dims: ${spatial_dims} - # base_patch_size* num_patches should be your patch shape - base_patch_size: 2 + # patch_size* num_patches should be your patch shape + patch_size: 2 num_patches: 8 emb_dim: 16 encoder_layer: 2 diff --git a/cyto_dl/image/transforms/generate_jepa_masks.py b/cyto_dl/image/transforms/generate_jepa_masks.py index 0a4dae69e..da8201fba 100644 --- a/cyto_dl/image/transforms/generate_jepa_masks.py +++ b/cyto_dl/image/transforms/generate_jepa_masks.py @@ -5,6 +5,8 @@ from monai.transforms import RandomizableTransform from skimage.segmentation import find_boundaries +from cyto_dl.nn.vits.utils import match_tuple_dimensions + class JEPAMaskGenerator(RandomizableTransform): """Transform for generating Block-contiguous masks for JEPA training. @@ -15,6 +17,7 @@ class JEPAMaskGenerator(RandomizableTransform): def __init__( self, + spatial_dims: int, mask_size: int = 12, block_aspect_ratio: Tuple[float] = (0.5, 1.5), num_patches: Tuple[float] = (6, 24, 24), @@ -23,6 +26,8 @@ def __init__( """ Parameters ---------- + spatial_dims : int + The number of spatial dimensions of the image (2 or 3) mask_size : int, optional The size of the blocks used to generate mask. Block dimensions are determined by the mask size and an aspect ratio sampled from the range `block_aspect_ratio` block_aspect_ratio : Tuple[float], optional @@ -32,7 +37,9 @@ def __init__( mask_ratio : float, optional The proportion of the image to be masked """ - assert mask_ratio < 1, "mask_ratio must be less than 1" + assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1" + + num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0] assert mask_size * max(block_aspect_ratio) < min( num_patches[-2:] ), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches" @@ -46,7 +53,7 @@ def __init__( self.mask = np.zeros(num_patches) self.edge_mask = np.ones(num_patches) - self.spatial_dims = len(num_patches) + self.spatial_dims = spatial_dims # create a mask that identified pixels on the edge of the image if self.spatial_dims == 3: self.edge_mask[1:-1, 1:-1, 1:-1] = 0 diff --git a/cyto_dl/nn/vits/__init__.py b/cyto_dl/nn/vits/__init__.py index ee5076c52..5207f6f9d 100644 --- a/cyto_dl/nn/vits/__init__.py +++ b/cyto_dl/nn/vits/__init__.py @@ -1,4 +1,5 @@ -from .cross_mae import CrossMAE_Decoder -from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT +from .decoder import CrossMAE_Decoder, MAE_Decoder +from .encoder import HieraEncoder, MAE_Encoder +from .mae import MAE, HieraMAE from .seg import Seg_ViT, SuperresDecoder from .utils import take_indexes diff --git a/cyto_dl/nn/vits/blocks/masked_unit_attention.py b/cyto_dl/nn/vits/blocks/masked_unit_attention.py new file mode 100644 index 000000000..ad1a02a99 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/masked_unit_attention.py @@ -0,0 +1,229 @@ +# inspired by https://github.com/facebookresearch/hiera/tree/main +from typing import List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, reduce +from einops.layers.torch import Reduce +from timm.models.layers import DropPath, Mlp + +from cyto_dl.nn.vits.utils import match_tuple_dimensions + + +class MaskUnitAttention(torch.nn.Module): + def __init__( + self, + dim, + dim_out, + spatial_dims: int = 3, + num_heads=8, + qkv_bias=False, + attn_drop=0.0, + proj_drop=0.0, + q_stride=[1, 1, 1], + patches_per_mask_unit=[2, 12, 12], + ): + """ + Parameters + ---------- + dim : int + Dimension of the input features. + dim_out : int + Dimension of the output features. + spatial_dims : int, optional + Number of spatial dimensions, by default 3. + num_heads : int, optional + Number of attention heads, by default 8. + qkv_bias : bool, optional + If True, add a learnable bias to query, key, value, by default False. + attn_drop : float, optional + Dropout rate for attention, by default 0.0. + proj_drop : float, optional + Dropout rate for projection, by default 0.0. + q_stride : List[int], optional + Stride for query, by default [1, 1, 1]. + patches_per_mask_unit : List[int], optional + Number of patches per mask unit, by default [2, 12, 12]. + """ + super().__init__() + q_stride, patches_per_mask_unit = match_tuple_dimensions( + spatial_dims, [q_stride, patches_per_mask_unit] + ) + + self.spatial_dims = spatial_dims + self.num_heads = num_heads + self.head_dim = dim_out // num_heads + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) + self.attn_drop = attn_drop + self.proj = nn.Linear(dim_out, dim_out) + self.proj_drop = nn.Dropout(proj_drop) + self.dim_out = dim_out + self.q_stride = np.array(q_stride) + self.pooled_patches_per_mask_unit = ( + np.array(patches_per_mask_unit) / self.q_stride + ).astype(int) + + def forward(self, x): + # project and split into q,k,v embeddings + qkv = rearrange( + self.qkv(x), + "batch num_mask_units tokens_per_mask_unit (head_dim num_heads qkv) -> qkv batch num_mask_units num_heads tokens_per_mask_unit head_dim", + head_dim=self.head_dim, + qkv=3, + num_heads=self.num_heads, + ) + + q, k, v = qkv[0], qkv[1], qkv[2] + + if np.any(self.q_stride > 1): + # within a mask unit, tokens are spatially ordered + # perform spatial 2x2x2 max pooling over tokens + if self.spatial_dims == 3: + q = reduce( + q, + "batch num_mask_units num_heads (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> batch num_mask_units num_heads (n_patches_z n_patches_y n_patches_x) c", + reduction="max", + q_stride_z=self.q_stride[0], + q_stride_y=self.q_stride[1], + q_stride_x=self.q_stride[2], + n_patches_z=self.pooled_patches_per_mask_unit[0], + n_patches_y=self.pooled_patches_per_mask_unit[1], + n_patches_x=self.pooled_patches_per_mask_unit[2], + ) + elif self.spatial_dims == 2: + q = reduce( + q, + "batch num_mask_units num_heads (n_patches_y q_stride_y n_patches_x q_stride_x) c ->batch num_mask_units num_heads (n_patches_y n_patches_x) c", + reduction="max", + q_stride_y=self.q_stride[0], + q_stride_x=self.q_stride[1], + n_patches_y=self.pooled_patches_per_mask_unit[0], + n_patches_x=self.pooled_patches_per_mask_unit[1], + ) + + attn = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop, + ) + # combine heads into single channel dimension + x = rearrange( + attn, "b mask_units n_heads t c -> b mask_units t (n_heads c)", n_heads=self.num_heads + ) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class HieraBlock(nn.Module): + def __init__( + self, + dim: int, + dim_out: int, + heads: int, + spatial_dims: int = 3, + mlp_ratio: float = 4.0, + drop_path: float = 0.0, + norm_layer: nn.Module = nn.LayerNorm, + act_layer: nn.Module = nn.GELU, + q_stride: List[int] = [1, 1, 1], + patches_per_mask_unit: List[int] = [2, 12, 12], + ): + """ + Parameters + ---------- + dim : int + Dimension of the input features. + dim_out : int + Dimension of the output features. + heads : int + Number of attention heads. + spatial_dims : int, optional + Number of spatial dimensions, by default 3. + mlp_ratio : float, optional + Ratio of MLP hidden dim to embedding dim, by default 4.0. + drop_path : float, optional + Dropout rate for the path, by default 0.0. + norm_layer : nn.Module, optional + Normalization layer, by default nn.LayerNorm. + act_layer : nn.Module, optional + Activation layer for the MLP, by default nn.GELU. + q_stride : List[int], optional + Stride for query, by default [1, 1, 1]. + patches_per_mask_unit : List[int], optional + Number of patches per mask unit, by default [2, 12, 12]. + """ + super().__init__() + patches_per_mask_unit, q_stride = match_tuple_dimensions( + spatial_dims, [patches_per_mask_unit, q_stride] + ) + + self.spatial_dims = spatial_dims + self.dim = dim + self.dim_out = dim_out + self.q_stride = q_stride + + self.norm1 = norm_layer(dim) + + do_pool = np.any(np.array(q_stride) > 1) or dim != dim_out + + self.attn = MaskUnitAttention( + dim, + dim_out, + spatial_dims=spatial_dims, + num_heads=heads, + q_stride=q_stride, + patches_per_mask_unit=patches_per_mask_unit, + ) + + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(dim_out * mlp_ratio), act_layer=act_layer) + + self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity() + + # mean pooling by q stride within a mask unit + if self.spatial_dims == 3: + skip_connection_pooling = Reduce( + "b n (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_z n_patches_y n_patches_x) c", + reduction="mean", + q_stride_z=self.q_stride[0], + q_stride_y=self.q_stride[1], + q_stride_x=self.q_stride[2], + n_patches_z=self.attn.pooled_patches_per_mask_unit[0], + n_patches_y=self.attn.pooled_patches_per_mask_unit[1], + n_patches_x=self.attn.pooled_patches_per_mask_unit[2], + ) + elif self.spatial_dims == 2: + skip_connection_pooling = Reduce( + "b n (n_patches_y q_stride_y n_patches_x q_stride_x) c -> b n (n_patches_y n_patches_x) c", + reduction="mean", + q_stride_y=self.q_stride[0], + q_stride_x=self.q_stride[1], + n_patches_y=self.attn.pooled_patches_per_mask_unit[0], + n_patches_x=self.attn.pooled_patches_per_mask_unit[1], + ) + + self.proj = ( + torch.nn.Sequential(skip_connection_pooling, nn.Linear(dim, dim_out)) + if do_pool + else torch.nn.Identity() + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + x: batch x mask units x tokens x emb_dim + """ + # Attention + Q Pooling + x_norm = self.norm1(x) + + # change dimension and subsample within mask unit for skip connection + x = self.proj(x_norm) + + x = x + self.drop_path(self.attn(x_norm)) + # MLP + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x diff --git a/cyto_dl/nn/vits/blocks/patchify/__init__.py b/cyto_dl/nn/vits/blocks/patchify/__init__.py new file mode 100644 index 000000000..02b79746c --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify/__init__.py @@ -0,0 +1,2 @@ +from .patchify import Patchify +from .patchify_hiera import PatchifyHiera diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify.py b/cyto_dl/nn/vits/blocks/patchify/patchify.py new file mode 100644 index 000000000..4da4e788e --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify/patchify.py @@ -0,0 +1,52 @@ +from typing import List, Optional + +import numpy as np +from einops.layers.torch import Rearrange + +from cyto_dl.nn.vits.blocks.patchify.patchify_base import PatchifyBase +from cyto_dl.nn.vits.utils import take_indexes + + +class Patchify(PatchifyBase): + """Class for converting images to a masked sequence of patches with positional embeddings.""" + + def __init__( + self, + patch_size: List[int], + emb_dim: int, + n_patches: List[int], + spatial_dims: int = 3, + context_pixels: List[int] = [0, 0, 0], + input_channels: int = 1, + tasks: Optional[List[str]] = [], + learnable_pos_embedding: bool = True, + ): + super().__init__( + patch_size=patch_size, + emb_dim=emb_dim, + n_patches=n_patches, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, + tasks=tasks, + learnable_pos_embedding=learnable_pos_embedding, + ) + + @property + def img2token(self): + return self.create_img2token() + + def get_mask_args(self, mask_ratio): + num_patches = np.prod(self.n_patches) + n_visible_patches = int(num_patches * (1 - mask_ratio)) + return n_visible_patches, num_patches + + def create_img2token(self): + """Rearranges the image tensor to a sequence of patches.""" + if self.spatial_dims == 3: + return Rearrange("b c z y x -> (z y x) b c") + elif self.spatial_dims == 2: + return Rearrange("b c y x -> (y x) b c") + + def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): + return take_indexes(tokens, forward_indexes)[:n_visible_patches] diff --git a/cyto_dl/nn/vits/blocks/patchify.py b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py similarity index 72% rename from cyto_dl/nn/vits/blocks/patchify.py rename to cyto_dl/nn/vits/blocks/patchify/patchify_base.py index 98d615924..1b2573084 100644 --- a/cyto_dl/nn/vits/blocks/patchify.py +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_base.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import List, Optional import numpy as np @@ -6,16 +7,10 @@ from einops.layers.torch import Rearrange, Reduce from timm.models.layers import trunc_normal_ -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes +from cyto_dl.nn.vits.utils import get_positional_embedding, random_indexes, take_indexes -def random_indexes(size: int, device): - forward_indexes = torch.randperm(size, device=device, dtype=torch.long) - backward_indexes = torch.argsort(forward_indexes) - return forward_indexes, backward_indexes - - -class Patchify(torch.nn.Module): +class PatchifyBase(torch.nn.Module, ABC): """Class for converting images to a masked sequence of patches with positional embeddings.""" def __init__( @@ -50,33 +45,76 @@ def __init__( If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ super().__init__() - self.n_patches = np.asarray(n_patches) + + if spatial_dims not in (2, 3): + raise ValueError("Only 2D and 3D images are supported") self.spatial_dims = spatial_dims + self.n_patches = np.asarray(n_patches) self.pos_embedding = get_positional_embedding( n_patches, emb_dim, learnable=learnable_pos_embedding, use_cls_token=False ) - context_pixels = context_pixels[:spatial_dims] + self.patch2img = self.create_patch2img(n_patches, patch_size) + self.conv = self.create_conv(input_channels, emb_dim, patch_size, context_pixels) + + self.task_embedding = torch.nn.ParameterDict( + {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} + ) + self._init_weight() + + def _init_weight(self): + for task in self.task_embedding: + trunc_normal_(self.task_embedding[task], std=0.02) + + @property + @abstractmethod + def img2token(self): + pass + + @abstractmethod + def get_mask_args(self): + pass + + @abstractmethod + def extract_visible_tokens(self): + pass + + def create_conv(self, input_channels, emb_dim, patch_size, context_pixels): + context_pixels = context_pixels[: self.spatial_dims] weight_size = np.asarray(patch_size) + np.round(np.array(context_pixels) * 2).astype(int) - if spatial_dims == 3: - self.conv = nn.Conv3d( + if self.spatial_dims == 3: + return nn.Conv3d( in_channels=input_channels, out_channels=emb_dim, kernel_size=weight_size, stride=patch_size, padding=context_pixels, ) - self.img2token = Rearrange("b c z y x -> (z y x) b c") - self.patch2img = torch.nn.Sequential( + elif self.spatial_dims == 2: + return nn.Conv2d( + in_channels=input_channels, + out_channels=emb_dim, + kernel_size=weight_size, + stride=patch_size, + padding=context_pixels, + ) + + def create_patch2img(self, n_patches, patch_size): + """Converts boolean array of whether to keep index of each patch to an image-shaped mask of + same size as input image.""" + if self.spatial_dims == 3: + return torch.nn.Sequential( *[ + # rearrange tokens to image Rearrange( "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", n_patch_z=n_patches[0], n_patch_y=n_patches[1], n_patch_x=n_patches[2], ), + # nearest neighbor resize image to match input image size Reduce( "b c n_patch_z n_patch_y n_patch_x -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", reduction="repeat", @@ -86,17 +124,8 @@ def __init__( ), ] ) - - elif spatial_dims == 2: - self.conv = nn.Conv2d( - in_channels=input_channels, - out_channels=emb_dim, - kernel_size=weight_size, - stride=patch_size, - padding=context_pixels, - ) - self.img2token = Rearrange("b c y x -> (y x) b c") - self.patch2img = torch.nn.Sequential( + elif self.spatial_dims == 2: + return torch.nn.Sequential( *[ Rearrange( "(n_patch_y n_patch_x) b c -> b c n_patch_y n_patch_x", @@ -104,21 +133,13 @@ def __init__( n_patch_x=n_patches[1], ), Reduce( - "b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + "b c n_patch_y n_patch_x -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", reduction="repeat", patch_size_y=patch_size[0], patch_size_x=patch_size[1], ), ] ) - self.task_embedding = torch.nn.ParameterDict( - {task: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for task in tasks} - ) - self._init_weight() - - def _init_weight(self): - for task in self.task_embedding: - trunc_normal_(self.task_embedding[task], std=0.02) def get_mask(self, img, n_visible_patches, num_patches): B = img.shape[0] @@ -140,20 +161,25 @@ def get_mask(self, img, n_visible_patches, num_patches): def forward(self, img, mask_ratio, task=None): # generate mask - num_patches = np.prod(self.n_patches) - n_visible_patches = int(num_patches * (1 - mask_ratio)) - mask, forward_indexes, backward_indexes = self.get_mask( - img, n_visible_patches, num_patches - ) + mask = torch.ones_like(img).bool() + forward_indexes, backward_indexes = None, None + if mask_ratio > 0: + n_visible_patches, num_patches = self.get_mask_args(mask_ratio) + mask, forward_indexes, backward_indexes = self.get_mask( + img, n_visible_patches, num_patches + ) # generate patches tokens = self.conv(img * mask) tokens = self.img2token(tokens) + # add position embedding tokens = tokens + self.pos_embedding + + # extract visible patches if mask_ratio > 0: - # extract visible patches - tokens = take_indexes(tokens, forward_indexes)[:n_visible_patches] + tokens = self.extract_visible_tokens(tokens, forward_indexes, n_visible_patches) + # add task embedding if task in self.task_embedding: tokens = tokens + self.task_embedding[task] diff --git a/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py new file mode 100644 index 000000000..b9b7471e2 --- /dev/null +++ b/cyto_dl/nn/vits/blocks/patchify/patchify_hiera.py @@ -0,0 +1,125 @@ +from typing import List, Optional + +import numpy as np +import torch +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_ + +from cyto_dl.nn.vits.blocks.patchify.patchify_base import PatchifyBase + + +def take_indexes_mask(sequences, indexes): + """ + sequences: batch x mask units x patches x emb_dim + indexes: mask_units x batch + """ + # always gather across tokens dimension + return torch.gather( + sequences, + 1, + repeat( + indexes, + "mu b -> b mu p c", + b=sequences.shape[0], + c=sequences.shape[-1], + mu=sequences.shape[1], + p=sequences.shape[2], + ), + ) + + +class PatchifyHiera(PatchifyBase): + """Class for converting images to a sequence of patches with positional embeddings, masked at + the level of mask units (groups of patches specified by mask_units_per_dim).""" + + def __init__( + self, + patch_size: List[int], + n_patches: List[int], + emb_dim: int = 64, + spatial_dims: int = 3, + context_pixels: List[int] = [0, 0, 0], + input_channels: int = 1, + tasks: Optional[List[str]] = [], + mask_units_per_dim: List[int] = [8, 8, 8], + ): + """ + patch_size: List[int] + Size of each patch in pix (ZYX order for 3D, YX order for 2D) + n_patches: List[int] + Number of patches in each spatial dimension (ZYX order for 3D, YX order for 2D) + emb_dim: int + Dimension of encoder + spatial_dims: int + Number of spatial dimensions + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + tasks: List[str] + List of tasks to encode + mask_units_per_dim: List[int] + Number of mask units in each spatial dimension (ZYX order for 3D, YX order for 2D) + """ + super().__init__( + patch_size=patch_size, + emb_dim=emb_dim, + n_patches=n_patches, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, + tasks=tasks, + learnable_pos_embedding=True, + ) + + self.total_n_mask_units = np.prod(mask_units_per_dim) + # mask_unit_size is the img shape / mask_units_per_dim, img_shape is size per patch * n_patches + mask_unit_size_pix = ( + (np.array(patch_size) * np.array(n_patches)) / np.array(mask_units_per_dim) + ).astype(int) + + patches_per_mask_unit = mask_unit_size_pix // patch_size + + # rearrange patch embeddings to mask units + self.pos_embedding = torch.nn.Parameter( + rearrange( + self.pos_embedding, + "(ppmu total_n_mu) 1 emb_dim -> 1 total_n_mu ppmu emb_dim", + total_n_mu=self.total_n_mask_units, + ppmu=patches_per_mask_unit.prod(), + emb_dim=emb_dim, + ) + ) + + self.mask_units_per_dim = mask_units_per_dim + + self.patch2img = self.create_patch2img(mask_units_per_dim, mask_unit_size_pix) + + @property + def img2token(self): + # redefine this to work with mask units instead of patches + return self.create_img2token(self.mask_units_per_dim) + + # in hiera, the masking is done at the mask unit, not the patch level + def get_mask_args(self, mask_ratio): + n_visible_patches = int(self.total_n_mask_units * (1 - mask_ratio)) + return n_visible_patches, self.total_n_mask_units + + def create_img2token(self, mask_units_per_dim): + if self.spatial_dims == 3: + return Rearrange( + "b c (n_mu_z z) (n_mu_y y) (n_mu_x x) -> b (n_mu_z n_mu_y n_mu_x) (z y x) c ", + n_mu_z=mask_units_per_dim[0], + n_mu_y=mask_units_per_dim[1], + n_mu_x=mask_units_per_dim[2], + ) + elif self.spatial_dims == 2: + return Rearrange( + "b c (n_mu_y y) (n_mu_x x) -> b (n_mu_y n_mu_x) (y x) c ", + n_mu_y=mask_units_per_dim[0], + n_mu_x=mask_units_per_dim[1], + ) + + def extract_visible_tokens(self, tokens, forward_indexes, n_visible_patches): + return take_indexes_mask(tokens, forward_indexes)[:, :n_visible_patches] diff --git a/cyto_dl/nn/vits/cross_mae.py b/cyto_dl/nn/vits/cross_mae.py deleted file mode 100644 index 3981de6e7..000000000 --- a/cyto_dl/nn/vits/cross_mae.py +++ /dev/null @@ -1,170 +0,0 @@ -from typing import List, Optional - -import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_ - -from cyto_dl.nn.vits.blocks import CrossAttentionBlock -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes - - -class CrossMAE_Decoder(torch.nn.Module): - """Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend - to visible tokens.""" - - def __init__( - self, - num_patches: List[int], - spatial_dims: int = 3, - base_patch_size: Optional[List[int]] = [4, 8, 8], - enc_dim: Optional[int] = 768, - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 4, - num_head: Optional[int] = 3, - learnable_pos_embedding: Optional[bool] = True, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - base_patch_size: Tuple[int] - Size of each patch - enc_dim: int - Dimension of encoder - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - learnable_pos_embedding: bool - If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. - """ - super().__init__() - - self.transformer = torch.nn.ParameterList( - [ - CrossAttentionBlock( - encoder_dim=emb_dim, - decoder_dim=emb_dim, - num_heads=num_head, - ) - for _ in range(num_layer) - ] - ) - self.decoder_norm = nn.LayerNorm(emb_dim) - self.projection_norm = nn.LayerNorm(emb_dim) - - self.projection = torch.nn.Linear(enc_dim, emb_dim) - self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - - self.pos_embedding = get_positional_embedding( - num_patches, emb_dim, learnable=learnable_pos_embedding - ) - - self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) - self.num_patches = torch.as_tensor(num_patches) - - if spatial_dims == 3: - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_z=num_patches[0], - n_patch_y=num_patches[1], - n_patch_x=num_patches[2], - patch_size_z=base_patch_size[0], - patch_size_y=base_patch_size[1], - patch_size_x=base_patch_size[2], - ) - elif spatial_dims == 2: - self.patch2img = Rearrange( - "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_y=num_patches[0], - n_patch_x=num_patches[1], - patch_size_y=base_patch_size[0], - patch_size_x=base_patch_size[1], - ) - - self.init_weight() - - def init_weight(self): - trunc_normal_(self.mask_token, std=0.02) - - def forward(self, features, forward_indexes, backward_indexes): - # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 0 layers - # features can be n t b c (if intermediate feature weighter used) or t b c if not - features = features[0] if len(features.shape) == 4 else features - T, B, C = features.shape - # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now - features = self.projection_norm(self.projection(features)) - - # add cls token - backward_indexes = torch.cat( - [ - torch.zeros( - 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long - ), - backward_indexes + 1, - ], - dim=0, - ) - forward_indexes = torch.cat( - [ - torch.zeros( - 1, forward_indexes.shape[1], device=forward_indexes.device, dtype=torch.long - ), - forward_indexes + 1, - ], - dim=0, - ) - # fill in masked regions - features = torch.cat( - [ - features, - self.mask_token.expand( - backward_indexes.shape[0] - features.shape[0], features.shape[1], -1 - ), - ], - dim=0, - ) - - # unshuffle to original positions for positional embedding so we can do cross attention during decoding - features = take_indexes(features, backward_indexes) - features = features + self.pos_embedding - - # reshuffle to shuffled positions for cross attention - features = take_indexes(features, forward_indexes) - features, masked = features[:T], features[T:] - - masked = rearrange(masked, "t b c -> b t c") - features = rearrange(features, "t b c -> b t c") - - for transformer in self.transformer: - masked = transformer(masked, features) - - # noneed to remove cls token, it's a part of the features vector - masked = rearrange(masked, "b t c -> t b c") - - # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) - masked = self.decoder_norm(masked) - patches = self.head(masked) - - # add back in visible/encoded tokens that we don't calculate loss on - patches = torch.cat( - [ - torch.zeros( - (T - 1, B, patches.shape[-1]), - requires_grad=False, - device=patches.device, - dtype=patches.dtype, - ), - patches, - ], - dim=0, - ) - patches = take_indexes(patches, backward_indexes[1:] - 1) - # patches to image - img = self.patch2img(patches) - return img diff --git a/cyto_dl/nn/vits/decoder.py b/cyto_dl/nn/vits/decoder.py new file mode 100644 index 000000000..0ef96cda0 --- /dev/null +++ b/cyto_dl/nn/vits/decoder.py @@ -0,0 +1,261 @@ +# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 + +from typing import List, Optional, Union + +import torch +import torch.nn as nn +from einops import rearrange, repeat +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_ +from timm.models.vision_transformer import Block + +from cyto_dl.nn.vits.blocks import CrossAttentionBlock +from cyto_dl.nn.vits.utils import ( + get_positional_embedding, + match_tuple_dimensions, + take_indexes, +) + + +class MAE_Decoder(torch.nn.Module): + def __init__( + self, + num_patches: Union[int, List[int]], + spatial_dims: int = 3, + patch_size: Optional[Union[int, List[int]]] = 4, + enc_dim: Optional[int] = 768, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 4, + num_head: Optional[int] = 3, + has_cls_token: Optional[bool] = False, + learnable_pos_embedding: Optional[bool] = True, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int], int + Number of patches in each dimension. If int, the same number of patches is used for all dimensions. + patch_size: Tuple[int], int + Size of each patch. If int, the same patch size is used for all dimensions. + enc_dim: int + Dimension of encoder + emb_dim: int + Dimension of decoder + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + has_cls_token: bool + Whether encoder features have a cls token + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. + """ + super().__init__() + num_patches, patch_size = match_tuple_dimensions(spatial_dims, [num_patches, patch_size]) + + self.has_cls_token = has_cls_token + + self.projection_norm = nn.LayerNorm(emb_dim) + self.projection = torch.nn.Linear(enc_dim, emb_dim) + self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) + + self.pos_embedding = get_positional_embedding( + num_patches, emb_dim, use_cls_token=has_cls_token, learnable=learnable_pos_embedding + ) + + self.transformer = torch.nn.Sequential( + *[Block(emb_dim, num_head) for _ in range(num_layer)] + ) + out_dim = torch.prod(torch.as_tensor(patch_size)).item() + self.decoder_norm = nn.LayerNorm(emb_dim) + self.head = torch.nn.Linear(emb_dim, out_dim) + self.num_patches = torch.as_tensor(num_patches) + + if spatial_dims == 3: + self.patch2img = Rearrange( + "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_z=num_patches[0], + n_patch_y=num_patches[1], + n_patch_x=num_patches[2], + patch_size_z=patch_size[0], + patch_size_y=patch_size[1], + patch_size_x=patch_size[2], + ) + elif spatial_dims == 2: + self.patch2img = Rearrange( + "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_y=num_patches[0], + n_patch_x=num_patches[1], + patch_size_y=patch_size[0], + patch_size_x=patch_size[1], + ) + + self.init_weight() + + def init_weight(self): + trunc_normal_(self.mask_token, std=0.02) + + def adjust_indices_for_cls(self, indexes): + if self.has_cls_token: + # add all zeros to indices - this keeps the class token as the first index in the + # tensor. We also have to add 1 to all the indices to account for the zeros we added + return torch.cat( + [ + torch.zeros(1, indexes.shape[1], device=indexes.device, dtype=torch.long), + indexes + 1, + ], + dim=0, + ) + return indexes + + def add_mask_tokens(self, features, backward_indexes): + # fill in deleted masked regions with mask token + num_visible_tokens, B, _ = features.shape + # total number of tokens - number of visible tokens + num_mask_tokens = backward_indexes.shape[0] - num_visible_tokens + mask_tokens = repeat(self.mask_token, "1 1 c -> t b c", t=num_mask_tokens, b=B) + return torch.cat([features, mask_tokens], dim=0) + + def forward(self, features, forward_indexes, backward_indexes): + # project from encoder dimension to decoder dimension + features = self.projection_norm(self.projection(features)) + + backward_indexes = self.adjust_indices_for_cls(backward_indexes) + + features = self.add_mask_tokens(features, backward_indexes) + + # unshuffle to original positions + features = take_indexes(features, backward_indexes) + features = features + self.pos_embedding + + # decode + features = rearrange(features, "t b c -> b t c") + features = self.transformer(features) + features = rearrange(features, "b t c -> t b c") + + if self.has_cls_token: + features = features[1:] # remove global feature + + # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) + patches = self.head(self.decoder_norm(features)) + + # patches to image + img = self.patch2img(patches) + return img + + +class CrossMAE_Decoder(MAE_Decoder): + """Decoder inspired by [CrossMAE](https://crossmae.github.io/) where masked tokens only attend + to visible tokens.""" + + def __init__( + self, + num_patches: Union[int, List[int]], + spatial_dims: int = 3, + patch_size: Optional[Union[int, List[int]]] = 4, + enc_dim: Optional[int] = 768, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 4, + num_head: Optional[int] = 3, + has_cls_token: Optional[bool] = True, + learnable_pos_embedding: Optional[bool] = True, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int], int + Number of patches in each dimension. If int, the same number of patches is used for all dimensions. + patch_size: Tuple[int] + Size of each patch in each dimension. If int, the same patch size is used for all dimensions. + enc_dim: int + Dimension of encoder + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + has_cls_token: bool + Whether encoder features have a cls token + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings are used. Empirically, fixed positional embeddings work better for brightfield images. + """ + super().__init__( + num_patches, + spatial_dims, + patch_size, + enc_dim, + emb_dim, + num_layer, + num_head, + has_cls_token, + learnable_pos_embedding, + ) + + self.transformer = torch.nn.ParameterList( + [ + CrossAttentionBlock( + encoder_dim=emb_dim, + decoder_dim=emb_dim, + num_heads=num_head, + ) + for _ in range(num_layer) + ] + ) + + def forward(self, features, forward_indexes, backward_indexes): + # HACK TODO allow usage of multiple intermediate feature weights, this works when decoder is 1 layer + # features can be n t b c (if intermediate feature weighter used) or t b c if not + features = features[0] if len(features.shape) == 4 else features + T, B, C = features.shape + # we could do cross attention between decoder_dim queries and encoder_dim features, but it seems to work fine having both at decoder_dim for now + features = self.projection_norm(self.projection(features)) + + backward_indexes = self.adjust_indices_for_cls(backward_indexes) + forward_indexes = self.adjust_indices_for_cls(forward_indexes) + + features = self.add_mask_tokens(features, backward_indexes) + + # unshuffle to original positions for positional embedding so we can do cross attention during decoding + features = take_indexes(features, backward_indexes) + features = features + self.pos_embedding + + # reshuffle to shuffled positions for cross attention + features = take_indexes(features, forward_indexes) + features, masked = features[:T], features[T:] + + masked = rearrange(masked, "t b c -> b t c") + features = rearrange(features, "t b c -> b t c") + + for transformer in self.transformer: + masked = transformer(masked, features) + + # noneed to remove cls token, it's a part of the features vector + masked = rearrange(masked, "b t c -> t b c") + + # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) + masked = self.decoder_norm(masked) + patches = self.head(masked) + + # add back in visible/encoded tokens that we don't calculate loss on + patches = torch.cat( + [ + torch.zeros( + # T-1 accounts for cls token + (T - self.has_cls_token, B, patches.shape[-1]), + requires_grad=False, + device=patches.device, + dtype=patches.dtype, + ), + patches, + ], + dim=0, + ) + patches = ( + take_indexes(patches, backward_indexes[1:] - 1) + if self.has_cls_token + else take_indexes(patches, backward_indexes) + ) + # patches to image + img = self.patch2img(patches) + return img diff --git a/cyto_dl/nn/vits/encoder.py b/cyto_dl/nn/vits/encoder.py new file mode 100644 index 000000000..1c579c9d8 --- /dev/null +++ b/cyto_dl/nn/vits/encoder.py @@ -0,0 +1,383 @@ +# modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 +# inspired by https://github.com/facebookresearch/hiera + +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional +from einops import rearrange +from einops.layers.torch import Rearrange +from timm.models.layers import trunc_normal_ +from timm.models.vision_transformer import Block + +from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify +from cyto_dl.nn.vits.blocks.masked_unit_attention import HieraBlock +from cyto_dl.nn.vits.blocks.patchify import PatchifyHiera +from cyto_dl.nn.vits.utils import match_tuple_dimensions + + +class MAE_Encoder(torch.nn.Module): + def __init__( + self, + num_patches: List[int], + spatial_dims: int = 3, + patch_size: Union[int, List[int]] = 4, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 12, + num_head: Optional[int] = 3, + context_pixels: Optional[Union[int, List[int]]] = 0, + input_channels: Optional[int] = 1, + n_intermediate_weights: Optional[int] = -1, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int], int + Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same. + spatial_dims: int + Number of spatial dimensions + patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + context_pixels: List[int], int + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same. + input_channels: int + Number of input channels + n_intermediate_weights: bool + Whether to use intermediate weights for weighted sum of intermediate layers + """ + super().__init__() + num_patches, patch_size, context_pixels = match_tuple_dimensions( + spatial_dims, [num_patches, patch_size, context_pixels] + ) + + self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) + self.patchify = Patchify( + patch_size, emb_dim, num_patches, spatial_dims, context_pixels, input_channels + ) + weight_intermediates = n_intermediate_weights > 0 + if weight_intermediates: + self.transformer = torch.nn.ModuleList( + [Block(emb_dim, num_head) for _ in range(num_layer)] + ) + else: + self.transformer = torch.nn.Sequential( + *[Block(emb_dim, num_head) for _ in range(num_layer)] + ) + + self.layer_norm = torch.nn.LayerNorm(emb_dim) + + self.intermediate_weighter = ( + IntermediateWeigher(num_layer, emb_dim, n_intermediate_weights) + if weight_intermediates + else None + ) + self.init_weight() + + def init_weight(self): + trunc_normal_(self.cls_token, std=0.02) + + def forward(self, img, mask_ratio=0.75): + patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio) + patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) + patches = rearrange(patches, "t b c -> b t c") + + if self.intermediate_weighter is not None: + intermediates = [patches] + for block in self.transformer: + patches = block(patches) + intermediates.append(patches) + features = self.layer_norm(self.intermediate_weighter(intermediates)) + features = rearrange(features, "n b t c -> n t b c") + else: + features = self.layer_norm(self.transformer(patches)) + features = rearrange(features, "b t c -> t b c") + if mask_ratio > 0: + return features, mask, forward_indexes, backward_indexes + return features + + +class JEPAEncoder(torch.nn.Module): + def __init__( + self, + num_patches: Union[int, List[int]], + spatial_dims: int = 3, + patch_size: Union[int, List[int]] = 4, + emb_dim: Optional[int] = 192, + num_layer: Optional[int] = 12, + num_head: Optional[int] = 3, + context_pixels: Optional[Union[int, List[int]]] = 0, + input_channels: Optional[int] = 1, + learnable_pos_embedding: Optional[bool] = True, + ) -> None: + """ + Parameters + ---------- + num_patches: List[int], int + Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same. + spatial_dims: int + Number of spatial dimensions + patch_size: List[int] + Size of each patch + emb_dim: int + Dimension of embedding + num_layer: int + Number of transformer layers + num_head: int + Number of heads in transformer + context_pixels: List[int], int + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. If a single int is provided, the number of context pixels in each dimension will be the same. + input_channels: int + Number of input channels + learnable_pos_embedding: bool + If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. + """ + super().__init__() + num_patches, patch_size, context_pixels = match_tuple_dimensions( + spatial_dims, [num_patches, patch_size, context_pixels] + ) + + self.patchify = Patchify( + patch_size=patch_size, + emb_dim=emb_dim, + n_patches=num_patches, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, + learnable_pos_embedding=learnable_pos_embedding, + ) + + self.transformer = torch.nn.Sequential( + *[Block(emb_dim, num_head) for _ in range(num_layer)] + ) + + self.layer_norm = torch.nn.LayerNorm(emb_dim) + + def forward(self, img, patchify=True): + if patchify: + patches, _, _, _ = self.patchify(img, mask_ratio=0) + else: + patches = img + patches = rearrange(patches, "t b c -> b t c") + features = self.layer_norm(self.transformer(patches)) + return features + + +class SpatialMerger(nn.Module): + """Class for converting multi-resolution Hiera features to the same (lowest) spatial resolution + via convolution.""" + + def __init__( + self, downsample_factor: List[int], in_dim: int, out_dim: int, spatial_dims: int = 3 + ): + super().__init__() + downsample_factor = match_tuple_dimensions(spatial_dims, [downsample_factor])[0] + + self.spatial_dims = spatial_dims + conv_fn = nn.Conv3d if spatial_dims == 3 else nn.Conv2d + conv = conv_fn( + in_channels=in_dim, + out_channels=out_dim, + kernel_size=downsample_factor, + stride=downsample_factor, + padding=0, + bias=False, + ) + if spatial_dims == 3: + tokens2img = Rearrange( + "b n_mu (z y x) c -> (b n_mu) c z y x", + z=downsample_factor[0], + y=downsample_factor[1], + x=downsample_factor[2], + ) + else: + tokens2img = Rearrange( + "b n_mu (y x) c -> (b n_mu) c y x", + y=downsample_factor[0], + x=downsample_factor[1], + ) + self.model = nn.Sequential(tokens2img, conv) + + def forward(self, x): + b, n_mu, _, _ = x.shape + x = self.model(x) + if self.spatial_dims == 3: + x = rearrange(x, "(b n_mu) c z y x -> b n_mu (z y x) c", b=b, n_mu=n_mu) + else: + x = rearrange(x, "(b n_mu) c y x -> b n_mu (y x) c", b=b, n_mu=n_mu) + return x + + +class HieraEncoder(torch.nn.Module): + def __init__( + self, + num_patches: Union[int, List[int]], + num_mask_units: Union[int, List[int]], + architecture: List[Dict], + emb_dim: int = 64, + spatial_dims: int = 3, + patch_size: Union[int, List[int]] = 4, + context_pixels: Optional[Union[int, List[int]]] = 0, + input_channels: Optional[int] = 1, + save_layers: Optional[bool] = False, + ) -> None: + """ + Parameters + ---------- + num_patches: int, List[int] + Number of patches in each dimension. If a single int is provided, the number of patches in each dimension will be the same. + num_mask_units: int, List[int] + Number of mask units in each dimension. If a single int is provided, the number of mask units in each dimension will be the same. + architecture: List[Dict] + List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: + - repeat: int + Number of times to repeat the block + - num_heads: int + Number of heads in the multihead attention + - q_stride: int, List[int] + Stride for the query in each spatial dimension + - self_attention: bool + Whether to use self attention or mask unit attention + On the last repeat of each non-self-attention block, the embedding dimension is doubled and spatial pooling with `q_stride` is performed within each mask unit. For example, a block with a embed_dim=4, q_stride=2, and repeat=2, the first repeat just does mask unit attention, while the second will produce an 8-dimensional output that has been spatially pooled. + emb_dim: int + Dimension of embedding + spatial_dims: int + Number of spatial dimensions + patch_size: List[int] + Size of each patch + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + save_layers: bool + Whether to save the intermediate layer outputs + """ + super().__init__() + num_patches, num_mask_units, patch_size, context_pixels = match_tuple_dimensions( + spatial_dims, [num_patches, num_mask_units, patch_size, context_pixels] + ) + # make sure q stride shape matches spatial dims + for i in range(len(architecture)): + if "q_stride" in architecture[i]: + architecture[i]["q_stride"] = match_tuple_dimensions( + spatial_dims, [architecture[i]["q_stride"]] + )[0] + + self.save_layers = save_layers + self.patchify = PatchifyHiera( + patch_size=patch_size, + n_patches=num_patches, + emb_dim=emb_dim, + spatial_dims=spatial_dims, + context_pixels=context_pixels, + input_channels=input_channels, + mask_units_per_dim=num_mask_units, + ) + + patches_per_mask_unit = np.array(num_patches) // np.array(num_mask_units) + + total_downsampling_per_axis = np.prod( + [block.get("q_stride", [1] * spatial_dims) for block in architecture], axis=0 + ) + + assert np.all( + patches_per_mask_unit - total_downsampling_per_axis >= 0 + ), f"Number of mask units must be greater than the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride or increase the number of patches per mask unit." + assert np.all( + patches_per_mask_unit % total_downsampling_per_axis == 0 + ), f"Number of mask units must be divisible by the total downsampling ratio, got {patches_per_mask_unit} patches per mask unit and {total_downsampling_per_axis} total downsampling ratio. Please adjust your q_stride" + + # number of output features doubles in each masked unit attention block, stays constant during self attention blocks + self.final_dim = emb_dim * (2 ** len(architecture)) + + self.save_block_idxs = [] + self.save_block_dims = [] + self.spatial_mergers = torch.nn.ParameterDict({}) + transformer = [] + num_blocks = 0 + for stage_num, stage in enumerate(architecture): + # use mask unit attention until first layer that uses self attention + if stage.get("self_attention", False): + break + print(f"Stage: {stage_num}") + for block in range(stage["repeat"]): + is_last = block == stage["repeat"] - 1 + # do spatial pooling within mask unit on last block of stage + q_stride = stage["q_stride"] if is_last else [1] * spatial_dims + + # double embedding dimension in last block of stage + dim_in = emb_dim * (2**stage_num) + dim_out = dim_in if not is_last else dim_in * 2 + print( + f"\tBlock {block}:\t\tdim_in: {dim_in}, dim_out: {dim_out}, num_heads: {stage['num_heads']}, q_stride: {q_stride}, patches_per_mask_unit: {patches_per_mask_unit}" + ) + transformer.append( + HieraBlock( + dim=dim_in, + dim_out=dim_out, + heads=stage["num_heads"], + spatial_dims=spatial_dims, + q_stride=q_stride, + patches_per_mask_unit=patches_per_mask_unit, + ) + ) + if is_last: + # save the block before the spatial pooling unless it's the final stage + save_block = ( + num_blocks - 1 if stage_num < len(architecture) - 1 else num_blocks + ) + self.save_block_idxs.append(save_block) + self.save_block_dims.append(dim_in) + + # create a spatial merger for combining tokens pre-downsampling, last stage doesn't need merging since it has expected num channels, spatial shape + self.spatial_mergers[f"block_{save_block}"] = ( + SpatialMerger( + patches_per_mask_unit, + dim_in, + self.final_dim, + spatial_dims=spatial_dims, + ) + if stage_num < len(architecture) - 1 + else torch.nn.Identity() + ) + + # at end of each layer, patches per mask unit is reduced as we pool spatially within mask units + patches_per_mask_unit = patches_per_mask_unit // np.array(stage["q_stride"]) + num_blocks += 1 + self.mask_unit_transformer = torch.nn.Sequential(*transformer) + self.save_block_dims.append(self.final_dim) + self.save_block_dims.reverse() + + self.self_attention_transformer = torch.nn.Sequential( + *[Block(self.final_dim, stage["num_heads"]) for _ in range(stage["repeat"])] + ) + + self.layer_norm = torch.nn.LayerNorm(self.final_dim) + + def forward(self, img, mask_ratio): + patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio) + + # mask unit attention + mask_unit_embeddings = 0.0 + save_layers = [] + for i, block in enumerate(self.mask_unit_transformer): + patches = block(patches) + if i in self.save_block_idxs: + mask_unit_embeddings += self.spatial_mergers[f"block_{i}"](patches) + if self.save_layers: + save_layers.append(patches) + + # combine mask units and tokens for full self attention transformer + mask_unit_embeddings = rearrange(mask_unit_embeddings, "b n_mu t c -> b (n_mu t) c") + mask_unit_embeddings = self.self_attention_transformer(mask_unit_embeddings) + mask_unit_embeddings = self.layer_norm(mask_unit_embeddings) + mask_unit_embeddings = rearrange(mask_unit_embeddings, "b t c -> t b c") + + return mask_unit_embeddings, mask, forward_indexes, backward_indexes # , save_layers diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index 1617bf687..17d3e4729 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -1,220 +1,63 @@ # modified from https://github.com/IcarusWizard/MAE/blob/main/model.py#L124 -from typing import List, Optional +from abc import ABC, abstractmethod +from typing import Dict, List, Optional, Union import numpy as np import torch -import torch.nn as nn -from einops import rearrange -from einops.layers.torch import Rearrange -from timm.models.layers import trunc_normal_ -from timm.models.vision_transformer import Block -from cyto_dl.nn.vits.blocks import IntermediateWeigher, Patchify -from cyto_dl.nn.vits.cross_mae import CrossMAE_Decoder -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes +from cyto_dl.nn.vits.decoder import CrossMAE_Decoder, MAE_Decoder +from cyto_dl.nn.vits.encoder import HieraEncoder, MAE_Encoder +from cyto_dl.nn.vits.utils import match_tuple_dimensions -class MAE_Encoder(torch.nn.Module): +class MAE_Base(torch.nn.Module, ABC): def __init__( - self, - num_patches: List[int], - spatial_dims: int = 3, - base_patch_size: List[int] = (16, 16, 16), - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 12, - num_head: Optional[int] = 3, - context_pixels: Optional[List[int]] = [0, 0, 0], - input_channels: Optional[int] = 1, - n_intermediate_weights: Optional[int] = -1, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - spatial_dims: int - Number of spatial dimensions - base_patch_size: List[int] - Size of each patch - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. - input_channels: int - Number of input channels - weight_intermediates: bool - Whether to output linear combination of intermediate layers as final output like CrossMAE - """ - super().__init__() - self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - self.patchify = Patchify( - base_patch_size, emb_dim, num_patches, spatial_dims, context_pixels, input_channels - ) - weight_intermediates = n_intermediate_weights > 0 - if weight_intermediates: - self.transformer = torch.nn.ModuleList( - [Block(emb_dim, num_head) for _ in range(num_layer)] - ) - else: - self.transformer = torch.nn.Sequential( - *[Block(emb_dim, num_head) for _ in range(num_layer)] - ) - - self.layer_norm = torch.nn.LayerNorm(emb_dim) - - self.intermediate_weighter = ( - IntermediateWeigher(num_layer, emb_dim, n_intermediate_weights) - if weight_intermediates - else None - ) - self.init_weight() - - def init_weight(self): - trunc_normal_(self.cls_token, std=0.02) - - def forward(self, img, mask_ratio=0.75): - patches, mask, forward_indexes, backward_indexes = self.patchify(img, mask_ratio) - patches = torch.cat([self.cls_token.expand(-1, patches.shape[1], -1), patches], dim=0) - patches = rearrange(patches, "t b c -> b t c") - - if self.intermediate_weighter is not None: - intermediates = [patches] - for block in self.transformer: - patches = block(patches) - intermediates.append(patches) - features = self.layer_norm(self.intermediate_weighter(intermediates)) - features = rearrange(features, "n b t c -> n t b c") - else: - features = self.layer_norm(self.transformer(patches)) - features = rearrange(features, "b t c -> t b c") - if mask_ratio > 0: - return features, mask, forward_indexes, backward_indexes - return features - - -class MAE_Decoder(torch.nn.Module): - def __init__( - self, - num_patches: List[int], - spatial_dims: int = 3, - base_patch_size: Optional[List[int]] = [4, 8, 8], - enc_dim: Optional[int] = 768, - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 4, - num_head: Optional[int] = 3, - learnable_pos_embedding: Optional[bool] = True, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - base_patch_size: Tuple[int] - Size of each patch - enc_dim: int - Dimension of encoder - emb_dim: int - Dimension of decoder - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - learnable_pos_embedding: bool - If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. - """ + self, spatial_dims, num_patches, patch_size, mask_ratio, features_only, context_pixels + ): super().__init__() - self.projection_norm = nn.LayerNorm(emb_dim) - self.projection = torch.nn.Linear(enc_dim, emb_dim) - self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) - - self.pos_embedding = get_positional_embedding( - num_patches, emb_dim, learnable=learnable_pos_embedding + num_patches, patch_size, context_pixels = match_tuple_dimensions( + spatial_dims, [num_patches, patch_size, context_pixels] ) - self.transformer = torch.nn.Sequential( - *[Block(emb_dim, num_head) for _ in range(num_layer)] - ) - out_dim = torch.prod(torch.as_tensor(base_patch_size)).item() - self.head_norm = nn.LayerNorm(out_dim) - self.head = torch.nn.Linear(emb_dim, out_dim) - self.num_patches = torch.as_tensor(num_patches) - - if spatial_dims == 3: - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_z=num_patches[0], - n_patch_y=num_patches[1], - n_patch_x=num_patches[2], - patch_size_z=base_patch_size[0], - patch_size_y=base_patch_size[1], - patch_size_x=base_patch_size[2], - ) - elif spatial_dims == 2: - self.patch2img = Rearrange( - "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_y=num_patches[0], - n_patch_x=num_patches[1], - patch_size_y=base_patch_size[0], - patch_size_x=base_patch_size[1], - ) - - self.init_weight() - - def init_weight(self): - trunc_normal_(self.mask_token, std=0.02) + self.spatial_dims = spatial_dims + self.num_patches = num_patches + self.patch_size = patch_size + self.mask_ratio = mask_ratio + self.features_only = features_only + self.context_pixels = context_pixels - def forward(self, features, forward_indexes, backward_indexes): - # project from encoder dimension to decoder dimension - features = self.projection_norm(self.projection(features)) + # encoder and decoder must be defined in subclasses + @property + @abstractmethod + def encoder(self): + pass - backward_indexes = torch.cat( - [ - torch.zeros( - 1, backward_indexes.shape[1], device=backward_indexes.device, dtype=torch.long - ), - backward_indexes + 1, - ], - dim=0, - ) - # fill in masked regions - features = torch.cat( - [ - features, - self.mask_token.expand( - backward_indexes.shape[0] - features.shape[0], features.shape[1], -1 - ), - ], - dim=0, - ) - # unshuffle to original positions - features = take_indexes(features, backward_indexes) - features = features + self.pos_embedding + @property + @abstractmethod + def decoder(self): + pass - # decode - features = rearrange(features, "t b c -> b t c") - features = self.transformer(features) - features = rearrange(features, "b t c -> t b c") - features = features[1:] # remove global feature + def init_encoder(self): + raise NotImplementedError - # (npatches x npatches x npatches) b (emb dim) -> (npatches* npatches * npatches) b (z y x) - patches = self.head_norm(self.head(features)) + def init_decoder(self): + raise NotImplementedError - # patches to image - img = self.patch2img(patches) - return img + def forward(self, img): + features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio) + if self.features_only: + return features + predicted_img = self.decoder(features, forward_indexes, backward_indexes) + return predicted_img, mask -class MAE_ViT(torch.nn.Module): +class MAE(MAE_Base): def __init__( self, spatial_dims: int = 3, - num_patches: Optional[List[int]] = [2, 32, 32], - base_patch_size: Optional[List[int]] = [16, 16, 16], + num_patches: Optional[List[int]] = 16, + patch_size: Optional[List[int]] = 4, emb_dim: Optional[int] = 768, encoder_layer: Optional[int] = 12, encoder_head: Optional[int] = 8, @@ -223,7 +66,7 @@ def __init__( decoder_dim: Optional[int] = 192, mask_ratio: Optional[int] = 0.75, use_crossmae: Optional[bool] = False, - context_pixels: Optional[List[int]] = [0, 0, 0], + context_pixels: Optional[List[int]] = 0, input_channels: Optional[int] = 1, features_only: Optional[bool] = False, learnable_pos_embedding: Optional[bool] = True, @@ -235,7 +78,7 @@ def __init__( Number of spatial dimensions num_patches: List[int] Number of patches in each dimension (ZYX order) - base_patch_size: List[int] + patch_size: List[int] Size of each patch (ZYX order) emb_dim: int Dimension of encoder embedding @@ -260,30 +103,23 @@ def __init__( learnable_pos_embedding: bool If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. """ - super().__init__() - assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" - - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(base_patch_size, int): - base_patch_size = [base_patch_size] * spatial_dims - - assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" - assert ( - len(base_patch_size) == spatial_dims - ), "base_patch_size must be of length spatial_dims" - - self.mask_ratio = mask_ratio - self.features_only = features_only + super().__init__( + spatial_dims=spatial_dims, + num_patches=num_patches, + patch_size=patch_size, + mask_ratio=mask_ratio, + features_only=features_only, + context_pixels=context_pixels, + ) - self.encoder = MAE_Encoder( - num_patches, + self._encoder = MAE_Encoder( + self.num_patches, spatial_dims, - base_patch_size, + self.patch_size, emb_dim, encoder_layer, encoder_head, - context_pixels, + self.context_pixels, input_channels, n_intermediate_weights=-1 if not use_crossmae else decoder_layer, ) @@ -291,10 +127,10 @@ def __init__( decoder_class = MAE_Decoder if use_crossmae: decoder_class = CrossMAE_Decoder - self.decoder = decoder_class( - num_patches=num_patches, + self._decoder = decoder_class( + num_patches=self.num_patches, spatial_dims=spatial_dims, - base_patch_size=base_patch_size, + patch_size=self.patch_size, enc_dim=emb_dim, emb_dim=decoder_dim, num_layer=decoder_layer, @@ -302,9 +138,117 @@ def __init__( learnable_pos_embedding=learnable_pos_embedding, ) - def forward(self, img): - features, mask, forward_indexes, backward_indexes = self.encoder(img, self.mask_ratio) - if self.features_only: - return features - predicted_img = self.decoder(features, forward_indexes, backward_indexes) - return predicted_img, mask + @property + def encoder(self): + return self._encoder + + @property + def decoder(self): + return self._decoder + + +class HieraMAE(MAE_Base): + def __init__( + self, + architecture: List[Dict], + spatial_dims: int = 3, + num_patches: Optional[Union[int, List[int]]] = 16, + num_mask_units: Optional[Union[int, List[int]]] = 8, + patch_size: Optional[Union[int, List[int]]] = 4, + emb_dim: Optional[int] = 64, + decoder_layer: Optional[int] = 4, + decoder_head: Optional[int] = 8, + decoder_dim: Optional[int] = 192, + mask_ratio: Optional[int] = 0.75, + use_crossmae: Optional[bool] = False, + context_pixels: Optional[List[int]] = 0, + input_channels: Optional[int] = 1, + features_only: Optional[bool] = False, + ) -> None: + """ + Parameters + ---------- + architecture: List[Dict] + List of dictionaries specifying the architecture of the transformer. Each dictionary should have the following keys: + - repeat: int + Number of times to repeat the block + - num_heads: int + Number of heads in the multihead attention + - q_stride: List[int] + Stride for the query in each spatial dimension + - self_attention: bool + Whether to use self attention or mask unit attention + spatial_dims: int + Number of spatial dimensions + num_patches: int, List[int] + Number of patches in each dimension (Z)YX order. If int, the same number of patches is used in each dimension. + num_mask_units: int, List[int] + Number of mask units in each dimension (Z)YX order. If int, the same number of mask units is used in each dimension. + patch_size: int, List[int] + Size of each patch (Z)YX order. If int, the same patch size is used in each dimension. + emb_dim: int + Dimension of embedding + decoder_layer: int + Number of layers in the decoder + decoder_head: int + Number of heads in the decoder + decoder_dim: int + Dimension of the decoder + mask_ratio: float + Fraction of mask units to remove + use_crossmae: bool + Use CrossMAE-style decoder instead of MAE decoder + context_pixels: List[int] + Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. + input_channels: int + Number of input channels + features_only: bool + Only use encoder to extract features + """ + super().__init__( + spatial_dims=spatial_dims, + num_patches=num_patches, + patch_size=patch_size, + mask_ratio=mask_ratio, + features_only=features_only, + context_pixels=context_pixels, + ) + num_mask_units = match_tuple_dimensions(self.spatial_dims, [num_mask_units])[0] + + self._encoder = HieraEncoder( + num_patches=self.num_patches, + num_mask_units=num_mask_units, + architecture=architecture, + emb_dim=emb_dim, + spatial_dims=self.spatial_dims, + patch_size=self.patch_size, + context_pixels=self.context_pixels, + input_channels=input_channels, + ) + # "patches" to the decoder are actually mask units, so num_patches is num_mask_units, patch_size is mask unit size + mask_unit_size = (np.array(self.num_patches) * np.array(self.patch_size)) / np.array( + num_mask_units + ) + + decoder_class = MAE_Decoder + if use_crossmae: + decoder_class = CrossMAE_Decoder + + self._decoder = decoder_class( + num_patches=num_mask_units, + spatial_dims=spatial_dims, + patch_size=mask_unit_size.astype(int), + enc_dim=self.encoder.final_dim, + emb_dim=decoder_dim, + num_layer=decoder_layer, + num_head=decoder_head, + has_cls_token=False, + ) + + @property + def encoder(self): + return self._encoder + + @property + def decoder(self): + return self._decoder diff --git a/cyto_dl/nn/vits/jepa.py b/cyto_dl/nn/vits/predictor.py similarity index 64% rename from cyto_dl/nn/vits/jepa.py rename to cyto_dl/nn/vits/predictor.py index 6b45303d2..8eb5978fb 100644 --- a/cyto_dl/nn/vits/jepa.py +++ b/cyto_dl/nn/vits/predictor.py @@ -4,76 +4,13 @@ import torch.nn.functional from einops import rearrange from timm.models.layers import trunc_normal_ -from timm.models.vision_transformer import Block -from cyto_dl.nn.vits.blocks import CrossAttentionBlock, Patchify -from cyto_dl.nn.vits.utils import get_positional_embedding, take_indexes - - -class JEPAEncoder(torch.nn.Module): - def __init__( - self, - num_patches: Union[int, List[int]], - spatial_dims: int = 3, - patch_size: Union[int, List[int]] = (16, 16, 16), - emb_dim: Optional[int] = 192, - num_layer: Optional[int] = 12, - num_head: Optional[int] = 3, - context_pixels: Optional[List[int]] = [0, 0, 0], - input_channels: Optional[int] = 1, - learnable_pos_embedding: Optional[bool] = True, - ) -> None: - """ - Parameters - ---------- - num_patches: List[int] - Number of patches in each dimension - spatial_dims: int - Number of spatial dimensions - patch_size: List[int] - Size of each patch - emb_dim: int - Dimension of embedding - num_layer: int - Number of transformer layers - num_head: int - Number of heads in transformer - context_pixels: List[int] - Number of extra pixels around each patch to include in convolutional embedding to encoder dimension. - input_channels: int - Number of input channels - learnable_pos_embedding: bool - If True, learnable positional embeddings are used. If False, fixed sin/cos positional embeddings. Empirically, fixed positional embeddings work better for brightfield images. - """ - super().__init__() - if isinstance(num_patches, int): - num_patches = [num_patches] * spatial_dims - if isinstance(patch_size, int): - patch_size = [patch_size] * spatial_dims - self.patchify = Patchify( - patch_size, - emb_dim, - num_patches, - spatial_dims, - context_pixels, - input_channels, - learnable_pos_embedding=learnable_pos_embedding, - ) - - self.transformer = torch.nn.Sequential( - *[Block(emb_dim, num_head) for _ in range(num_layer)] - ) - - self.layer_norm = torch.nn.LayerNorm(emb_dim) - - def forward(self, img, patchify=True): - if patchify: - patches, _, _, _ = self.patchify(img, mask_ratio=0) - else: - patches = img - patches = rearrange(patches, "t b c -> b t c") - features = self.layer_norm(self.transformer(patches)) - return features +from cyto_dl.nn.vits.blocks import CrossAttentionBlock +from cyto_dl.nn.vits.utils import ( + get_positional_embedding, + match_tuple_dimensions, + take_indexes, +) class JEPAPredictor(torch.nn.Module): @@ -81,7 +18,8 @@ class JEPAPredictor(torch.nn.Module): def __init__( self, - num_patches: List[int], + num_patches: Union[int, List[int]], + spatial_dims: int = 3, input_dim: Optional[int] = 192, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 12, @@ -91,8 +29,12 @@ def __init__( """ Parameters ---------- - num_patches: List[int] - Number of patches in each dimension + num_patches: List[int], int + Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions + spatial_dims: int + Number of spatial dimensions + input_dim: int + Dimension of input emb_dim: int Dimension of embedding num_layer: int @@ -114,6 +56,8 @@ def __init__( ] ) + num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0] + self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = get_positional_embedding( num_patches, emb_dim, use_cls_token=False, learnable=learnable_pos_embedding @@ -160,7 +104,8 @@ class IWMPredictor(JEPAPredictor): def __init__( self, domains: List[str], - num_patches: List[int], + num_patches: Union[int, List[int]], + spatial_dims: int = 3, input_dim: Optional[int] = 192, emb_dim: Optional[int] = 192, num_layer: Optional[int] = 12, @@ -172,7 +117,11 @@ def __init__( domains: List[str] List of names of target domains num_patches: List[int] - Number of patches in each dimension + Number of patches in each dimension. If int, the same number of patches is used for all spatial dimensions + spatial_dims: int + Number of spatial dimensions + spatial_dims: int + Number of spatial dimensions emb_dim: int Dimension of embedding num_layer: int @@ -180,7 +129,14 @@ def __init__( num_head: int Number of heads in transformer """ - super().__init__(num_patches, input_dim, emb_dim, num_layer, num_head) + super().__init__( + num_patches=num_patches, + spatial_dims=spatial_dims, + input_dim=input_dim, + emb_dim=emb_dim, + num_layer=num_layer, + num_head=num_head, + ) self.domain_embeddings = torch.nn.ParameterDict( {d: torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) for d in domains} diff --git a/cyto_dl/nn/vits/utils.py b/cyto_dl/nn/vits/utils.py index b918fb73c..aa96557fe 100644 --- a/cyto_dl/nn/vits/utils.py +++ b/cyto_dl/nn/vits/utils.py @@ -3,6 +3,7 @@ import numpy as np import torch from einops import rearrange, repeat +from monai.utils.misc import ensure_tuple_rep from positional_encodings.torch_encodings import ( PositionalEncoding2D, PositionalEncoding3D, @@ -14,6 +15,12 @@ def take_indexes(sequences, indexes): return torch.gather(sequences, 0, repeat(indexes, "t b -> t b c", c=sequences.shape[-1])) +def random_indexes(size: int, device): + forward_indexes = torch.randperm(size, device=device, dtype=torch.long) + backward_indexes = torch.argsort(forward_indexes) + return forward_indexes, backward_indexes + + def get_positional_embedding( num_patches: Sequence[int], emb_dim: int, use_cls_token: bool = True, learnable: bool = True ): @@ -40,3 +47,12 @@ def get_positional_embedding( cls_token = torch.zeros(1, 1, emb_dim) pe = torch.cat([cls_token, pe], dim=0) return torch.nn.Parameter(pe, requires_grad=False) + + +def match_tuple_dimensions(spatial_dims, tuples): + """Ensure that each element in a list of tuples has the same length as spatial_dims. + + If a single element, the element is repeated to match the spatial_dims. + """ + assert spatial_dims in (2, 3), "spatial_dims must be 2 or 3" + return [ensure_tuple_rep(t, spatial_dims) for t in tuples] diff --git a/tests/conftest.py b/tests/conftest.py index 5094ae856..ed987812d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,16 @@ OmegaConf.register_new_resolver("eval", eval) # Experiment configs to test -experiment_types = ["mae", "ijepa", "iwm", "segmentation", "labelfree", "gan", "instance_seg"] +experiment_types = [ + "hiera", + "mae", + "ijepa", + "iwm", + "segmentation", + "labelfree", + "gan", + "instance_seg", +] @pytest.fixture(scope="package", params=experiment_types) diff --git a/tests/utils.py b/tests/utils.py index 61778c230..99d097067 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,4 +9,4 @@ def resolve_readonly(cfg): def skip_test(test_name): """Skip pretraining models for testing.""" - return bool(np.any([x in test_name for x in ("mae", "ijepa", "iwm")])) + return bool(np.any([x in test_name for x in ("mae", "ijepa", "iwm", "hiera")]))