Skip to content

Commit

Permalink
Merge pull request #686 from frheault/use_mask_as_whole_uniformize
Browse files Browse the repository at this point in the history
Switch to kdtree for uniformize --target
  • Loading branch information
arnaudbore authored Mar 7, 2023
2 parents 957dfc3 + 286c648 commit 711e389
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
30 changes: 18 additions & 12 deletions scilpy/utils/streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,30 +173,38 @@ def uniformize_bundle_sft(sft, axis=None, ref_bundle=None, swap=False):
sft.to_origin(old_origin)


def uniformize_bundle_sft_using_mask_barycenter(sft, mask, swap=False):
def uniformize_bundle_sft_using_mask(sft, mask, swap=False):
"""Uniformize the streamlines in the given tractogram so head is closer to
to the barycenter.
to a region of interest.
Parameters
----------
sft: StatefulTractogram
The tractogram that contains the list of streamlines to be uniformized
mask: np.ndarray
Mask to use as a reference for the barycenter.
Mask to use as a reference for the ROI.
swap: boolean, optional
Swap the orientation of streamlines
"""

barycenter = np.average(np.argwhere(mask), axis=0)
# barycenter = np.average(np.argwhere(mask), axis=0)
old_space = sft.space
old_origin = sft.origin
sft.to_vox()
sft.to_corner()

tree = cKDTree(np.argwhere(mask))
for i in range(len(sft.streamlines)):
if (np.linalg.norm(sft.streamlines[i][0] - barycenter) >
np.linalg.norm(sft.streamlines[i][-1] - barycenter)) ^ bool(swap):
head_dist = tree.query(sft.streamlines[i][0])[0]
tail_dist = tree.query(sft.streamlines[i][-1])[0]
if bool(head_dist > tail_dist) ^ bool(swap):
sft.streamlines[i] = sft.streamlines[i][::-1]
for key in sft.data_per_point[i]:
sft.data_per_point[key][i] = \
sft.data_per_point[key][i][::-1]

sft.to_space(old_space)
sft.to_origin(old_origin)

def get_color_streamlines_along_length(sft, colormap='jet'):
"""Color streamlines according to their length.
Expand Down Expand Up @@ -817,10 +825,9 @@ def cut_invalid_streamlines(sft):
return new_sft, cutting_counter


def upsample_tractogram(
sft, nb, point_wise_std=None,
streamline_wise_std=None, gaussian=None, spline=None, seed=None
):
def upsample_tractogram(sft, nb, point_wise_std=None,
streamline_wise_std=None, gaussian=None, spline=None,
seed=None):
"""
Generate new streamlines by either adding gaussian noise around
streamlines' points, or by translating copies of existing streamlines
Expand Down Expand Up @@ -870,8 +877,7 @@ def upsample_tractogram(
if point_wise_std:
noise = rng.normal(scale=point_wise_std, size=s.shape)
elif streamline_wise_std:
noise = rng.normal(
scale=streamline_wise_std, size=s.shape[-1])
noise = rng.normal(scale=streamline_wise_std, size=s.shape[-1])
new_s = s + noise
if gaussian:
new_s = smooth_line_gaussian(new_s, gaussian)
Expand Down
6 changes: 3 additions & 3 deletions scripts/scil_uniformize_streamlines_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
assert_inputs_exist)
from scilpy.segment.streamlines import filter_grid_roi
from scilpy.utils.streamlines import (uniformize_bundle_sft,
uniformize_bundle_sft_using_mask_barycenter)
uniformize_bundle_sft_using_mask)


def _build_arg_parser():
Expand Down Expand Up @@ -96,10 +96,10 @@ def main():
mask = atlas > 0
else:
mask = merge_labels_into_mask(atlas, " ".join(args.target_roi[1:]))

# Uncomment if the user wants to filter the streamlines
# sft, _ = filter_grid_roi(sft, mask, 'either_end', False)
uniformize_bundle_sft_using_mask_barycenter(sft, mask, swap=args.swap)
uniformize_bundle_sft_using_mask(sft, mask, swap=args.swap)

if args.axis:
uniformize_bundle_sft(sft, args.axis, swap=args.swap)
Expand Down

0 comments on commit 711e389

Please sign in to comment.