diff --git a/scilpy/image/tests/test_volume_operations.py b/scilpy/image/tests/test_volume_operations.py index d9d09a1d8..3772167e7 100644 --- a/scilpy/image/tests/test_volume_operations.py +++ b/scilpy/image/tests/test_volume_operations.py @@ -13,7 +13,8 @@ crop_volume, flip_volume, merge_metrics, normalize_metric, resample_volume, register_image, - mask_data_with_default_cube) + mask_data_with_default_cube, + compute_distance_map) from scilpy.io.fetcher import fetch_data, get_testing_files_dict from scilpy.image.utils import compute_nifti_bounding_box @@ -240,3 +241,60 @@ def test_mask_data_with_default_cube(): assert out[0, 0, 0] == 0 assert out[-1, -1, -1] == 0 assert out[6, 6, 6] == 1 + + +def test_distance_map_smallest_first(): + mask_1 = np.zeros((3, 3, 3)) + mask_1[0, 0, 0] = 1 + + mask_2 = np.zeros((3, 3, 3)) + mask_2[1:3, 1:3, 1:3] = 1 + + distance = compute_distance_map(mask_1, mask_2) + assert np.abs(np.sum(distance) - 1.732050) < 1e-6 + + +def test_compute_distance_map_biggest_first(): + # Swap both masks + mask_2 = np.zeros((3, 3, 3)) + mask_2[0, 0, 0] = 1 + + mask_1 = np.zeros((3, 3, 3)) + mask_1[1:3, 1:3, 1:3] = 1 + + distance = compute_distance_map(mask_1, mask_2) + assert np.abs(np.sum(distance) - 21.544621) < 1e-6 + + +def test_compute_distance_map_symmetric(): + mask_1 = np.zeros((3, 3, 3)) + mask_1[0, 0, 0] = 1 + + mask_2 = np.zeros((3, 3, 3)) + mask_2[1:3, 1:3, 1:3] = 1 + + distance = compute_distance_map(mask_1, mask_2, symmetric=True) + assert np.abs(np.sum(distance) - 23.276672) < 1e-6 + + +def test_compute_distance_map_overlap(): + mask_1 = np.zeros((3, 3, 3)) + mask_1[1, 1, 1] = 1 + + mask_2 = np.zeros((3, 3, 3)) + mask_2[1:3, 1:3, 1:3] = 1 + + distance = compute_distance_map(mask_1, mask_2) + assert np.all(distance == 0) + + +def test_compute_distance_map_wrong_shape(): + mask_1 = np.zeros((3, 3, 3)) + mask_2 = np.zeros((3, 3, 4)) + + # Different shapes, test should fail + try: + compute_distance_map(mask_1, mask_2) + assert False + except ValueError: + assert True diff --git a/scilpy/image/volume_operations.py b/scilpy/image/volume_operations.py index e2e120df1..56b2731e1 100644 --- a/scilpy/image/volume_operations.py +++ b/scilpy/image/volume_operations.py @@ -16,6 +16,7 @@ import numpy as np from numpy import ma from scipy.ndimage import binary_dilation, gaussian_filter +from scipy.spatial import KDTree from sklearn import linear_model from scilpy.image.reslice import reslice # Don't use Dipy's reslice. Buggy. @@ -688,3 +689,50 @@ def merge_metrics(*arrays, beta=1.0): boosted_mean = geometric_mean ** beta return ma.filled(boosted_mean, fill_value=np.nan) + + +def compute_distance_map(mask_1, mask_2, symmetric=False, + max_distance=np.inf): + """ + Compute the distance map between two binary masks. + The distance is computed using the Euclidean distance between the + first mask and the closest point in the second mask. + + Use the symmetric flag to compute the distance map in both directions. + + WARNING: This function will work even if inputs are not binary masks, + just make sure that you know what you are doing. + + Parameters + ---------- + mask_1: np.ndarray + First binary mask. + mask_2: np.ndarray + Second binary mask. + symmetric: bool, optional + If True, compute the symmetric distance map. Default is np.inf + max_distance: float, optional + Maximum distance to consider for kdtree exploration. Default is None. + + Returns + ------- + distance_map: np.ndarray + Distance map between the two masks. + """ + if mask_1.shape != mask_2.shape: + raise ValueError("Masks must have the same shape.") + + tree = KDTree(np.argwhere(mask_2)) + distance_map = np.zeros(mask_1.shape) + distance = tree.query(np.argwhere(mask_1), + distance_upper_bound=max_distance)[0] + distance_map[np.where(mask_1)] = distance + + if symmetric: + # Compute the symmetric distance map and merge it with the previous one + tree = KDTree(np.argwhere(mask_1)) + distance = tree.query(np.argwhere(mask_2), + distance_upper_bound=max_distance)[0] + distance_map[np.where(mask_2)] = distance + + return distance_map diff --git a/scripts/scil_volume_distance_map.py b/scripts/scil_volume_distance_map.py new file mode 100755 index 000000000..4b14918c2 --- /dev/null +++ b/scripts/scil_volume_distance_map.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Compute distance map between two binary masks. The distance map is the +Euclidean distance from each voxel of the first mask to the closest +voxel of the second mask. + +Slowest scenarios are 1) two very large masks that are far appart or 2) a very +small mask completely inside a very large mask (around 20-30 seconds). + +Take this command as an example: + scil_volume_distance_map.py brain_mask.nii.gz AF_L.nii.gz \ + AF_L_to_brain_mask.nii.gz + +We have a brain mask and a bundle, the second is 100% inside the first. +The output will be a distance map from the brain mask to the bundle. + +If we take the bundle as the first input and the brain mask as the second, +The output will be a distance map from the bundle to the brain mask, which +will be all zeros (because the bundle is fully inside the brain mask). + +If you want both distance maps at once, you can use the --symmetric_distance +option. +""" + +import argparse +import logging + +import nibabel as nib +import numpy as np + +from scilpy.image.volume_operations import compute_distance_map +from scilpy.io.image import get_data_as_mask +from scilpy.io.utils import add_overwrite_arg, add_verbose_arg, \ + assert_headers_compatible, assert_inputs_exist, assert_outputs_exist + + +def _build_arg_parser(): + p = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawTextHelpFormatter) + p.add_argument('in_mask_1', metavar='IN_SOURCE', + help='Input file name, in nifti format.') + p.add_argument('in_mask_2', metavar='IN_TARGET', + help='Input file name, in nifti format.') + p.add_argument('out_distance', metavar='OUT_DISTANCE_MAP', + help='Input file name, in nifti format.') + + p.add_argument('--symmetric_distance', action='store_true', + help='Compute the distance from mask 1 to mask 2 and the ' + 'distance from mask 2 to mask 1 and sum them up.') + add_verbose_arg(p) + add_overwrite_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + assert_inputs_exist(parser, [args.in_mask_1, args.in_mask_2]) + assert_outputs_exist(parser, args, args.out_distance) + assert_headers_compatible(parser, [args.in_mask_1, args.in_mask_2]) + + img_1 = nib.load(args.in_mask_1) + img_2 = nib.load(args.in_mask_2) + + mask_1 = get_data_as_mask(img_1) + mask_2 = get_data_as_mask(img_2) + logging.debug(f'Loaded two masks with {np.count_nonzero(mask_1)} and ' + f'{np.count_nonzero(mask_2)} voxels') + + # Compute distance map using KDTree + distance_map = compute_distance_map(mask_1, mask_2, + args.symmetric_distance) + + out_img = nib.Nifti1Image(distance_map.astype(float), img_1.affine) + nib.save(out_img, args.out_distance) + + +if __name__ == "__main__": + main() diff --git a/scripts/tests/test_volume_distance_map.py b/scripts/tests/test_volume_distance_map.py new file mode 100644 index 000000000..ed5802019 --- /dev/null +++ b/scripts/tests/test_volume_distance_map.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +import nibabel as nib + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['tractograms.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run('scil_volume_distance_map.py', '--help') + assert ret.success + + +def test_execution(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_mask_1 = os.path.join(SCILPY_HOME, 'tractograms', + 'streamline_and_mask_operations', + 'bundle_4_head_tail.nii.gz') + in_mask_2 = os.path.join(SCILPY_HOME, 'tractograms', + 'streamline_and_mask_operations', + 'bundle_4_center.nii.gz') + ret = script_runner.run('scil_volume_distance_map.py', + in_mask_1, in_mask_2, + 'distance_map.nii.gz') + + img = nib.load('distance_map.nii.gz') + data = img.get_fdata() + assert data[data > 0].mean() - 17.7777 < 0.0001 + assert ret.success