Skip to content

Commit

Permalink
feat: 3d local window attention
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Apr 28, 2024
1 parent c2f121b commit 46ee2b0
Showing 1 changed file with 75 additions and 7 deletions.
82 changes: 75 additions & 7 deletions monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,32 @@ def forward(self, x: torch.Tensor):
f"Input tensor spatial dimension {x.shape[1]} should be equal to {self.input_size} product"
)

h, w = self.input_size
x = rearrange(x, "b (h w) d -> b h w d", h=h, w=w)
x, pad_hw = window_partition(x, self.window_size)
x = rearrange(x, "b h w d -> b (h w) d", h=self.window_size, w=self.window_size)
if len(self.input_size) == 2:
x = rearrange(x, "b (h w) c -> b h w c", h=self.input_size[0], w=self.input_size[1])
x, pad_hw = window_partition(x, self.window_size)
x = rearrange(x, "b h w c -> b (h w) c", h=self.window_size, w=self.window_size)
elif len(self.input_size) == 3:
x = rearrange(
x, "b (h w d) c -> b h w d c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2]
)
x, pad_hwd = window_partition_3d(x, self.window_size)
x = rearrange(x, "b h w d c -> b (h w d) c", h=self.window_size, w=self.window_size, d=self.window_size)

x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = rearrange(x, "b (h w) d -> b h w d", h=self.window_size, w=self.window_size)
x = window_unpartition(x, self.window_size, pad_hw, (h, w))
x = rearrange(x, "b h w d -> b (h w) d", h=h, w=w)
if len(self.input_size) == 2:
x = rearrange(x, "b (h w) c -> b h w c", h=self.window_size, w=self.window_size)
x = window_unpartition(x, self.window_size, pad_hw, (self.input_size[0], self.input_size[1]))
x = rearrange(x, "b h w c -> b (h w) c", h=self.input_size[0], w=self.input_size[1])
elif len(self.input_size) == 3:
x = rearrange(x, "b (h w d) c -> b h w d c", h=self.window_size, w=self.window_size, d=self.window_size)
x = window_unpartition_3d(
x, self.window_size, pad_hwd, (self.input_size[0], self.input_size[1], self.input_size[2])
)
x = rearrange(
x, "b h w d c -> b (h w d) c", h=self.input_size[0], w=self.input_size[1], d=self.input_size[2]
)

x = shortcut + x
x = x + self.mlp(self.norm2(x))
Expand Down Expand Up @@ -131,6 +146,32 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
return windows, (hp, wp)


def window_partition_3d(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
"""
Partition into non-overlapping windows with padding if needed. 3d implementation.
Args:
x (tensor): input tokens with [B, H, W, D, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, window_size, C].
(Hp, Wp, Dp): padded height, width and depth before partition
"""
batch, h, w, d, c = x.shape

pad_h = (window_size - h % window_size) % window_size
pad_w = (window_size - w % window_size) % window_size
pad_d = (window_size - d % window_size) % window_size
if pad_h > 0 or pad_w > 0 or pad_d > 0:
x = F.pad(x, (0, 0, 0, pad_d, 0, pad_w, 0, pad_h))
hp, wp, dp = h + pad_h, w + pad_w, d + pad_d

x = x.view(batch, hp // window_size, window_size, wp // window_size, window_size, dp // window_size, window_size, c)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, c)
return windows, (hp, wp, dp)
...


def window_unpartition(
windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
) -> torch.Tensor:
Expand All @@ -154,3 +195,30 @@ def window_unpartition(
if hp > h or wp > w:
x = x[:, :h, :w, :].contiguous()
return x


def window_unpartition_3d(
windows: torch.Tensor, window_size: int, pad_hwd: Tuple[int, int, int], hwd: Tuple[int, int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding. 3d implementation.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, window_size, C].
window_size (int): window size.
pad_hwd (Tuple): padded height, width and depth (hp, wp, dp).
hwd (Tuple): original height, width and depth (H, W, D) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, D, C].
"""
hp, wp, dp = pad_hwd
h, w, d = hwd
batch = windows.shape[0] // (hp * wp * dp // window_size // window_size // window_size)
x = windows.view(
batch, hp // window_size, wp // window_size, dp // window_size, window_size, window_size, window_size, -1
)
x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(batch, hp, wp, dp, -1)

if hp > h or wp > w or dp > d:
x = x[:, :h, :w, :d, :].contiguous()
return x

0 comments on commit 46ee2b0

Please sign in to comment.