From 1874fcdb2e98da14ec9a3d0ca1db0ded19b10603 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Wed, 9 Oct 2024 15:06:22 -0700 Subject: [PATCH 1/3] Streamlined rearrange Signed-off-by: Boris Fomitchev --- monai/networks/blocks/spatialattention.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index 665442b55e..c6ce8487e0 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -19,7 +19,6 @@ from monai.networks.blocks import SABlock from monai.utils import optional_import -Rearrange, _ = optional_import("einops.layers.torch", name="Rearrange") class SpatialAttentionBlock(nn.Module): @@ -71,27 +70,13 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - + def forward(self, x: torch.Tensor): residual = x - - if self.spatial_dims == 1: - h = x.shape[2] - rearrange_input = Rearrange("b c h -> b h c") - rearrange_output = Rearrange("b h c -> b c h", h=h) - if self.spatial_dims == 2: - h, w = x.shape[2], x.shape[3] - rearrange_input = Rearrange("b c h w -> b (h w) c") - rearrange_output = Rearrange("b (h w) c -> b c h w", h=h, w=w) - else: - h, w, d = x.shape[2], x.shape[3], x.shape[4] - rearrange_input = Rearrange("b c h w d -> b (h w d) c") - rearrange_output = Rearrange("b (h w d) c -> b c h w d", h=h, w=w, d=d) - + shape = x.shape x = self.norm(x) - x = rearrange_input(x) # B x (x_dim * y_dim [ * z_dim]) x C - + x = x.reshape(*shape[:2], -1).transpose(1,2) # "b c h w d -> b (h w d) c" x = self.attn(x) - x = rearrange_output(x) # B x x C x x_dim * y_dim * [z_dim] + x = x.transpose(1,2).reshape(shape) # "b (h w d) c -> b c h w d" x = x + residual return x From e059f3c9b0c2e886b00555be3a8ac5bfdb838a76 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 9 Oct 2024 22:15:11 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/networks/blocks/spatialattention.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index c6ce8487e0..8037eddcd1 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -17,7 +17,6 @@ import torch.nn as nn from monai.networks.blocks import SABlock -from monai.utils import optional_import @@ -70,7 +69,7 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - + def forward(self, x: torch.Tensor): residual = x shape = x.shape From 6b92c8e1c9cf54278390ff95ae6a3aeed583c898 Mon Sep 17 00:00:00 2001 From: Boris Fomitchev Date: Thu, 10 Oct 2024 00:34:49 -0700 Subject: [PATCH 3/3] reformat Signed-off-by: Boris Fomitchev --- monai/networks/blocks/spatialattention.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/monai/networks/blocks/spatialattention.py b/monai/networks/blocks/spatialattention.py index c6ce8487e0..903d11942d 100644 --- a/monai/networks/blocks/spatialattention.py +++ b/monai/networks/blocks/spatialattention.py @@ -20,7 +20,6 @@ from monai.utils import optional_import - class SpatialAttentionBlock(nn.Module): """Perform spatial self-attention on the input tensor. @@ -70,13 +69,13 @@ def __init__( use_combined_linear=use_combined_linear, use_flash_attention=use_flash_attention, ) - + def forward(self, x: torch.Tensor): residual = x shape = x.shape x = self.norm(x) - x = x.reshape(*shape[:2], -1).transpose(1,2) # "b c h w d -> b (h w d) c" + x = x.reshape(*shape[:2], -1).transpose(1, 2) # "b c h w d -> b (h w d) c" x = self.attn(x) - x = x.transpose(1,2).reshape(shape) # "b (h w d) c -> b c h w d" + x = x.transpose(1, 2).reshape(shape) # "b (h w d) c -> b c h w d" x = x + residual return x