Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/hiera #418

Merged
merged 31 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions configs/data/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

test:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -69,6 +70,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

predict:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -127,6 +129,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

_aux:
_scales_dict:
Expand Down
2 changes: 2 additions & 0 deletions configs/data/im2im/iwm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

test:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -176,5 +177,6 @@ transforms:
- _target_: cyto_dl.image.transforms.generate_jepa_masks.JEPAMaskGenerator
mask_size: 4
num_patches: ${model._aux.num_patches}
spatial_dims: ${spatial_dims}

_aux:
19 changes: 8 additions & 11 deletions configs/data/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ transforms:
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
- _target_: monai.transforms.RandSpatialCropSamplesd
keys:
- ${source_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
roi_size: ${data._aux.patch_shape}
num_samples: 1
random_size: False

test:
_target_: monai.transforms.Compose
Expand Down Expand Up @@ -104,14 +104,11 @@ transforms:
- _target_: monai.transforms.NormalizeIntensityd
keys: ${source_col}
channel_wise: True
- _target_: cyto_dl.image.transforms.RandomMultiScaleCropd
- _target_: monai.transforms.RandSpatialCropSamplesd
keys:
- ${source_col}
patch_shape: ${data._aux.patch_shape}
patch_per_image: 1
scales_dict: ${kv_to_dict:${data._aux._scales_dict}}
roi_size: ${data._aux.patch_shape}
num_samples: 1
random_size: False

_aux:
_scales_dict:
- - ${source_col}
- [1]
52 changes: 52 additions & 0 deletions configs/experiment/im2im/hiera.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# @package _global_
# to execute this experiment run:
# python train.py experiment=example
defaults:
- override /data: im2im/mae.yaml
- override /model: im2im/hiera.yaml
- override /callbacks: default.yaml
- override /trainer: gpu.yaml
- override /logger: csv.yaml

# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

tags: ["dev"]
seed: 12345

experiment_name: YOUR_EXP_NAME
run_name: YOUR_RUN_NAME

# only source_col is needed for masked autoencoder
source_col: raw
spatial_dims: 3
raw_im_channels: 1

trainer:
max_epochs: 100
gradient_clip_val: 10

data:
path: ${paths.data_dir}/example_experiment_data/segmentation
cache_dir: ${paths.data_dir}/example_experiment_data/cache
batch_size: 1
_aux:
# 2D
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

callbacks:
# prediction
# saving:
# _target_: cyto_dl.callbacks.ImageSaver
# save_dir: ${paths.output_dir}
# save_every_n_epochs: ${model.save_images_every_n_epochs}
# stages: ["predict"]
# save_input: False
# training
saving:
_target_: cyto_dl.callbacks.ImageSaver
save_dir: ${paths.output_dir}
save_every_n_epochs: ${model.save_images_every_n_epochs}
stages: ["train", "test", "val"]
7 changes: 0 additions & 7 deletions configs/experiment/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,3 @@ data:
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 16, 16]

model:
_aux:
# 3D
num_patches: [8, 8, 8]
# 2d
# num_patches: [8, 8]
7 changes: 0 additions & 7 deletions configs/experiment/im2im/iwm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,6 @@ data:
# 3D
patch_shape: [16, 16, 16]

model:
_aux:
# 3D
num_patches: [8, 8, 8]
# 2d
# num_patches: [8, 8]

callbacks:
prediction_saver:
_target_: cyto_dl.callbacks.csv_saver.CSVSaver
Expand Down
70 changes: 70 additions & 0 deletions configs/model/im2im/hiera.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
_target_: cyto_dl.models.im2im.MultiTaskIm2Im

save_images_every_n_epochs: 1
save_dir: ${paths.output_dir}

x_key: ${source_col}

backbone:
_target_: cyto_dl.nn.vits.mae.HieraMAE
spatial_dims: ${spatial_dims}
patch_size: 2 # patch_size * num_patches should be your image shape (data._aux.patch_shape)
num_patches: 8 # patch_size * num_patches = img_shape
num_mask_units: 4 #Mask units are used for local attention. img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit
emb_dim: 4
# NOTE: this is a very small model for testing - for best performance, the downsampling ratios, embedding dimension, number of layers and number of heads should be adjusted to your data
architecture:
# mask_unit_attention blocks - attention is only done within a mask unit and not across mask units
# the total amount of q_stride across the architecture must be less than the number of patches per mask unit
- repeat: 1 # number of times to repeat this block
q_stride: 2 # size of downsampling within a mask unit
num_heads: 1
- repeat: 1
q_stride: 1
num_heads: 2
# self attention transformer - attention is done across all patches, irrespective of which mask unit they're in
- repeat: 2
num_heads: 4
self_attention: True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so last layer is global attention and first 2 layers are local attention? Is 3 layers the recommended hierarchy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. 3 layers is small enough to test quickly. All of the models with unit tests are tiny by default in the configs and I have somewhere in the docs that you should increase the model size if you want good performance.

decoder_layer: 1
decoder_dim: 16
mask_ratio: 0.66666666666
context_pixels: 3
use_crossmae: True

task_heads: ${kv_to_dict:${model._aux._tasks}}

optimizer:
generator:
_partial_: True
_target_: torch.optim.AdamW
weight_decay: 0.05

lr_scheduler:
generator:
_partial_: True
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 0.0001
epochs: ${trainer.max_epochs}
steps_per_epoch: 1
pct_start: 0.1

inference_args:
sw_batch_size: 1
roi_size: ${data._aux.patch_shape}
overlap: 0
progress: True
mode: "gaussian"

_aux:
_tasks:
- - ${source_col}
- _target_: cyto_dl.nn.head.mae_head.MAEHead
loss:
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
rescale_dtype: numpy.uint8
prediction:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
rescale_dtype: numpy.uint8
9 changes: 5 additions & 4 deletions configs/model/im2im/ijepa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ max_epochs: ${trainer.max_epochs}
save_dir: ${paths.output_dir}

encoder:
_target_: cyto_dl.nn.vits.jepa.JEPAEncoder
patch_size: 2
_target_: cyto_dl.nn.vits.encoder.JEPAEncoder
patch_size: 2 # patch_size * num_patches should equl data._aux.patch_shape
num_patches: ${model._aux.num_patches}
emb_dim: 16
num_layer: 2
num_head: 1
spatial_dims: ${spatial_dims}

predictor:
_target_: cyto_dl.nn.vits.jepa.JEPAPredictor
_target_: cyto_dl.nn.vits.predictor.JEPAPredictor
num_patches: ${model._aux.num_patches}
input_dim: ${model.encoder.emb_dim}
emb_dim: 8
Expand All @@ -34,4 +35,4 @@ lr_scheduler:
pct_start: 0.1

_aux:
num_patches:
num_patches: 8
9 changes: 5 additions & 4 deletions configs/model/im2im/iwm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@ max_epochs: ${trainer.max_epochs}
save_dir: ${paths.output_dir}

encoder:
_target_: cyto_dl.nn.vits.jepa.JEPAEncoder
patch_size: 2
_target_: cyto_dl.nn.vits.encoder.JEPAEncoder
patch_size: 2 # patch_size * num_patches should be the same as data._aux.patch_shape
num_patches: ${model._aux.num_patches}
emb_dim: 16
num_layer: 1
num_head: 1
spatial_dims: ${spatial_dims}

predictor:
_target_: cyto_dl.nn.vits.jepa.IWMPredictor
_target_: cyto_dl.nn.vits.predictor.IWMPredictor
domains: [SEC61B]
num_patches: ${model._aux.num_patches}
input_dim: ${model.encoder.emb_dim}
Expand All @@ -38,4 +39,4 @@ lr_scheduler:
pct_start: 0.1

_aux:
num_patches:
num_patches: 8
6 changes: 3 additions & 3 deletions configs/model/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ save_dir: ${paths.output_dir}
x_key: ${source_col}

backbone:
_target_: cyto_dl.nn.vits.MAE_ViT
_target_: cyto_dl.nn.vits.MAE
spatial_dims: ${spatial_dims}
# base_patch_size* num_patches should be your patch shape
base_patch_size: 2
# patch_size* num_patches should be your patch shape
patch_size: 2
num_patches: 8
emb_dim: 16
encoder_layer: 2
Expand Down
4 changes: 2 additions & 2 deletions configs/model/im2im/vit_segmentation_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ x_key: ${source_col}
backbone:
_target_: cyto_dl.nn.vits.Seg_ViT
spatial_dims: ${spatial_dims}
# base_patch_size* num_patches should be your patch shape
base_patch_size: 2
# patch_size* num_patches should be your patch shape
patch_size: 2
num_patches: 8
emb_dim: 16
encoder_layer: 2
Expand Down
11 changes: 9 additions & 2 deletions cyto_dl/image/transforms/generate_jepa_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from monai.transforms import RandomizableTransform
from skimage.segmentation import find_boundaries

from cyto_dl.nn.vits.utils import validate_spatial_dims


class JEPAMaskGenerator(RandomizableTransform):
"""Transform for generating Block-contiguous masks for JEPA training.
Expand All @@ -15,6 +17,7 @@ class JEPAMaskGenerator(RandomizableTransform):

def __init__(
self,
spatial_dims: int,
mask_size: int = 12,
block_aspect_ratio: Tuple[float] = (0.5, 1.5),
num_patches: Tuple[float] = (6, 24, 24),
Expand All @@ -23,6 +26,8 @@ def __init__(
"""
Parameters
----------
spatial_dims : int
The number of spatial dimensions of the image (2 or 3)
mask_size : int, optional
The size of the blocks used to generate mask. Block dimensions are determined by the mask size and an aspect ratio sampled from the range `block_aspect_ratio`
block_aspect_ratio : Tuple[float], optional
Expand All @@ -32,7 +37,9 @@ def __init__(
mask_ratio : float, optional
The proportion of the image to be masked
"""
assert mask_ratio < 1, "mask_ratio must be less than 1"
assert 0 < mask_ratio < 1, "mask_ratio must be between 0 and 1"

num_patches = validate_spatial_dims(spatial_dims, [num_patches])[0]
assert mask_size * max(block_aspect_ratio) < min(
num_patches[-2:]
), "mask_size * max mask aspect ratio must be less than the smallest dimension of num_patches"
Expand All @@ -46,7 +53,7 @@ def __init__(
self.mask = np.zeros(num_patches)
self.edge_mask = np.ones(num_patches)

self.spatial_dims = len(num_patches)
self.spatial_dims = spatial_dims
# create a mask that identified pixels on the edge of the image
if self.spatial_dims == 3:
self.edge_mask[1:-1, 1:-1, 1:-1] = 0
Expand Down
5 changes: 3 additions & 2 deletions cyto_dl/nn/vits/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .cross_mae import CrossMAE_Decoder
from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT
from .decoder import CrossMAE_Decoder, MAE_Decoder
from .encoder import HieraEncoder, MAE_Encoder
from .mae import MAE, HieraMAE
from .seg import Seg_ViT, SuperresDecoder
from .utils import take_indexes
Loading
Loading