From 22ed5d67cc22c652b9462c3fab34fc037e29b3ff Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 1 Dec 2023 16:16:58 +0800 Subject: [PATCH 1/2] fix #7250 Signed-off-by: KumoLiu --- requirements-dev.txt | 1 - tests/min_tests.py | 1 - tests/test_transchex.py | 84 ----------------------------------------- 3 files changed, 86 deletions(-) delete mode 100644 tests/test_transchex.py diff --git a/requirements-dev.txt b/requirements-dev.txt index 6332d5b0a5..63b13a494a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,7 +33,6 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/tests/min_tests.py b/tests/min_tests.py index 8128bb7b84..1387fdb91a 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -183,7 +183,6 @@ def run_testsuit(): "test_testtimeaugmentation", "test_torchvision", "test_torchvisiond", - "test_transchex", "test_transformerblock", "test_unetr", "test_unetr_block", diff --git a/tests/test_transchex.py b/tests/test_transchex.py deleted file mode 100644 index 9ad847cdaa..0000000000 --- a/tests/test_transchex.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import unittest - -import torch -from parameterized import parameterized - -from monai.networks import eval_mode -from monai.networks.nets.transchex import Transchex -from tests.utils import skip_if_quick - -TEST_CASE_TRANSCHEX = [] -for drop_out in [0.4]: - for in_channels in [3]: - for img_size in [224]: - for patch_size in [16, 32]: - for num_language_layers in [2]: - for num_vision_layers in [4]: - for num_mixed_layers in [3]: - for num_classes in [8]: - test_case = [ - { - "in_channels": in_channels, - "img_size": (img_size,) * 2, - "patch_size": (patch_size,) * 2, - "num_vision_layers": num_vision_layers, - "num_mixed_layers": num_mixed_layers, - "num_language_layers": num_language_layers, - "num_classes": num_classes, - "drop_out": drop_out, - }, - (2, num_classes), - ] - TEST_CASE_TRANSCHEX.append(test_case) - - -@skip_if_quick -class TestTranschex(unittest.TestCase): - @parameterized.expand(TEST_CASE_TRANSCHEX) - def test_shape(self, input_param, expected_shape): - net = Transchex(**input_param) - with eval_mode(net): - result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) - self.assertEqual(result.shape, expected_shape) - - def test_ill_arg(self): - with self.assertRaises(ValueError): - Transchex( - in_channels=3, - img_size=(128, 128), - patch_size=(16, 16), - num_language_layers=2, - num_mixed_layers=4, - num_vision_layers=2, - num_classes=2, - drop_out=5.0, - ) - - with self.assertRaises(ValueError): - Transchex( - in_channels=1, - img_size=(97, 97), - patch_size=(16, 16), - num_language_layers=6, - num_mixed_layers=6, - num_vision_layers=8, - num_classes=8, - drop_out=0.4, - ) - - -if __name__ == "__main__": - unittest.main() From e18a579baa236d31accc1d5939ace17a5c1607ef Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Mon, 4 Dec 2023 17:18:34 +0800 Subject: [PATCH 2/2] address comments Signed-off-by: KumoLiu --- docs/requirements.txt | 2 +- requirements-dev.txt | 1 + setup.cfg | 4 +- tests/min_tests.py | 1 + tests/test_transchex.py | 85 +++++++++++++++++++++++++++++++++++++++++ 5 files changed, 90 insertions(+), 3 deletions(-) create mode 100644 tests/test_transchex.py diff --git a/docs/requirements.txt b/docs/requirements.txt index a9bbc384f8..e5bedf8552 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -21,7 +21,7 @@ sphinxcontrib-serializinghtml sphinx-autodoc-typehints==1.11.1 pandas einops -transformers<4.22 # https://github.com/Project-MONAI/MONAI/issues/5157 +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 tensorboardX diff --git a/requirements-dev.txt b/requirements-dev.txt index 63b13a494a..cacbefe234 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -33,6 +33,7 @@ tifffile; platform_system == "Linux" or platform_system == "Darwin" pandas requests einops +transformers<4.22; python_version <= '3.10' # https://github.com/Project-MONAI/MONAI/issues/5157 mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib!=3.5.0 diff --git a/setup.cfg b/setup.cfg index 123da68dfa..0370d0062d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -65,7 +65,7 @@ all = imagecodecs pandas einops - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow>=1.28.0 clearml>=1.10.0rc0 matplotlib @@ -123,7 +123,7 @@ pandas = einops = einops transformers = - transformers<4.22 + transformers<4.22; python_version <= '3.10' mlflow = mlflow matplotlib = diff --git a/tests/min_tests.py b/tests/min_tests.py index 1387fdb91a..8128bb7b84 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -183,6 +183,7 @@ def run_testsuit(): "test_testtimeaugmentation", "test_torchvision", "test_torchvisiond", + "test_transchex", "test_transformerblock", "test_unetr", "test_unetr_block", diff --git a/tests/test_transchex.py b/tests/test_transchex.py new file mode 100644 index 0000000000..8fb1f56715 --- /dev/null +++ b/tests/test_transchex.py @@ -0,0 +1,85 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch +from parameterized import parameterized + +from monai.networks import eval_mode +from monai.networks.nets.transchex import Transchex +from tests.utils import SkipIfAtLeastPyTorchVersion, skip_if_quick + +TEST_CASE_TRANSCHEX = [] +for drop_out in [0.4]: + for in_channels in [3]: + for img_size in [224]: + for patch_size in [16, 32]: + for num_language_layers in [2]: + for num_vision_layers in [4]: + for num_mixed_layers in [3]: + for num_classes in [8]: + test_case = [ + { + "in_channels": in_channels, + "img_size": (img_size,) * 2, + "patch_size": (patch_size,) * 2, + "num_vision_layers": num_vision_layers, + "num_mixed_layers": num_mixed_layers, + "num_language_layers": num_language_layers, + "num_classes": num_classes, + "drop_out": drop_out, + }, + (2, num_classes), + ] + TEST_CASE_TRANSCHEX.append(test_case) + + +@skip_if_quick +@SkipIfAtLeastPyTorchVersion((1, 10)) +class TestTranschex(unittest.TestCase): + @parameterized.expand(TEST_CASE_TRANSCHEX) + def test_shape(self, input_param, expected_shape): + net = Transchex(**input_param) + with eval_mode(net): + result = net(torch.randint(2, (2, 512)), torch.randint(2, (2, 512)), torch.randn((2, 3, 224, 224))) + self.assertEqual(result.shape, expected_shape) + + def test_ill_arg(self): + with self.assertRaises(ValueError): + Transchex( + in_channels=3, + img_size=(128, 128), + patch_size=(16, 16), + num_language_layers=2, + num_mixed_layers=4, + num_vision_layers=2, + num_classes=2, + drop_out=5.0, + ) + + with self.assertRaises(ValueError): + Transchex( + in_channels=1, + img_size=(97, 97), + patch_size=(16, 16), + num_language_layers=6, + num_mixed_layers=6, + num_vision_layers=8, + num_classes=8, + drop_out=0.4, + ) + + +if __name__ == "__main__": + unittest.main()