Skip to content

Commit

Permalink
6973 sincos pos embed (Project-MONAI#6986)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#6973 

### Description

Adding support for sincos positional embedding for
monai.networks.blocks.patchembedding.PatchEmbedding class.

This pull request corresponds to this opened issue
Project-MONAI#6973

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: NoTody <[email protected]>
  • Loading branch information
NoTody authored Sep 16, 2023
1 parent 56ca224 commit 281cb01
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 53 deletions.
2 changes: 1 addition & 1 deletion monai/networks/blocks/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
Args:
hidden_size: dimension of hidden layer.
mlp_dim: dimension of feedforward layer. If 0, `hidden_size` will be used.
dropout_rate: faction of the input units to drop.
dropout_rate: fraction of the input units to drop.
act: activation type and arguments. Defaults to GELU. Also supports "GEGLU" and others.
dropout_mode: dropout mode, can be "vit" or "swin".
"vit" mode uses two dropout instances as implemented in
Expand Down
51 changes: 37 additions & 14 deletions monai/networks/blocks/patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import torch.nn.functional as F
from torch.nn import LayerNorm

from monai.networks.blocks.pos_embed_utils import build_sincos_position_embedding
from monai.networks.layers import Conv, trunc_normal_
from monai.utils import ensure_tuple_rep, optional_import
from monai.utils import deprecated_arg, ensure_tuple_rep, optional_import
from monai.utils.module import look_up_option

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
SUPPORTED_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_PATCH_EMBEDDING_TYPES = {"conv", "perceptron"}
SUPPORTED_POS_EMBEDDING_TYPES = {"none", "learnable", "sincos"}


class PatchEmbeddingBlock(nn.Module):
Expand All @@ -35,18 +37,22 @@ class PatchEmbeddingBlock(nn.Module):
Example::
>>> from monai.networks.blocks import PatchEmbeddingBlock
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4, pos_embed="conv")
>>> PatchEmbeddingBlock(in_channels=4, img_size=32, patch_size=8, hidden_size=32, num_heads=4,
>>> proj_type="conv", pos_embed_type="sincos")
"""

@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
def __init__(
self,
in_channels: int,
img_size: Sequence[int] | int,
patch_size: Sequence[int] | int,
hidden_size: int,
num_heads: int,
pos_embed: str,
pos_embed: str = "conv",
proj_type: str = "conv",
pos_embed_type: str = "learnable",
dropout_rate: float = 0.0,
spatial_dims: int = 3,
) -> None:
Expand All @@ -57,11 +63,12 @@ def __init__(
patch_size: dimension of patch size.
hidden_size: dimension of hidden layer.
num_heads: number of attention heads.
pos_embed: position embedding layer type.
dropout_rate: faction of the input units to drop.
proj_type: patch embedding layer type.
pos_embed_type: position embedding layer type.
dropout_rate: fraction of the input units to drop.
spatial_dims: number of spatial dimensions.
.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.
"""

super().__init__()
Expand All @@ -72,24 +79,25 @@ def __init__(
if hidden_size % num_heads != 0:
raise ValueError(f"hidden size {hidden_size} should be divisible by num_heads {num_heads}.")

self.pos_embed = look_up_option(pos_embed, SUPPORTED_EMBEDDING_TYPES)
self.proj_type = look_up_option(proj_type, SUPPORTED_PATCH_EMBEDDING_TYPES)
self.pos_embed_type = look_up_option(pos_embed_type, SUPPORTED_POS_EMBEDDING_TYPES)

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_size = ensure_tuple_rep(patch_size, spatial_dims)
for m, p in zip(img_size, patch_size):
if m < p:
raise ValueError("patch_size should be smaller than img_size.")
if self.pos_embed == "perceptron" and m % p != 0:
if self.proj_type == "perceptron" and m % p != 0:
raise ValueError("patch_size should be divisible by img_size for perceptron.")
self.n_patches = np.prod([im_d // p_d for im_d, p_d in zip(img_size, patch_size)])
self.patch_dim = int(in_channels * np.prod(patch_size))

self.patch_embeddings: nn.Module
if self.pos_embed == "conv":
if self.proj_type == "conv":
self.patch_embeddings = Conv[Conv.CONV, spatial_dims](
in_channels=in_channels, out_channels=hidden_size, kernel_size=patch_size, stride=patch_size
)
elif self.pos_embed == "perceptron":
elif self.proj_type == "perceptron":
# for 3d: "b c (h p1) (w p2) (d p3)-> b (h w d) (p1 p2 p3 c)"
chars = (("h", "p1"), ("w", "p2"), ("d", "p3"))[:spatial_dims]
from_chars = "b c " + " ".join(f"({k} {v})" for k, v in chars)
Expand All @@ -100,7 +108,22 @@ def __init__(
)
self.position_embeddings = nn.Parameter(torch.zeros(1, self.n_patches, hidden_size))
self.dropout = nn.Dropout(dropout_rate)
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)

if self.pos_embed_type == "none":
pass
elif self.pos_embed_type == "learnable":
trunc_normal_(self.position_embeddings, mean=0.0, std=0.02, a=-2.0, b=2.0)
elif self.pos_embed_type == "sincos":
grid_size = []
for in_size, pa_size in zip(img_size, patch_size):
grid_size.append(in_size // pa_size)

with torch.no_grad():
pos_embeddings = build_sincos_position_embedding(grid_size, hidden_size, spatial_dims)
self.position_embeddings.data.copy_(pos_embeddings.float())
else:
raise ValueError(f"pos_embed_type {self.pos_embed_type} not supported.")

self.apply(self._init_weights)

def _init_weights(self, m):
Expand All @@ -114,7 +137,7 @@ def _init_weights(self, m):

def forward(self, x):
x = self.patch_embeddings(x)
if self.pos_embed == "conv":
if self.proj_type == "conv":
x = x.flatten(2).transpose(-1, -2)
embeddings = x + self.position_embeddings
embeddings = self.dropout(embeddings)
Expand Down
103 changes: 103 additions & 0 deletions monai/networks/blocks/pos_embed_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import collections.abc
from itertools import repeat
from typing import List, Union

import torch
import torch.nn as nn

__all__ = ["build_sincos_position_embedding"]


# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x)
return tuple(repeat(x, n))

return parse


def build_sincos_position_embedding(
grid_size: Union[int, List[int]], embed_dim: int, spatial_dims: int = 3, temperature: float = 10000.0
) -> torch.nn.Parameter:
"""
Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature.
Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py
Args:
grid_size (List[int]): The size of the grid in each spatial dimension.
embed_dim (int): The dimension of the embedding.
spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
temperature (float): The temperature for the sin-cos position embedding.
Returns:
pos_embed (nn.Parameter): The sin-cos position embedding as a learnable parameter.
"""

if spatial_dims == 2:
to_2tuple = _ntuple(2)
grid_size_t = to_2tuple(grid_size)
h, w = grid_size_t
grid_h = torch.arange(h, dtype=torch.float32)
grid_w = torch.arange(w, dtype=torch.float32)

grid_h, grid_w = torch.meshgrid(grid_h, grid_w, indexing="ij")

assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"

pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :]
elif spatial_dims == 3:
to_3tuple = _ntuple(3)
grid_size_t = to_3tuple(grid_size)
h, w, d = grid_size_t
grid_h = torch.arange(h, dtype=torch.float32)
grid_w = torch.arange(w, dtype=torch.float32)
grid_d = torch.arange(d, dtype=torch.float32)

grid_h, grid_w, grid_d = torch.meshgrid(grid_h, grid_w, grid_d, indexing="ij")

assert embed_dim % 6 == 0, "Embed dimension must be divisible by 6 for 3D sin-cos position embedding"

pos_dim = embed_dim // 6
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
out_d = torch.einsum("m,d->md", [grid_d.flatten(), omega])
pos_emb = torch.cat(
[
torch.sin(out_w),
torch.cos(out_w),
torch.sin(out_h),
torch.cos(out_h),
torch.sin(out_d),
torch.cos(out_d),
],
dim=1,
)[None, :, :]
else:
raise NotImplementedError("Spatial Dimension Size {spatial_dims} Not Implemented!")

pos_embed = nn.Parameter(pos_emb)
pos_embed.requires_grad = False

return pos_embed
2 changes: 1 addition & 1 deletion monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
hidden_size (int): dimension of hidden layer.
mlp_dim (int): dimension of feedforward layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/nets/transchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def __init__(
num_language_layers: number of language transformer layers.
num_vision_layers: number of vision transformer layers.
num_mixed_layers: number of mixed transformer layers.
drop_out: faction of the input units to drop.
drop_out: fraction of the input units to drop.
The other parameters are part of the `bert_config` to `MultiModal.from_pretrained`.
Expand Down
13 changes: 9 additions & 4 deletions monai/networks/nets/unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from monai.networks.blocks.dynunet_block import UnetOutBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock
from monai.networks.nets.vit import ViT
from monai.utils import ensure_tuple_rep
from monai.utils import deprecated_arg, ensure_tuple_rep


class UNETR(nn.Module):
Expand All @@ -27,6 +27,7 @@ class UNETR(nn.Module):
UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>"
"""

@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
def __init__(
self,
in_channels: int,
Expand All @@ -37,6 +38,7 @@ def __init__(
mlp_dim: int = 3072,
num_heads: int = 12,
pos_embed: str = "conv",
proj_type: str = "conv",
norm_name: tuple | str = "instance",
conv_block: bool = True,
res_block: bool = True,
Expand All @@ -54,7 +56,7 @@ def __init__(
hidden_size: dimension of hidden layer. Defaults to 768.
mlp_dim: dimension of feedforward layer. Defaults to 3072.
num_heads: number of attention heads. Defaults to 12.
pos_embed: position embedding layer type. Defaults to "conv".
proj_type: patch embedding layer type. Defaults to "conv".
norm_name: feature normalization type and arguments. Defaults to "instance".
conv_block: if convolutional block is used. Defaults to True.
res_block: if residual block is used. Defaults to True.
Expand All @@ -63,6 +65,9 @@ def __init__(
qkv_bias: apply the bias term for the qkv linear layer in self attention block. Defaults to False.
save_attn: to make accessible the attention in self attention block. Defaults to False.
.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.
Examples::
# for single channel input 4-channel output with image size of (96,96,96), feature size of 32 and batch norm
Expand All @@ -72,7 +77,7 @@ def __init__(
>>> net = UNETR(in_channels=1, out_channels=4, img_size=96, feature_size=32, norm_name='batch', spatial_dims=2)
# for 4-channel input 3-channel output with image size of (128,128,128), conv position embedding and instance norm
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance')
>>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), proj_type='conv', norm_name='instance')
"""

Expand All @@ -98,7 +103,7 @@ def __init__(
mlp_dim=mlp_dim,
num_layers=self.num_layers,
num_heads=num_heads,
pos_embed=pos_embed,
proj_type=proj_type,
classification=self.classification,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
Expand Down
22 changes: 16 additions & 6 deletions monai/networks/nets/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
from monai.utils import deprecated_arg

__all__ = ["ViT"]

Expand All @@ -30,6 +31,7 @@ class ViT(nn.Module):
ViT supports Torchscript but only works for Pytorch after 1.8.
"""

@deprecated_arg(name="pos_embed", since="1.2", new_name="proj_type", msg_suffix="please use `proj_type` instead.")
def __init__(
self,
in_channels: int,
Expand All @@ -40,6 +42,8 @@ def __init__(
num_layers: int = 12,
num_heads: int = 12,
pos_embed: str = "conv",
proj_type: str = "conv",
pos_embed_type: str = "learnable",
classification: bool = False,
num_classes: int = 2,
dropout_rate: float = 0.0,
Expand All @@ -57,27 +61,32 @@ def __init__(
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
num_layers (int, optional): number of transformer blocks. Defaults to 12.
num_heads (int, optional): number of attention heads. Defaults to 12.
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
proj_type (str, optional): patch embedding layer type. Defaults to "conv".
pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
post_activation (str, optional): add a final acivation function to the classification head
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
Set to other values to remove this function.
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
.. deprecated:: 1.4
``pos_embed`` is deprecated in favor of ``proj_type``.
Examples::
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
>>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos')
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
>>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True)
# for 3-channel with image size of (224,224), 12 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
>>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True,
>>> spatial_dims=2)
"""

Expand All @@ -96,7 +105,8 @@ def __init__(
patch_size=patch_size,
hidden_size=hidden_size,
num_heads=num_heads,
pos_embed=pos_embed,
proj_type=proj_type,
pos_embed_type=pos_embed_type,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
)
Expand Down
Loading

0 comments on commit 281cb01

Please sign in to comment.