Skip to content

Commit

Permalink
[Attention block] relative positional embedding (Project-MONAI#7346)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7356

### Description

Add relative positinoal embedding in attention block as described in
https://arxiv.org/pdf/2112.01526.pdf
Largely inspired by
https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py

Can be useful for Project-MONAI#6357

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: vgrau98 <[email protected]>
Signed-off-by: vgrau98 <[email protected]>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: YunLiu <[email protected]>
Signed-off-by: Mark Graham <[email protected]>
  • Loading branch information
4 people authored and marksgraham committed Jan 30, 2024
1 parent e15b570 commit fdbc611
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 10 deletions.
6 changes: 6 additions & 0 deletions docs/source/networks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,12 @@ Blocks
.. autoclass:: monai.apps.reconstruction.networks.blocks.varnetblock.VarNetBlock
:members:

`Attention utilities`
~~~~~~~~~~~~~~~~~~~~~
.. automodule:: monai.networks.blocks.attention_utils
.. autofunction:: monai.networks.blocks.attention_utils.get_rel_pos
.. autofunction:: monai.networks.blocks.attention_utils.add_decomposed_rel_pos

N-Dim Fourier Transform
~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: monai.networks.blocks.fft_utils_t
Expand Down
128 changes: 128 additions & 0 deletions monai/networks/blocks/attention_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Tuple

import torch
import torch.nn.functional as F
from torch import nn


def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
rel_pos_resized: torch.Tensor = torch.Tensor()
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear"
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos

# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

return rel_pos_resized[relative_coords.long()]


def add_decomposed_rel_pos(
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple
) -> torch.Tensor:
r"""
Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py
Only 2D and 3D are supported.
Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
`d` apart will have the same embedding value (unlike absolute positional embedding).
.. math::
Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale
where
.. math::
E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}
with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
respectively spatial positions of element :math:`i` and :math:`j`
When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:
.. math::
R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}
with :math:`n = 1...dim`
Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
:math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n).
Returns:
attn (Tensor): attention logits with added relative positional embeddings.
"""
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0])
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1])

batch, _, dim = q.shape

if len(rel_pos_lst) == 2:
q_h, q_w = q_size[:2]
k_h, k_w = k_size[:2]
r_q = q.reshape(batch, q_h, q_w, dim)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw)

attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(
batch, q_h * q_w, k_h * k_w
)
elif len(rel_pos_lst) == 3:
q_h, q_w, q_d = q_size[:3]
k_h, k_w, k_d = k_size[:3]

rd = get_rel_pos(q_d, k_d, rel_pos_lst[2])

r_q = q.reshape(batch, q_h, q_w, q_d, dim)
rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh)
rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw)
rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd)

attn = (
attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d)
+ rel_h[:, :, :, :, None, None]
+ rel_w[:, :, :, None, :, None]
+ rel_d[:, :, :, None, None, :]
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d)

return attn
56 changes: 56 additions & 0 deletions monai/networks/blocks/rel_pos_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import Iterable, Tuple

import torch
from torch import nn

from monai.networks.blocks.attention_utils import add_decomposed_rel_pos
from monai.utils.misc import ensure_tuple_size


class DecomposedRelativePosEmbedding(nn.Module):
def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None:
"""
Args:
s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D)
c_dim (int): channel dimension
num_heads(int): number of attention heads
"""
super().__init__()

# validate inputs
if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]:
raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)")

self.s_input_dims = s_input_dims
self.c_dim = c_dim
self.num_heads = num_heads
self.rel_pos_arr = nn.ParameterList(
[nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims]
)

def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor:
""""""
batch = x.shape[0]
h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1)

att_mat = add_decomposed_rel_pos(
att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d),
q.contiguous().view(batch * self.num_heads, h * w * d, -1),
self.rel_pos_arr,
(h, w) if d == 1 else (h, w, d),
(h, w) if d == 1 else (h, w, d),
)

att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d)
return att_mat
33 changes: 31 additions & 2 deletions monai/networks/blocks/selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@

from __future__ import annotations

from typing import Optional, Tuple

import torch
import torch.nn as nn

from monai.networks.layers.utils import get_rel_pos_embedding_layer
from monai.utils import optional_import

Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange")
Expand All @@ -23,6 +26,7 @@ class SABlock(nn.Module):
"""
A self-attention block, based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
"""

def __init__(
Expand All @@ -32,13 +36,19 @@ def __init__(
dropout_rate: float = 0.0,
qkv_bias: bool = False,
save_attn: bool = False,
rel_pos_embedding: Optional[str] = None,
input_size: Optional[Tuple] = None,
) -> None:
"""
Args:
hidden_size (int): dimension of hidden layer.
num_heads (int): number of attention heads.
dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
positional parameter size.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
"""
Expand All @@ -62,11 +72,30 @@ def __init__(
self.scale = self.head_dim**-0.5
self.save_attn = save_attn
self.att_mat = torch.Tensor()
self.rel_positional_embedding = (
get_rel_pos_embedding_layer(rel_pos_embedding, input_size, self.head_dim, self.num_heads)
if rel_pos_embedding is not None
else None
)
self.input_size = input_size

def forward(self, x: torch.Tensor):
"""
Args:
x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
def forward(self, x):
Return:
torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
"""
output = self.input_rearrange(self.qkv(x))
q, k, v = output[0], output[1], output[2]
att_mat = (torch.einsum("blxd,blyd->blxy", q, k) * self.scale).softmax(dim=-1)
att_mat = torch.einsum("blxd,blyd->blxy", q, k) * self.scale

# apply relative positional embedding if defined
att_mat = self.rel_positional_embedding(x, att_mat, q) if self.rel_positional_embedding is not None else att_mat

att_mat = att_mat.softmax(dim=-1)

if self.save_attn:
# no gradients and new tensor;
# https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
Expand Down
13 changes: 12 additions & 1 deletion monai/networks/layers/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def use_factory(fact_args):
from monai.networks.utils import has_nvfuser_instance_norm
from monai.utils import ComponentStore, look_up_option, optional_import

__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "split_args"]
__all__ = ["LayerFactory", "Dropout", "Norm", "Act", "Conv", "Pool", "Pad", "RelPosEmbedding", "split_args"]


class LayerFactory(ComponentStore):
Expand Down Expand Up @@ -201,6 +201,10 @@ def split_args(args):
Conv = LayerFactory(name="Convolution layers", description="Factory for creating convolution layers.")
Pool = LayerFactory(name="Pooling layers", description="Factory for creating pooling layers.")
Pad = LayerFactory(name="Padding layers", description="Factory for creating padding layers.")
RelPosEmbedding = LayerFactory(
name="Relative positional embedding layers",
description="Factory for creating relative positional embedding factory",
)


@Dropout.factory_function("dropout")
Expand Down Expand Up @@ -468,3 +472,10 @@ def constant_pad_factory(dim: int) -> type[nn.ConstantPad1d | nn.ConstantPad2d |
"""
types = (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d)
return types[dim - 1]


@RelPosEmbedding.factory_function("decomposed")
def decomposed_rel_pos_embedding() -> type[nn.Module]:
from monai.networks.blocks.rel_pos_embedding import DecomposedRelativePosEmbedding

return DecomposedRelativePosEmbedding
15 changes: 14 additions & 1 deletion monai/networks/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@

from __future__ import annotations

from typing import Optional

import torch.nn

from monai.networks.layers.factories import Act, Dropout, Norm, Pool, split_args
from monai.networks.layers.factories import Act, Dropout, Norm, Pool, RelPosEmbedding, split_args
from monai.utils import has_option

__all__ = ["get_norm_layer", "get_act_layer", "get_dropout_layer", "get_pool_layer"]
Expand Down Expand Up @@ -124,3 +126,14 @@ def get_pool_layer(name: tuple | str, spatial_dims: int | None = 1):
pool_name, pool_args = split_args(name)
pool_type = Pool[pool_name, spatial_dims]
return pool_type(**pool_args)


def get_rel_pos_embedding_layer(name: tuple | str, s_input_dims: Optional[tuple], c_dim: int, num_heads: int):
embedding_name, embedding_args = split_args(name)
embedding_type = RelPosEmbedding[embedding_name]
# create a dictionary with the default values which can be overridden by embedding_args
kw_args = {"s_input_dims": s_input_dims, "c_dim": c_dim, "num_heads": num_heads, **embedding_args}
# filter out unused argument names
kw_args = {k: v for k, v in kw_args.items() if has_option(embedding_type, k)}

return embedding_type(**kw_args)
21 changes: 15 additions & 6 deletions tests/test_selfattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.networks import eval_mode
from monai.networks.blocks.selfattention import SABlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import

einops, has_einops = optional_import("einops")
Expand All @@ -28,12 +29,20 @@
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 6, 8, 12]:
test_case = [
{"hidden_size": hidden_size, "num_heads": num_heads, "dropout_rate": dropout_rate},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding,
"input_size": input_size,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_SABLOCK.append(test_case)


class TestResBlock(unittest.TestCase):
Expand Down

0 comments on commit fdbc611

Please sign in to comment.