Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda Device Selection #311

Merged
merged 11 commits into from
Jun 13, 2024
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ The mapping from label ID to class name can be found [here](https://github.com/w


### Advanced settings
* `--device`: Choose `cpu` or `gpu`
* `--device`: Choose `cpu` or `gpu` or `gpu:X (e.g., gpu:1 -> cuda:1)`
* `--fast`: For faster runtime and less memory requirements use this option. It will run a lower resolution model (3mm instead of 1.5mm).
* `--roi_subset`: Takes a space-separated list of class names (e.g. `spleen colon brain`) and only predicts those classes. Saves a lot of runtime and memory. Might be less accurate especially for small classes (e.g. prostate).
* `--preview`: This will generate a 3D rendering of all classes, giving you a quick overview if the segmentation worked and where it failed (see `preview.png` in output directory).
Expand Down Expand Up @@ -140,12 +140,16 @@ import nibabel as nib
from totalsegmentator.python_api import totalsegmentator

if __name__ == "__main__":
# specify the device
# 'gpu', 'cpu', 'mps', 'gpu:X'
device = 'gpu:1' ## cuda:1

# option 1: provide input and output as file paths
totalsegmentator(input_path, output_path)

totalsegmentator(input_path, output_path, device=device)
# option 2: provide input and output as nifti image objects
input_img = nib.load(input_path)
output_img = totalsegmentator(input_img)
output_img = totalsegmentator(input_img, device=device)
nib.save(output_img, output_path)
```
You can see all available arguments [here](https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/python_api.py). Running from within the main environment should avoid some multiprocessing issues.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_device_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from totalsegmentator.bin.TotalSegmentator import validate_device_type
import unittest
import argparse
class TestValidateDeviceType(unittest.TestCase):
def test_valid_inputs(self):
self.assertEqual(validate_device_type("gpu"), "gpu")
self.assertEqual(validate_device_type("cpu"), "cpu")
self.assertEqual(validate_device_type("mps"), "mps")
self.assertEqual(validate_device_type("gpu:0"), "cuda:0")
self.assertEqual(validate_device_type("gpu:1"), "cuda:1")

def test_invalid_inputs(self):
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("invalid")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:invalid")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:-1")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:3.1415926")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:")


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tests/tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ set -e
# To run these tests do
# ./tests/tests.sh

# Test device type selection function
pytest -v tests/test_device_type.py

# Test - multilabel prediction
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu
Expand Down
25 changes: 19 additions & 6 deletions totalsegmentator/bin/TotalSegmentator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
#!/usr/bin/env python
import sys
import os
import argparse
from pkg_resources import require
from pathlib import Path

import re
from totalsegmentator.python_api import totalsegmentator


def validate_device_type(value):
valid_strings = {"gpu", "cpu", "mps"}
if value in valid_strings:
return value

# Check if the value matches the pattern "gpu:X" where X is an integer
pattern = r"^gpu:(\d+)$"
match = re.match(pattern, value)
if match:
device_id = int(match.group(1))
return f"cuda:{device_id}"

raise argparse.ArgumentTypeError(
f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.")


def main():
parser = argparse.ArgumentParser(description="Segment 104 anatomical structures in CT images.",
epilog="Written by Jakob Wasserthal. If you use this tool please cite https://pubs.rsna.org/doi/10.1148/ryai.230024")
Expand Down Expand Up @@ -103,9 +117,8 @@ def main():
# "mps" is for apple silicon; the latest pytorch nightly version supports 3D Conv but not ConvTranspose3D which is
# also needed by nnU-Net. So "mps" not working for now.
# https://github.com/pytorch/pytorch/issues/77818
parser.add_argument("-d", "--device", choices=["gpu", "cpu", "mps"],
help="Device to run on (default: gpu).",
default="gpu")
parser.add_argument("-d",'--device', type=validate_device_type, required=True,
help="Device type: 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.")

parser.add_argument("-q", "--quiet", action="store_true", help="Print no intermediate outputs",
default=False)
Expand Down
4 changes: 3 additions & 1 deletion totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
model_folder = get_output_folder(task_id, trainer, plans, model)

assert device in ['cpu', 'cuda',
'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.'
'mps'] or isinstance(device, torch.device), f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.'
if device == 'cpu':
# let's allow torch to use hella threads
import multiprocessing
Expand All @@ -181,6 +181,8 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
torch.set_num_threads(1)
# torch.set_num_interop_threads(1) # throws error if setting the second time
device = torch.device('cuda')
elif isinstance(device, torch.device):
device = device
else:
device = torch.device('mps')
disable_tta = not tta
Expand Down
33 changes: 29 additions & 4 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,28 @@
import nibabel as nib
from nibabel.nifti1 import Nifti1Image
import torch

from totalsegmentator.statistics import get_basic_statistics, get_radiomics_features_for_entire_dir
from totalsegmentator.libs import download_pretrained_weights
from totalsegmentator.config import setup_nnunet, setup_totalseg, increase_prediction_counter
from totalsegmentator.config import send_usage_stats, set_license_number, has_valid_license_offline
from totalsegmentator.config import get_config_key, set_config_key
from totalsegmentator.map_to_binary import class_map
from totalsegmentator.map_to_total import map_to_total
import re
def validate_device_type_api(value):
valid_strings = {"gpu", "cpu", "mps"}
if value in valid_strings:
return value

# Check if the value matches the pattern "gpu:X" where X is an integer
pattern = r"^gpu:(\d+)$"
match = re.match(pattern, value)
if match:
device_id = int(match.group(1))
return f"cuda:{device_id}"

raise ValueError(
f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.")


def show_license_info():
Expand Down Expand Up @@ -67,14 +81,25 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa

nora_tag = "None" if nora_tag is None else nora_tag

if not quiet:
print("\nIf you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024\n")
device = validate_device_type_api(device)

# available devices: gpu | cpu | mps
# available devices: gpu | cpu | mps | gpu:1, gpu:2, etc.
if device == "gpu": device = "cuda"
if device == "cuda" and not torch.cuda.is_available():
print("No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.")
device = "cpu"
elif device.startswith('cuda:'):
device_id = int(device[5:])
if device_id < torch.cuda.device_count():
device = torch.device(device)
else:
print("Invalid GPU config, running on the CPU")
device = "cpu"
print(f"Using Device: {device}")


if not quiet:
print("\nIf you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024\n")

setup_nnunet()
setup_totalseg()
Expand Down