Skip to content

Commit

Permalink
Feature/hiera (#418)
Browse files Browse the repository at this point in the history
* Bump version: 0.1.5 → 0.1.6

* add hiera

* start of mask2former

* first take at transfomer

* fix dimensionality, now updating instance queries instead of mask

* give instance queries own dim

* add mask creation

* remove experimental code

* update to base patchify

* wip

* update configs

* update patchify

* rearrange encoder/decoder/mae

* add 2d hiera

* 2d masked unit attention

* precommit

* update configs

* update hiera model config

* update deafults

* update tests

* delete patchify_conv

* fix jepa tests

* update mask transform

* add predictor

* precommit

* update with ritviks comments

* replace all function names

---------

Co-authored-by: Benjamin Morris <[email protected]>
Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
3 people authored Aug 26, 2024
1 parent 5f45508 commit 48e7cb9
Show file tree
Hide file tree
Showing 26 changed files with 1,508 additions and 555 deletions.
3 changes: 3 additions & 0 deletions configs/data/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

test:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -69,6 +70,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

predict:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -127,6 +129,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

_aux:
_scales_dict:
Expand Down
2 changes: 2 additions & 0 deletions configs/data/im2im/iwm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

test:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -176,5 +177,6 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

_aux:
19 changes: 8 additions & 11 deletions configs/data/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ transforms:
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
- _target_: monai.transforms.RandSpatialCropSamplesd
keys:
- ${source_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
roi_size: ${data._aux.patch_shape}
num_samples: 1
random_size: False

test:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -104,14 +104,11 @@ transforms:
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
- _target_: monai.transforms.RandSpatialCropSamplesd
keys:
- ${source_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
roi_size: ${data._aux.patch_shape}
num_samples: 1
random_size: False

_aux:
_scales_dict:
- - ${source_col}
- [1]
52 changes: 52 additions & 0 deletions configs/experiment/im2im/hiera.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=example
defaults:
- override /data: im2im/mae.yaml
- override /model: im2im/hiera.yaml
- override /callbacks: default.yaml
- override /trainer: gpu.yaml
- override /logger: csv.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["dev"]
seed: 12345

experiment_name: YOUR_EXP_NAME
run_name: YOUR_RUN_NAME

# only source_col is needed for masked autoencoder
source_col: raw
spatial_dims: 3
raw_im_channels: 1

trainer:
max_epochs: 100
gradient_clip_val: 10

data:
path: ${paths.data_dir}/example_experiment_data/segmentation
cache_dir: ${paths.data_dir}/example_experiment_data/cache
batch_size: 1
_aux:
# 2D
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
7 changes: 0 additions & 7 deletions configs/experiment/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,3 @@ data:
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

model:
_aux:
# 3D
num_patches: [8, 8, 8]
# 2d
# num_patches: [8, 8]
7 changes: 0 additions & 7 deletions configs/experiment/im2im/iwm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ data:
# 3D
patch_shape: [16, 16, 16]

model:
_aux:
# 3D
num_patches: [8, 8, 8]
# 2d
# num_patches: [8, 8]

callbacks:
prediction_saver:
_target_: cyto_dl.callbacks.csv_saver.CSVSaver
Expand Down
70 changes: 70 additions & 0 deletions configs/model/im2im/hiera.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
_target_: cyto_dl.models.im2im.MultiTaskIm2Im

save_images_every_n_epochs: 1
save_dir: ${paths.output_dir}

x_key: ${source_col}

backbone:
_target_: cyto_dl.nn.vits.mae.HieraMAE
spatial_dims: ${spatial_dims}
patch_size: 2 # patch_size * num_patches should be your image shape (data._aux.patch_shape)
num_patches: 8 # patch_size * num_patches = img_shape
num_mask_units: 4 #Mask units are used for local attention. img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit
emb_dim: 4
# NOTE: this is a very small model for testing - for best performance, the downsampling ratios, embedding dimension, number of layers and number of heads should be adjusted to your data
architecture:
# mask_unit_attention blocks - attention is only done within a mask unit and not across mask units
# the total amount of q_stride across the architecture must be less than the number of patches per mask unit
- repeat: 1 # number of times to repeat this block
q_stride: 2 # size of downsampling within a mask unit
num_heads: 1
- repeat: 1
q_stride: 1
num_heads: 2
# self attention transformer - attention is done across all patches, irrespective of which mask unit they're in
- repeat: 2
num_heads: 4
self_attention: True
decoder_layer: 1
decoder_dim: 16
mask_ratio: 0.66666666666
context_pixels: 3
use_crossmae: True

task_heads: ${kv_to_dict:${model._aux._tasks}}

optimizer:
generator:
_partial_: True
_target_: torch.optim.AdamW
weight_decay: 0.05

lr_scheduler:
generator:
_partial_: True
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 0.0001
epochs: ${trainer.max_epochs}
steps_per_epoch: 1
pct_start: 0.1

inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0
progress: True
mode: "gaussian"

_aux:
_tasks:
- - ${source_col}
- _target_: cyto_dl.nn.head.mae_head.MAEHead
loss:
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
rescale_dtype: numpy.uint8
prediction:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
rescale_dtype: numpy.uint8
9 changes: 5 additions & 4 deletions configs/model/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ max_epochs: ${trainer.max_epochs}
save_dir: ${paths.output_dir}

encoder:
_target_: cyto_dl.nn.vits.jepa.JEPAEncoder
patch_size: 2
_target_: cyto_dl.nn.vits.encoder.JEPAEncoder
patch_size: 2 # patch_size * num_patches should equl data._aux.patch_shape
num_patches: ${model._aux.num_patches}
emb_dim: 16
num_layer: 2
num_head: 1
spatial_dims: ${spatial_dims}

predictor:
_target_: cyto_dl.nn.vits.jepa.JEPAPredictor
_target_: cyto_dl.nn.vits.predictor.JEPAPredictor
num_patches: ${model._aux.num_patches}
input_dim: ${model.encoder.emb_dim}
emb_dim: 8
Expand All @@ -34,4 +35,4 @@ lr_scheduler:
pct_start: 0.1

_aux:
num_patches:
num_patches: 8
9 changes: 5 additions & 4 deletions configs/model/im2im/iwm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ max_epochs: ${trainer.max_epochs}
save_dir: ${paths.output_dir}

encoder:
_target_: cyto_dl.nn.vits.jepa.JEPAEncoder
patch_size: 2
_target_: cyto_dl.nn.vits.encoder.JEPAEncoder
patch_size: 2 # patch_size * num_patches should be the same as data._aux.patch_shape
num_patches: ${model._aux.num_patches}
emb_dim: 16
num_layer: 1
num_head: 1
spatial_dims: ${spatial_dims}

predictor:
_target_: cyto_dl.nn.vits.jepa.IWMPredictor
_target_: cyto_dl.nn.vits.predictor.IWMPredictor
domains: [SEC61B]
num_patches: ${model._aux.num_patches}
input_dim: ${model.encoder.emb_dim}
Expand All @@ -38,4 +39,4 @@ lr_scheduler:
pct_start: 0.1

_aux:
num_patches:
num_patches: 8
6 changes: 3 additions & 3 deletions configs/model/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ save_dir: ${paths.output_dir}
x_key: ${source_col}

backbone:
_target_: cyto_dl.nn.vits.MAE_ViT
_target_: cyto_dl.nn.vits.MAE
spatial_dims: ${spatial_dims}
# base_patch_size* num_patches should be your patch shape
base_patch_size: 2
# patch_size* num_patches should be your patch shape
patch_size: 2
num_patches: 8
emb_dim: 16
encoder_layer: 2
Expand Down
4 changes: 2 additions & 2 deletions configs/model/im2im/vit_segmentation_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ x_key: ${source_col}
backbone:
_target_: cyto_dl.nn.vits.Seg_ViT
spatial_dims: ${spatial_dims}
# base_patch_size* num_patches should be your patch shape
base_patch_size: 2
# patch_size* num_patches should be your patch shape
patch_size: 2
num_patches: 8
emb_dim: 16
encoder_layer: 2
Expand Down
11 changes: 9 additions & 2 deletions cyto_dl/image/transforms/generate_jepa_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from monai.transforms import RandomizableTransform
from skimage.segmentation import find_boundaries

from cyto_dl.nn.vits.utils import match_tuple_dimensions


class JEPAMaskGenerator(RandomizableTransform):
"""Transform for generating Block-contiguous masks for JEPA training.
Expand All @@ -15,6 +17,7 @@ class JEPAMaskGenerator(RandomizableTransform):

def __init__(
self,
spatial_dims: int,
mask_size: int = 12,
block_aspect_ratio: Tuple[float] = (0.5, 1.5),
num_patches: Tuple[float] = (6, 24, 24),
Expand All @@ -23,6 +26,8 @@ def __init__(
"""
Parameters
----------
spatial_dims : int
The number of spatial dimensions of the image (2 or 3)
mask_size : int, optional
The size of the blocks used to generate mask. Block dimensions are determined by the mask size and an aspect ratio sampled from the range `block_aspect_ratio`
block_aspect_ratio : Tuple[float], optional
Expand All @@ -32,7 +37,9 @@ def __init__(
mask_ratio : float, optional
The proportion of the image to be masked
"""
assert mask_ratio < 1, "mask_ratio must be less than 1"
assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"

num_patches = match_tuple_dimensions(spatial_dims, [num_patches])[0]
assert mask_size * max(block_aspect_ratio) < min(
num_patches[-2:]
), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches"
Expand All @@ -46,7 +53,7 @@ def __init__(
self.mask = np.zeros(num_patches)
self.edge_mask = np.ones(num_patches)

self.spatial_dims = len(num_patches)
self.spatial_dims = spatial_dims
# create a mask that identified pixels on the edge of the image
if self.spatial_dims == 3:
self.edge_mask[1:-1, 1:-1, 1:-1] = 0
Expand Down
5 changes: 3 additions & 2 deletions cyto_dl/nn/vits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .cross_mae import CrossMAE_Decoder
from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT
from .decoder import CrossMAE_Decoder, MAE_Decoder
from .encoder import HieraEncoder, MAE_Encoder
from .mae import MAE, HieraMAE
from .seg import Seg_ViT, SuperresDecoder
from .utils import take_indexes
Loading

0 comments on commit 48e7cb9

Please sign in to comment.