From 8a200c4e251cefb87a8e511e35a6a230ad26bc94 Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 23 Oct 2023 09:11:47 -0400 Subject: [PATCH 01/18] Transfer script --- scripts/scil_compare_tractograms.py | 324 ++++++++++++++++++++++++++++ 1 file changed, 324 insertions(+) create mode 100644 scripts/scil_compare_tractograms.py diff --git a/scripts/scil_compare_tractograms.py b/scripts/scil_compare_tractograms.py new file mode 100644 index 000000000..00e74fc45 --- /dev/null +++ b/scripts/scil_compare_tractograms.py @@ -0,0 +1,324 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" + +""" + +import argparse +from concurrent.futures import ProcessPoolExecutor, as_completed +from copy import deepcopy +import logging +import os +from tqdm import tqdm +import warnings + +from dipy.data import get_sphere +from dipy.reconst.shm import sh_to_sf_matrix +from dipy.segment.fss import FastStreamlineSearch +import nibabel as nib +import numpy as np +import numpy.ma as ma +from scipy.spatial import cKDTree + +from scilpy.image.volume_math import correlation +from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, + assert_inputs_exist, + assert_outputs_exist, + assert_output_dirs_exist_and_empty, + add_processes_arg, + add_verbose_arg, + is_header_compatible_multiple_files, + load_tractogram_with_reference, + validate_nbr_processes) +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from scilpy.tractanalysis.todi import TrackOrientationDensityImaging +from scilpy.tractograms.streamline_operations import resample_streamlines_step_size + +def _build_arg_parser(): + + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawTextHelpFormatter) + + p.add_argument('in_tractogram_1', + help='Input tractogram 1.') + p.add_argument('in_tractogram_2', + help='Input tractogram 2.') + + p.add_argument('--out_dir', default='', + help='Directory where all output files will be saved. ' + '\nIf not specified, outputs will be saved in the current ' + 'directory.') + p.add_argument('--out_prefix', default='out', + help='Prefix for output files. Useful for distinguishing between ' + 'different runs.') + + p.add_argument('--in_mask', metavar='IN_FILE', + help='Optional input mask.') + + add_processes_arg(p) + add_reference_arg(p) + add_verbose_arg(p) + add_overwrite_arg(p) + + return p + + +def generate_matched_points(sft): + """ + Generate an array where each element i is set to the index of the streamline + that contributes the ith point. + + Parameters: + ----------- + sft : StatefulTractogram + The stateful tractogram containing the streamlines. + + Returns: + -------- + matched_points : ndarray + An array where each element is set to the index of the streamline + that contributes that point. + """ + total_points = sft.streamlines._data.shape[0] + offsets = sft.streamlines._offsets + + matched_points = np.zeros(total_points, dtype=np.uint64) + + for i in range(len(offsets) - 1): + matched_points[offsets[i]:offsets[i+1]] = i + + matched_points[offsets[-1]:] = len(offsets) - 1 + + return matched_points + + +def compute_difference_for_voxel(chunk_indices): + """ + Compute the difference for a single voxel index. + """ + global sft_1, sft_2, matched_points_1, matched_points_2, tree_1, tree_2, \ + sh_data_1, sh_data_2 + results = [] + for vox_ind in chunk_indices: + vox_ind = tuple(vox_ind) + + # Get the streamlines in the neighborhood (i.e., 2mm away) + pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) + if not pts_ind_1: + results.append([-1, -1]) + continue + strs_ind_1 = np.unique(matched_points_1[pts_ind_1]) + neighb_streamlines_1 = sft_1.streamlines[strs_ind_1] + + # Get the streamlines in the neighborhood (i.e., 1mm away) + pts_ind_2 = tree_2.query_ball_point(vox_ind, 1.5) + if not pts_ind_2: + results.append([-1, -1]) + continue + strs_ind_2 = np.unique(matched_points_2[pts_ind_2]) + neighb_streamlines_2 = sft_2.streamlines[strs_ind_2] + + with warnings.catch_warnings(record=True) as _: + fss = FastStreamlineSearch(neighb_streamlines_1, 10, resampling=12) + dist_mat = fss.radius_search(neighb_streamlines_2, 10) + sparse_dist_mat = np.abs(dist_mat.tocsr()).toarray() + sparse_ma_dist_mat = np.ma.masked_where(sparse_dist_mat < 1e-3, + sparse_dist_mat) + sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, + axis=0)) + + if np.any(sparse_ma_dist_vec): + global B + sf_1 = np.dot(sh_data_1[vox_ind], B) + sf_2 = np.dot(sh_data_2[vox_ind], B) + dist = np.average(sparse_ma_dist_vec) + corr = np.corrcoef(sf_1, sf_2)[0, 1] + results.append([dist, corr]) + else: + results.append([-1, -1]) + + return results + + +def normalize_metric(metric, reverse=False): + """ + Normalize a metric to be in the range [0, 1], ignoring specified values. + """ + mask = np.isnan(metric) + masked_metric = ma.masked_array(metric, mask) + + min_val, max_val = masked_metric.min(), masked_metric.max() + normalized_metric = (masked_metric - min_val) / (max_val - min_val) + + if reverse: + normalized_metric = 1 - normalized_metric + + return ma.filled(normalized_metric, fill_value=np.nan) + + +def merge_metrics(acc, corr, diff, beta=1.0): + """ + Merge the three metrics into a single heatmap using a weighted geometric mean, + ignoring specified values. + """ + mask = np.isnan(acc) | np.isnan(corr) | np.isnan(diff) + masked_acc, masked_corr, masked_diff = [ma.masked_array(x, mask) for x in [acc, corr, diff]] + + # Calculate the geometric mean for valid data + geometric_mean = np.cbrt(masked_acc * masked_corr * masked_diff) + + # Apply a boosting factor + boosted_mean = geometric_mean ** beta + + return ma.filled(boosted_mean, fill_value=np.nan) + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + + if args.verbose: + logging.basicConfig(level=logging.INFO) + + assert_inputs_exist(parser, [args.in_tractogram_1, + args.in_tractogram_2]) + is_header_compatible_multiple_files(parser, [args.in_tractogram_1, + args.in_tractogram_2], + verbose_all_compatible=True, + reference=args.reference) + + if args.out_prefix and args.out_prefix[-1] == '_': + args.out_prefix = args.out_prefix[:-1] + out_corr_filename = os.path.join(args.out_dir, + '{}_correlation.nii.gz'.format(args.out_prefix)) + out_acc_filename = os.path.join(args.out_dir, + '{}_acc.nii.gz'.format(args.out_prefix)) + out_diff_filename = os.path.join(args.out_dir, + '{}_diff.nii.gz'.format(args.out_prefix)) + out_merge_filename = os.path.join(args.out_dir, + '{}_heatmap.nii.gz'.format(args.out_prefix)) + assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir) + assert_outputs_exist(parser, args, [out_corr_filename, + out_acc_filename, + out_diff_filename, + out_merge_filename]) + nbr_cpu = validate_nbr_processes(parser, args) + + logging.info('Loading tractograms...') + global sft_1, sft_2 + sft_1 = load_tractogram_with_reference(parser, args, args.in_tractogram_1) + #sft_1 = resample_streamlines_step_size(sft_1, 0.5) + sft_2 = load_tractogram_with_reference(parser, args, args.in_tractogram_2) + #sft_2 = resample_streamlines_step_size(sft_2, 0.5) + sft_1.to_vox() + sft_2.to_vox() + sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) + sft_2.streamlines._data = sft_2.streamlines._data.astype(np.float16) + affine, dimensions = sft_1.affine, sft_1.dimensions + + global matched_points_1, matched_points_2 + matched_points_1 = generate_matched_points(sft_1) + matched_points_2 = generate_matched_points(sft_2) + + logging.info('Computing KDTree...') + global tree_1, tree_2 + tree_1 = cKDTree(sft_1.streamlines._data) + tree_2 = cKDTree(sft_2.streamlines._data) + + # Limits computation to mask AND streamlines (using density) + if args.in_mask: + mask = nib.load(args.in_mask).get_fdata() + else: + mask = np.ones(dimensions) + + logging.info('Computing density maps...') + density_1 = compute_tract_counts_map(sft_1.streamlines, + dimensions).astype(float) + density_2 = compute_tract_counts_map(sft_2.streamlines, + dimensions).astype(float) + mask = density_1 * density_2 * mask + mask[mask > 0] = 1 + + logging.info('Computing correlation map...') + corr_data = correlation([density_1, density_2], None) * mask + nib.save(nib.Nifti1Image(corr_data, affine), out_corr_filename) + + logging.info('Computing TODI #1...') + global sh_data_1, sh_data_2 + sft_1.to_corner() + todi_obj = TrackOrientationDensityImaging(tuple(dimensions), 'repulsion724') + todi_obj.compute_todi(deepcopy(sft_1.streamlines), length_weights=True) + todi_obj.mask_todi(mask) + sh_data_1 = todi_obj.get_sh('descoteaux07', 8) + sh_data_1 = todi_obj.reshape_to_3d(sh_data_1) + sft_1.to_center() + + logging.info('Computing TODI #2...') + sft_2.to_corner() + todi_obj = TrackOrientationDensityImaging(tuple(dimensions), 'repulsion724') + todi_obj.compute_todi(deepcopy(sft_2.streamlines), length_weights=True) + todi_obj.mask_todi(mask) + sh_data_2 = todi_obj.get_sh('descoteaux07', 8) + sh_data_2 = todi_obj.reshape_to_3d(sh_data_2) + sft_2.to_center() + + global B + B, _ = sh_to_sf_matrix(get_sphere('repulsion724'), 8, 'descoteaux07') + + # Initialize multiprocessing + indices = np.argwhere(mask > 0) + diff_data = np.zeros(dimensions) + diff_data[:] = np.nan + acc_data = np.zeros(dimensions) + acc_data[:] = np.nan + def chunked_indices(indices, chunk_size=1000): + """Yield successive chunk_size chunks from indices.""" + for i in range(0, len(indices), chunk_size): + yield indices[i:i + chunk_size] + + # Initialize tqdm progress bar + progress_bar = tqdm(total=len(indices)) + + # Create chunks of indices + np.random.shuffle(indices) + index_chunks = list(chunked_indices(indices)) + + with ProcessPoolExecutor(max_workers=nbr_cpu) as executor: + futures = {executor.submit(compute_difference_for_voxel, chunk): chunk for chunk in index_chunks} + + for future in as_completed(futures): + chunk = futures[future] + try: + results = future.result() + except Exception as exc: + print(f'Generated an exception: {exc}') + else: + results = np.array(results) + diff_data[tuple(chunk.T)] = results[:, 0] + acc_data[tuple(chunk.T)] = results[:, 1] + + # Update tqdm progress bar + progress_bar.update(len(chunk)) + + logging.info('Saving results...') + nib.save(nib.Nifti1Image(diff_data, affine), out_diff_filename) + nib.save(nib.Nifti1Image(acc_data, affine), out_acc_filename) + + # Normalize metrics + acc_norm = normalize_metric(acc_data) + corr_norm = normalize_metric(corr_data) + diff_norm = normalize_metric(diff_data, reverse=True) + indices_minus_one = np.where((acc_data == -1) | (corr_data == -1) | \ + (diff_data == -1)) + + # Merge into a single heatmap + heatmap = merge_metrics(acc_norm, corr_norm, diff_norm) + + # Save as a new NIFTI file + heatmap[indices_minus_one] = np.nan + nib.save(nib.Nifti1Image(heatmap, affine), out_merge_filename) + + +if __name__ == "__main__": + main() From abf57c3b2995b5df0db6a34c6452529261744e26 Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 23 Oct 2023 09:25:03 -0400 Subject: [PATCH 02/18] Pep8 and docstring --- scripts/scil_compare_tractograms.py | 89 +++++++++++++++++------------ 1 file changed, 51 insertions(+), 38 deletions(-) diff --git a/scripts/scil_compare_tractograms.py b/scripts/scil_compare_tractograms.py index 00e74fc45..3fc6f7676 100644 --- a/scripts/scil_compare_tractograms.py +++ b/scripts/scil_compare_tractograms.py @@ -2,7 +2,14 @@ # -*- coding: utf-8 -*- """ - +This script is designed to compare and help visualize differences between +two tractograms. Which can be especially useful in studies where multiple +tractograms from different algorithms or parameters need to be compared. + +The difference is computed in terms of +- A voxel-wise spatial distance between streamlines, out_diff.nii.gz +- A correlation (ACC) between streamline orientation (TODI), out_acc.nii.gz +- A correlation between streamline density, out_corr.nii.gz """ import argparse @@ -33,14 +40,14 @@ validate_nbr_processes) from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractanalysis.todi import TrackOrientationDensityImaging -from scilpy.tractograms.streamline_operations import resample_streamlines_step_size + def _build_arg_parser(): p = argparse.ArgumentParser( description=__doc__, formatter_class=argparse.RawTextHelpFormatter) - + p.add_argument('in_tractogram_1', help='Input tractogram 1.') p.add_argument('in_tractogram_2', @@ -53,15 +60,15 @@ def _build_arg_parser(): p.add_argument('--out_prefix', default='out', help='Prefix for output files. Useful for distinguishing between ' 'different runs.') - + p.add_argument('--in_mask', metavar='IN_FILE', help='Optional input mask.') - + add_processes_arg(p) add_reference_arg(p) add_verbose_arg(p) add_overwrite_arg(p) - + return p @@ -69,12 +76,12 @@ def generate_matched_points(sft): """ Generate an array where each element i is set to the index of the streamline that contributes the ith point. - + Parameters: ----------- sft : StatefulTractogram The stateful tractogram containing the streamlines. - + Returns: -------- matched_points : ndarray @@ -85,12 +92,12 @@ def generate_matched_points(sft): offsets = sft.streamlines._offsets matched_points = np.zeros(total_points, dtype=np.uint64) - + for i in range(len(offsets) - 1): matched_points[offsets[i]:offsets[i+1]] = i - + matched_points[offsets[-1]:] = len(offsets) - 1 - + return matched_points @@ -99,7 +106,7 @@ def compute_difference_for_voxel(chunk_indices): Compute the difference for a single voxel index. """ global sft_1, sft_2, matched_points_1, matched_points_2, tree_1, tree_2, \ - sh_data_1, sh_data_2 + sh_data_1, sh_data_2 results = [] for vox_ind in chunk_indices: vox_ind = tuple(vox_ind) @@ -127,8 +134,8 @@ def compute_difference_for_voxel(chunk_indices): sparse_ma_dist_mat = np.ma.masked_where(sparse_dist_mat < 1e-3, sparse_dist_mat) sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, - axis=0)) - + axis=0)) + if np.any(sparse_ma_dist_vec): global B sf_1 = np.dot(sh_data_1[vox_ind], B) @@ -148,13 +155,13 @@ def normalize_metric(metric, reverse=False): """ mask = np.isnan(metric) masked_metric = ma.masked_array(metric, mask) - + min_val, max_val = masked_metric.min(), masked_metric.max() normalized_metric = (masked_metric - min_val) / (max_val - min_val) - + if reverse: normalized_metric = 1 - normalized_metric - + return ma.filled(normalized_metric, fill_value=np.nan) @@ -164,16 +171,18 @@ def merge_metrics(acc, corr, diff, beta=1.0): ignoring specified values. """ mask = np.isnan(acc) | np.isnan(corr) | np.isnan(diff) - masked_acc, masked_corr, masked_diff = [ma.masked_array(x, mask) for x in [acc, corr, diff]] - + masked_acc, masked_corr, masked_diff = [ + ma.masked_array(x, mask) for x in [acc, corr, diff]] + # Calculate the geometric mean for valid data geometric_mean = np.cbrt(masked_acc * masked_corr * masked_diff) - + # Apply a boosting factor boosted_mean = geometric_mean ** beta - + return ma.filled(boosted_mean, fill_value=np.nan) + def main(): parser = _build_arg_parser() args = parser.parse_args() @@ -187,17 +196,17 @@ def main(): args.in_tractogram_2], verbose_all_compatible=True, reference=args.reference) - + if args.out_prefix and args.out_prefix[-1] == '_': args.out_prefix = args.out_prefix[:-1] out_corr_filename = os.path.join(args.out_dir, - '{}_correlation.nii.gz'.format(args.out_prefix)) + '{}_correlation.nii.gz'.format(args.out_prefix)) out_acc_filename = os.path.join(args.out_dir, - '{}_acc.nii.gz'.format(args.out_prefix)) + '{}_acc.nii.gz'.format(args.out_prefix)) out_diff_filename = os.path.join(args.out_dir, - '{}_diff.nii.gz'.format(args.out_prefix)) + '{}_diff.nii.gz'.format(args.out_prefix)) out_merge_filename = os.path.join(args.out_dir, - '{}_heatmap.nii.gz'.format(args.out_prefix)) + '{}_heatmap.nii.gz'.format(args.out_prefix)) assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir) assert_outputs_exist(parser, args, [out_corr_filename, out_acc_filename, @@ -208,9 +217,9 @@ def main(): logging.info('Loading tractograms...') global sft_1, sft_2 sft_1 = load_tractogram_with_reference(parser, args, args.in_tractogram_1) - #sft_1 = resample_streamlines_step_size(sft_1, 0.5) + # sft_1 = resample_streamlines_step_size(sft_1, 0.5) sft_2 = load_tractogram_with_reference(parser, args, args.in_tractogram_2) - #sft_2 = resample_streamlines_step_size(sft_2, 0.5) + # sft_2 = resample_streamlines_step_size(sft_2, 0.5) sft_1.to_vox() sft_2.to_vox() sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) @@ -234,20 +243,21 @@ def main(): logging.info('Computing density maps...') density_1 = compute_tract_counts_map(sft_1.streamlines, - dimensions).astype(float) + dimensions).astype(float) density_2 = compute_tract_counts_map(sft_2.streamlines, - dimensions).astype(float) + dimensions).astype(float) mask = density_1 * density_2 * mask mask[mask > 0] = 1 logging.info('Computing correlation map...') corr_data = correlation([density_1, density_2], None) * mask nib.save(nib.Nifti1Image(corr_data, affine), out_corr_filename) - + logging.info('Computing TODI #1...') global sh_data_1, sh_data_2 sft_1.to_corner() - todi_obj = TrackOrientationDensityImaging(tuple(dimensions), 'repulsion724') + todi_obj = TrackOrientationDensityImaging( + tuple(dimensions), 'repulsion724') todi_obj.compute_todi(deepcopy(sft_1.streamlines), length_weights=True) todi_obj.mask_todi(mask) sh_data_1 = todi_obj.get_sh('descoteaux07', 8) @@ -256,7 +266,8 @@ def main(): logging.info('Computing TODI #2...') sft_2.to_corner() - todi_obj = TrackOrientationDensityImaging(tuple(dimensions), 'repulsion724') + todi_obj = TrackOrientationDensityImaging( + tuple(dimensions), 'repulsion724') todi_obj.compute_todi(deepcopy(sft_2.streamlines), length_weights=True) todi_obj.mask_todi(mask) sh_data_2 = todi_obj.get_sh('descoteaux07', 8) @@ -272,6 +283,7 @@ def main(): diff_data[:] = np.nan acc_data = np.zeros(dimensions) acc_data[:] = np.nan + def chunked_indices(indices, chunk_size=1000): """Yield successive chunk_size chunks from indices.""" for i in range(0, len(indices), chunk_size): @@ -285,7 +297,8 @@ def chunked_indices(indices, chunk_size=1000): index_chunks = list(chunked_indices(indices)) with ProcessPoolExecutor(max_workers=nbr_cpu) as executor: - futures = {executor.submit(compute_difference_for_voxel, chunk): chunk for chunk in index_chunks} + futures = {executor.submit( + compute_difference_for_voxel, chunk): chunk for chunk in index_chunks} for future in as_completed(futures): chunk = futures[future] @@ -297,10 +310,10 @@ def chunked_indices(indices, chunk_size=1000): results = np.array(results) diff_data[tuple(chunk.T)] = results[:, 0] acc_data[tuple(chunk.T)] = results[:, 1] - + # Update tqdm progress bar progress_bar.update(len(chunk)) - + logging.info('Saving results...') nib.save(nib.Nifti1Image(diff_data, affine), out_diff_filename) nib.save(nib.Nifti1Image(acc_data, affine), out_acc_filename) @@ -309,8 +322,8 @@ def chunked_indices(indices, chunk_size=1000): acc_norm = normalize_metric(acc_data) corr_norm = normalize_metric(corr_data) diff_norm = normalize_metric(diff_data, reverse=True) - indices_minus_one = np.where((acc_data == -1) | (corr_data == -1) | \ - (diff_data == -1)) + indices_minus_one = np.where((acc_data == -1) | (corr_data == -1) | + (diff_data == -1)) # Merge into a single heatmap heatmap = merge_metrics(acc_norm, corr_norm, diff_norm) From 06799e6d69a283a1b0962fcbe02e21bdc359cbaa Mon Sep 17 00:00:00 2001 From: frheault Date: Fri, 3 Nov 2023 11:01:09 -0400 Subject: [PATCH 03/18] Skip --- scripts/scil_compare_tractograms.py | 37 ++++++++++++++++++----------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/scripts/scil_compare_tractograms.py b/scripts/scil_compare_tractograms.py index 3fc6f7676..b1ffde139 100644 --- a/scripts/scil_compare_tractograms.py +++ b/scripts/scil_compare_tractograms.py @@ -63,7 +63,9 @@ def _build_arg_parser(): p.add_argument('--in_mask', metavar='IN_FILE', help='Optional input mask.') - + p.add_argument('--skip_streamlines_distance', action='store_true', + help='Skip computation of the spatial distance between ' + 'streamlines.') add_processes_arg(p) add_reference_arg(p) add_verbose_arg(p) @@ -101,7 +103,8 @@ def generate_matched_points(sft): return matched_points -def compute_difference_for_voxel(chunk_indices): +def compute_difference_for_voxel(chunk_indices, + skip_streamlines_distance=False): """ Compute the difference for a single voxel index. """ @@ -111,6 +114,19 @@ def compute_difference_for_voxel(chunk_indices): for vox_ind in chunk_indices: vox_ind = tuple(vox_ind) + global B + has_data = sh_data_1[vox_ind].any() and sh_data_2[vox_ind].any() + if has_data: + sf_1 = np.dot(sh_data_1[vox_ind], B) + sf_2 = np.dot(sh_data_2[vox_ind], B) + corr = np.corrcoef(sf_1, sf_2)[0, 1] + else: + corr = -1 + + if skip_streamlines_distance or not has_data: + results.append([-1, corr]) + continue + # Get the streamlines in the neighborhood (i.e., 2mm away) pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) if not pts_ind_1: @@ -136,15 +152,8 @@ def compute_difference_for_voxel(chunk_indices): sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, axis=0)) - if np.any(sparse_ma_dist_vec): - global B - sf_1 = np.dot(sh_data_1[vox_ind], B) - sf_2 = np.dot(sh_data_2[vox_ind], B) - dist = np.average(sparse_ma_dist_vec) - corr = np.corrcoef(sf_1, sf_2)[0, 1] - results.append([dist, corr]) - else: - results.append([-1, -1]) + dist = np.average(sparse_ma_dist_vec) + results.append([dist, corr]) return results @@ -217,9 +226,8 @@ def main(): logging.info('Loading tractograms...') global sft_1, sft_2 sft_1 = load_tractogram_with_reference(parser, args, args.in_tractogram_1) - # sft_1 = resample_streamlines_step_size(sft_1, 0.5) sft_2 = load_tractogram_with_reference(parser, args, args.in_tractogram_2) - # sft_2 = resample_streamlines_step_size(sft_2, 0.5) + sft_1.to_vox() sft_2.to_vox() sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) @@ -298,7 +306,8 @@ def chunked_indices(indices, chunk_size=1000): with ProcessPoolExecutor(max_workers=nbr_cpu) as executor: futures = {executor.submit( - compute_difference_for_voxel, chunk): chunk for chunk in index_chunks} + compute_difference_for_voxel, chunk, + args.skip_streamlines_distance): chunk for chunk in index_chunks} for future in as_completed(futures): chunk = futures[future] From 54db51afcee4f74384755d00b4a2a1e5b6bd7b8e Mon Sep 17 00:00:00 2001 From: frheault Date: Thu, 14 Dec 2023 10:33:37 -0500 Subject: [PATCH 04/18] rename script --- ... => scil_tractogram_pairwise_agreement.py} | 2 - .../scil_tractogram_pairwise_comparison.py | 96 +++++++++++++++++++ 2 files changed, 96 insertions(+), 2 deletions(-) rename scripts/{scil_compare_tractograms.py => scil_tractogram_pairwise_agreement.py} (99%) create mode 100644 scripts/tests/scil_tractogram_pairwise_comparison.py diff --git a/scripts/scil_compare_tractograms.py b/scripts/scil_tractogram_pairwise_agreement.py similarity index 99% rename from scripts/scil_compare_tractograms.py rename to scripts/scil_tractogram_pairwise_agreement.py index 3fc6f7676..0872a2e0a 100644 --- a/scripts/scil_compare_tractograms.py +++ b/scripts/scil_tractogram_pairwise_agreement.py @@ -217,9 +217,7 @@ def main(): logging.info('Loading tractograms...') global sft_1, sft_2 sft_1 = load_tractogram_with_reference(parser, args, args.in_tractogram_1) - # sft_1 = resample_streamlines_step_size(sft_1, 0.5) sft_2 = load_tractogram_with_reference(parser, args, args.in_tractogram_2) - # sft_2 = resample_streamlines_step_size(sft_2, 0.5) sft_1.to_vox() sft_2.to_vox() sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) diff --git a/scripts/tests/scil_tractogram_pairwise_comparison.py b/scripts/tests/scil_tractogram_pairwise_comparison.py new file mode 100644 index 000000000..a5b825e84 --- /dev/null +++ b/scripts/tests/scil_tractogram_pairwise_comparison.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import tempfile + +from scilpy.io.fetcher import get_testing_files_dict, fetch_data, get_home + + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['bundles.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run( + 'scil_bundle_pairwise_comparison.py', '--help') + assert ret.success + + +def test_execution_bundles(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_1 = os.path.join(get_home(), 'bundles', 'bundle_0_reco.tck') + in_2 = os.path.join(get_home(), 'bundles', 'voting_results', + 'bundle_0.trk') + in_ref = os.path.join(get_home(), 'bundles', 'bundle_all_1mm.nii.gz') + ret = script_runner.run( + 'scil_bundle_pairwise_comparison.py', + in_1, in_2, 'AF_L_similarity.json', + '--streamline_dice', '--reference', in_ref, + '--processes', '1') + assert ret.success + + +def test_single(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_1 = os.path.join(get_home(), 'bundles', 'bundle_0_reco.tck') + in_2 = os.path.join(get_home(), 'bundles', 'voting_results', + 'bundle_0.trk') + in_ref = os.path.join(get_home(), 'bundles', 'bundle_all_1mm.nii.gz') + ret = script_runner.run( + 'scil_bundle_pairwise_comparison.py', + in_2, 'AF_L_similarity_single.json', + '--streamline_dice', '--reference', in_ref, + '--single_compare', in_1, + '--processes', '1') + assert ret.success + + +def test_no_overlap(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_1 = os.path.join(get_home(), 'bundles', 'bundle_0_reco.tck') + in_2 = os.path.join(get_home(), 'bundles', 'voting_results', + 'bundle_0.trk') + in_ref = os.path.join(get_home(), 'bundles', 'bundle_all_1mm.nii.gz') + ret = script_runner.run( + 'scil_bundle_pairwise_comparison.py', in_1, + in_2, 'AF_L_similarity_no_overlap.json', + '--streamline_dice', '--reference', in_ref, + '--bundle_adjency_no_overlap', + '--processes', '1') + assert ret.success + + +def test_ratio(script_runner): + os.chdir(os.path.expanduser(tmp_dir.name)) + in_1 = os.path.join(get_home(), 'bundles', 'bundle_0_reco.tck') + in_2 = os.path.join(get_home(), 'bundles', 'voting_results', + 'bundle_0.trk') + in_ref = os.path.join(get_home(), 'bundles', 'bundle_all_1mm.nii.gz') + ret = script_runner.run( + 'scil_bundle_pairwise_comparison.py', + in_2, 'AF_L_similarity_ratio.json', + '--streamline_dice', '--reference', in_ref, + '--single_compare', in_1, + '--processes', '1', + '--ratio') + assert ret.success + + +def test_ratio_fail(script_runner): + """ Test ratio without single_compare argument. + The test should fail. + """ + os.chdir(os.path.expanduser(tmp_dir.name)) + in_1 = os.path.join(get_home(), 'bundles', 'bundle_0_reco.tck') + in_2 = os.path.join(get_home(), 'bundles', 'voting_results', + 'bundle_0.trk') + in_ref = os.path.join(get_home(), 'bundles', 'bundle_all_1mm.nii.gz') + ret = script_runner.run( + 'scil_bundle_pairwise_comparison.py', + in_1, in_2, 'AF_L_similarity_fail.json', + '--streamline_dice', '--reference', in_ref, + '--processes', '1', + '--ratio') + assert not ret.success From 9199845e67d684ac4f2f1f050e594906e654e08c Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 29 Jan 2024 13:56:37 -0500 Subject: [PATCH 05/18] Converting the main to a wrapper function --- scilpy/image/volume_operations.py | 67 +++++ scilpy/tractograms/streamline_operations.py | 29 ++ scilpy/tractograms/tractogram_operations.py | 266 +++++++++++++++++- .../scil_tractogram_pairwise_comparison.py | 264 ++--------------- 4 files changed, 373 insertions(+), 253 deletions(-) diff --git a/scilpy/image/volume_operations.py b/scilpy/image/volume_operations.py index 463c39c8d..e391c9880 100644 --- a/scilpy/image/volume_operations.py +++ b/scilpy/image/volume_operations.py @@ -14,6 +14,7 @@ from dipy.segment.mask import crop, median_otsu import nibabel as nib import numpy as np +import numpy.ma as ma from scipy.ndimage import binary_dilation from scilpy.image.reslice import reslice # Don't use Dipy's reslice. Buggy. @@ -466,3 +467,69 @@ def crop_data_with_default_cube(data): roi_mask = _mask_from_roi(shape, roi_center, roi_radii) return data * roi_mask + + +def normalize_metric(metric, reverse=False): + """ + Normalize a metric array to a range between 0 and 1, + optionally reversing the normalization. + + Parameters + ---------- + metric : ndarray + The input metric array to be normalized. + reverse : bool, optional + If True, reverse the normalization (i.e., 1 - normalized value). + Default is False. + + Returns + ------- + ndarray + The normalized (and possibly reversed) metric array. + NaN values in the input are retained. + """ + mask = np.isnan(metric) + masked_metric = ma.masked_array(metric, mask) + + min_val, max_val = masked_metric.min(), masked_metric.max() + normalized_metric = (masked_metric - min_val) / (max_val - min_val) + + if reverse: + normalized_metric = 1 - normalized_metric + + return ma.filled(normalized_metric, fill_value=np.nan) + + +def merge_metrics(*arrays, beta=1.0): + """ + Merge an arbitrary number of metrics into a single heatmap using a weighted + geometric mean, ignoring NaN values. Each input array contributes equally + to the geometric mean, and the result is boosted by a specified factor. + + Parameters + ---------- + *arrays : ndarray + An arbitrary number of input arrays (ndarrays). + All arrays must have the same shape. + beta : float, optional + Boosting factor for the geometric mean. The default is 1.0. + + Returns + ------- + ndarray + Boosted geometric mean of the inputs (same shape as the input arrays) + NaN values in any input array are propagated to the output. + """ + + # Create a mask for NaN values in any of the arrays + mask = np.any([np.isnan(arr) for arr in arrays], axis=0) + masked_arrays = [ma.masked_array(arr, mask) for arr in arrays] + + # Calculate the product of the arrays for the geometric mean + array_product = np.prod(masked_arrays, axis=0) + + # Calculate the geometric mean for valid data + geometric_mean = np.power(array_product, 1 / len(arrays)) + boosted_mean = geometric_mean ** beta + + return ma.filled(boosted_mean, fill_value=np.nan) diff --git a/scilpy/tractograms/streamline_operations.py b/scilpy/tractograms/streamline_operations.py index 18d247df1..d3dc34251 100644 --- a/scilpy/tractograms/streamline_operations.py +++ b/scilpy/tractograms/streamline_operations.py @@ -389,3 +389,32 @@ def smooth_line_spline(streamline, smoothing_parameter, nb_ctrl_points): smoothed_streamline[-1] = streamline[-1] return smoothed_streamline + + +def generate_matched_points(sft): + """ + Generate an array where each element i is set to the index of the + streamline that contributes the ith point. + + Parameters: + ----------- + sft : StatefulTractogram + The stateful tractogram containing the streamlines. + + Returns: + -------- + matched_points : ndarray + An array where each element is set to the index of the streamline + that contributes that point. + """ + total_points = sft.streamlines._data.shape[0] + offsets = sft.streamlines._offsets + + matched_points = np.zeros(total_points, dtype=np.uint64) + + for i in range(len(offsets) - 1): + matched_points[offsets[i]:offsets[i+1]] = i + + matched_points[offsets[-1]:] = len(offsets) - 1 + + return matched_points diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index e896f2300..b42acacf8 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -7,24 +7,35 @@ individually. See scilpy.tractograms.streamline_operations.py for the latter. """ -import itertools - -import random +from concurrent.futures import ProcessPoolExecutor, as_completed +from copy import deepcopy from functools import reduce +import itertools import logging +import random +import warnings +from dipy.data import get_sphere from dipy.io.stateful_tractogram import StatefulTractogram, Space from dipy.io.utils import get_reference_info, is_header_compatible from dipy.segment.clustering import qbx_and_merge +from dipy.segment.fss import FastStreamlineSearch from dipy.tracking.streamline import transform_streamlines +from dipy.reconst.shm import sh_to_sf_matrix from nibabel.streamlines.array_sequence import ArraySequence import numpy as np from scipy.ndimage import map_coordinates from scipy.spatial import cKDTree - -from scilpy.tractograms.streamline_operations import smooth_line_gaussian, \ - smooth_line_spline +from tqdm import tqdm + +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from scilpy.tractanalysis.todi import TrackOrientationDensityImaging +from scilpy.tractograms.streamline_operations import (generate_matched_points, + smooth_line_gaussian, + smooth_line_spline) +from scilpy.image.volume_operations import (normalize_metric, merge_metrics) from scilpy.utils.streamlines import cut_invalid_streamlines +from scilpy.image.volume_math import correlation MIN_NB_POINTS = 10 KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1))) @@ -843,3 +854,246 @@ def split_sft_randomly_per_cluster(orig_sft, chunk_sizes, seed, thresholds): 'concatenate': 'concatenate', 'lazy_concatenate': 'lazy_concatenate' } + + +def _compute_difference_for_voxel(chunk_indices, + skip_streamlines_distance=False): + """ + Compute the difference between two sets of streamlines for a given voxel. + This function uses global variable to avoid duplicating the data for each + chunk of voxels. + + Use the function tractogram_pairwise_comparison() as an entry point. + + Parameters + ---------- + chunk_indices: list + List of indices of the voxel to process. + skip_streamlines_distance: bool + If true, skip the computation of the distance between streamlines. + + Returns + ------- + results: list + List of the computed differences in the same order as the input voxel. + """ + global sft_1, sft_2, matched_points_1, matched_points_2, tree_1, tree_2, \ + sh_data_1, sh_data_2 + results = [] + for vox_ind in chunk_indices: + vox_ind = tuple(vox_ind) + + global B + has_data = sh_data_1[vox_ind].any() and sh_data_2[vox_ind].any() + if has_data: + sf_1 = np.dot(sh_data_1[vox_ind], B) + sf_2 = np.dot(sh_data_2[vox_ind], B) + corr = np.corrcoef(sf_1, sf_2)[0, 1] + else: + corr = -1 + + if skip_streamlines_distance or not has_data: + results.append([-1, corr]) + continue + + # Get the streamlines in the neighborhood (i.e., 2mm away) + pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) + if not pts_ind_1: + results.append([-1, -1]) + continue + strs_ind_1 = np.unique(matched_points_1[pts_ind_1]) + neighb_streamlines_1 = sft_1.streamlines[strs_ind_1] + + # Get the streamlines in the neighborhood (i.e., 1mm away) + pts_ind_2 = tree_2.query_ball_point(vox_ind, 1.5) + if not pts_ind_2: + results.append([-1, -1]) + continue + strs_ind_2 = np.unique(matched_points_2[pts_ind_2]) + neighb_streamlines_2 = sft_2.streamlines[strs_ind_2] + + with warnings.catch_warnings(record=True) as _: + fss = FastStreamlineSearch(neighb_streamlines_1, 10, resampling=12) + dist_mat = fss.radius_search(neighb_streamlines_2, 10) + sparse_dist_mat = np.abs(dist_mat.tocsr()).toarray() + sparse_ma_dist_mat = np.ma.masked_where(sparse_dist_mat < 1e-3, + sparse_dist_mat) + sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, + axis=0)) + + dist = np.average(sparse_ma_dist_vec) + results.append([dist, corr]) + + return results + + +def compare_tractogram_wrapper(mask, nbr_cpu, skip_streamlines_distance): + """ + Wrapper for the comparison of two tractograms. This function uses + multiprocessing to compute the difference between two sets of streamlines + for each voxel. + + This function simple calls the function _compute_difference_for_voxel(), + which expect chunks of indices to process and use global variables to avoid + duplicating the data for each chunk of voxels. + + Use the function tractogram_pairwise_comparison() as an entry point. + + Parameters + ---------- + mask: np.ndarray + Mask of the data to compare. + nbr_cpu: int + Number of CPU to use. + skip_streamlines_distance: bool + If true, skip the computation of the distance between streamlines. + + Returns + ------- + Tuple of np.ndarray + diff_data: np.ndarray + Array containing the computed differences (mm). + acc_data: np.ndarray + Array containing the computed angular correlation. + """ + dimensions = mask.shape + + # Initialize multiprocessing + indices = np.argwhere(mask > 0) + diff_data = np.zeros(dimensions) + diff_data[:] = np.nan + acc_data = np.zeros(dimensions) + acc_data[:] = np.nan + + def chunked_indices(indices, chunk_size=1000): + """Yield successive chunk_size chunks from indices.""" + for i in range(0, len(indices), chunk_size): + yield indices[i:i + chunk_size] + + # Initialize tqdm progress bar + progress_bar = tqdm(total=len(indices)) + + # Create chunks of indices + np.random.shuffle(indices) + index_chunks = list(chunked_indices(indices)) + + with ProcessPoolExecutor(max_workers=nbr_cpu) as executor: + futures = {executor.submit( + _compute_difference_for_voxel, chunk, + skip_streamlines_distance): chunk for chunk in index_chunks} + + for future in as_completed(futures): + chunk = futures[future] + try: + results = future.result() + except Exception as exc: + print(f'Generated an exception: {exc}') + else: + results = np.array(results) + diff_data[tuple(chunk.T)] = results[:, 0] + acc_data[tuple(chunk.T)] = results[:, 1] + + # Update tqdm progress bar + progress_bar.update(len(chunk)) + + return diff_data, acc_data + + +def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, + skip_streamlines_distance=True): + """ + Compute the difference between two sets of streamlines for each voxel in + the mask. This function uses multiprocessing to compute the difference + between two sets of streamlines for each voxel. + + Parameters + ---------- + sft_one: StatefulTractogram + First tractogram to compare. + sft_two: StatefulTractogram + Second tractogram to compare. + mask: np.ndarray + Mask of the data to compare (optional). + nbr_cpu: int + Number of CPU to use (default: 1). + skip_streamlines_distance: bool + If true, skip the computation of the distance between streamlines. + (default: True) + + Returns + ------- + List of np.ndarray + acc_norm: Angular correlation coefficient. + corr_norm: Correlation coefficient of density maps. + diff_norm: Voxelwise distance between sets of streamlines. + heatmap: Merged heatmap of the three metrics using harmonic mean. + """ + global sft_1, sft_2 + sft_1, sft_2 = sft_one, sft_two + + sft_1.to_vox() + sft_2.to_vox() + sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) + sft_2.streamlines._data = sft_2.streamlines._data.astype(np.float16) + dimensions = tuple(sft_1.dimensions) + + global matched_points_1, matched_points_2 + matched_points_1 = generate_matched_points(sft_1) + matched_points_2 = generate_matched_points(sft_2) + + logging.info('Computing KDTree...') + global tree_1, tree_2 + tree_1 = cKDTree(sft_1.streamlines._data) + tree_2 = cKDTree(sft_2.streamlines._data) + + # Limits computation to mask AND streamlines (using density) + if not mask: + mask = np.ones(dimensions) + + logging.info('Computing density maps...') + density_1 = compute_tract_counts_map(sft_1.streamlines, + dimensions).astype(float) + density_2 = compute_tract_counts_map(sft_2.streamlines, + dimensions).astype(float) + mask = density_1 * density_2 * mask + mask[mask > 0] = 1 + + logging.info('Computing correlation map...') + corr_data = correlation([density_1, density_2], None) * mask + + logging.info('Computing TODI #1...') + global sh_data_1, sh_data_2 + sft_1.to_corner() + todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') + todi_obj.compute_todi(deepcopy(sft_1.streamlines), length_weights=True) + todi_obj.mask_todi(mask) + sh_data_1 = todi_obj.get_sh('descoteaux07', 8) + sh_data_1 = todi_obj.reshape_to_3d(sh_data_1) + sft_1.to_center() + + logging.info('Computing TODI #2...') + sft_2.to_corner() + todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') + todi_obj.compute_todi(deepcopy(sft_2.streamlines), length_weights=True) + todi_obj.mask_todi(mask) + sh_data_2 = todi_obj.get_sh('descoteaux07', 8) + sh_data_2 = todi_obj.reshape_to_3d(sh_data_2) + sft_2.to_center() + + global B + B, _ = sh_to_sf_matrix(get_sphere('repulsion724'), 8, 'descoteaux07') + + diff_data, acc_data = compare_tractogram_wrapper(mask, nbr_cpu, + skip_streamlines_distance) + + # Normalize metrics + acc_norm = normalize_metric(acc_data) + corr_norm = normalize_metric(corr_data) + diff_norm = normalize_metric(diff_data, reverse=True) + indices_minus_one = np.where((acc_data == -1) | (corr_data == -1) | + (diff_data == -1)) + # Merge into a single heatmap + heatmap = merge_metrics(acc_norm, corr_norm, diff_norm) + heatmap[indices_minus_one] = np.nan + + return acc_norm, corr_norm, diff_norm, heatmap diff --git a/scripts/scil_tractogram_pairwise_comparison.py b/scripts/scil_tractogram_pairwise_comparison.py index 02256c1d6..cdac5ea3b 100644 --- a/scripts/scil_tractogram_pairwise_comparison.py +++ b/scripts/scil_tractogram_pairwise_comparison.py @@ -13,22 +13,11 @@ """ import argparse -from concurrent.futures import ProcessPoolExecutor, as_completed -from copy import deepcopy import logging import os -from tqdm import tqdm -import warnings -from dipy.data import get_sphere -from dipy.reconst.shm import sh_to_sf_matrix -from dipy.segment.fss import FastStreamlineSearch import nibabel as nib -import numpy as np -import numpy.ma as ma -from scipy.spatial import cKDTree -from scilpy.image.volume_math import correlation from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, assert_inputs_exist, assert_outputs_exist, @@ -38,8 +27,7 @@ is_header_compatible_multiple_files, load_tractogram_with_reference, validate_nbr_processes) -from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map -from scilpy.tractanalysis.todi import TrackOrientationDensityImaging +from scilpy.tractograms.tractogram_operations import tractogram_pairwise_comparison def _build_arg_parser(): @@ -55,11 +43,11 @@ def _build_arg_parser(): p.add_argument('--out_dir', default='', help='Directory where all output files will be saved. ' - '\nIf not specified, outputs will be saved in the current ' - 'directory.') + '\nIf not specified, outputs will be saved in the ' + 'current directory.') p.add_argument('--out_prefix', default='out', - help='Prefix for output files. Useful for distinguishing between ' - 'different runs.') + help='Prefix for output files. Useful for distinguishing ' + 'between different runs.') p.add_argument('--in_mask', metavar='IN_FILE', help='Optional input mask.') @@ -74,124 +62,6 @@ def _build_arg_parser(): return p -def generate_matched_points(sft): - """ - Generate an array where each element i is set to the index of the streamline - that contributes the ith point. - - Parameters: - ----------- - sft : StatefulTractogram - The stateful tractogram containing the streamlines. - - Returns: - -------- - matched_points : ndarray - An array where each element is set to the index of the streamline - that contributes that point. - """ - total_points = sft.streamlines._data.shape[0] - offsets = sft.streamlines._offsets - - matched_points = np.zeros(total_points, dtype=np.uint64) - - for i in range(len(offsets) - 1): - matched_points[offsets[i]:offsets[i+1]] = i - - matched_points[offsets[-1]:] = len(offsets) - 1 - - return matched_points - - -def compute_difference_for_voxel(chunk_indices, - skip_streamlines_distance=False): - """ - Compute the difference for a single voxel index. - """ - global sft_1, sft_2, matched_points_1, matched_points_2, tree_1, tree_2, \ - sh_data_1, sh_data_2 - results = [] - for vox_ind in chunk_indices: - vox_ind = tuple(vox_ind) - - global B - has_data = sh_data_1[vox_ind].any() and sh_data_2[vox_ind].any() - if has_data: - sf_1 = np.dot(sh_data_1[vox_ind], B) - sf_2 = np.dot(sh_data_2[vox_ind], B) - corr = np.corrcoef(sf_1, sf_2)[0, 1] - else: - corr = -1 - - if skip_streamlines_distance or not has_data: - results.append([-1, corr]) - continue - - # Get the streamlines in the neighborhood (i.e., 2mm away) - pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) - if not pts_ind_1: - results.append([-1, -1]) - continue - strs_ind_1 = np.unique(matched_points_1[pts_ind_1]) - neighb_streamlines_1 = sft_1.streamlines[strs_ind_1] - - # Get the streamlines in the neighborhood (i.e., 1mm away) - pts_ind_2 = tree_2.query_ball_point(vox_ind, 1.5) - if not pts_ind_2: - results.append([-1, -1]) - continue - strs_ind_2 = np.unique(matched_points_2[pts_ind_2]) - neighb_streamlines_2 = sft_2.streamlines[strs_ind_2] - - with warnings.catch_warnings(record=True) as _: - fss = FastStreamlineSearch(neighb_streamlines_1, 10, resampling=12) - dist_mat = fss.radius_search(neighb_streamlines_2, 10) - sparse_dist_mat = np.abs(dist_mat.tocsr()).toarray() - sparse_ma_dist_mat = np.ma.masked_where(sparse_dist_mat < 1e-3, - sparse_dist_mat) - sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, - axis=0)) - - dist = np.average(sparse_ma_dist_vec) - results.append([dist, corr]) - - return results - - -def normalize_metric(metric, reverse=False): - """ - Normalize a metric to be in the range [0, 1], ignoring specified values. - """ - mask = np.isnan(metric) - masked_metric = ma.masked_array(metric, mask) - - min_val, max_val = masked_metric.min(), masked_metric.max() - normalized_metric = (masked_metric - min_val) / (max_val - min_val) - - if reverse: - normalized_metric = 1 - normalized_metric - - return ma.filled(normalized_metric, fill_value=np.nan) - - -def merge_metrics(acc, corr, diff, beta=1.0): - """ - Merge the three metrics into a single heatmap using a weighted geometric mean, - ignoring specified values. - """ - mask = np.isnan(acc) | np.isnan(corr) | np.isnan(diff) - masked_acc, masked_corr, masked_diff = [ - ma.masked_array(x, mask) for x in [acc, corr, diff]] - - # Calculate the geometric mean for valid data - geometric_mean = np.cbrt(masked_acc * masked_corr * masked_diff) - - # Apply a boosting factor - boosted_mean = geometric_mean ** beta - - return ma.filled(boosted_mean, fill_value=np.nan) - - def main(): parser = _build_arg_parser() args = parser.parse_args() @@ -201,8 +71,10 @@ def main(): assert_inputs_exist(parser, [args.in_tractogram_1, args.in_tractogram_2]) - is_header_compatible_multiple_files(parser, [args.in_tractogram_1, - args.in_tractogram_2], + to_verify = [args.in_tractogram_1, args.in_tractogram_2] + if args.in_mask: + to_verify.append(args.in_mask) + is_header_compatible_multiple_files(parser, to_verify, verbose_all_compatible=True, reference=args.reference) @@ -224,121 +96,19 @@ def main(): nbr_cpu = validate_nbr_processes(parser, args) logging.info('Loading tractograms...') - global sft_1, sft_2 sft_1 = load_tractogram_with_reference(parser, args, args.in_tractogram_1) sft_2 = load_tractogram_with_reference(parser, args, args.in_tractogram_2) - sft_1.to_vox() - sft_2.to_vox() - sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) - sft_2.streamlines._data = sft_2.streamlines._data.astype(np.float16) - affine, dimensions = sft_1.affine, sft_1.dimensions - - global matched_points_1, matched_points_2 - matched_points_1 = generate_matched_points(sft_1) - matched_points_2 = generate_matched_points(sft_2) - - logging.info('Computing KDTree...') - global tree_1, tree_2 - tree_1 = cKDTree(sft_1.streamlines._data) - tree_2 = cKDTree(sft_2.streamlines._data) - - # Limits computation to mask AND streamlines (using density) - if args.in_mask: - mask = nib.load(args.in_mask).get_fdata() - else: - mask = np.ones(dimensions) - - logging.info('Computing density maps...') - density_1 = compute_tract_counts_map(sft_1.streamlines, - dimensions).astype(float) - density_2 = compute_tract_counts_map(sft_2.streamlines, - dimensions).astype(float) - mask = density_1 * density_2 * mask - mask[mask > 0] = 1 - - logging.info('Computing correlation map...') - corr_data = correlation([density_1, density_2], None) * mask - nib.save(nib.Nifti1Image(corr_data, affine), out_corr_filename) - - logging.info('Computing TODI #1...') - global sh_data_1, sh_data_2 - sft_1.to_corner() - todi_obj = TrackOrientationDensityImaging( - tuple(dimensions), 'repulsion724') - todi_obj.compute_todi(deepcopy(sft_1.streamlines), length_weights=True) - todi_obj.mask_todi(mask) - sh_data_1 = todi_obj.get_sh('descoteaux07', 8) - sh_data_1 = todi_obj.reshape_to_3d(sh_data_1) - sft_1.to_center() + mask = nib.load(args.in_mask) if args.in_mask else None - logging.info('Computing TODI #2...') - sft_2.to_corner() - todi_obj = TrackOrientationDensityImaging( - tuple(dimensions), 'repulsion724') - todi_obj.compute_todi(deepcopy(sft_2.streamlines), length_weights=True) - todi_obj.mask_todi(mask) - sh_data_2 = todi_obj.get_sh('descoteaux07', 8) - sh_data_2 = todi_obj.reshape_to_3d(sh_data_2) - sft_2.to_center() - - global B - B, _ = sh_to_sf_matrix(get_sphere('repulsion724'), 8, 'descoteaux07') - - # Initialize multiprocessing - indices = np.argwhere(mask > 0) - diff_data = np.zeros(dimensions) - diff_data[:] = np.nan - acc_data = np.zeros(dimensions) - acc_data[:] = np.nan - - def chunked_indices(indices, chunk_size=1000): - """Yield successive chunk_size chunks from indices.""" - for i in range(0, len(indices), chunk_size): - yield indices[i:i + chunk_size] - - # Initialize tqdm progress bar - progress_bar = tqdm(total=len(indices)) - - # Create chunks of indices - np.random.shuffle(indices) - index_chunks = list(chunked_indices(indices)) - - with ProcessPoolExecutor(max_workers=nbr_cpu) as executor: - futures = {executor.submit( - compute_difference_for_voxel, chunk, - args.skip_streamlines_distance): chunk for chunk in index_chunks} - - for future in as_completed(futures): - chunk = futures[future] - try: - results = future.result() - except Exception as exc: - print(f'Generated an exception: {exc}') - else: - results = np.array(results) - diff_data[tuple(chunk.T)] = results[:, 0] - acc_data[tuple(chunk.T)] = results[:, 1] - - # Update tqdm progress bar - progress_bar.update(len(chunk)) + acc_data, corr_data, diff_data, heatmap = \ + tractogram_pairwise_comparison(sft_1, sft_2, mask, nbr_cpu, + args.skip_streamlines_distance) logging.info('Saving results...') - nib.save(nib.Nifti1Image(diff_data, affine), out_diff_filename) - nib.save(nib.Nifti1Image(acc_data, affine), out_acc_filename) - - # Normalize metrics - acc_norm = normalize_metric(acc_data) - corr_norm = normalize_metric(corr_data) - diff_norm = normalize_metric(diff_data, reverse=True) - indices_minus_one = np.where((acc_data == -1) | (corr_data == -1) | - (diff_data == -1)) - - # Merge into a single heatmap - heatmap = merge_metrics(acc_norm, corr_norm, diff_norm) - - # Save as a new NIFTI file - heatmap[indices_minus_one] = np.nan - nib.save(nib.Nifti1Image(heatmap, affine), out_merge_filename) + nib.save(nib.Nifti1Image(acc_data, sft_1.affine), out_acc_filename) + nib.save(nib.Nifti1Image(corr_data, sft_1.affine), out_corr_filename) + nib.save(nib.Nifti1Image(diff_data, sft_1.affine), out_diff_filename) + nib.save(nib.Nifti1Image(heatmap, sft_1.affine), out_merge_filename) if __name__ == "__main__": From 602948cea4da0347913eab2b5d2b79b389cf2b2c Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 29 Jan 2024 13:57:36 -0500 Subject: [PATCH 06/18] Add _ to wrapper --- scilpy/tractograms/tractogram_operations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index b42acacf8..67761e3f7 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -927,7 +927,7 @@ def _compute_difference_for_voxel(chunk_indices, return results -def compare_tractogram_wrapper(mask, nbr_cpu, skip_streamlines_distance): +def _compare_tractogram_wrapper(mask, nbr_cpu, skip_streamlines_distance): """ Wrapper for the comparison of two tractograms. This function uses multiprocessing to compute the difference between two sets of streamlines @@ -1083,8 +1083,8 @@ def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, global B B, _ = sh_to_sf_matrix(get_sphere('repulsion724'), 8, 'descoteaux07') - diff_data, acc_data = compare_tractogram_wrapper(mask, nbr_cpu, - skip_streamlines_distance) + diff_data, acc_data = _compare_tractogram_wrapper(mask, nbr_cpu, + skip_streamlines_distance) # Normalize metrics acc_norm = normalize_metric(acc_data) From 569aa3d6a1267c628f3cd881a8098e9e3104e333 Mon Sep 17 00:00:00 2001 From: frheault Date: Fri, 2 Feb 2024 12:48:44 -0500 Subject: [PATCH 07/18] First pass emmanuel comments --- scilpy/tractograms/tractogram_operations.py | 9 ++++++-- .../scil_tractogram_pairwise_comparison.py | 21 ++++++++++++++----- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index 67761e3f7..274fe8e3a 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -896,7 +896,7 @@ def _compute_difference_for_voxel(chunk_indices, results.append([-1, corr]) continue - # Get the streamlines in the neighborhood (i.e., 2mm away) + # Get the streamlines in the neighborhood (i.e., 1.5mm away) pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) if not pts_ind_1: results.append([-1, -1]) @@ -904,7 +904,7 @@ def _compute_difference_for_voxel(chunk_indices, strs_ind_1 = np.unique(matched_points_1[pts_ind_1]) neighb_streamlines_1 = sft_1.streamlines[strs_ind_1] - # Get the streamlines in the neighborhood (i.e., 1mm away) + # Get the streamlines in the neighborhood (i.e., 1.5mm away) pts_ind_2 = tree_2.query_ball_point(vox_ind, 1.5) if not pts_ind_2: results.append([-1, -1]) @@ -912,6 +912,9 @@ def _compute_difference_for_voxel(chunk_indices, strs_ind_2 = np.unique(matched_points_2[pts_ind_2]) neighb_streamlines_2 = sft_2.streamlines[strs_ind_2] + # Using neighb_streamlines (all streamlines in the neighborhood of our + # voxel), we can compute the distance between the two sets of + # streamlines using FSS (FastStreamlineSearch). with warnings.catch_warnings(record=True) as _: fss = FastStreamlineSearch(neighb_streamlines_1, 10, resampling=12) dist_mat = fss.radius_search(neighb_streamlines_2, 10) @@ -921,6 +924,8 @@ def _compute_difference_for_voxel(chunk_indices, sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, axis=0)) + # dists will represent the average distance between the two sets of + # streamlines in the neighborhood of the voxel. dist = np.average(sparse_ma_dist_vec) results.append([dist, corr]) diff --git a/scripts/scil_tractogram_pairwise_comparison.py b/scripts/scil_tractogram_pairwise_comparison.py index cdac5ea3b..635a25f67 100644 --- a/scripts/scil_tractogram_pairwise_comparison.py +++ b/scripts/scil_tractogram_pairwise_comparison.py @@ -3,13 +3,24 @@ """ This script is designed to compare and help visualize differences between -two tractograms. Which can be especially useful in studies where multiple +two tractograms. This can be especially useful in studies where multiple tractograms from different algorithms or parameters need to be compared. +A similar script (scil_bundle_pairwise_comparison.py) is available for bundles, +with metrics more adapted to bundles (and spatial aggrement). + The difference is computed in terms of -- A voxel-wise spatial distance between streamlines, out_diff.nii.gz -- A correlation (ACC) between streamline orientation (TODI), out_acc.nii.gz -- A correlation between streamline density, out_corr.nii.gz +- A voxel-wise spatial distance between streamlines crossing each voxel. + This can help to see if both tractography reconstructions at each voxel + looks similar (out_diff.nii.gz) +- An angular correlation (ACC) between streamline orientation from TODI. + This compares the local orientation of streamlines at each voxel + (out_acc.nii.gz) +- A patch-wise correlation between streamline density maps from both + tractograms. This compares where the high/low density regions agree or not + (out_corr.nii.gz) +- A heatmap combining all the previous metrics using an harmonic means of the + normalized metrics to summarize general agreement (out_heatmap.nii.gz) """ import argparse @@ -47,7 +58,7 @@ def _build_arg_parser(): 'current directory.') p.add_argument('--out_prefix', default='out', help='Prefix for output files. Useful for distinguishing ' - 'between different runs.') + 'between different runs [%(default)s].') p.add_argument('--in_mask', metavar='IN_FILE', help='Optional input mask.') From 824561bb2c8b070b17cb83452a3349b9ece60178 Mon Sep 17 00:00:00 2001 From: frheault Date: Fri, 2 Feb 2024 14:09:50 -0500 Subject: [PATCH 08/18] Fix typo/errors --- scripts/scil_tractogram_pairwise_comparison.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/scil_tractogram_pairwise_comparison.py b/scripts/scil_tractogram_pairwise_comparison.py index 635a25f67..4e5ad693a 100644 --- a/scripts/scil_tractogram_pairwise_comparison.py +++ b/scripts/scil_tractogram_pairwise_comparison.py @@ -7,12 +7,12 @@ tractograms from different algorithms or parameters need to be compared. A similar script (scil_bundle_pairwise_comparison.py) is available for bundles, -with metrics more adapted to bundles (and spatial aggrement). +with metrics more adapted to bundles (and spatial agreement). The difference is computed in terms of - A voxel-wise spatial distance between streamlines crossing each voxel. This can help to see if both tractography reconstructions at each voxel - looks similar (out_diff.nii.gz) + look similar (out_diff.nii.gz) - An angular correlation (ACC) between streamline orientation from TODI. This compares the local orientation of streamlines at each voxel (out_acc.nii.gz) From aa78ad2a13e3d28d59f72491385a349e62396d7f Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 12 Feb 2024 10:21:52 -0500 Subject: [PATCH 09/18] test start --- scilpy/image/volume_operations.py | 2 +- .../tests/test_tractogram_operations.py | 33 +++++++++++++++++-- scilpy/tractograms/tractogram_operations.py | 22 ++++++------- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/scilpy/image/volume_operations.py b/scilpy/image/volume_operations.py index e391c9880..89e175278 100644 --- a/scilpy/image/volume_operations.py +++ b/scilpy/image/volume_operations.py @@ -532,4 +532,4 @@ def merge_metrics(*arrays, beta=1.0): geometric_mean = np.power(array_product, 1 / len(arrays)) boosted_mean = geometric_mean ** beta - return ma.filled(boosted_mean, fill_value=np.nan) + return ma.filled(boosted_mean, fill_value=-2) diff --git a/scilpy/tractograms/tests/test_tractogram_operations.py b/scilpy/tractograms/tests/test_tractogram_operations.py index e488a6a43..8a345cf98 100644 --- a/scilpy/tractograms/tests/test_tractogram_operations.py +++ b/scilpy/tractograms/tests/test_tractogram_operations.py @@ -5,15 +5,19 @@ import numpy as np from dipy.io.streamline import load_tractogram +from dipy.io.stateful_tractogram import StatefulTractogram from scilpy.io.fetcher import fetch_data, get_testing_files_dict, get_home +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractograms.tractogram_operations import flip_sft, \ shuffle_streamlines, perform_tractogram_operation_on_lines, intersection, union, \ difference, intersection_robust, difference_robust, union_robust, \ - concatenate_sft, perform_tractogram_operation_on_sft + concatenate_sft, perform_tractogram_operation_on_sft, \ + tractogram_pairwise_comparison # Prepare SFT -fetch_data(get_testing_files_dict(), keys='surface_vtk_fib.zip') +fetch_data(get_testing_files_dict(), keys=['surface_vtk_fib.zip', + 'bst.zip']) tmp_dir = tempfile.TemporaryDirectory() in_sft = os.path.join(get_home(), 'surface_vtk_fib', 'gyri_fanning.trk') @@ -148,3 +152,28 @@ def test_combining_sft(): perform_tractogram_operation_on_sft('union', [sft, sft], precision=None, fake_metadata=False, no_metadata=False) + +def test_tractogram_pairwise_comparison(): + sft_path = os.path.join(get_home(), 'bst', 'template', 'rpt_m.trk') + print(sft_path) + sft = load_tractogram(sft_path, 'same') + sft_1 = StatefulTractogram.from_sft(sft.streamlines[0:100], sft) + sft_2 = StatefulTractogram.from_sft(sft.streamlines[100:200], sft) + + sft.to_vox() + sft.to_corner() + mask = compute_tract_counts_map(sft.streamlines, sft.dimensions) + mask[mask > 0] = 1 + + results = tractogram_pairwise_comparison(sft_1, sft_2, mask, + skip_streamlines_distance=False) + assert len(results) == 4 + for r in results: + assert np.array_equal(r.shape, sft.dimensions) + + assert np.mean(results[0][~np.isnan(results[0])]) == 0.7526672821759841 + assert np.mean(results[1][~np.isnan(results[1])]) == 0.2573811906676235 + assert np.mean(results[2][~np.isnan(results[2])]) == 0.6381088710917843 + assert np.mean(results[3][~np.isnan(results[3])]) == 0.7035940807380183 + + assert False diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index 274fe8e3a..b549a2e4c 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -864,6 +864,8 @@ def _compute_difference_for_voxel(chunk_indices, chunk of voxels. Use the function tractogram_pairwise_comparison() as an entry point. + To differentiate empty voxels from voxels with no data, the function + returns NaN if no data is found. Parameters ---------- @@ -888,18 +890,18 @@ def _compute_difference_for_voxel(chunk_indices, if has_data: sf_1 = np.dot(sh_data_1[vox_ind], B) sf_2 = np.dot(sh_data_2[vox_ind], B) - corr = np.corrcoef(sf_1, sf_2)[0, 1] + acc = np.corrcoef(sf_1, sf_2)[0, 1] else: - corr = -1 + acc = np.nan - if skip_streamlines_distance or not has_data: - results.append([-1, corr]) + if skip_streamlines_distance: + results.append([np.nan, acc]) continue # Get the streamlines in the neighborhood (i.e., 1.5mm away) pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) if not pts_ind_1: - results.append([-1, -1]) + results.append([np.nan, acc]) continue strs_ind_1 = np.unique(matched_points_1[pts_ind_1]) neighb_streamlines_1 = sft_1.streamlines[strs_ind_1] @@ -907,7 +909,7 @@ def _compute_difference_for_voxel(chunk_indices, # Get the streamlines in the neighborhood (i.e., 1.5mm away) pts_ind_2 = tree_2.query_ball_point(vox_ind, 1.5) if not pts_ind_2: - results.append([-1, -1]) + results.append([np.nan, acc]) continue strs_ind_2 = np.unique(matched_points_2[pts_ind_2]) neighb_streamlines_2 = sft_2.streamlines[strs_ind_2] @@ -927,7 +929,7 @@ def _compute_difference_for_voxel(chunk_indices, # dists will represent the average distance between the two sets of # streamlines in the neighborhood of the voxel. dist = np.average(sparse_ma_dist_vec) - results.append([dist, corr]) + results.append([dist, acc]) return results @@ -1052,7 +1054,7 @@ def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, tree_2 = cKDTree(sft_2.streamlines._data) # Limits computation to mask AND streamlines (using density) - if not mask: + if mask is None: mask = np.ones(dimensions) logging.info('Computing density maps...') @@ -1095,10 +1097,8 @@ def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, acc_norm = normalize_metric(acc_data) corr_norm = normalize_metric(corr_data) diff_norm = normalize_metric(diff_data, reverse=True) - indices_minus_one = np.where((acc_data == -1) | (corr_data == -1) | - (diff_data == -1)) + # Merge into a single heatmap heatmap = merge_metrics(acc_norm, corr_norm, diff_norm) - heatmap[indices_minus_one] = np.nan return acc_norm, corr_norm, diff_norm, heatmap From d674e17b9156bddd3b1487110bbe4ce45c768a2e Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 12 Feb 2024 12:02:34 -0500 Subject: [PATCH 10/18] Added test --- .../tests/test_tractogram_operations.py | 17 ++++++++++------- scilpy/tractograms/tractogram_operations.py | 11 ++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/scilpy/tractograms/tests/test_tractogram_operations.py b/scilpy/tractograms/tests/test_tractogram_operations.py index 8a345cf98..04d029036 100644 --- a/scilpy/tractograms/tests/test_tractogram_operations.py +++ b/scilpy/tractograms/tests/test_tractogram_operations.py @@ -166,14 +166,17 @@ def test_tractogram_pairwise_comparison(): mask[mask > 0] = 1 results = tractogram_pairwise_comparison(sft_1, sft_2, mask, - skip_streamlines_distance=False) + skip_streamlines_distance=False) assert len(results) == 4 for r in results: assert np.array_equal(r.shape, sft.dimensions) - - assert np.mean(results[0][~np.isnan(results[0])]) == 0.7526672821759841 - assert np.mean(results[1][~np.isnan(results[1])]) == 0.2573811906676235 - assert np.mean(results[2][~np.isnan(results[2])]) == 0.6381088710917843 - assert np.mean(results[3][~np.isnan(results[3])]) == 0.7035940807380183 - assert False + assert np.mean(results[0][~np.isnan(results[0])]) == 0.7171550368952226 + assert np.mean(results[1][~np.isnan(results[1])]) == 0.6063336089511456 + assert np.mean(results[2][~np.isnan(results[2])]) == 0.722988562131705 + assert np.mean(results[3][~np.isnan(results[3])]) == 0.7526672393158469 + + assert np.count_nonzero(np.isnan(results[0])) == 877627 + assert np.count_nonzero(np.isnan(results[1])) == 877014 + assert np.count_nonzero(np.isnan(results[2])) == 877034 + assert np.count_nonzero(np.isnan(results[3])) == 877671 diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index b549a2e4c..fd3d1ca1b 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -1067,6 +1067,7 @@ def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, logging.info('Computing correlation map...') corr_data = correlation([density_1, density_2], None) * mask + corr_data[mask == 0] = np.nan logging.info('Computing TODI #1...') global sh_data_1, sh_data_2 @@ -1093,12 +1094,8 @@ def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, diff_data, acc_data = _compare_tractogram_wrapper(mask, nbr_cpu, skip_streamlines_distance) - # Normalize metrics - acc_norm = normalize_metric(acc_data) - corr_norm = normalize_metric(corr_data) + # Normalize metrics and merge into a single heatmap diff_norm = normalize_metric(diff_data, reverse=True) + heatmap = merge_metrics(acc_data, corr_data, diff_norm) - # Merge into a single heatmap - heatmap = merge_metrics(acc_norm, corr_norm, diff_norm) - - return acc_norm, corr_norm, diff_norm, heatmap + return acc_data, corr_data, diff_norm, heatmap From 122db3d0ca18f911a9b179fc38281f2d17361663 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 21 Feb 2024 13:03:56 -0500 Subject: [PATCH 11/18] Address Emmanuel comments --- scilpy/tractograms/streamline_operations.py | 9 ++++++--- scilpy/tractograms/tractogram_operations.py | 8 ++++---- scripts/scil_tractogram_pairwise_comparison.py | 11 ++++++----- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/scilpy/tractograms/streamline_operations.py b/scilpy/tractograms/streamline_operations.py index ce9cc7f4a..77f8b59ae 100644 --- a/scilpy/tractograms/streamline_operations.py +++ b/scilpy/tractograms/streamline_operations.py @@ -84,6 +84,7 @@ def _get_point_on_line(first_point, second_point, vox_lower_corner): return first_point + ray * (t0 + t1) / 2. + def filter_streamlines_by_length(sft, min_length=0., max_length=np.inf): """ Filter streamlines using minimum and max length. @@ -392,8 +393,8 @@ def smooth_line_spline(streamline, smoothing_parameter, nb_ctrl_points): def generate_matched_points(sft): """ - Generate an array where each element i is set to the index of the - streamline that contributes the ith point. + Generates an array where each element i is set to the index of the + streamline to which it belongs Parameters: ----------- @@ -404,7 +405,9 @@ def generate_matched_points(sft): -------- matched_points : ndarray An array where each element is set to the index of the streamline - that contributes that point. + to which it belongs + + """ total_points = sft.streamlines._data.shape[0] offsets = sft.streamlines._offsets diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index fd3d1ca1b..55b80bd82 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -877,7 +877,7 @@ def _compute_difference_for_voxel(chunk_indices, Returns ------- results: list - List of the computed differences in the same order as the input voxel. + List of the computed differences in the same order as the input voxels. """ global sft_1, sft_2, matched_points_1, matched_points_2, tree_1, tree_2, \ sh_data_1, sh_data_2 @@ -940,7 +940,7 @@ def _compare_tractogram_wrapper(mask, nbr_cpu, skip_streamlines_distance): multiprocessing to compute the difference between two sets of streamlines for each voxel. - This function simple calls the function _compute_difference_for_voxel(), + This function simply calls the function _compute_difference_for_voxel(), which expect chunks of indices to process and use global variables to avoid duplicating the data for each chunk of voxels. @@ -1069,7 +1069,7 @@ def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, corr_data = correlation([density_1, density_2], None) * mask corr_data[mask == 0] = np.nan - logging.info('Computing TODI #1...') + logging.info('Computing TODI from tractogram #1...') global sh_data_1, sh_data_2 sft_1.to_corner() todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') @@ -1079,7 +1079,7 @@ def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, sh_data_1 = todi_obj.reshape_to_3d(sh_data_1) sft_1.to_center() - logging.info('Computing TODI #2...') + logging.info('Computing TODI from tractogram #2...') sft_2.to_corner() todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') todi_obj.compute_todi(deepcopy(sft_2.streamlines), length_weights=True) diff --git a/scripts/scil_tractogram_pairwise_comparison.py b/scripts/scil_tractogram_pairwise_comparison.py index 4e5ad693a..461c0fbb0 100644 --- a/scripts/scil_tractogram_pairwise_comparison.py +++ b/scripts/scil_tractogram_pairwise_comparison.py @@ -80,8 +80,8 @@ def main(): if args.verbose: logging.basicConfig(level=logging.INFO) - assert_inputs_exist(parser, [args.in_tractogram_1, - args.in_tractogram_2]) + assert_inputs_exist(parser, [args.in_tractogram_1, args.in_tractogram_2], + [args.in_mask, args.reference]) to_verify = [args.in_tractogram_1, args.in_tractogram_2] if args.in_mask: to_verify.append(args.in_mask) @@ -102,8 +102,8 @@ def main(): assert_output_dirs_exist_and_empty(parser, args, [], optional=args.out_dir) assert_outputs_exist(parser, args, [out_corr_filename, out_acc_filename, - out_diff_filename, - out_merge_filename]) + out_merge_filename], + out_diff_filename) nbr_cpu = validate_nbr_processes(parser, args) logging.info('Loading tractograms...') @@ -118,7 +118,8 @@ def main(): logging.info('Saving results...') nib.save(nib.Nifti1Image(acc_data, sft_1.affine), out_acc_filename) nib.save(nib.Nifti1Image(corr_data, sft_1.affine), out_corr_filename) - nib.save(nib.Nifti1Image(diff_data, sft_1.affine), out_diff_filename) + if not args.skip_streamlines_distance: + nib.save(nib.Nifti1Image(diff_data, sft_1.affine), out_diff_filename) nib.save(nib.Nifti1Image(heatmap, sft_1.affine), out_merge_filename) From b7e0d4a5dadb7fffa884be5959563c4387c6abdd Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 21 Feb 2024 16:12:42 -0500 Subject: [PATCH 12/18] Reorganize functions and tests --- scilpy/image/tests/test_volume_operations.py | 38 ++- scilpy/image/volume_operations.py | 2 +- .../tractanalysis/reproducibility_measures.py | 265 +++++++++++++++++- .../tests/test_reproducibility_measures.py | 41 +++ .../tests/test_tractogram_operations.py | 35 +-- scilpy/tractograms/tractogram_operations.py | 259 +---------------- .../scil_tractogram_pairwise_comparison.py | 7 +- 7 files changed, 345 insertions(+), 302 deletions(-) create mode 100644 scilpy/tractanalysis/tests/test_reproducibility_measures.py diff --git a/scilpy/image/tests/test_volume_operations.py b/scilpy/image/tests/test_volume_operations.py index 3b051b716..4cc7e47df 100644 --- a/scilpy/image/tests/test_volume_operations.py +++ b/scilpy/image/tests/test_volume_operations.py @@ -6,13 +6,15 @@ from dipy.io.gradients import read_bvals_bvecs import nibabel as nib import numpy as np -from numpy.testing import assert_equal +from numpy.testing import assert_equal, assert_almost_equal from scilpy.image.volume_operations import (flip_volume, crop_volume, apply_transform, compute_snr, - resample_volume) + resample_volume, + normalize_metric, + merge_metrics) from scilpy.io.fetcher import fetch_data, get_testing_files_dict, get_home from scilpy.utils.util import compute_nifti_bounding_box @@ -126,3 +128,35 @@ def test_resample_volume(): resampled_img = resample_volume(moving3d_img, res=(2, 2, 2), interp='nn') assert_equal(resampled_img.get_fdata(), ref3d) + + +def test_normalize_metric_basic(): + metric = np.array([1, 2, 3, 4, 5]) + expected_output = np.array([0., 0.25, 0.5, 0.75, 1.]) + normalized_metric = normalize_metric(metric) + assert_almost_equal(normalized_metric, expected_output) + + +def test_normalize_metric_nan_handling(): + metric = np.array([1, np.nan, 3, np.nan, 5]) + expected_output = np.array([0., np.nan, 0.5, np.nan, 1.]) + normalized_metric = normalize_metric(metric) + + assert_almost_equal(normalized_metric, expected_output) + + +def test_merge_metrics_basic(): + arrays = [np.array([1, 2, 3]), np.array([4, 5, 6])] + # Geometric mean boosted by beta=1 + expected_output = np.array([2.0, 3.162278, 4.242641]) + merged_metric = merge_metrics(*arrays) + + assert_almost_equal(merged_metric, expected_output, decimal=6) + + +def test_merge_metrics_nan_propagation(): + arrays = [np.array([1, np.nan, 3]), np.array([4, 5, 6])] + expected_output = np.array([2., np.nan, 4.242641]) # NaN replaced with -2 + merged_metric = merge_metrics(*arrays) + + assert_almost_equal(merged_metric, expected_output, decimal=6) diff --git a/scilpy/image/volume_operations.py b/scilpy/image/volume_operations.py index 727e78c04..b7963ca32 100644 --- a/scilpy/image/volume_operations.py +++ b/scilpy/image/volume_operations.py @@ -526,4 +526,4 @@ def merge_metrics(*arrays, beta=1.0): geometric_mean = np.power(array_product, 1 / len(arrays)) boosted_mean = geometric_mean ** beta - return ma.filled(boosted_mean, fill_value=-2) + return ma.filled(boosted_mean, fill_value=np.nan) diff --git a/scilpy/tractanalysis/reproducibility_measures.py b/scilpy/tractanalysis/reproducibility_measures.py index 5d3ca9544..3bdacdd59 100755 --- a/scilpy/tractanalysis/reproducibility_measures.py +++ b/scilpy/tractanalysis/reproducibility_measures.py @@ -1,17 +1,29 @@ # -*- coding: utf-8 -*- +from concurrent.futures import ProcessPoolExecutor, as_completed +from copy import deepcopy +import logging +import warnings +from dipy.data import get_sphere +from dipy.reconst.shm import sh_to_sf_matrix from dipy.segment.clustering import qbx_and_merge +from dipy.segment.fss import FastStreamlineSearch from dipy.tracking.distances import bundles_distances_mdf import numpy as np from numpy.random import RandomState from scipy.spatial import cKDTree from sklearn.metrics import cohen_kappa_score -from sklearn.neighbors import KDTree +from tqdm import tqdm +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from scilpy.tractanalysis.todi import TrackOrientationDensityImaging +from scilpy.tractograms.streamline_operations import generate_matched_points from scilpy.tractograms.tractogram_operations import (difference_robust, intersection_robust, union_robust) +from scilpy.image.volume_operations import (normalize_metric, merge_metrics) +from scilpy.image.volume_math import correlation def binary_classification(segmentation_indices, @@ -104,9 +116,9 @@ def approximate_surface_node(roi): int: the number of surface voxels """ ind = np.argwhere(roi > 0) - tree = KDTree(ind) - count = np.sum(7 - tree.query_radius(ind, r=1.0, - count_only=True)) + tree = cKDTree(ind) + neighbors = np.sum(7 - tree.query_radius(ind, r=1.0)) + count = [len(neighbor) for neighbor in neighbors] return count @@ -382,3 +394,248 @@ def compute_dice_streamlines(bundle_1, bundle_2): dice = np.nan return dice, streamlines_intersect, streamlines_union_robust + + +def _compute_difference_for_voxel(chunk_indices, + skip_streamlines_distance=False): + """ + Compute the difference between two sets of streamlines for a given voxel. + This function uses global variable to avoid duplicating the data for each + chunk of voxels. + + Use the function tractogram_pairwise_comparison() as an entry point. + To differentiate empty voxels from voxels with no data, the function + returns NaN if no data is found. + + Parameters + ---------- + chunk_indices: list + List of indices of the voxel to process. + skip_streamlines_distance: bool + If true, skip the computation of the distance between streamlines. + + Returns + ------- + results: list + List of the computed differences in the same order as the input voxels. + """ + global sft_1, sft_2, matched_points_1, matched_points_2, tree_1, tree_2, \ + sh_data_1, sh_data_2 + results = [] + for vox_ind in chunk_indices: + vox_ind = tuple(vox_ind) + + global B + has_data = sh_data_1[vox_ind].any() and sh_data_2[vox_ind].any() + if has_data: + sf_1 = np.dot(sh_data_1[vox_ind], B) + sf_2 = np.dot(sh_data_2[vox_ind], B) + acc = np.corrcoef(sf_1, sf_2)[0, 1] + else: + acc = np.nan + + if skip_streamlines_distance: + results.append([np.nan, acc]) + continue + + # Get the streamlines in the neighborhood (i.e., 1.5mm away) + pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) + if not pts_ind_1: + results.append([np.nan, acc]) + continue + strs_ind_1 = np.unique(matched_points_1[pts_ind_1]) + neighb_streamlines_1 = sft_1.streamlines[strs_ind_1] + + # Get the streamlines in the neighborhood (i.e., 1.5mm away) + pts_ind_2 = tree_2.query_ball_point(vox_ind, 1.5) + if not pts_ind_2: + results.append([np.nan, acc]) + continue + strs_ind_2 = np.unique(matched_points_2[pts_ind_2]) + neighb_streamlines_2 = sft_2.streamlines[strs_ind_2] + + # Using neighb_streamlines (all streamlines in the neighborhood of our + # voxel), we can compute the distance between the two sets of + # streamlines using FSS (FastStreamlineSearch). + with warnings.catch_warnings(record=True) as _: + fss = FastStreamlineSearch(neighb_streamlines_1, 10, resampling=12) + dist_mat = fss.radius_search(neighb_streamlines_2, 10) + sparse_dist_mat = np.abs(dist_mat.tocsr()).toarray() + sparse_ma_dist_mat = np.ma.masked_where(sparse_dist_mat < 1e-3, + sparse_dist_mat) + sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, + axis=0)) + + # dists will represent the average distance between the two sets of + # streamlines in the neighborhood of the voxel. + dist = np.average(sparse_ma_dist_vec) + results.append([dist, acc]) + + return results + + +def _compare_tractogram_wrapper(mask, nbr_cpu, skip_streamlines_distance): + """ + Wrapper for the comparison of two tractograms. This function uses + multiprocessing to compute the difference between two sets of streamlines + for each voxel. + + This function simply calls the function _compute_difference_for_voxel(), + which expect chunks of indices to process and use global variables to avoid + duplicating the data for each chunk of voxels. + + Use the function tractogram_pairwise_comparison() as an entry point. + + Parameters + ---------- + mask: np.ndarray + Mask of the data to compare. + nbr_cpu: int + Number of CPU to use. + skip_streamlines_distance: bool + If true, skip the computation of the distance between streamlines. + + Returns + ------- + Tuple of np.ndarray + diff_data: np.ndarray + Array containing the computed differences (mm). + acc_data: np.ndarray + Array containing the computed angular correlation. + """ + dimensions = mask.shape + + # Initialize multiprocessing + indices = np.argwhere(mask > 0) + diff_data = np.zeros(dimensions) + diff_data[:] = np.nan + acc_data = np.zeros(dimensions) + acc_data[:] = np.nan + + def chunked_indices(indices, chunk_size=1000): + """Yield successive chunk_size chunks from indices.""" + for i in range(0, len(indices), chunk_size): + yield indices[i:i + chunk_size] + + # Initialize tqdm progress bar + progress_bar = tqdm(total=len(indices)) + + # Create chunks of indices + np.random.shuffle(indices) + index_chunks = list(chunked_indices(indices)) + + with ProcessPoolExecutor(max_workers=nbr_cpu) as executor: + futures = {executor.submit( + _compute_difference_for_voxel, chunk, + skip_streamlines_distance): chunk for chunk in index_chunks} + + for future in as_completed(futures): + chunk = futures[future] + try: + results = future.result() + except Exception as exc: + print(f'Generated an exception: {exc}') + else: + results = np.array(results) + diff_data[tuple(chunk.T)] = results[:, 0] + acc_data[tuple(chunk.T)] = results[:, 1] + + # Update tqdm progress bar + progress_bar.update(len(chunk)) + + return diff_data, acc_data + + +def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, + skip_streamlines_distance=True): + """ + Compute the difference between two sets of streamlines for each voxel in + the mask. This function uses multiprocessing to compute the difference + between two sets of streamlines for each voxel. + + Parameters + ---------- + sft_one: StatefulTractogram + First tractogram to compare. + sft_two: StatefulTractogram + Second tractogram to compare. + mask: np.ndarray + Mask of the data to compare (optional). + nbr_cpu: int + Number of CPU to use (default: 1). + skip_streamlines_distance: bool + If true, skip the computation of the distance between streamlines. + (default: True) + + Returns + ------- + List of np.ndarray + acc_norm: Angular correlation coefficient. + corr_norm: Correlation coefficient of density maps. + diff_norm: Voxelwise distance between sets of streamlines. + heatmap: Merged heatmap of the three metrics using harmonic mean. + """ + global sft_1, sft_2 + sft_1, sft_2 = sft_one, sft_two + + sft_1.to_vox() + sft_2.to_vox() + sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) + sft_2.streamlines._data = sft_2.streamlines._data.astype(np.float16) + dimensions = tuple(sft_1.dimensions) + + global matched_points_1, matched_points_2 + matched_points_1 = generate_matched_points(sft_1) + matched_points_2 = generate_matched_points(sft_2) + + logging.info('Computing KDTree...') + global tree_1, tree_2 + tree_1 = cKDTree(sft_1.streamlines._data) + tree_2 = cKDTree(sft_2.streamlines._data) + + # Limits computation to mask AND streamlines (using density) + if mask is None: + mask = np.ones(dimensions) + + logging.info('Computing density maps...') + density_1 = compute_tract_counts_map(sft_1.streamlines, + dimensions).astype(float) + density_2 = compute_tract_counts_map(sft_2.streamlines, + dimensions).astype(float) + mask = density_1 * density_2 * mask + mask[mask > 0] = 1 + + logging.info('Computing correlation map...') + corr_data = correlation([density_1, density_2], None) * mask + corr_data[mask == 0] = np.nan + + logging.info('Computing TODI from tractogram #1...') + global sh_data_1, sh_data_2 + sft_1.to_corner() + todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') + todi_obj.compute_todi(deepcopy(sft_1.streamlines), length_weights=True) + todi_obj.mask_todi(mask) + sh_data_1 = todi_obj.get_sh('descoteaux07', 8) + sh_data_1 = todi_obj.reshape_to_3d(sh_data_1) + sft_1.to_center() + + logging.info('Computing TODI from tractogram #2...') + sft_2.to_corner() + todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') + todi_obj.compute_todi(deepcopy(sft_2.streamlines), length_weights=True) + todi_obj.mask_todi(mask) + sh_data_2 = todi_obj.get_sh('descoteaux07', 8) + sh_data_2 = todi_obj.reshape_to_3d(sh_data_2) + sft_2.to_center() + + global B + B, _ = sh_to_sf_matrix(get_sphere('repulsion724'), 8, 'descoteaux07') + + diff_data, acc_data = _compare_tractogram_wrapper(mask, nbr_cpu, + skip_streamlines_distance) + + # Normalize metrics and merge into a single heatmap + diff_data_norm = normalize_metric(diff_data, reverse=True) + heatmap = merge_metrics(acc_data, corr_data, diff_data_norm) + + return acc_data, corr_data, diff_data_norm, heatmap diff --git a/scilpy/tractanalysis/tests/test_reproducibility_measures.py b/scilpy/tractanalysis/tests/test_reproducibility_measures.py new file mode 100644 index 000000000..968471e20 --- /dev/null +++ b/scilpy/tractanalysis/tests/test_reproducibility_measures.py @@ -0,0 +1,41 @@ + +# -*- coding: utf-8 -*- + +import os + +from dipy.io.stateful_tractogram import StatefulTractogram +from dipy.io.streamline import load_tractogram +import numpy as np + +from scilpy.io.fetcher import get_home +from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map +from scilpy.tractanalysis.reproducibility_measures import tractogram_pairwise_comparison + + +def test_tractogram_pairwise_comparison(): + sft_path = os.path.join(get_home(), 'bst', 'template', 'rpt_m.trk') + print(sft_path) + sft = load_tractogram(sft_path, 'same') + sft_1 = StatefulTractogram.from_sft(sft.streamlines[0:100], sft) + sft_2 = StatefulTractogram.from_sft(sft.streamlines[100:200], sft) + + sft.to_vox() + sft.to_corner() + mask = compute_tract_counts_map(sft.streamlines, sft.dimensions) + mask[mask > 0] = 1 + + results = tractogram_pairwise_comparison(sft_1, sft_2, mask, + skip_streamlines_distance=False) + assert len(results) == 4 + for r in results: + assert np.array_equal(r.shape, sft.dimensions) + + assert np.mean(results[0][~np.isnan(results[0])]) == 0.7171550368952226 + assert np.mean(results[1][~np.isnan(results[1])]) == 0.6063336089511456 + assert np.mean(results[2][~np.isnan(results[2])]) == 0.722988562131705 + assert np.mean(results[3][~np.isnan(results[3])]) == 0.7526672393158469 + + assert np.count_nonzero(np.isnan(results[0])) == 877627 + assert np.count_nonzero(np.isnan(results[1])) == 877014 + assert np.count_nonzero(np.isnan(results[2])) == 877034 + assert np.count_nonzero(np.isnan(results[3])) == 877671 diff --git a/scilpy/tractograms/tests/test_tractogram_operations.py b/scilpy/tractograms/tests/test_tractogram_operations.py index 04d029036..96fd633f8 100644 --- a/scilpy/tractograms/tests/test_tractogram_operations.py +++ b/scilpy/tractograms/tests/test_tractogram_operations.py @@ -1,19 +1,17 @@ # -*- coding: utf-8 -*- + import logging import os import tempfile import numpy as np from dipy.io.streamline import load_tractogram -from dipy.io.stateful_tractogram import StatefulTractogram from scilpy.io.fetcher import fetch_data, get_testing_files_dict, get_home -from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map from scilpy.tractograms.tractogram_operations import flip_sft, \ shuffle_streamlines, perform_tractogram_operation_on_lines, intersection, union, \ difference, intersection_robust, difference_robust, union_robust, \ - concatenate_sft, perform_tractogram_operation_on_sft, \ - tractogram_pairwise_comparison + concatenate_sft, perform_tractogram_operation_on_sft # Prepare SFT fetch_data(get_testing_files_dict(), keys=['surface_vtk_fib.zip', @@ -151,32 +149,3 @@ def test_combining_sft(): # todo perform_tractogram_operation_on_sft('union', [sft, sft], precision=None, fake_metadata=False, no_metadata=False) - - -def test_tractogram_pairwise_comparison(): - sft_path = os.path.join(get_home(), 'bst', 'template', 'rpt_m.trk') - print(sft_path) - sft = load_tractogram(sft_path, 'same') - sft_1 = StatefulTractogram.from_sft(sft.streamlines[0:100], sft) - sft_2 = StatefulTractogram.from_sft(sft.streamlines[100:200], sft) - - sft.to_vox() - sft.to_corner() - mask = compute_tract_counts_map(sft.streamlines, sft.dimensions) - mask[mask > 0] = 1 - - results = tractogram_pairwise_comparison(sft_1, sft_2, mask, - skip_streamlines_distance=False) - assert len(results) == 4 - for r in results: - assert np.array_equal(r.shape, sft.dimensions) - - assert np.mean(results[0][~np.isnan(results[0])]) == 0.7171550368952226 - assert np.mean(results[1][~np.isnan(results[1])]) == 0.6063336089511456 - assert np.mean(results[2][~np.isnan(results[2])]) == 0.722988562131705 - assert np.mean(results[3][~np.isnan(results[3])]) == 0.7526672393158469 - - assert np.count_nonzero(np.isnan(results[0])) == 877627 - assert np.count_nonzero(np.isnan(results[1])) == 877014 - assert np.count_nonzero(np.isnan(results[2])) == 877034 - assert np.count_nonzero(np.isnan(results[3])) == 877671 diff --git a/scilpy/tractograms/tractogram_operations.py b/scilpy/tractograms/tractogram_operations.py index 64c794d42..0d508f9ac 100644 --- a/scilpy/tractograms/tractogram_operations.py +++ b/scilpy/tractograms/tractogram_operations.py @@ -7,35 +7,23 @@ individually. See scilpy.tractograms.streamline_operations.py for the latter. """ -from concurrent.futures import ProcessPoolExecutor, as_completed -from copy import deepcopy from functools import reduce import itertools import logging import random -import warnings -from dipy.data import get_sphere from dipy.io.stateful_tractogram import StatefulTractogram, Space from dipy.io.utils import get_reference_info, is_header_compatible from dipy.segment.clustering import qbx_and_merge -from dipy.segment.fss import FastStreamlineSearch from dipy.tracking.streamline import transform_streamlines -from dipy.reconst.shm import sh_to_sf_matrix from nibabel.streamlines.array_sequence import ArraySequence import numpy as np from scipy.ndimage import map_coordinates from scipy.spatial import cKDTree -from tqdm import tqdm -from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map -from scilpy.tractanalysis.todi import TrackOrientationDensityImaging -from scilpy.tractograms.streamline_operations import (generate_matched_points, - smooth_line_gaussian, +from scilpy.tractograms.streamline_operations import (smooth_line_gaussian, smooth_line_spline) -from scilpy.image.volume_operations import (normalize_metric, merge_metrics) from scilpy.utils.streamlines import cut_invalid_streamlines -from scilpy.image.volume_math import correlation MIN_NB_POINTS = 10 KEY_INDEX = np.concatenate((range(5), range(-1, -6, -1))) @@ -854,248 +842,3 @@ def split_sft_randomly_per_cluster(orig_sft, chunk_sizes, seed, thresholds): 'concatenate': 'concatenate', 'lazy_concatenate': 'lazy_concatenate' } - - -def _compute_difference_for_voxel(chunk_indices, - skip_streamlines_distance=False): - """ - Compute the difference between two sets of streamlines for a given voxel. - This function uses global variable to avoid duplicating the data for each - chunk of voxels. - - Use the function tractogram_pairwise_comparison() as an entry point. - To differentiate empty voxels from voxels with no data, the function - returns NaN if no data is found. - - Parameters - ---------- - chunk_indices: list - List of indices of the voxel to process. - skip_streamlines_distance: bool - If true, skip the computation of the distance between streamlines. - - Returns - ------- - results: list - List of the computed differences in the same order as the input voxels. - """ - global sft_1, sft_2, matched_points_1, matched_points_2, tree_1, tree_2, \ - sh_data_1, sh_data_2 - results = [] - for vox_ind in chunk_indices: - vox_ind = tuple(vox_ind) - - global B - has_data = sh_data_1[vox_ind].any() and sh_data_2[vox_ind].any() - if has_data: - sf_1 = np.dot(sh_data_1[vox_ind], B) - sf_2 = np.dot(sh_data_2[vox_ind], B) - acc = np.corrcoef(sf_1, sf_2)[0, 1] - else: - acc = np.nan - - if skip_streamlines_distance: - results.append([np.nan, acc]) - continue - - # Get the streamlines in the neighborhood (i.e., 1.5mm away) - pts_ind_1 = tree_1.query_ball_point(vox_ind, 1.5) - if not pts_ind_1: - results.append([np.nan, acc]) - continue - strs_ind_1 = np.unique(matched_points_1[pts_ind_1]) - neighb_streamlines_1 = sft_1.streamlines[strs_ind_1] - - # Get the streamlines in the neighborhood (i.e., 1.5mm away) - pts_ind_2 = tree_2.query_ball_point(vox_ind, 1.5) - if not pts_ind_2: - results.append([np.nan, acc]) - continue - strs_ind_2 = np.unique(matched_points_2[pts_ind_2]) - neighb_streamlines_2 = sft_2.streamlines[strs_ind_2] - - # Using neighb_streamlines (all streamlines in the neighborhood of our - # voxel), we can compute the distance between the two sets of - # streamlines using FSS (FastStreamlineSearch). - with warnings.catch_warnings(record=True) as _: - fss = FastStreamlineSearch(neighb_streamlines_1, 10, resampling=12) - dist_mat = fss.radius_search(neighb_streamlines_2, 10) - sparse_dist_mat = np.abs(dist_mat.tocsr()).toarray() - sparse_ma_dist_mat = np.ma.masked_where(sparse_dist_mat < 1e-3, - sparse_dist_mat) - sparse_ma_dist_vec = np.squeeze(np.min(sparse_ma_dist_mat, - axis=0)) - - # dists will represent the average distance between the two sets of - # streamlines in the neighborhood of the voxel. - dist = np.average(sparse_ma_dist_vec) - results.append([dist, acc]) - - return results - - -def _compare_tractogram_wrapper(mask, nbr_cpu, skip_streamlines_distance): - """ - Wrapper for the comparison of two tractograms. This function uses - multiprocessing to compute the difference between two sets of streamlines - for each voxel. - - This function simply calls the function _compute_difference_for_voxel(), - which expect chunks of indices to process and use global variables to avoid - duplicating the data for each chunk of voxels. - - Use the function tractogram_pairwise_comparison() as an entry point. - - Parameters - ---------- - mask: np.ndarray - Mask of the data to compare. - nbr_cpu: int - Number of CPU to use. - skip_streamlines_distance: bool - If true, skip the computation of the distance between streamlines. - - Returns - ------- - Tuple of np.ndarray - diff_data: np.ndarray - Array containing the computed differences (mm). - acc_data: np.ndarray - Array containing the computed angular correlation. - """ - dimensions = mask.shape - - # Initialize multiprocessing - indices = np.argwhere(mask > 0) - diff_data = np.zeros(dimensions) - diff_data[:] = np.nan - acc_data = np.zeros(dimensions) - acc_data[:] = np.nan - - def chunked_indices(indices, chunk_size=1000): - """Yield successive chunk_size chunks from indices.""" - for i in range(0, len(indices), chunk_size): - yield indices[i:i + chunk_size] - - # Initialize tqdm progress bar - progress_bar = tqdm(total=len(indices)) - - # Create chunks of indices - np.random.shuffle(indices) - index_chunks = list(chunked_indices(indices)) - - with ProcessPoolExecutor(max_workers=nbr_cpu) as executor: - futures = {executor.submit( - _compute_difference_for_voxel, chunk, - skip_streamlines_distance): chunk for chunk in index_chunks} - - for future in as_completed(futures): - chunk = futures[future] - try: - results = future.result() - except Exception as exc: - print(f'Generated an exception: {exc}') - else: - results = np.array(results) - diff_data[tuple(chunk.T)] = results[:, 0] - acc_data[tuple(chunk.T)] = results[:, 1] - - # Update tqdm progress bar - progress_bar.update(len(chunk)) - - return diff_data, acc_data - - -def tractogram_pairwise_comparison(sft_one, sft_two, mask, nbr_cpu=1, - skip_streamlines_distance=True): - """ - Compute the difference between two sets of streamlines for each voxel in - the mask. This function uses multiprocessing to compute the difference - between two sets of streamlines for each voxel. - - Parameters - ---------- - sft_one: StatefulTractogram - First tractogram to compare. - sft_two: StatefulTractogram - Second tractogram to compare. - mask: np.ndarray - Mask of the data to compare (optional). - nbr_cpu: int - Number of CPU to use (default: 1). - skip_streamlines_distance: bool - If true, skip the computation of the distance between streamlines. - (default: True) - - Returns - ------- - List of np.ndarray - acc_norm: Angular correlation coefficient. - corr_norm: Correlation coefficient of density maps. - diff_norm: Voxelwise distance between sets of streamlines. - heatmap: Merged heatmap of the three metrics using harmonic mean. - """ - global sft_1, sft_2 - sft_1, sft_2 = sft_one, sft_two - - sft_1.to_vox() - sft_2.to_vox() - sft_1.streamlines._data = sft_1.streamlines._data.astype(np.float16) - sft_2.streamlines._data = sft_2.streamlines._data.astype(np.float16) - dimensions = tuple(sft_1.dimensions) - - global matched_points_1, matched_points_2 - matched_points_1 = generate_matched_points(sft_1) - matched_points_2 = generate_matched_points(sft_2) - - logging.info('Computing KDTree...') - global tree_1, tree_2 - tree_1 = cKDTree(sft_1.streamlines._data) - tree_2 = cKDTree(sft_2.streamlines._data) - - # Limits computation to mask AND streamlines (using density) - if mask is None: - mask = np.ones(dimensions) - - logging.info('Computing density maps...') - density_1 = compute_tract_counts_map(sft_1.streamlines, - dimensions).astype(float) - density_2 = compute_tract_counts_map(sft_2.streamlines, - dimensions).astype(float) - mask = density_1 * density_2 * mask - mask[mask > 0] = 1 - - logging.info('Computing correlation map...') - corr_data = correlation([density_1, density_2], None) * mask - corr_data[mask == 0] = np.nan - - logging.info('Computing TODI from tractogram #1...') - global sh_data_1, sh_data_2 - sft_1.to_corner() - todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') - todi_obj.compute_todi(deepcopy(sft_1.streamlines), length_weights=True) - todi_obj.mask_todi(mask) - sh_data_1 = todi_obj.get_sh('descoteaux07', 8) - sh_data_1 = todi_obj.reshape_to_3d(sh_data_1) - sft_1.to_center() - - logging.info('Computing TODI from tractogram #2...') - sft_2.to_corner() - todi_obj = TrackOrientationDensityImaging(dimensions, 'repulsion724') - todi_obj.compute_todi(deepcopy(sft_2.streamlines), length_weights=True) - todi_obj.mask_todi(mask) - sh_data_2 = todi_obj.get_sh('descoteaux07', 8) - sh_data_2 = todi_obj.reshape_to_3d(sh_data_2) - sft_2.to_center() - - global B - B, _ = sh_to_sf_matrix(get_sphere('repulsion724'), 8, 'descoteaux07') - - diff_data, acc_data = _compare_tractogram_wrapper(mask, nbr_cpu, - skip_streamlines_distance) - - # Normalize metrics and merge into a single heatmap - diff_norm = normalize_metric(diff_data, reverse=True) - heatmap = merge_metrics(acc_data, corr_data, diff_norm) - - return acc_data, corr_data, diff_norm, heatmap diff --git a/scripts/scil_tractogram_pairwise_comparison.py b/scripts/scil_tractogram_pairwise_comparison.py index 461c0fbb0..47d261810 100644 --- a/scripts/scil_tractogram_pairwise_comparison.py +++ b/scripts/scil_tractogram_pairwise_comparison.py @@ -38,7 +38,7 @@ is_header_compatible_multiple_files, load_tractogram_with_reference, validate_nbr_processes) -from scilpy.tractograms.tractogram_operations import tractogram_pairwise_comparison +from scilpy.tractanalysis.reproducibility_measures import tractogram_pairwise_comparison def _build_arg_parser(): @@ -64,7 +64,7 @@ def _build_arg_parser(): help='Optional input mask.') p.add_argument('--skip_streamlines_distance', action='store_true', help='Skip computation of the spatial distance between ' - 'streamlines.') + 'streamlines. Slowest part of the computation.') add_processes_arg(p) add_reference_arg(p) add_verbose_arg(p) @@ -77,8 +77,7 @@ def main(): parser = _build_arg_parser() args = parser.parse_args() - if args.verbose: - logging.basicConfig(level=logging.INFO) + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) assert_inputs_exist(parser, [args.in_tractogram_1, args.in_tractogram_2], [args.in_mask, args.reference]) From 33ce72f4e0415b3a5892d6b043255349a20bcf37 Mon Sep 17 00:00:00 2001 From: frheault Date: Tue, 27 Feb 2024 15:53:42 -0500 Subject: [PATCH 13/18] Fix errors --- scilpy/image/volume_operations.py | 1 + scilpy/tractanalysis/reproducibility_measures.py | 2 +- scripts/scil_tractogram_pairwise_comparison.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/scilpy/image/volume_operations.py b/scilpy/image/volume_operations.py index d805bedae..716fd32ef 100644 --- a/scilpy/image/volume_operations.py +++ b/scilpy/image/volume_operations.py @@ -14,6 +14,7 @@ from dipy.segment.mask import crop, median_otsu import nibabel as nib import numpy as np +from numpy import ma from scipy.ndimage import binary_dilation, gaussian_filter from scilpy.image.reslice import reslice # Don't use Dipy's reslice. Buggy. diff --git a/scilpy/tractanalysis/reproducibility_measures.py b/scilpy/tractanalysis/reproducibility_measures.py index a7cfa2ce6..e80cc9d0c 100755 --- a/scilpy/tractanalysis/reproducibility_measures.py +++ b/scilpy/tractanalysis/reproducibility_measures.py @@ -122,7 +122,7 @@ def approximate_surface_node(roi): """ ind = np.argwhere(roi > 0) tree = cKDTree(ind) - neighbors = np.sum(7 - tree.query_radius(ind, r=1.0)) + neighbors = np.sum(7 - tree.query_ball_point(ind, r=1.0)) count = [len(neighbor) for neighbor in neighbors] return count diff --git a/scripts/scil_tractogram_pairwise_comparison.py b/scripts/scil_tractogram_pairwise_comparison.py index 47d261810..e373d9dd6 100644 --- a/scripts/scil_tractogram_pairwise_comparison.py +++ b/scripts/scil_tractogram_pairwise_comparison.py @@ -29,6 +29,7 @@ import nibabel as nib +from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, assert_inputs_exist, assert_outputs_exist, @@ -36,7 +37,6 @@ add_processes_arg, add_verbose_arg, is_header_compatible_multiple_files, - load_tractogram_with_reference, validate_nbr_processes) from scilpy.tractanalysis.reproducibility_measures import tractogram_pairwise_comparison From 8593c925f5ec881cd581f7b7296bf5f827017e20 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 28 Feb 2024 08:46:28 -0500 Subject: [PATCH 14/18] Fix kdtree type error --- scilpy/tractanalysis/reproducibility_measures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scilpy/tractanalysis/reproducibility_measures.py b/scilpy/tractanalysis/reproducibility_measures.py index e80cc9d0c..5797d2afb 100755 --- a/scilpy/tractanalysis/reproducibility_measures.py +++ b/scilpy/tractanalysis/reproducibility_measures.py @@ -122,7 +122,7 @@ def approximate_surface_node(roi): """ ind = np.argwhere(roi > 0) tree = cKDTree(ind) - neighbors = np.sum(7 - tree.query_ball_point(ind, r=1.0)) + neighbors = np.sum(7 - len(tree.query_ball_point(ind, r=1.0))) count = [len(neighbor) for neighbor in neighbors] return count From 813f3e7158aeb6910d8a61912343ad356bdfb3a2 Mon Sep 17 00:00:00 2001 From: frheault Date: Wed, 28 Feb 2024 08:56:00 -0500 Subject: [PATCH 15/18] Fix kdtree type error --- scilpy/tractanalysis/reproducibility_measures.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scilpy/tractanalysis/reproducibility_measures.py b/scilpy/tractanalysis/reproducibility_measures.py index 5797d2afb..8bda05051 100755 --- a/scilpy/tractanalysis/reproducibility_measures.py +++ b/scilpy/tractanalysis/reproducibility_measures.py @@ -14,6 +14,7 @@ from numpy.random import RandomState from scipy.spatial import cKDTree from sklearn.metrics import cohen_kappa_score +from sklearn.neighbors import KDTree from tqdm import tqdm from scilpy.tractanalysis.streamlines_metrics import compute_tract_counts_map @@ -121,9 +122,11 @@ def approximate_surface_node(roi): int: the number of surface voxels """ ind = np.argwhere(roi > 0) - tree = cKDTree(ind) - neighbors = np.sum(7 - len(tree.query_ball_point(ind, r=1.0))) - count = [len(neighbor) for neighbor in neighbors] + tree = KDTree(ind) + count = np.sum(7 - tree.query_radius(ind, r=1.0, + count_only=True)) + + return count return count From 607581038b227eb974303d620cbbbe6c18040286 Mon Sep 17 00:00:00 2001 From: frheault Date: Fri, 1 Mar 2024 12:59:49 -0500 Subject: [PATCH 16/18] Remove data from moved test --- scilpy/tractograms/tests/test_tractogram_operations.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scilpy/tractograms/tests/test_tractogram_operations.py b/scilpy/tractograms/tests/test_tractogram_operations.py index 96fd633f8..4ab1db8cd 100644 --- a/scilpy/tractograms/tests/test_tractogram_operations.py +++ b/scilpy/tractograms/tests/test_tractogram_operations.py @@ -14,8 +14,7 @@ concatenate_sft, perform_tractogram_operation_on_sft # Prepare SFT -fetch_data(get_testing_files_dict(), keys=['surface_vtk_fib.zip', - 'bst.zip']) +fetch_data(get_testing_files_dict(), keys=['surface_vtk_fib.zip']) tmp_dir = tempfile.TemporaryDirectory() in_sft = os.path.join(get_home(), 'surface_vtk_fib', 'gyri_fanning.trk') From e5272e5bdf34cbf452ab94257b734970a9f91873 Mon Sep 17 00:00:00 2001 From: frheault Date: Sun, 17 Mar 2024 21:46:15 -0400 Subject: [PATCH 17/18] Passing test --- scripts/scil_tractogram_pairwise_comparison.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/scripts/scil_tractogram_pairwise_comparison.py b/scripts/scil_tractogram_pairwise_comparison.py index e373d9dd6..c555f3b11 100644 --- a/scripts/scil_tractogram_pairwise_comparison.py +++ b/scripts/scil_tractogram_pairwise_comparison.py @@ -36,7 +36,7 @@ assert_output_dirs_exist_and_empty, add_processes_arg, add_verbose_arg, - is_header_compatible_multiple_files, + assert_headers_compatible, validate_nbr_processes) from scilpy.tractanalysis.reproducibility_measures import tractogram_pairwise_comparison @@ -84,9 +84,7 @@ def main(): to_verify = [args.in_tractogram_1, args.in_tractogram_2] if args.in_mask: to_verify.append(args.in_mask) - is_header_compatible_multiple_files(parser, to_verify, - verbose_all_compatible=True, - reference=args.reference) + assert_headers_compatible(parser, to_verify, reference=args.reference) if args.out_prefix and args.out_prefix[-1] == '_': args.out_prefix = args.out_prefix[:-1] From 3d52b37f1cdc86b052b342ccf6ef88c63bb8d1e7 Mon Sep 17 00:00:00 2001 From: frheault Date: Mon, 18 Mar 2024 12:11:54 -0400 Subject: [PATCH 18/18] Missing import --- scilpy/image/tests/test_volume_operations.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scilpy/image/tests/test_volume_operations.py b/scilpy/image/tests/test_volume_operations.py index 5aab78f4e..02d52e3b1 100644 --- a/scilpy/image/tests/test_volume_operations.py +++ b/scilpy/image/tests/test_volume_operations.py @@ -6,13 +6,15 @@ import nibabel as nib import numpy as np from dipy.io.gradients import read_bvals_bvecs -from numpy.testing import assert_equal +from numpy.testing import assert_equal, assert_almost_equal from scilpy import SCILPY_HOME from scilpy.image.volume_operations import (apply_transform, compute_snr, crop_volume, flip_volume, + merge_metrics, + normalize_metric, resample_volume) from scilpy.io.fetcher import fetch_data, get_testing_files_dict from scilpy.utils.util import compute_nifti_bounding_box