forked from Project-MONAI/MONAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Attention block] relative positional embedding (Project-MONAI#7346)
Fixes Project-MONAI#7356 ### Description Add relative positinoal embedding in attention block as described in https://arxiv.org/pdf/2112.01526.pdf Largely inspired by https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py Can be useful for Project-MONAI#6357 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: vgrau98 <[email protected]> Signed-off-by: vgrau98 <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]> Signed-off-by: Mark Graham <[email protected]>
- Loading branch information
1 parent
e15b570
commit fdbc611
Showing
7 changed files
with
262 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn | ||
|
||
|
||
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: | ||
""" | ||
Get relative positional embeddings according to the relative positions of | ||
query and key sizes. | ||
Args: | ||
q_size (int): size of query q. | ||
k_size (int): size of key k. | ||
rel_pos (Tensor): relative position embeddings (L, C). | ||
Returns: | ||
Extracted positional embeddings according to relative positions. | ||
""" | ||
rel_pos_resized: torch.Tensor = torch.Tensor() | ||
max_rel_dist = int(2 * max(q_size, k_size) - 1) | ||
# Interpolate rel pos if needed. | ||
if rel_pos.shape[0] != max_rel_dist: | ||
# Interpolate rel pos. | ||
rel_pos_resized = F.interpolate( | ||
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), size=max_rel_dist, mode="linear" | ||
) | ||
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | ||
else: | ||
rel_pos_resized = rel_pos | ||
|
||
# Scale the coords with short length if shapes for q and k are different. | ||
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) | ||
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) | ||
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) | ||
|
||
return rel_pos_resized[relative_coords.long()] | ||
|
||
|
||
def add_decomposed_rel_pos( | ||
attn: torch.Tensor, q: torch.Tensor, rel_pos_lst: nn.ParameterList, q_size: Tuple, k_size: Tuple | ||
) -> torch.Tensor: | ||
r""" | ||
Calculate decomposed Relative Positional Embeddings from mvitv2 implementation: | ||
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py | ||
Only 2D and 3D are supported. | ||
Encoding the relative position of tokens in the attention matrix: tokens spaced a distance | ||
`d` apart will have the same embedding value (unlike absolute positional embedding). | ||
.. math:: | ||
Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale | ||
where | ||
.. math:: | ||
E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)} | ||
with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`, | ||
respectively spatial positions of element :math:`i` and :math:`j` | ||
When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow: | ||
.. math:: | ||
R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)} | ||
with :math:`n = 1...dim` | ||
Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to | ||
:math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding. | ||
Args: | ||
attn (Tensor): attention map. | ||
q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C). | ||
rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis. | ||
q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n). | ||
k_size (Tuple): spatial sequence size of key k with (k_dim_1, ..., k_dim_n). | ||
Returns: | ||
attn (Tensor): attention logits with added relative positional embeddings. | ||
""" | ||
rh = get_rel_pos(q_size[0], k_size[0], rel_pos_lst[0]) | ||
rw = get_rel_pos(q_size[1], k_size[1], rel_pos_lst[1]) | ||
|
||
batch, _, dim = q.shape | ||
|
||
if len(rel_pos_lst) == 2: | ||
q_h, q_w = q_size[:2] | ||
k_h, k_w = k_size[:2] | ||
r_q = q.reshape(batch, q_h, q_w, dim) | ||
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, rh) | ||
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, rw) | ||
|
||
attn = (attn.view(batch, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view( | ||
batch, q_h * q_w, k_h * k_w | ||
) | ||
elif len(rel_pos_lst) == 3: | ||
q_h, q_w, q_d = q_size[:3] | ||
k_h, k_w, k_d = k_size[:3] | ||
|
||
rd = get_rel_pos(q_d, k_d, rel_pos_lst[2]) | ||
|
||
r_q = q.reshape(batch, q_h, q_w, q_d, dim) | ||
rel_h = torch.einsum("bhwdc,hkc->bhwdk", r_q, rh) | ||
rel_w = torch.einsum("bhwdc,wkc->bhwdk", r_q, rw) | ||
rel_d = torch.einsum("bhwdc,wkc->bhwdk", r_q, rd) | ||
|
||
attn = ( | ||
attn.view(batch, q_h, q_w, q_d, k_h, k_w, k_d) | ||
+ rel_h[:, :, :, :, None, None] | ||
+ rel_w[:, :, :, None, :, None] | ||
+ rel_d[:, :, :, None, None, :] | ||
).view(batch, q_h * q_w * q_d, k_h * k_w * k_d) | ||
|
||
return attn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import annotations | ||
|
||
from typing import Iterable, Tuple | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from monai.networks.blocks.attention_utils import add_decomposed_rel_pos | ||
from monai.utils.misc import ensure_tuple_size | ||
|
||
|
||
class DecomposedRelativePosEmbedding(nn.Module): | ||
def __init__(self, s_input_dims: Tuple[int, int] | Tuple[int, int, int], c_dim: int, num_heads: int) -> None: | ||
""" | ||
Args: | ||
s_input_dims (Tuple): input spatial dimension. (H, W) or (H, W, D) | ||
c_dim (int): channel dimension | ||
num_heads(int): number of attention heads | ||
""" | ||
super().__init__() | ||
|
||
# validate inputs | ||
if not isinstance(s_input_dims, Iterable) or len(s_input_dims) not in [2, 3]: | ||
raise ValueError("s_input_dims must be set as follows: (H, W) or (H, W, D)") | ||
|
||
self.s_input_dims = s_input_dims | ||
self.c_dim = c_dim | ||
self.num_heads = num_heads | ||
self.rel_pos_arr = nn.ParameterList( | ||
[nn.Parameter(torch.zeros(2 * dim_input_size - 1, c_dim)) for dim_input_size in s_input_dims] | ||
) | ||
|
||
def forward(self, x: torch.Tensor, att_mat: torch.Tensor, q: torch.Tensor) -> torch.Tensor: | ||
"""""" | ||
batch = x.shape[0] | ||
h, w, d = ensure_tuple_size(self.s_input_dims, 3, 1) | ||
|
||
att_mat = add_decomposed_rel_pos( | ||
att_mat.contiguous().view(batch * self.num_heads, h * w * d, h * w * d), | ||
q.contiguous().view(batch * self.num_heads, h * w * d, -1), | ||
self.rel_pos_arr, | ||
(h, w) if d == 1 else (h, w, d), | ||
(h, w) if d == 1 else (h, w, d), | ||
) | ||
|
||
att_mat = att_mat.reshape(batch, self.num_heads, h * w * d, h * w * d) | ||
return att_mat |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters