-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Bump version: 0.1.5 → 0.1.6 * add hiera * start of mask2former * first take at transfomer * fix dimensionality, now updating instance queries instead of mask * give instance queries own dim * add mask creation * remove experimental code * update to base patchify * wip * update configs * update patchify * rearrange encoder/decoder/mae * add 2d hiera * 2d masked unit attention * precommit * update configs * update hiera model config * update deafults * update tests * delete patchify_conv * fix jepa tests * update mask transform * add predictor * precommit * update with ritviks comments * replace all function names --------- Co-authored-by: Benjamin Morris <[email protected]> Co-authored-by: Benjamin Morris <[email protected]>
- Loading branch information
1 parent
5f45508
commit 48e7cb9
Showing
26 changed files
with
1,508 additions
and
555 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# @package _global_ | ||
# to execute this experiment run: | ||
# python train.py experiment=example | ||
defaults: | ||
- override /data: im2im/mae.yaml | ||
- override /model: im2im/hiera.yaml | ||
- override /callbacks: default.yaml | ||
- override /trainer: gpu.yaml | ||
- override /logger: csv.yaml | ||
|
||
# all parameters below will be merged with parameters from default configurations set above | ||
# this allows you to overwrite only specified parameters | ||
|
||
tags: ["dev"] | ||
seed: 12345 | ||
|
||
experiment_name: YOUR_EXP_NAME | ||
run_name: YOUR_RUN_NAME | ||
|
||
# only source_col is needed for masked autoencoder | ||
source_col: raw | ||
spatial_dims: 3 | ||
raw_im_channels: 1 | ||
|
||
trainer: | ||
max_epochs: 100 | ||
gradient_clip_val: 10 | ||
|
||
data: | ||
path: ${paths.data_dir}/example_experiment_data/segmentation | ||
cache_dir: ${paths.data_dir}/example_experiment_data/cache | ||
batch_size: 1 | ||
_aux: | ||
# 2D | ||
# patch_shape: [16, 16] | ||
# 3D | ||
patch_shape: [16, 16, 16] | ||
|
||
callbacks: | ||
# prediction | ||
# saving: | ||
# _target_: cyto_dl.callbacks.ImageSaver | ||
# save_dir: ${paths.output_dir} | ||
# save_every_n_epochs: ${model.save_images_every_n_epochs} | ||
# stages: ["predict"] | ||
# save_input: False | ||
# training | ||
saving: | ||
_target_: cyto_dl.callbacks.ImageSaver | ||
save_dir: ${paths.output_dir} | ||
save_every_n_epochs: ${model.save_images_every_n_epochs} | ||
stages: ["train", "test", "val"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
_target_: cyto_dl.models.im2im.MultiTaskIm2Im | ||
|
||
save_images_every_n_epochs: 1 | ||
save_dir: ${paths.output_dir} | ||
|
||
x_key: ${source_col} | ||
|
||
backbone: | ||
_target_: cyto_dl.nn.vits.mae.HieraMAE | ||
spatial_dims: ${spatial_dims} | ||
patch_size: 2 # patch_size * num_patches should be your image shape (data._aux.patch_shape) | ||
num_patches: 8 # patch_size * num_patches = img_shape | ||
num_mask_units: 4 #Mask units are used for local attention. img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit | ||
emb_dim: 4 | ||
# NOTE: this is a very small model for testing - for best performance, the downsampling ratios, embedding dimension, number of layers and number of heads should be adjusted to your data | ||
architecture: | ||
# mask_unit_attention blocks - attention is only done within a mask unit and not across mask units | ||
# the total amount of q_stride across the architecture must be less than the number of patches per mask unit | ||
- repeat: 1 # number of times to repeat this block | ||
q_stride: 2 # size of downsampling within a mask unit | ||
num_heads: 1 | ||
- repeat: 1 | ||
q_stride: 1 | ||
num_heads: 2 | ||
# self attention transformer - attention is done across all patches, irrespective of which mask unit they're in | ||
- repeat: 2 | ||
num_heads: 4 | ||
self_attention: True | ||
decoder_layer: 1 | ||
decoder_dim: 16 | ||
mask_ratio: 0.66666666666 | ||
context_pixels: 3 | ||
use_crossmae: True | ||
|
||
task_heads: ${kv_to_dict:${model._aux._tasks}} | ||
|
||
optimizer: | ||
generator: | ||
_partial_: True | ||
_target_: torch.optim.AdamW | ||
weight_decay: 0.05 | ||
|
||
lr_scheduler: | ||
generator: | ||
_partial_: True | ||
_target_: torch.optim.lr_scheduler.OneCycleLR | ||
max_lr: 0.0001 | ||
epochs: ${trainer.max_epochs} | ||
steps_per_epoch: 1 | ||
pct_start: 0.1 | ||
|
||
inference_args: | ||
sw_batch_size: 1 | ||
roi_size: ${data._aux.patch_shape} | ||
overlap: 0 | ||
progress: True | ||
mode: "gaussian" | ||
|
||
_aux: | ||
_tasks: | ||
- - ${source_col} | ||
- _target_: cyto_dl.nn.head.mae_head.MAEHead | ||
loss: | ||
postprocess: | ||
input: | ||
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel | ||
rescale_dtype: numpy.uint8 | ||
prediction: | ||
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel | ||
rescale_dtype: numpy.uint8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .cross_mae import CrossMAE_Decoder | ||
from .mae import MAE_Decoder, MAE_Encoder, MAE_ViT | ||
from .decoder import CrossMAE_Decoder, MAE_Decoder | ||
from .encoder import HieraEncoder, MAE_Encoder | ||
from .mae import MAE, HieraMAE | ||
from .seg import Seg_ViT, SuperresDecoder | ||
from .utils import take_indexes |
Oops, something went wrong.