Skip to content

Commit

Permalink
Merge pull request #719 from arnaudbore/enh_scil_filtering
Browse files Browse the repository at this point in the history
[ENH] Add soft distance for each filter
  • Loading branch information
arnaudbore authored May 23, 2023
2 parents 0a8b5fa + 28f2413 commit 06a2715
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 54 deletions.
11 changes: 9 additions & 2 deletions scilpy/segment/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from dipy.tracking.streamline import set_number_of_points
from dipy.tracking.vox2track import _streamlines_in_mask
from nibabel.affines import apply_affine
from scipy.ndimage import map_coordinates
from scipy.ndimage import (map_coordinates, generate_binary_structure,
binary_dilation)

import numpy as np

Expand Down Expand Up @@ -104,7 +105,7 @@ def filter_grid_roi_both(sft, mask_1, mask_2):
return new_sft, line_based_indices


def filter_grid_roi(sft, mask, filter_type, is_exclude):
def filter_grid_roi(sft, mask, filter_type, is_exclude, filter_distance=0):
"""
Parameters
----------
Expand All @@ -123,6 +124,12 @@ def filter_grid_roi(sft, mask, filter_type, is_exclude):
ids: list
Ids of the streamlines passing through the mask.
"""

if filter_distance != 0:
bin_struct = generate_binary_structure(3, 2)
mask = binary_dilation(mask, bin_struct,
iterations=filter_distance)

line_based_indices = []
if filter_type in ['any', 'all']:
line_based_indices = streamlines_in_mask(sft, mask,
Expand Down
120 changes: 68 additions & 52 deletions scripts/scil_filter_tractogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@

"""
Now supports sequential filtering condition and mixed filtering object.
For example, --atlas_roi ROI_NAME ID MODE CRITERIA
For example, --atlas_roi ROI_NAME ID MODE CRITERIA DISTANCE
- ROI_NAME is the filename of a Nifti
- ID is one or multiple integer values in the atlas. If multiple values,
ID needs to be between quotes.
Example: "1:6 9 10:15" will use values between 1 and 6 and values
between 10 and 15 included as well as value 9.
- MODE must be one of these values: ['any', 'all', 'either_end', 'both_ends']
- CRITERIA must be one of these values: ['include', 'exclude']
- DISTANCE must be a int and is optional
If any meant any part of the streamline must be in the mask, all means that
all part of the streamline must be in the mask.
Expand All @@ -28,8 +29,9 @@
A logical AND is the only behavior available. All theses filtering
conditions will be sequentially applied.
WARNING: --soft_distance should be used carefully with large voxel size
(e.g > 2.5mm).
WARNING: DISTANCE is optional and it should be used carefully with large
voxel size (e.g > 2.5mm). The value is in voxel for ROIs and in mm for
bounding box. Anisotropic data will affect each direction differently
"""

import argparse
Expand All @@ -43,7 +45,6 @@
from dipy.io.utils import is_header_compatible
import nibabel as nib
import numpy as np
from scipy import ndimage

from scilpy.io.image import (get_data_as_mask,
merge_labels_into_mask)
Expand All @@ -69,34 +70,35 @@ def _build_arg_parser():
p.add_argument('out_tractogram',
help='Path of the output tractogram file.')

p.add_argument('--drawn_roi', nargs=3, action='append',
metavar=('ROI_NAME', 'MODE', 'CRITERIA'),
help='Filename of a hand drawn ROI (.nii or .nii.gz).')
p.add_argument('--atlas_roi', nargs=4, action='append',
metavar=('ROI_NAME', 'ID', 'MODE', 'CRITERIA'),
help='Filename of an atlas (.nii or .nii.gz).')
p.add_argument('--bdo', nargs=3, action='append',
metavar=('BDO_NAME', 'MODE', 'CRITERIA'),
help='Filename of a bounding box (bdo) file from MI-Brain.')

p.add_argument('--x_plane', nargs=3, action='append',
metavar=('PLANE', 'MODE', 'CRITERIA'),
help='Slice number in X, in voxel space.')
p.add_argument('--y_plane', nargs=3, action='append',
metavar=('PLANE', 'MODE', 'CRITERIA'),
help='Slice number in Y, in voxel space.')
p.add_argument('--z_plane', nargs=3, action='append',
metavar=('PLANE', 'MODE', 'CRITERIA'),
help='Slice number in Z, in voxel space.')
p.add_argument('--drawn_roi', nargs='+', action='append',
help="ROI_NAME MODE CRITERIA DISTANCE "
"(distance in voxel is optional)\n"
"Filename of a hand drawn ROI (.nii or .nii.gz).")
p.add_argument('--atlas_roi', nargs='+', action='append',
help="ROI_NAME ID MODE CRITERIA DISTANCE "
"(distance in voxel is optional)\n"
"Filename of an atlas (.nii or .nii.gz).")
p.add_argument('--bdo', nargs='+', action='append',
help="BDO_NAME MODE CRITERIA DISTANCE "
"(distance in mm is optional)\n"
"Filename of a bounding box (bdo) file from MI-Brain.")

p.add_argument('--x_plane', nargs='+', action='append',
help="PLANE MODE CRITERIA DISTANCE "
"(distance in voxel is optional)\n"
"Slice number in X, in voxel space.")
p.add_argument('--y_plane', nargs='+', action='append',
help="PLANE MODE CRITERIA DISTANCE "
"(distance in voxel is optional)\n"
"Slice number in Y, in voxel space.")
p.add_argument('--z_plane', nargs='+', action='append',
help="PLANE MODE CRITERIA DISTANCE "
"(distance in voxel is optional)\n"
"Slice number in Z, in voxel space.")
p.add_argument('--filtering_list',
help='Text file containing one rule per line\n'
'(i.e. drawn_roi mask.nii.gz both_ends include).')
'(i.e. drawn_roi mask.nii.gz both_ends include 1).')

p.add_argument('--soft_distance', type=int,
help='All ROIs are enlarged by the specified value.\n'
'The value is in voxel (NOT mm).\n'
'Anisotropic data will affect each direction '
'differently')
p.add_argument('--extract_masks_atlas_roi', action='store_true',
help='Extract atlas roi masks.')
p.add_argument('--no_empty', action='store_true',
Expand All @@ -118,13 +120,6 @@ def prepare_filtering_list(parser, args):
roi_opt_list = []
only_filtering_list = True

if args.soft_distance is not None:
if args.soft_distance < 1:
parser.error('The minimum soft distance is 1 voxel.')
elif args.soft_distance > 5:
logging.warning('Soft distance above 5 voxels leads to weird'
' results.')

if args.drawn_roi:
only_filtering_list = False
for roi_opt in args.drawn_roi:
Expand Down Expand Up @@ -160,11 +155,27 @@ def prepare_filtering_list(parser, args):
else:
roi_opt_list.append(roi_opt.strip().split())

for roi_opt in roi_opt_list:
if (len(roi_opt_list[-1]) < 4 or len(roi_opt_list) > 5) and roi_opt_list[-1][0] != 'atlas_roi':
logging.error("Please specify 3 or 4 values "
"for {} filtering.".format(roi_opt_list[-1][0]))
elif (len(roi_opt_list[-1]) < 5 or len(roi_opt_list) > 6) and roi_opt_list[-1][0] == 'atlas_roi':
logging.error("Please specify 4 or 5 values"
" for {} filtering.".format(roi_opt_list[-1][0]))

filter_distance = 0
for index, roi_opt in enumerate(roi_opt_list):
if roi_opt[0] == 'atlas_roi':
filter_type, filter_arg, _, filter_mode, filter_criteria = roi_opt
else:
if len(roi_opt) == 5:
filter_type, filter_arg, _, filter_mode, filter_criteria = roi_opt
roi_opt_list[index].append(0)
else:
filter_type, filter_arg, _, filter_mode, filter_criteria, filter_distance = roi_opt
elif len(roi_opt) == 4:
filter_type, filter_arg, filter_mode, filter_criteria = roi_opt
roi_opt_list[index].append(0)
else:
filter_type, filter_arg, filter_mode, filter_criteria, filter_distance = roi_opt

if filter_type not in ['x_plane', 'y_plane', 'z_plane']:
if not os.path.isfile(filter_arg):
parser.error('{} does not exist'.format(filter_arg))
Expand All @@ -175,6 +186,10 @@ def prepare_filtering_list(parser, args):
parser.error('{} is not a valid option for filter_criteria'.format(
filter_criteria))

if int(filter_distance) < 0:
parser.error("Distance should be positive. "
"{} is not a valid option.".format(filter_distance))

return roi_opt_list, only_filtering_list


Expand All @@ -184,6 +199,7 @@ def main():

assert_inputs_exist(parser, args.in_tractogram)
assert_outputs_exist(parser, args, args.out_tractogram, args.save_rejected)

if args.verbose:
logging.getLogger().setLevel(logging.DEBUG)
set_sft_logger_level('WARNING')
Expand All @@ -195,7 +211,6 @@ def main():
sft = load_tractogram_with_reference(parser, args, args.in_tractogram)
if args.save_rejected:
initial_sft = deepcopy(sft)
bin_struct = ndimage.generate_binary_structure(3, 2)

# Streamline count before filtering
o_dict['streamline_count_before_filtering'] = len(sft.streamlines)
Expand All @@ -209,14 +224,17 @@ def main():
# Atlas needs an extra argument (value in the LUT)
if roi_opt[0] == 'atlas_roi':
filter_type, filter_arg, filter_arg_2, \
filter_mode, filter_criteria = roi_opt
filter_mode, filter_criteria, filter_distance = roi_opt
else:
filter_type, filter_arg, filter_mode, filter_criteria = roi_opt
filter_type, filter_arg, filter_mode, filter_criteria, filter_distance = roi_opt

curr_dict['filename'] = os.path.abspath(filter_arg)
curr_dict['type'] = filter_type
curr_dict['mode'] = filter_mode
curr_dict['criteria'] = filter_criteria
curr_dict['distance'] = filter_distance

filter_distance = int(filter_distance)

is_exclude = False if filter_criteria == 'include' else True

Expand All @@ -236,11 +254,9 @@ def main():
nib.Nifti1Image(mask.astype(np.uint16),
img.affine).to_filename('mask_atlas_roi_{}.nii.gz'.format(str(atlas_roi_item)))

if args.soft_distance is not None:
mask = ndimage.binary_dilation(mask, bin_struct,
iterations=args.soft_distance)
filtered_sft, kept_ids = filter_grid_roi(sft, mask,
filter_mode, is_exclude)
filter_mode, is_exclude,
filter_distance)

# For every case, the input number must be greater or equal to 0 and
# below the dimension, since this is a voxel space operation
Expand Down Expand Up @@ -271,16 +287,16 @@ def main():
parser.error('{} is not valid according to the '
'tractogram header.'.format(error_msg))

if args.soft_distance is not None:
mask = ndimage.binary_dilation(mask, bin_struct,
iterations=args.soft_distance)
filtered_sft, kept_ids = filter_grid_roi(sft, mask,
filter_mode, is_exclude)
filter_mode, is_exclude,
filter_distance)

elif filter_type == 'bdo':
geometry, radius, center = read_info_from_mb_bdo(filter_arg)
if args.soft_distance is not None:
radius += args.soft_distance * sft.space_attributes[2]

if filter_distance != 0:
radius += filter_distance * sft.space_attributes[2]

if geometry == 'Ellipsoid':
filtered_sft, kept_ids = filter_ellipsoid(
sft, radius, center, filter_mode, is_exclude)
Expand Down

0 comments on commit 06a2715

Please sign in to comment.