Skip to content

Commit

Permalink
3d local attention window tests
Browse files Browse the repository at this point in the history
Signed-off-by: vgrau98 <[email protected]>
  • Loading branch information
vgrau98 committed Apr 28, 2024
1 parent 46ee2b0 commit dc7efd4
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
2 changes: 1 addition & 1 deletion monai/networks/blocks/transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False.
save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643.
If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0).
If 0, global attention used.
See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py.
input_size (Tuple): spatial input dimensions (h, w, and d). Has to be set if local window attention is used.
Expand Down
24 changes: 24 additions & 0 deletions tests/test_transformerblock.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@
]
TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN.append(test_case)

TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D = []
for window_size in [0, 2, 3, 4]:
test_case = [
{
"hidden_size": 360,
"num_heads": 4,
"mlp_dim": 1024,
"dropout_rate": 0,
"window_size": window_size,
"input_size": (3, 3, 3),
},
(2, 27, 360),
(2, 27, 360),
]
TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D.append(test_case)


class TestTransformerBlock(unittest.TestCase):
@parameterized.expand(TEST_CASE_TRANSFORMERBLOCK)
Expand Down Expand Up @@ -119,6 +135,14 @@ def test_local_window(self, input_param, input_shape, expected_shape):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)

@parameterized.expand(TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D)
@skipUnless(has_einops, "Requires einops")
def test_local_window_3d(self, input_param, input_shape, expected_shape):
net = TransformerBlock(**input_param)
with eval_mode(net):
result = net(torch.randn(input_shape))
self.assertEqual(result.shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit dc7efd4

Please sign in to comment.