diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 9e4497d5e41..fde4e70ec71 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union import numpy as np import torch @@ -89,7 +89,7 @@ def __init__( batch_size: int, num_workers: int, inferrer_fn: Callable, - device: Optional[Union[str, torch.device]] = "cuda" if torch.cuda.is_available() else "cpu", + device: Union[str, torch.device] = "cpu", image_key=CommonKeys.IMAGE, label_key=CommonKeys.LABEL, meta_key_postfix="meta_dict", diff --git a/tests/test_set_visible_devices.py b/tests/test_set_visible_devices.py new file mode 100644 index 00000000000..b6da879f4be --- /dev/null +++ b/tests/test_set_visible_devices.py @@ -0,0 +1,38 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +from tests.utils import skip_if_no_cuda + + +class TestVisibleDevices(unittest.TestCase): + @staticmethod + def run_process_and_get_exit_code(code_to_execute): + value = os.system(code_to_execute) + return int(bin(value).replace("0b", "").rjust(16, "0")[:8], 2) + + @skip_if_no_cuda + def test_visible_devices(self): + num_gpus_before = self.run_process_and_get_exit_code( + 'python -c "import os; import torch; ' + + "os.environ['CUDA_VISIBLE_DEVICES'] = ''; exit(torch.cuda.device_count())\"" + ) + num_gpus_after = self.run_process_and_get_exit_code( + 'python -c "import os; import monai; import torch; ' + + "os.environ['CUDA_VISIBLE_DEVICES'] = ''; exit(torch.cuda.device_count())\"" + ) + self.assertEqual(num_gpus_before, num_gpus_after) + + +if __name__ == "__main__": + unittest.main()