Skip to content

Commit

Permalink
add 2d vits (#330)
Browse files Browse the repository at this point in the history
* add 2d vits

* update configs and fix 2d

---------

Co-authored-by: Benjamin Morris <[email protected]>
  • Loading branch information
benjijamorris and Benjamin Morris authored Feb 8, 2024
1 parent 3974dc2 commit c66e57d
Show file tree
Hide file tree
Showing 8 changed files with 138 additions and 78 deletions.
16 changes: 12 additions & 4 deletions configs/data/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ transforms:
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: CZYX
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
Expand All @@ -45,8 +47,10 @@ transforms:
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: CZYX
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
Expand All @@ -65,8 +69,10 @@ transforms:
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: CZYX
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
Expand All @@ -85,8 +91,10 @@ transforms:
keys: ${source_col}
reader:
- _target_: cyto_dl.image.io.MonaiBioReader
dimension_order_out: CZYX
# NOTE: eval is used so only the experiment file is required to change for beginning users. This is not recommended when creating your own configs.
dimension_order_out: ${eval:'"CZYX" if ${spatial_dims}==3 else "CYX"'}
C: 5
Z: ${eval:'None if ${spatial_dims}==3 else 38'}
- _target_: monai.transforms.Zoomd
keys: ${source_col}
zoom: 0.25
Expand Down
5 changes: 2 additions & 3 deletions configs/experiment/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ run_name: YOUR_RUN_NAME

# only source_col is needed for masked autoencoder
source_col: raw
# only 3d MAE is currently supported
spatial_dims: 3
raw_im_channels: 1

Expand All @@ -33,6 +32,6 @@ data:
batch_size: 1
_aux:
# 2D
# patch_shape: [64, 64]
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 32, 32]
patch_shape: [16, 16, 16]
5 changes: 3 additions & 2 deletions configs/experiment/im2im/vit_segmentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ experiment_name: YOUR_EXP_NAME
run_name: YOUR_RUN_NAME
source_col: raw
target_col: seg
# dimensionality of your data - VITs currently on support 3d
spatial_dims: 3
# number of channels in your input images
raw_im_channels: 1
Expand All @@ -34,5 +33,7 @@ data:
cache_dir: ${paths.data_dir}/example_experiment_data/cache
batch_size: 1
_aux:
# 2D
# patch_shape: [16, 16]
# 3D
patch_shape: [16, 32, 32]
patch_shape: [16, 16, 16]
5 changes: 3 additions & 2 deletions configs/model/im2im/mae.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ x_key: ${source_col}

backbone:
_target_: cyto_dl.nn.vits.MAE_ViT
spatial_dims: ${spatial_dims}
# base_patch_size* num_patches should be your patch shape
base_patch_size: [2, 2, 2]
num_patches: [8, 16, 16]
base_patch_size: 2
num_patches: 8
emb_dim: 16
encoder_layer: 2
encoder_head: 1
Expand Down
12 changes: 6 additions & 6 deletions configs/model/im2im/vit_segmentation_decoder.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ x_key: ${source_col}

backbone:
_target_: cyto_dl.nn.vits.Seg_ViT
spatial_dims: ${spatial_dims}
# base_patch_size* num_patches should be your patch shape
base_patch_size: [2, 2, 2]
num_patches: [8, 16, 16]
base_patch_size: 2
num_patches: 8
emb_dim: 16
encoder_layer: 2
encoder_head: 1
encoder_ckpt:
decoder_layer: 1
upsample_factor: [1, 1, 1]
mask_ratio: 0.75

task_heads: ${kv_to_dict:${model._aux._tasks}}

Expand Down Expand Up @@ -46,13 +46,13 @@ _aux:
- _target_: cyto_dl.nn.BaseHead
loss:
_target_: cyto_dl.models.im2im.utils.InstanceSegLoss
dim: 3
dim: ${spatial_dims}
save_raw: True
postprocess:
input:
_target_: cyto_dl.models.im2im.utils.postprocessing.ActThreshLabel
dtype: numpy.float32
prediction:
_target_: cyto_dl.models.im2im.utils.instance_seg.InstanceSegCluster
dim: 3
dim: ${spatial_dims}
min_size: 100
5 changes: 4 additions & 1 deletion cyto_dl/nn/head/mae_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def run_head(
else:
raise ValueError("MAE head is only intended for use during training.")
loss = (batch[self.head_name] - y_hat) ** 2
loss = loss[mask.bool()].mean()
if mask.sum() > 0:
loss = loss[mask.bool()].mean()
else:
loss = loss.mean()

y_hat_out, y_out, out_paths = None, None, None
if save_image:
Expand Down
93 changes: 57 additions & 36 deletions cyto_dl/nn/vits/mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,31 +50,33 @@ class Patchify(torch.nn.Module):
Convolutional weights are resized to match the `base_patch_size`.
"""

def __init__(self, base_patch_size, emb_dim, n_patches):
def __init__(self, base_patch_size, emb_dim, n_patches, spatial_dims=3):
super().__init__()
self.n_patches = np.asarray(n_patches)
self.weight = torch.nn.Parameter(torch.zeros(emb_dim, 1, *base_patch_size))
self.norm = torch.nn.LayerNorm([emb_dim, n_patches[0], n_patches[1], n_patches[2]])
self.norm = torch.nn.LayerNorm([emb_dim, *n_patches[:spatial_dims]])
self.emb_dim = emb_dim
self.spatial_dims = spatial_dims
self.conv = torch.nn.functional.conv3d if spatial_dims == 3 else torch.nn.functional.conv2d

def resample_weight(self, length):
return torch.nn.functional.interpolate(self.weight, size=length, mode="trilinear")
return torch.nn.functional.interpolate(self.weight, size=length)

def forward(self, img):
# all images in batch assumed to be same resolution
patch_size = (np.asarray(img.shape[-3:]) / self.n_patches).astype(int).tolist()
tokens = torch.nn.functional.conv3d(
img, weight=self.resample_weight(patch_size), stride=patch_size
patch_size = (
(np.asarray(img.shape[-self.spatial_dims :]) / self.n_patches).astype(int).tolist()
)
tokens = self.conv(img, weight=self.resample_weight(patch_size), stride=patch_size)
tokens = self.norm(tokens)
assert np.all(tokens.shape[-3:] == self.n_patches)
assert np.all(tokens.shape[-self.spatial_dims :] == self.n_patches)
return tokens, patch_size


class MAE_Encoder(torch.nn.Module):
def __init__(
self,
num_patches: List[int],
spatial_dims: int = 3,
base_patch_size: List[int] = (16, 16, 16),
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 12,
Expand All @@ -86,6 +88,8 @@ def __init__(
----------
num_patches: List[int]
Number of patches in each dimension
spatial_dims: int
Number of spatial dimensions
base_patch_size: List[int]
Size of each patch
emb_dim: int
Expand All @@ -98,24 +102,20 @@ def __init__(
Ratio of patches to mask out
"""
super().__init__()

self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches), 1, emb_dim))
self.shuffle = PatchShuffle(mask_ratio)

self.patchify = Patchify(base_patch_size, emb_dim, num_patches)
self.patchify = Patchify(base_patch_size, emb_dim, num_patches, spatial_dims)

self.transformer = torch.nn.Sequential(
*[Block(emb_dim, num_head) for _ in range(num_layer)]
)

self.layer_norm = torch.nn.LayerNorm(emb_dim)
self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b c -> b c n_patch_z n_patch_y n_patch_x",
n_patch_z=num_patches[0],
n_patch_y=num_patches[1],
n_patch_x=num_patches[2],
)
if spatial_dims == 3:
self.img2token = Rearrange("b c z y x -> (z y x) b c")
elif spatial_dims == 2:
self.img2token = Rearrange("b c y x -> (y x) b c")

self.init_weight()

Expand All @@ -125,7 +125,7 @@ def init_weight(self):

def forward(self, img, do_mask=True):
patches, patch_size = self.patchify(img)
patches = rearrange(patches, "b c z y x -> (z y x) b c")
patches = self.img2token(patches)
patches = patches + self.pos_embedding

backward_indexes = None
Expand All @@ -138,14 +138,14 @@ def forward(self, img, do_mask=True):
features = rearrange(features, "b t c -> t b c")
if do_mask:
return features, backward_indexes, patch_size

return features


class MAE_Decoder(torch.nn.Module):
def __init__(
self,
num_patches: List[int],
spatial_dims: int = 3,
base_patch_size: Optional[List[int]] = [4, 8, 8],
emb_dim: Optional[int] = 192,
num_layer: Optional[int] = 4,
Expand All @@ -166,7 +166,6 @@ def __init__(
Number of heads in transformer
"""
super().__init__()

self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, emb_dim))
self.pos_embedding = torch.nn.Parameter(torch.zeros(np.prod(num_patches) + 1, 1, emb_dim))

Expand All @@ -177,15 +176,24 @@ def __init__(
self.head = torch.nn.Linear(emb_dim, torch.prod(torch.as_tensor(base_patch_size)))
self.num_patches = torch.as_tensor(num_patches)

self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
n_patch_z=num_patches[0],
n_patch_y=num_patches[1],
n_patch_x=num_patches[2],
patch_size_z=base_patch_size[0],
patch_size_y=base_patch_size[1],
patch_size_x=base_patch_size[2],
)
if spatial_dims == 3:
self.patch2img = Rearrange(
"(n_patch_z n_patch_y n_patch_x) b (c patch_size_z patch_size_y patch_size_x) -> b c (n_patch_z patch_size_z) (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
n_patch_z=num_patches[0],
n_patch_y=num_patches[1],
n_patch_x=num_patches[2],
patch_size_z=base_patch_size[0],
patch_size_y=base_patch_size[1],
patch_size_x=base_patch_size[2],
)
elif spatial_dims == 2:
self.patch2img = Rearrange(
"(n_patch_y n_patch_x) b (c patch_size_y patch_size_x) -> b c (n_patch_y patch_size_y) (n_patch_x patch_size_x)",
n_patch_y=num_patches[0],
n_patch_x=num_patches[1],
patch_size_y=base_patch_size[0],
patch_size_x=base_patch_size[1],
)

self.init_weight()

Expand Down Expand Up @@ -227,20 +235,20 @@ def forward(self, features, backward_indexes, patch_size):
# patches to image
img = self.patch2img(patches)
img = torch.nn.functional.interpolate(
img, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="trilinear"
img, tuple(torch.as_tensor(patch_size) * self.num_patches)
)

mask = self.patch2img(mask)
mask = torch.nn.functional.interpolate(
mask, tuple(torch.as_tensor(patch_size) * self.num_patches), mode="nearest"
)

return img, mask


class MAE_ViT(torch.nn.Module):
def __init__(
self,
spatial_dims: int = 3,
num_patches: Optional[List[int]] = [2, 32, 32],
base_patch_size: Optional[List[int]] = [16, 16, 16],
emb_dim: Optional[int] = 768,
Expand All @@ -253,6 +261,8 @@ def __init__(
"""
Parameters
----------
spatial_dims: int
Number of spatial dimensions
num_patches: List[int]
Number of patches in each dimension (ZYX order)
base_patch_size: List[int]
Expand All @@ -270,19 +280,30 @@ def __init__(
mask_ratio: float
Ratio of patches to mask out
"""

super().__init__()
assert spatial_dims in (2, 3), "Spatial dims must be 2 or 3"

if isinstance(num_patches, int):
num_patches = [num_patches] * 3
num_patches = [num_patches] * spatial_dims
if isinstance(base_patch_size, int):
base_patch_size = [base_patch_size] * 3
base_patch_size = [base_patch_size] * spatial_dims

assert len(num_patches) == spatial_dims, "num_patches must be of length spatial_dims"
assert (
len(base_patch_size) == spatial_dims
), "base_patch_size must be of length spatial_dims"

self.encoder = MAE_Encoder(
num_patches, base_patch_size, emb_dim, encoder_layer, encoder_head, mask_ratio
num_patches,
spatial_dims,
base_patch_size,
emb_dim,
encoder_layer,
encoder_head,
mask_ratio,
)
self.decoder = MAE_Decoder(
num_patches, base_patch_size, emb_dim, decoder_layer, decoder_head
num_patches, spatial_dims, base_patch_size, emb_dim, decoder_layer, decoder_head
)

def forward(self, img):
Expand Down
Loading

0 comments on commit c66e57d

Please sign in to comment.