diff --git a/monai/networks/blocks/transformerblock.py b/monai/networks/blocks/transformerblock.py index 06e38bc8b9..d0d305ad42 100644 --- a/monai/networks/blocks/transformerblock.py +++ b/monai/networks/blocks/transformerblock.py @@ -50,7 +50,7 @@ def __init__( qkv_bias (bool, optional): apply bias term for the qkv linear layer. Defaults to False. save_attn (bool, optional): to make accessible the attention matrix. Defaults to False. window_size (int): Window size for local attention as used in Segment Anything https://arxiv.org/abs/2304.02643. - If 0, global attention used. Only 2D inputs are supported for local attention (window_size > 0). + If 0, global attention used. See https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/image_encoder.py. input_size (Tuple): spatial input dimensions (h, w, and d). Has to be set if local window attention is used. diff --git a/tests/test_transformerblock.py b/tests/test_transformerblock.py index 193efdd16c..5c49afe96b 100644 --- a/tests/test_transformerblock.py +++ b/tests/test_transformerblock.py @@ -57,6 +57,22 @@ ] TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN.append(test_case) +TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D = [] +for window_size in [0, 2, 3, 4]: + test_case = [ + { + "hidden_size": 360, + "num_heads": 4, + "mlp_dim": 1024, + "dropout_rate": 0, + "window_size": window_size, + "input_size": (3, 3, 3), + }, + (2, 27, 360), + (2, 27, 360), + ] + TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D.append(test_case) + class TestTransformerBlock(unittest.TestCase): @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK) @@ -119,6 +135,14 @@ def test_local_window(self, input_param, input_shape, expected_shape): result = net(torch.randn(input_shape)) self.assertEqual(result.shape, expected_shape) + @parameterized.expand(TEST_CASE_TRANSFORMERBLOCK_LOCAL_WIN_3D) + @skipUnless(has_einops, "Requires einops") + def test_local_window_3d(self, input_param, input_shape, expected_shape): + net = TransformerBlock(**input_param) + with eval_mode(net): + result = net(torch.randn(input_shape)) + self.assertEqual(result.shape, expected_shape) + if __name__ == "__main__": unittest.main()