From cf9e4fc3748cfab995a76158db4887454ca3a6d0 Mon Sep 17 00:00:00 2001 From: He Date: Fri, 24 May 2024 19:27:23 -0700 Subject: [PATCH 1/9] device selection feature --- tests/test_python_device_selection_api.sh | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 tests/test_python_device_selection_api.sh diff --git a/tests/test_python_device_selection_api.sh b/tests/test_python_device_selection_api.sh new file mode 100644 index 000000000..b5fc7ecae --- /dev/null +++ b/tests/test_python_device_selection_api.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +base_path=$(realpath ../..) +script_path="$base_path/totalsegmentator/bin/TotalSegmentator.py" + +python3 "$script_path" -i "$base_path/tests/reference_files/example_ct_sm.nii.gz" -o "$base_path/tests/unittest_prediction.nii.gz" -bs --ml -d 1 +pytest -v "$base_path/tests/test_end_to_end.py"::test_end_to_end::test_prediction_multilabel \ No newline at end of file From f0db94646c01f20116d131640317cfbf5224ff30 Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Fri, 24 May 2024 22:12:15 -0700 Subject: [PATCH 2/9] Finish Device Selection Feature. --- tests/test_python_device_selection_api.sh | 7 ------- tests/tests.sh | 3 ++- totalsegmentator/bin/TotalSegmentator.py | 17 ++++++++++++++--- totalsegmentator/nnunet.py | 4 +++- totalsegmentator/python_api.py | 15 ++++++++++++--- 5 files changed, 31 insertions(+), 15 deletions(-) delete mode 100644 tests/test_python_device_selection_api.sh diff --git a/tests/test_python_device_selection_api.sh b/tests/test_python_device_selection_api.sh deleted file mode 100644 index b5fc7ecae..000000000 --- a/tests/test_python_device_selection_api.sh +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/bash - -base_path=$(realpath ../..) -script_path="$base_path/totalsegmentator/bin/TotalSegmentator.py" - -python3 "$script_path" -i "$base_path/tests/reference_files/example_ct_sm.nii.gz" -o "$base_path/tests/unittest_prediction.nii.gz" -bs --ml -d 1 -pytest -v "$base_path/tests/test_end_to_end.py"::test_end_to_end::test_prediction_multilabel \ No newline at end of file diff --git a/tests/tests.sh b/tests/tests.sh index cc7a33cbf..792f9e1c4 100755 --- a/tests/tests.sh +++ b/tests/tests.sh @@ -4,8 +4,9 @@ set -e # ./tests/tests.sh + # Test multilabel prediction -TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu +TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d 0 pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel # Test organ prediction - roi subset diff --git a/totalsegmentator/bin/TotalSegmentator.py b/totalsegmentator/bin/TotalSegmentator.py index 625b53525..dab2c505f 100644 --- a/totalsegmentator/bin/TotalSegmentator.py +++ b/totalsegmentator/bin/TotalSegmentator.py @@ -103,9 +103,20 @@ def main(): # "mps" is for apple silicon; the latest pytorch nightly version supports 3D Conv but not ConvTranspose3D which is # also needed by nnU-Net. So "mps" not working for now. # https://github.com/pytorch/pytorch/issues/77818 - parser.add_argument("-d", "--device", choices=["gpu", "cpu", "mps"], - help="Device to run on (default: gpu).", - default="gpu") + + def validate_device_type(value): + valid_strings = {"gpu", "cpu", "mps"} + if value in valid_strings: + return value + if value.isdigit(): + return int(value) + raise argparse.ArgumentTypeError(f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or or an desired device ID (integer).") + + parser.add_argument("-d",'--device', type=validate_device_type, required=True, + help="Device type: 'GPU', 'CPU', 'MPS', or an desired device ID (integer).") + # parser.add_argument("-d", "--device", choices=["gpu", "cpu", "mps"], + # help="Device to run on (default: gpu).", + # default="gpu") parser.add_argument("-q", "--quiet", action="store_true", help="Print no intermediate outputs", default=False) diff --git a/totalsegmentator/nnunet.py b/totalsegmentator/nnunet.py index 7732c7073..828a3d720 100644 --- a/totalsegmentator/nnunet.py +++ b/totalsegmentator/nnunet.py @@ -170,7 +170,7 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None, model_folder = get_output_folder(task_id, trainer, plans, model) assert device in ['cpu', 'cuda', - 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.' + 'mps'] or isinstance(device, torch.device), f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.' if device == 'cpu': # let's allow torch to use hella threads import multiprocessing @@ -181,6 +181,8 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None, torch.set_num_threads(1) # torch.set_num_interop_threads(1) # throws error if setting the second time device = torch.device('cuda') + elif isinstance(device, torch.device): + device = device else: device = torch.device('mps') disable_tta = not tta diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index 458e05314..421e0ae32 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -67,14 +67,23 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa nora_tag = "None" if nora_tag is None else nora_tag - if not quiet: - print("\nIf you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024\n") - # available devices: gpu | cpu | mps + # available devices: gpu | cpu | mps | your desired device id (integer) if device == "gpu": device = "cuda" if device == "cuda" and not torch.cuda.is_available(): print("No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.") device = "cpu" + elif isinstance(device, int): + if device < torch.cuda.device_count(): + device = torch.device(device) + else: + print("Invalid GPU config, running on the CPU") + device = "cpu" + print(f"Using Deivce: {device}") + + + if not quiet: + print("\nIf you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024\n") setup_nnunet() setup_totalseg() From 294efcc56c8740bfc77d9b4d686521e28041171e Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Fri, 24 May 2024 22:44:24 -0700 Subject: [PATCH 3/9] Update README.md and testcase. --- README.md | 11 +++++++---- tests/tests.sh | 5 ++++- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 0f9515da2..8480734aa 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ The mapping from label ID to class name can be found [here](https://github.com/w ### Advanced settings -* `--device`: Choose `cpu` or `gpu` +* `--device`: Choose `cpu` or `gpu` or `desired gpu id (e.g., 1 -> cuda:1)` * `--fast`: For faster runtime and less memory requirements use this option. It will run a lower resolution model (3mm instead of 1.5mm). * `--roi_subset`: Takes a space-separated list of class names (e.g. `spleen colon brain`) and only predicts those classes. Saves a lot of runtime and memory. Might be less accurate especially for small classes (e.g. prostate). * `--preview`: This will generate a 3D rendering of all classes, giving you a quick overview if the segmentation worked and where it failed (see `preview.png` in output directory). @@ -145,12 +145,15 @@ import nibabel as nib from totalsegmentator.python_api import totalsegmentator if __name__ == "__main__": + # specify the device + # 'gpu', 'cpu', 'mps', '1', ... + device = '1' ## cuda:1 # option 1: provide input and output as file paths - totalsegmentator(input_path, output_path) - + totalsegmentator(input_path, output_path, device=device) + # option 2: provide input and output as nifti image objects input_img = nib.load(input_path) - output_img = totalsegmentator(input_img) + output_img = totalsegmentator(input_img, device=device) nib.save(output_img, output_path) ``` You can see all available arguments [here](https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/python_api.py). Running from within the main environment should avoid some multiprocessing issues. diff --git a/tests/tests.sh b/tests/tests.sh index 792f9e1c4..f5a573acc 100755 --- a/tests/tests.sh +++ b/tests/tests.sh @@ -6,7 +6,10 @@ set -e # Test multilabel prediction -TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d 0 +TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d 1 +pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel + +TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel # Test organ prediction - roi subset From 72533742a4cce822ab7f5bd60ee9818c767fc46b Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Fri, 24 May 2024 22:45:47 -0700 Subject: [PATCH 4/9] Update README.md and testcase. --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 8480734aa..7c1640037 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,7 @@ if __name__ == "__main__": # specify the device # 'gpu', 'cpu', 'mps', '1', ... device = '1' ## cuda:1 + # option 1: provide input and output as file paths totalsegmentator(input_path, output_path, device=device) From c3ed1eccefe315c5f122c86dc020a4434c171293 Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Tue, 28 May 2024 22:06:09 -0700 Subject: [PATCH 5/9] Update syntax and add one testcase. --- README.md | 6 ++-- tests/test_device_type.py | 20 ++++++++++++++ tests/tests.sh | 5 ++-- totalsegmentator/bin/TotalSegmentator.py | 35 +++++++++++++++++------- totalsegmentator/python_api.py | 7 +++-- 5 files changed, 55 insertions(+), 18 deletions(-) create mode 100644 tests/test_device_type.py diff --git a/README.md b/README.md index 7c1640037..1e574140e 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ The mapping from label ID to class name can be found [here](https://github.com/w ### Advanced settings -* `--device`: Choose `cpu` or `gpu` or `desired gpu id (e.g., 1 -> cuda:1)` +* `--device`: Choose `cpu` or `gpu` or `gpu:X (e.g., gpu:1 -> cuda:1)` * `--fast`: For faster runtime and less memory requirements use this option. It will run a lower resolution model (3mm instead of 1.5mm). * `--roi_subset`: Takes a space-separated list of class names (e.g. `spleen colon brain`) and only predicts those classes. Saves a lot of runtime and memory. Might be less accurate especially for small classes (e.g. prostate). * `--preview`: This will generate a 3D rendering of all classes, giving you a quick overview if the segmentation worked and where it failed (see `preview.png` in output directory). @@ -146,8 +146,8 @@ from totalsegmentator.python_api import totalsegmentator if __name__ == "__main__": # specify the device - # 'gpu', 'cpu', 'mps', '1', ... - device = '1' ## cuda:1 + # 'gpu', 'cpu', 'mps', 'gpu:X' + device = 'gpu:1' ## cuda:1 # option 1: provide input and output as file paths totalsegmentator(input_path, output_path, device=device) diff --git a/tests/test_device_type.py b/tests/test_device_type.py new file mode 100644 index 000000000..9f4a78a17 --- /dev/null +++ b/tests/test_device_type.py @@ -0,0 +1,20 @@ +from totalsegmentator.bin.TotalSegmentator import validate_device_type +import unittest +import argparse +class TestValidateDeviceType(unittest.TestCase): + def test_valid_inputs(self): + self.assertEqual(validate_device_type("gpu"), "gpu") + self.assertEqual(validate_device_type("cpu"), "cpu") + self.assertEqual(validate_device_type("mps"), "mps") + self.assertEqual(validate_device_type("gpu:0"), "cuda:0") + self.assertEqual(validate_device_type("gpu:1"), "cuda:1") + + def test_invalid_inputs(self): + with self.assertRaises(argparse.ArgumentTypeError): + validate_device_type("invalid") + with self.assertRaises(argparse.ArgumentTypeError): + validate_device_type("gpu:invalid") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tests.sh b/tests/tests.sh index f5a573acc..b5f8cc2c9 100755 --- a/tests/tests.sh +++ b/tests/tests.sh @@ -3,10 +3,11 @@ set -e # To run these tests do # ./tests/tests.sh - +# Test device type selection function +pytest -v tests/test_device_type.py # Test multilabel prediction -TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d 1 +TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu diff --git a/totalsegmentator/bin/TotalSegmentator.py b/totalsegmentator/bin/TotalSegmentator.py index dab2c505f..dce1aab9f 100644 --- a/totalsegmentator/bin/TotalSegmentator.py +++ b/totalsegmentator/bin/TotalSegmentator.py @@ -4,10 +4,26 @@ import argparse from pkg_resources import require from pathlib import Path - +import re from totalsegmentator.python_api import totalsegmentator +def validate_device_type(value): + valid_strings = {"gpu", "cpu", "mps"} + if value in valid_strings: + return value + + # Check if the value matches the pattern "gpu:X" where X is an integer + pattern = r"^gpu:(\d+)$" + match = re.match(pattern, value) + if match: + device_id = int(match.group(1)) + return f"cuda:{device_id}" + + raise argparse.ArgumentTypeError( + f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.") + + def main(): parser = argparse.ArgumentParser(description="Segment 104 anatomical structures in CT images.", epilog="Written by Jakob Wasserthal. If you use this tool please cite https://pubs.rsna.org/doi/10.1148/ryai.230024") @@ -103,17 +119,16 @@ def main(): # "mps" is for apple silicon; the latest pytorch nightly version supports 3D Conv but not ConvTranspose3D which is # also needed by nnU-Net. So "mps" not working for now. # https://github.com/pytorch/pytorch/issues/77818 - - def validate_device_type(value): - valid_strings = {"gpu", "cpu", "mps"} - if value in valid_strings: - return value - if value.isdigit(): - return int(value) - raise argparse.ArgumentTypeError(f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or or an desired device ID (integer).") + # def validate_device_type(value): + # valid_strings = {"gpu", "cpu", "mps"} + # if value in valid_strings: + # return value + # if value.isdigit(): + # return int(value) + # raise argparse.ArgumentTypeError(f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or or an desired device ID (integer).") parser.add_argument("-d",'--device', type=validate_device_type, required=True, - help="Device type: 'GPU', 'CPU', 'MPS', or an desired device ID (integer).") + help="Device type: 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.") # parser.add_argument("-d", "--device", choices=["gpu", "cpu", "mps"], # help="Device to run on (default: gpu).", # default="gpu") diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index 421e0ae32..3d10978a0 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -68,13 +68,14 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa nora_tag = "None" if nora_tag is None else nora_tag - # available devices: gpu | cpu | mps | your desired device id (integer) + # available devices: gpu | cpu | mps | gpu:1, gpu:2, etc. if device == "gpu": device = "cuda" if device == "cuda" and not torch.cuda.is_available(): print("No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.") device = "cpu" - elif isinstance(device, int): - if device < torch.cuda.device_count(): + elif device.startswith('cuda:'): + device_id = int(device[5:]) + if device_id < torch.cuda.device_count(): device = torch.device(device) else: print("Invalid GPU config, running on the CPU") From 571e51646f3883468fd26e01d231149033092e8f Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Tue, 28 May 2024 23:04:31 -0700 Subject: [PATCH 6/9] Update syntax and add one testcase. --- tests/tests.sh | 3 --- totalsegmentator/bin/TotalSegmentator.py | 13 ------------- 2 files changed, 16 deletions(-) diff --git a/tests/tests.sh b/tests/tests.sh index b5f8cc2c9..299357b2c 100755 --- a/tests/tests.sh +++ b/tests/tests.sh @@ -10,9 +10,6 @@ pytest -v tests/test_device_type.py TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel -TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu -pytest -v tests/test_end_to_end.py::test_end_to_end::test_prediction_multilabel - # Test organ prediction - roi subset # 2 cpus: # example_ct_sm.nii.gz: 34s, 3.0GB diff --git a/totalsegmentator/bin/TotalSegmentator.py b/totalsegmentator/bin/TotalSegmentator.py index dce1aab9f..f8470c9e2 100644 --- a/totalsegmentator/bin/TotalSegmentator.py +++ b/totalsegmentator/bin/TotalSegmentator.py @@ -1,6 +1,4 @@ #!/usr/bin/env python -import sys -import os import argparse from pkg_resources import require from pathlib import Path @@ -119,19 +117,8 @@ def main(): # "mps" is for apple silicon; the latest pytorch nightly version supports 3D Conv but not ConvTranspose3D which is # also needed by nnU-Net. So "mps" not working for now. # https://github.com/pytorch/pytorch/issues/77818 - # def validate_device_type(value): - # valid_strings = {"gpu", "cpu", "mps"} - # if value in valid_strings: - # return value - # if value.isdigit(): - # return int(value) - # raise argparse.ArgumentTypeError(f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or or an desired device ID (integer).") - parser.add_argument("-d",'--device', type=validate_device_type, required=True, help="Device type: 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.") - # parser.add_argument("-d", "--device", choices=["gpu", "cpu", "mps"], - # help="Device to run on (default: gpu).", - # default="gpu") parser.add_argument("-q", "--quiet", action="store_true", help="Print no intermediate outputs", default=False) From f0b684c13715dedb5ba4e0a73375543a6e3349fa Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Wed, 29 May 2024 00:22:25 -0700 Subject: [PATCH 7/9] Add more test case. --- tests/test_device_type.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_device_type.py b/tests/test_device_type.py index 9f4a78a17..5faa8e7b6 100644 --- a/tests/test_device_type.py +++ b/tests/test_device_type.py @@ -14,6 +14,12 @@ def test_invalid_inputs(self): validate_device_type("invalid") with self.assertRaises(argparse.ArgumentTypeError): validate_device_type("gpu:invalid") + with self.assertRaises(argparse.ArgumentTypeError): + validate_device_type("gpu:-1") + with self.assertRaises(argparse.ArgumentTypeError): + validate_device_type("gpu:3.1415926") + with self.assertRaises(argparse.ArgumentTypeError): + validate_device_type("gpu:") if __name__ == "__main__": From acd6e01a42d45ef1e6435f4af28220feea9de5c7 Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Wed, 29 May 2024 15:35:27 -0700 Subject: [PATCH 8/9] resolve typo issue. --- totalsegmentator/python_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index 3d10978a0..f711fc37e 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -80,7 +80,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa else: print("Invalid GPU config, running on the CPU") device = "cpu" - print(f"Using Deivce: {device}") + print(f"Using Device: {device}") if not quiet: From f7fb9eb66b368ee010d98954fb0051aef00a876c Mon Sep 17 00:00:00 2001 From: hym97 <78656286+hym97@users.noreply.github.com> Date: Thu, 6 Jun 2024 18:05:43 -0700 Subject: [PATCH 9/9] resolve typo issue. --- totalsegmentator/python_api.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/totalsegmentator/python_api.py b/totalsegmentator/python_api.py index 3c552be36..44b5f717d 100644 --- a/totalsegmentator/python_api.py +++ b/totalsegmentator/python_api.py @@ -10,7 +10,6 @@ import nibabel as nib from nibabel.nifti1 import Nifti1Image import torch - from totalsegmentator.statistics import get_basic_statistics, get_radiomics_features_for_entire_dir from totalsegmentator.libs import download_pretrained_weights from totalsegmentator.config import setup_nnunet, setup_totalseg, increase_prediction_counter @@ -18,6 +17,21 @@ from totalsegmentator.config import get_config_key, set_config_key from totalsegmentator.map_to_binary import class_map from totalsegmentator.map_to_total import map_to_total +import re +def validate_device_type_api(value): + valid_strings = {"gpu", "cpu", "mps"} + if value in valid_strings: + return value + + # Check if the value matches the pattern "gpu:X" where X is an integer + pattern = r"^gpu:(\d+)$" + match = re.match(pattern, value) + if match: + device_id = int(match.group(1)) + return f"cuda:{device_id}" + + raise ValueError( + f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.") def show_license_info(): @@ -67,6 +81,7 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa nora_tag = "None" if nora_tag is None else nora_tag + device = validate_device_type_api(device) # available devices: gpu | cpu | mps | gpu:1, gpu:2, etc. if device == "gpu": device = "cuda"