Skip to content

Commit

Permalink
Merge pull request #311 from hym97/cuda_device_selection
Browse files Browse the repository at this point in the history
Cuda Device Selection
  • Loading branch information
wasserth authored Jun 13, 2024
2 parents 5d2a859 + f7fb9eb commit 71d3584
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 15 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ The mapping from label ID to class name can be found [here](https://github.com/w


### Advanced settings
* `--device`: Choose `cpu` or `gpu`
* `--device`: Choose `cpu` or `gpu` or `gpu:X (e.g., gpu:1 -> cuda:1)`
* `--fast`: For faster runtime and less memory requirements use this option. It will run a lower resolution model (3mm instead of 1.5mm).
* `--roi_subset`: Takes a space-separated list of class names (e.g. `spleen colon brain`) and only predicts those classes. Saves a lot of runtime and memory. Might be less accurate especially for small classes (e.g. prostate).
* `--preview`: This will generate a 3D rendering of all classes, giving you a quick overview if the segmentation worked and where it failed (see `preview.png` in output directory).
Expand Down Expand Up @@ -140,12 +140,16 @@ import nibabel as nib
from totalsegmentator.python_api import totalsegmentator

if __name__ == "__main__":
# specify the device
# 'gpu', 'cpu', 'mps', 'gpu:X'
device = 'gpu:1' ## cuda:1

# option 1: provide input and output as file paths
totalsegmentator(input_path, output_path)

totalsegmentator(input_path, output_path, device=device)
# option 2: provide input and output as nifti image objects
input_img = nib.load(input_path)
output_img = totalsegmentator(input_img)
output_img = totalsegmentator(input_img, device=device)
nib.save(output_img, output_path)
```
You can see all available arguments [here](https://github.com/wasserth/TotalSegmentator/blob/master/totalsegmentator/python_api.py). Running from within the main environment should avoid some multiprocessing issues.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_device_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from totalsegmentator.bin.TotalSegmentator import validate_device_type
import unittest
import argparse
class TestValidateDeviceType(unittest.TestCase):
def test_valid_inputs(self):
self.assertEqual(validate_device_type("gpu"), "gpu")
self.assertEqual(validate_device_type("cpu"), "cpu")
self.assertEqual(validate_device_type("mps"), "mps")
self.assertEqual(validate_device_type("gpu:0"), "cuda:0")
self.assertEqual(validate_device_type("gpu:1"), "cuda:1")

def test_invalid_inputs(self):
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("invalid")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:invalid")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:-1")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:3.1415926")
with self.assertRaises(argparse.ArgumentTypeError):
validate_device_type("gpu:")


if __name__ == "__main__":
unittest.main()
2 changes: 2 additions & 0 deletions tests/tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ set -e
# To run these tests do
# ./tests/tests.sh

# Test device type selection function
pytest -v tests/test_device_type.py

# Test - multilabel prediction
TotalSegmentator -i tests/reference_files/example_ct_sm.nii.gz -o tests/unittest_prediction.nii.gz -bs --ml -d cpu
Expand Down
25 changes: 19 additions & 6 deletions totalsegmentator/bin/TotalSegmentator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,27 @@
#!/usr/bin/env python
import sys
import os
import argparse
from pkg_resources import require
from pathlib import Path

import re
from totalsegmentator.python_api import totalsegmentator


def validate_device_type(value):
valid_strings = {"gpu", "cpu", "mps"}
if value in valid_strings:
return value

# Check if the value matches the pattern "gpu:X" where X is an integer
pattern = r"^gpu:(\d+)$"
match = re.match(pattern, value)
if match:
device_id = int(match.group(1))
return f"cuda:{device_id}"

raise argparse.ArgumentTypeError(
f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.")


def main():
parser = argparse.ArgumentParser(description="Segment 104 anatomical structures in CT images.",
epilog="Written by Jakob Wasserthal. If you use this tool please cite https://pubs.rsna.org/doi/10.1148/ryai.230024")
Expand Down Expand Up @@ -103,9 +117,8 @@ def main():
# "mps" is for apple silicon; the latest pytorch nightly version supports 3D Conv but not ConvTranspose3D which is
# also needed by nnU-Net. So "mps" not working for now.
# https://github.com/pytorch/pytorch/issues/77818
parser.add_argument("-d", "--device", choices=["gpu", "cpu", "mps"],
help="Device to run on (default: gpu).",
default="gpu")
parser.add_argument("-d",'--device', type=validate_device_type, required=True,
help="Device type: 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.")

parser.add_argument("-q", "--quiet", action="store_true", help="Print no intermediate outputs",
default=False)
Expand Down
4 changes: 3 additions & 1 deletion totalsegmentator/nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
model_folder = get_output_folder(task_id, trainer, plans, model)

assert device in ['cpu', 'cuda',
'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.'
'mps'] or isinstance(device, torch.device), f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {device}.'
if device == 'cpu':
# let's allow torch to use hella threads
import multiprocessing
Expand All @@ -181,6 +181,8 @@ def nnUNetv2_predict(dir_in, dir_out, task_id, model="3d_fullres", folds=None,
torch.set_num_threads(1)
# torch.set_num_interop_threads(1) # throws error if setting the second time
device = torch.device('cuda')
elif isinstance(device, torch.device):
device = device
else:
device = torch.device('mps')
disable_tta = not tta
Expand Down
33 changes: 29 additions & 4 deletions totalsegmentator/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,28 @@
import nibabel as nib
from nibabel.nifti1 import Nifti1Image
import torch

from totalsegmentator.statistics import get_basic_statistics, get_radiomics_features_for_entire_dir
from totalsegmentator.libs import download_pretrained_weights
from totalsegmentator.config import setup_nnunet, setup_totalseg, increase_prediction_counter
from totalsegmentator.config import send_usage_stats, set_license_number, has_valid_license_offline
from totalsegmentator.config import get_config_key, set_config_key
from totalsegmentator.map_to_binary import class_map
from totalsegmentator.map_to_total import map_to_total
import re
def validate_device_type_api(value):
valid_strings = {"gpu", "cpu", "mps"}
if value in valid_strings:
return value

# Check if the value matches the pattern "gpu:X" where X is an integer
pattern = r"^gpu:(\d+)$"
match = re.match(pattern, value)
if match:
device_id = int(match.group(1))
return f"cuda:{device_id}"

raise ValueError(
f"Invalid device type: '{value}'. Must be 'gpu', 'cpu', 'mps', or 'gpu:X' where X is an integer representing the GPU device ID.")


def show_license_info():
Expand Down Expand Up @@ -67,14 +81,25 @@ def totalsegmentator(input: Union[str, Path, Nifti1Image], output: Union[str, Pa

nora_tag = "None" if nora_tag is None else nora_tag

if not quiet:
print("\nIf you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024\n")
device = validate_device_type_api(device)

# available devices: gpu | cpu | mps
# available devices: gpu | cpu | mps | gpu:1, gpu:2, etc.
if device == "gpu": device = "cuda"
if device == "cuda" and not torch.cuda.is_available():
print("No GPU detected. Running on CPU. This can be very slow. The '--fast' or the `--roi_subset` option can help to reduce runtime.")
device = "cpu"
elif device.startswith('cuda:'):
device_id = int(device[5:])
if device_id < torch.cuda.device_count():
device = torch.device(device)
else:
print("Invalid GPU config, running on the CPU")
device = "cpu"
print(f"Using Device: {device}")


if not quiet:
print("\nIf you use this tool please cite: https://pubs.rsna.org/doi/10.1148/ryai.230024\n")

setup_nnunet()
setup_totalseg()
Expand Down

0 comments on commit 71d3584

Please sign in to comment.