From c66e57d6b95ae17b95bdaf9b5c3574369fecfa03 Mon Sep 17 00:00:00 2001 From: benjijamorris <54606172+benjijamorris@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:26:34 -0800 Subject: [PATCH] add 2d vits (#330) * add 2d vits * update configs and fix 2d --------- Co-authored-by: Benjamin Morris --- configs/data/im2im/mae.yaml | 16 +++- configs/experiment/im2im/mae.yaml | 5 +- .../experiment/im2im/vit_segmentation.yaml | 5 +- configs/model/im2im/mae.yaml | 5 +- .../model/im2im/vit_segmentation_decoder.yaml | 12 +-- cyto_dl/nn/head/mae_head.py | 5 +- cyto_dl/nn/vits/mae.py | 93 ++++++++++++------- cyto_dl/nn/vits/seg.py | 75 ++++++++++----- 8 files changed, 138 insertions(+), 78 deletions(-) diff --git a/configs/data/im2im/mae.yaml b/configs/data/im2im/mae.yaml index dfaf86b58..a310e15ba 100644 --- a/configs/data/im2im/mae.yaml +++ b/configs/data/im2im/mae.yaml @@ -19,8 +19,10 @@ transforms: keys: ${source_col} reader: - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: CZYX + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} - _target_: monai.transforms.Zoomd keys: ${source_col} zoom: 0.25 @@ -45,8 +47,10 @@ transforms: keys: ${source_col} reader: - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: CZYX + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} - _target_: monai.transforms.Zoomd keys: ${source_col} zoom: 0.25 @@ -65,8 +69,10 @@ transforms: keys: ${source_col} reader: - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: CZYX + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} - _target_: monai.transforms.Zoomd keys: ${source_col} zoom: 0.25 @@ -85,8 +91,10 @@ transforms: keys: ${source_col} reader: - _target_: cyto_dl.image.io.MonaiBioReader - dimension_order_out: CZYX + # NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs. + dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'} C: 5 + Z: ${eval:'None if ${spatial_dims}==3 else 38'} - _target_: monai.transforms.Zoomd keys: ${source_col} zoom: 0.25 diff --git a/configs/experiment/im2im/mae.yaml b/configs/experiment/im2im/mae.yaml index 5929a73b1..459335748 100644 --- a/configs/experiment/im2im/mae.yaml +++ b/configs/experiment/im2im/mae.yaml @@ -19,7 +19,6 @@ run_name: YOUR_RUN_NAME # only source_col is needed for masked autoencoder source_col: raw -# only 3d MAE is currently supported spatial_dims: 3 raw_im_channels: 1 @@ -33,6 +32,6 @@ data: batch_size: 1 _aux: # 2D - # patch_shape: [64, 64] + # patch_shape: [16, 16] # 3D - patch_shape: [16, 32, 32] + patch_shape: [16, 16, 16] diff --git a/configs/experiment/im2im/vit_segmentation.yaml b/configs/experiment/im2im/vit_segmentation.yaml index 2e66642cd..8c9c834b6 100644 --- a/configs/experiment/im2im/vit_segmentation.yaml +++ b/configs/experiment/im2im/vit_segmentation.yaml @@ -21,7 +21,6 @@ experiment_name: YOUR_EXP_NAME run_name: YOUR_RUN_NAME source_col: raw target_col: seg -# dimensionality of your data - VITs currently on support 3d spatial_dims: 3 # number of channels in your input images raw_im_channels: 1 @@ -34,5 +33,7 @@ data: cache_dir: ${paths.data_dir}/example_experiment_data/cache batch_size: 1 _aux: + # 2D + # patch_shape: [16, 16] # 3D - patch_shape: [16, 32, 32] + patch_shape: [16, 16, 16] diff --git a/configs/model/im2im/mae.yaml b/configs/model/im2im/mae.yaml index ae9b92cf3..f03ce6742 100644 --- a/configs/model/im2im/mae.yaml +++ b/configs/model/im2im/mae.yaml @@ -7,9 +7,10 @@ x_key: ${source_col} backbone: _target_: cyto_dl.nn.vits.MAE_ViT + spatial_dims: ${spatial_dims} # base_patch_size* num_patches should be your patch shape - base_patch_size: [2, 2, 2] - num_patches: [8, 16, 16] + base_patch_size: 2 + num_patches: 8 emb_dim: 16 encoder_layer: 2 encoder_head: 1 diff --git a/configs/model/im2im/vit_segmentation_decoder.yaml b/configs/model/im2im/vit_segmentation_decoder.yaml index 69d504fae..88b7a39a3 100644 --- a/configs/model/im2im/vit_segmentation_decoder.yaml +++ b/configs/model/im2im/vit_segmentation_decoder.yaml @@ -7,15 +7,15 @@ x_key: ${source_col} backbone: _target_: cyto_dl.nn.vits.Seg_ViT + spatial_dims: ${spatial_dims} # base_patch_size* num_patches should be your patch shape - base_patch_size: [2, 2, 2] - num_patches: [8, 16, 16] + base_patch_size: 2 + num_patches: 8 emb_dim: 16 encoder_layer: 2 encoder_head: 1 - encoder_ckpt: decoder_layer: 1 - upsample_factor: [1, 1, 1] + mask_ratio: 0.75 task_heads: ${kv_to_dict:${model._aux._tasks}} @@ -46,7 +46,7 @@ _aux: - _target_: cyto_dl.nn.BaseHead loss: _target_: cyto_dl.models.im2im.utils.InstanceSegLoss - dim: 3 + dim: ${spatial_dims} save_raw: True postprocess: input: @@ -54,5 +54,5 @@ _aux: dtype: numpy.float32 prediction: _target_: cyto_dl.models.im2im.utils.instance_seg.InstanceSegCluster - dim: 3 + dim: ${spatial_dims} min_size: 100 diff --git a/cyto_dl/nn/head/mae_head.py b/cyto_dl/nn/head/mae_head.py index 04a18d105..e8dbcddf2 100644 --- a/cyto_dl/nn/head/mae_head.py +++ b/cyto_dl/nn/head/mae_head.py @@ -19,7 +19,10 @@ def run_head( else: raise ValueError("MAE head is only intended for use during training.") loss = (batch[self.head_name] - y_hat) ** 2 - loss = loss[mask.bool()].mean() + if mask.sum() > 0: + loss = loss[mask.bool()].mean() + else: + loss = loss.mean() y_hat_out, y_out, out_paths = None, None, None if save_image: diff --git a/cyto_dl/nn/vits/mae.py b/cyto_dl/nn/vits/mae.py index b9bd0a0b1..d565a6da3 100644 --- a/cyto_dl/nn/vits/mae.py +++ b/cyto_dl/nn/vits/mae.py @@ -50,24 +50,25 @@ class Patchify(torch.nn.Module): Convolutional weights are resized to match the `base_patch_size`. """ - def __init__(self, base_patch_size, emb_dim, n_patches): + def __init__(self, base_patch_size, emb_dim, n_patches, spatial_dims=3): super().__init__() self.n_patches = np.asarray(n_patches) self.weight = torch.nn.Parameter(torch.zeros(emb_dim, 1, *base_patch_size)) - self.norm = torch.nn.LayerNorm([emb_dim, n_patches[0], n_patches[1], n_patches[2]]) + self.norm = torch.nn.LayerNorm([emb_dim, *n_patches[:spatial_dims]]) self.emb_dim = emb_dim + self.spatial_dims = spatial_dims + self.conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d def resample_weight(self, length): - return torch.nn.functional.interpolate(self.weight, size=length, mode="trilinear") + return torch.nn.functional.interpolate(self.weight, size=length) def forward(self, img): - # all images in batch assumed to be same resolution - patch_size = (np.asarray(img.shape[-3:]) / self.n_patches).astype(int).tolist() - tokens = torch.nn.functional.conv3d( - img, weight=self.resample_weight(patch_size), stride=patch_size + patch_size = ( + (np.asarray(img.shape[-self.spatial_dims :]) / self.n_patches).astype(int).tolist() ) + tokens = self.conv(img, weight=self.resample_weight(patch_size), stride=patch_size) tokens = self.norm(tokens) - assert np.all(tokens.shape[-3:] == self.n_patches) + assert np.all(tokens.shape[-self.spatial_dims :] == self.n_patches) return tokens, patch_size @@ -75,6 +76,7 @@ class MAE_Encoder(torch.nn.Module): def __init__( self, num_patches: List[int], + spatial_dims: int = 3, base_patch_size: List[int] = (16, 16, 16), emb_dim: Optional[int] = 192, num_layer: Optional[int] = 12, @@ -86,6 +88,8 @@ def __init__( ---------- num_patches: List[int] Number of patches in each dimension + spatial_dims: int + Number of spatial dimensions base_patch_size: List[int] Size of each patch emb_dim: int @@ -98,24 +102,20 @@ def __init__( Ratio of patches to mask out """ super().__init__() - self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches), 1, emb_dim)) self.shuffle = PatchShuffle(mask_ratio) - - self.patchify = Patchify(base_patch_size, emb_dim, num_patches) + self.patchify = Patchify(base_patch_size, emb_dim, num_patches, spatial_dims) self.transformer = torch.nn.Sequential( *[Block(emb_dim, num_head) for _ in range(num_layer)] ) self.layer_norm = torch.nn.LayerNorm(emb_dim) - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x", - n_patch_z=num_patches[0], - n_patch_y=num_patches[1], - n_patch_x=num_patches[2], - ) + if spatial_dims == 3: + self.img2token = Rearrange("b c z y x -> (z y x) b c") + elif spatial_dims == 2: + self.img2token = Rearrange("b c y x -> (y x) b c") self.init_weight() @@ -125,7 +125,7 @@ def init_weight(self): def forward(self, img, do_mask=True): patches, patch_size = self.patchify(img) - patches = rearrange(patches, "b c z y x -> (z y x) b c") + patches = self.img2token(patches) patches = patches + self.pos_embedding backward_indexes = None @@ -138,7 +138,6 @@ def forward(self, img, do_mask=True): features = rearrange(features, "b t c -> t b c") if do_mask: return features, backward_indexes, patch_size - return features @@ -146,6 +145,7 @@ class MAE_Decoder(torch.nn.Module): def __init__( self, num_patches: List[int], + spatial_dims: int = 3, base_patch_size: Optional[List[int]] = [4, 8, 8], emb_dim: Optional[int] = 192, num_layer: Optional[int] = 4, @@ -166,7 +166,6 @@ def __init__( Number of heads in transformer """ super().__init__() - self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim)) self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim)) @@ -177,15 +176,24 @@ def __init__( self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) self.num_patches = torch.as_tensor(num_patches) - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_z=num_patches[0], - n_patch_y=num_patches[1], - n_patch_x=num_patches[2], - patch_size_z=base_patch_size[0], - patch_size_y=base_patch_size[1], - patch_size_x=base_patch_size[2], - ) + if spatial_dims == 3: + self.patch2img = Rearrange( + "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_z=num_patches[0], + n_patch_y=num_patches[1], + n_patch_x=num_patches[2], + patch_size_z=base_patch_size[0], + patch_size_y=base_patch_size[1], + patch_size_x=base_patch_size[2], + ) + elif spatial_dims == 2: + self.patch2img = Rearrange( + "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_y=num_patches[0], + n_patch_x=num_patches[1], + patch_size_y=base_patch_size[0], + patch_size_x=base_patch_size[1], + ) self.init_weight() @@ -227,20 +235,20 @@ def forward(self, features, backward_indexes, patch_size): # patches to image img = self.patch2img(patches) img = torch.nn.functional.interpolate( - img, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="trilinear" + img, tuple(torch.as_tensor(patch_size) * self.num_patches) ) mask = self.patch2img(mask) mask = torch.nn.functional.interpolate( mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest" ) - return img, mask class MAE_ViT(torch.nn.Module): def __init__( self, + spatial_dims: int = 3, num_patches: Optional[List[int]] = [2, 32, 32], base_patch_size: Optional[List[int]] = [16, 16, 16], emb_dim: Optional[int] = 768, @@ -253,6 +261,8 @@ def __init__( """ Parameters ---------- + spatial_dims: int + Number of spatial dimensions num_patches: List[int] Number of patches in each dimension (ZYX order) base_patch_size: List[int] @@ -270,19 +280,30 @@ def __init__( mask_ratio: float Ratio of patches to mask out """ - super().__init__() + assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3" if isinstance(num_patches, int): - num_patches = [num_patches] * 3 + num_patches = [num_patches] * spatial_dims if isinstance(base_patch_size, int): - base_patch_size = [base_patch_size] * 3 + base_patch_size = [base_patch_size] * spatial_dims + + assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims" + assert ( + len(base_patch_size) == spatial_dims + ), "base_patch_size must be of length spatial_dims" self.encoder = MAE_Encoder( - num_patches, base_patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio + num_patches, + spatial_dims, + base_patch_size, + emb_dim, + encoder_layer, + encoder_head, + mask_ratio, ) self.decoder = MAE_Decoder( - num_patches, base_patch_size, emb_dim, decoder_layer, decoder_head + num_patches, spatial_dims, base_patch_size, emb_dim, decoder_layer, decoder_head ) def forward(self, img): diff --git a/cyto_dl/nn/vits/seg.py b/cyto_dl/nn/vits/seg.py index 99155071c..c3942d2d5 100644 --- a/cyto_dl/nn/vits/seg.py +++ b/cyto_dl/nn/vits/seg.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Union import torch from einops.layers.torch import Rearrange @@ -10,17 +10,20 @@ class SupperresDecoder(torch.nn.Module): def __init__( self, + spatial_dims: int = 3, num_patches: Optional[List[int]] = [2, 32, 32], base_patch_size: Optional[List[int]] = [4, 8, 8], emb_dim: Optional[int] = 192, num_layer: Optional[int] = 3, n_decoder_filters: Optional[int] = 16, out_channels: Optional[int] = 6, - upsample_factor: Optional[List[int]] = [2.6134, 2.5005, 2.5005], + upsample_factor: Optional[Union[int, List[int]]] = [2.6134, 2.5005, 2.5005], ) -> None: """ Parameters ---------- + spatial_dims: Optional[int]=3 + Number of spatial dimensions num_patches: Optional[List[int]]=[2, 32, 32] Number of patches in each dimension (ZYX) order base_patch_size: Optional[List[int]]=[16, 16, 16] @@ -34,7 +37,7 @@ def __init__( out_channels: Optional[int] =6 Number of output channels in convolutional decoder. Should be 6 for instance segmentation. upsample_factor:Optional[List[int]] = [2.6134, 2.5005, 2.5005] - Upsampling factor for each dimension (ZYX) order. Default is AICS 20x to 100x object upsampling + Upsampling factor for each dimension (ZYX) order. Default is AICS 20x to 100x objective upsampling """ super().__init__() @@ -46,7 +49,7 @@ def __init__( num_channels = 16 self.lr_conv.append( UnetResBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=num_channels, out_channels=n_decoder_filters, stride=1, @@ -60,7 +63,7 @@ def __init__( self.upsampler = torch.nn.Sequential( UpSample( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=n_decoder_filters, out_channels=n_decoder_filters, scale_factor=upsample_factor, @@ -68,7 +71,7 @@ def __init__( interp_mode="trilinear", ), UnetResBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=n_decoder_filters, out_channels=n_decoder_filters, stride=1, @@ -77,7 +80,7 @@ def __init__( dropout=0, ), UnetOutBlock( - spatial_dims=3, + spatial_dims=spatial_dims, in_channels=n_decoder_filters, out_channels=out_channels, dropout=0, @@ -87,15 +90,24 @@ def __init__( self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size))) self.num_patches = torch.as_tensor(num_patches) - self.patch2img = Rearrange( - "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", - n_patch_z=num_patches[0], - n_patch_y=num_patches[1], - n_patch_x=num_patches[2], - patch_size_z=base_patch_size[0], - patch_size_y=base_patch_size[1], - patch_size_x=base_patch_size[2], - ) + if spatial_dims == 3: + self.patch2img = Rearrange( + "(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_z=num_patches[0], + n_patch_y=num_patches[1], + n_patch_x=num_patches[2], + patch_size_z=base_patch_size[0], + patch_size_y=base_patch_size[1], + patch_size_x=base_patch_size[2], + ) + elif spatial_dims == 2: + self.patch2img = Rearrange( + "(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)", + n_patch_y=num_patches[0], + n_patch_x=num_patches[1], + patch_size_y=base_patch_size[0], + patch_size_x=base_patch_size[1], + ) def forward(self, features): # remove global feature @@ -117,6 +129,7 @@ class Seg_ViT(torch.nn.Module): def __init__( self, + spatial_dims: int = 3, num_patches: Optional[List[int]] = [2, 32, 32], base_patch_size: Optional[List[int]] = [16, 16, 16], emb_dim: Optional[int] = 768, @@ -128,10 +141,13 @@ def __init__( mask_ratio: Optional[int] = 0.75, upsample_factor: Optional[List[int]] = [2.6134, 2.5005, 2.5005], encoder_ckpt: Optional[str] = None, + freeze_encoder: Optional[bool] = True, ) -> None: """ Parameters ---------- + spatial_dims: Optional[int]=3 + Number of spatial dimensions num_patches: Optional[List[int]]=[2, 32, 32] Number of patches in each dimension (ZYX) order base_patch_size: Optional[List[int]]=[16, 16, 16] @@ -156,14 +172,25 @@ def __init__( Path to pretrained ViT backbone checkpoint """ super().__init__() - + assert spatial_dims in (2, 3) if isinstance(num_patches, int): - num_patches = [num_patches] * 3 + num_patches = [num_patches] * spatial_dims if isinstance(base_patch_size, int): - base_patch_size = [base_patch_size] * 3 + base_patch_size = [base_patch_size] * spatial_dims + if isinstance(upsample_factor, int): + upsample_factor = [upsample_factor] * spatial_dims + assert len(num_patches) == spatial_dims + assert len(base_patch_size) == spatial_dims + assert len(upsample_factor) == spatial_dims self.encoder = MAE_Encoder( - num_patches, base_patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio + spatial_dims=spatial_dims, + num_patches=num_patches, + base_patch_size=base_patch_size, + emb_dim=emb_dim, + num_layer=encoder_layer, + num_head=encoder_head, + mask_ratio=mask_ratio, ) if encoder_ckpt is not None: model = torch.load(encoder_ckpt) @@ -174,12 +201,12 @@ def __init__( } self.encoder.load_state_dict(enc_state_dict) - - # freeze encoder - for param in self.encoder.parameters(): - param.requires_grad = False + if freeze_encoder: + for param in self.encoder.parameters(): + param.requires_grad = False self.decoder = SupperresDecoder( + spatial_dims, num_patches, base_patch_size, emb_dim,