diff --git a/scilpy/io/image.py b/scilpy/io/image.py index 727e9b4a2..f2739d38d 100644 --- a/scilpy/io/image.py +++ b/scilpy/io/image.py @@ -2,9 +2,35 @@ from dipy.io.utils import is_header_compatible import logging +import nibabel as nib import numpy as np import os +from scilpy.utils.util import is_float + + +def load_img(arg): + if is_float(arg): + img = float(arg) + dtype = np.float64 + else: + if not os.path.isfile(arg): + raise ValueError('Input file {} does not exist.'.format(arg)) + img = nib.load(arg) + shape = img.header.get_data_shape() + dtype = img.header.get_data_dtype() + logging.info('Loaded {} of shape {} and data_type {}.'.format( + arg, shape, dtype)) + + if len(shape) > 3: + logging.warning('{} has {} dimensions, be careful.'.format( + arg, len(shape))) + elif len(shape) < 3: + raise ValueError('{} has {} dimensions, not valid.'.format( + arg, len(shape))) + + return img, dtype + def merge_labels_into_mask(atlas, filtering_args): """ diff --git a/scripts/scil_image_math.py b/scripts/scil_image_math.py index 4819280c7..508380a66 100755 --- a/scripts/scil_image_math.py +++ b/scripts/scil_image_math.py @@ -22,6 +22,7 @@ import numpy as np from scilpy.image.volume_math import (get_image_ops, get_operations_doc) +from scilpy.io.image import load_img from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_outputs_exist) @@ -59,29 +60,6 @@ def _build_arg_parser(): return p -def load_img(arg): - if is_float(arg): - img = float(arg) - dtype = np.float64 - else: - if not os.path.isfile(arg): - raise ValueError('Input file {} does not exist.'.format(arg)) - img = nib.load(arg) - shape = img.header.get_data_shape() - dtype = img.header.get_data_dtype() - logging.info('Loaded {} of shape {} and data_type {}.'.format( - arg, shape, dtype)) - - if len(shape) > 3: - logging.warning('{} has {} dimensions, be careful.'.format( - arg, len(shape))) - elif len(shape) < 3: - raise ValueError('{} has {} dimensions, not valid.'.format( - arg, len(shape))) - - return img, dtype - - def main(): parser = _build_arg_parser() args = parser.parse_args()