From bd64526369937d48b8708841916044e62ae64cb1 Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 4 Jan 2024 15:17:37 -0500 Subject: [PATCH 1/3] [NF] add concatenate tractogram --- dipy/io/streamline.py | 195 +++++++++++++++++++++++++++++++ dipy/io/tests/test_streamline.py | 14 ++- 2 files changed, 208 insertions(+), 1 deletion(-) diff --git a/dipy/io/streamline.py b/dipy/io/streamline.py index b4d284da29..a9d32e51b6 100644 --- a/dipy/io/streamline.py +++ b/dipy/io/streamline.py @@ -260,3 +260,198 @@ def f_gen(sft, filename, bbox_valid_check=True): save_vtp = save_generator('.vtp') save_fib = save_generator('.fib') save_dpy = save_generator('.dpy') + + +def concatenate_tractogram(tractogram_list, *, delete_dpv=False, + delete_dps=False, delete_groups=False, + check_space_attributes=True, preallocation=False): + """Concatenate multiple tractograms into one. + + If the data_per_point or data_per_streamline is not the same for all + tractograms, the data must be deleted first. + + Parameters + ---------- + tractogram_list : List[StatefulTractogram or TrxFile] + The stateful tractogram to concatenate + delete_dpv: bool, optional + Delete dpv keys that do not exist in all the provided TrxFiles + delete_dps: bool, optional + Delete dps keys that do not exist in all the provided TrxFile + delete_groups: bool, optional + Delete all the groups that currently exist in the TrxFiles + check_space_attributes: bool, optional + Verify that dimensions and size of data are similar between all the + TrxFiles + preallocation: bool, optional + Preallocated TrxFile has already been generated and is the first + element in trx_list (Note: delete_groups must be set to True as well) + + Returns + ------- + new_trx: TrxFile + TrxFile representing the concatenated data + + """ + trx_list = [] + for sft in tractogram_list: + if not isinstance(sft, tmm.TrxFile): + sft = tmm.TrxFile.from_sft(sft) + elif len(sft.groups): + delete_groups = True + trx_list.append(sft) + + trx_list = [curr_trx for curr_trx in trx_list + if curr_trx.header["NB_STREAMLINES"] > 0] + + if not trx_list: + logging.warning("Inputs of concatenation were empty.") + return tmm.TrxFile() + + if len(trx_list) == 1: + if len(tractogram_list) > 1: + logging.warning("Only 1 valid tractogram returned.") + return trx_list[0] + + ref_trx = trx_list[0] + all_dps = [] + all_dpv = [] + for curr_trx in trx_list: + all_dps.extend(list(curr_trx.data_per_streamline.keys())) + all_dpv.extend(list(curr_trx.data_per_vertex.keys())) + all_dps, all_dpv = set(all_dps), set(all_dpv) + + if check_space_attributes: + for curr_trx in trx_list[1:]: + if not np.allclose(ref_trx.header["VOXEL_TO_RASMM"], + curr_trx.header["VOXEL_TO_RASMM"]) or \ + not np.array_equal( + ref_trx.header["DIMENSIONS"], + curr_trx.header["DIMENSIONS"] + ): + raise ValueError("Wrong space attributes.") + + if preallocation and not delete_groups: + raise ValueError( + "Groups are variables, cannot be handled with " "preallocation" + ) + + # Verifying the validity of fixed-size arrays, coherence between inputs + for curr_trx in trx_list: + for key in all_dpv: + if key not in ref_trx.data_per_vertex.keys() \ + or key not in curr_trx.data_per_vertex.keys(): + if not delete_dpv: + logging.debug( + "{} dpv key does not exist in all TrxFile.".format(key) + ) + raise ValueError( + "TrxFile must be sharing identical dpv " "keys.") + elif ( + ref_trx.data_per_vertex[key]._data.dtype + != curr_trx.data_per_vertex[key]._data.dtype + ): + logging.debug( + "{} dpv key is not declared with the same dtype " + "in all TrxFile.".format(key) + ) + raise ValueError("Shared dpv key, has different dtype.") + + for curr_trx in trx_list: + for key in all_dps: + if key not in ref_trx.data_per_streamline.keys() \ + or key not in curr_trx.data_per_streamline.keys(): + if not delete_dps: + logging.debug( + "{} dps key does not exist in all " "TrxFile.".format( + key) + ) + raise ValueError( + "TrxFile must be sharing identical dps " "keys.") + elif ( + ref_trx.data_per_streamline[key].dtype + != curr_trx.data_per_streamline[key].dtype + ): + logging.debug( + "{} dps key is not declared with the same dtype " + "in all TrxFile.".format(key) + ) + raise ValueError("Shared dps key, has different dtype.") + + all_groups_len = {} + all_groups_dtype = {} + # Variable-size arrays do not have to exist in all TrxFile + if not delete_groups: + for trx_1 in trx_list: + for group_key in trx_1.groups.keys(): + # Concatenating groups together + if group_key in all_groups_len: + all_groups_len[group_key] += len(trx_1.groups[group_key]) + else: + all_groups_len[group_key] = len(trx_1.groups[group_key]) + if (group_key in all_groups_dtype and \ + trx_1.groups[group_key].dtype != all_groups_dtype[group_key] + ): + raise ValueError("Shared group key, has different dtype.") + else: + all_groups_dtype[group_key] = trx_1.groups[group_key].dtype + + # Once the checks are done, actually concatenate + to_concat_list = trx_list[1:] if preallocation else trx_list + if not preallocation: + nb_vertices = 0 + nb_streamlines = 0 + for curr_trx in to_concat_list: + curr_strs_len, curr_pts_len = curr_trx._get_real_len() + nb_streamlines += curr_strs_len + nb_vertices += curr_pts_len + + new_trx = tmm.TrxFile( + nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, + init_as=ref_trx + ) + if delete_dps: + new_trx.data_per_streamline = {} + if delete_dpv: + new_trx.data_per_vertex = {} + if delete_groups: + new_trx.groups = {} + + tmp_dir = new_trx._uncompressed_folder_handle.name + + # When memory is allocated on the spot, groups and data_per_group can + # be concatenated together + for group_key in all_groups_len.keys(): + if not os.path.isdir(os.path.join(tmp_dir, "groups/")): + os.mkdir(os.path.join(tmp_dir, "groups/")) + dtype = all_groups_dtype[group_key] + group_filename = os.path.join( + tmp_dir, "groups/" "{}.{}".format(group_key, dtype.name) + ) + group_len = all_groups_len[group_key] + new_trx.groups[group_key] = tmm._create_memmap( + group_filename, mode="w+", shape=(group_len,), dtype=dtype + ) + if delete_groups: + continue + pos = 0 + count = 0 + for curr_trx in trx_list: + curr_len = len(curr_trx.groups[group_key]) + new_trx.groups[group_key][pos: pos + curr_len] = \ + curr_trx.groups[group_key] + count + pos += curr_len + count += curr_trx.header["NB_STREAMLINES"] + + strs_end, pts_end = 0, 0 + else: + new_trx = ref_trx + strs_end, pts_end = new_trx._get_real_len() + + for curr_trx in to_concat_list: + # Copy the TrxFile fixed-size info (the right chunk) + strs_end, pts_end = new_trx._copy_fixed_arrays_from( + curr_trx, strs_start=strs_end, pts_start=pts_end + ) + return new_trx + diff --git a/dipy/io/tests/test_streamline.py b/dipy/io/tests/test_streamline.py index 7fee97d6b7..1f9c5bec15 100644 --- a/dipy/io/tests/test_streamline.py +++ b/dipy/io/tests/test_streamline.py @@ -2,9 +2,11 @@ import os from tempfile import TemporaryDirectory +import trx.trx_file_memmap as tmm + from dipy.data import fetch_gold_standard_io from dipy.io.streamline import (load_tractogram, save_tractogram, - load_trk, save_trk) + load_trk, save_trk, concatenate_tractogram) from dipy.io.stateful_tractogram import Space, StatefulTractogram from dipy.io.utils import create_nifti_header from dipy.io.vtk import save_vtk_streamlines, load_vtk_streamlines @@ -265,3 +267,13 @@ def test_io_trk_save(): msg='trk_saver should not be able to save a fib') npt.assert_(not trk_saver(filepath_dix['gs.dpy']), msg='trk_saver should not be able to save a dpy') + + +def test_concatenate(): + sft = load_tractogram(filepath_dix['gs.trk'], filepath_dix['gs.nii']) + trx = tmm.load(filepath_dix['gs.trx']) + concat = concatenate_tractogram([sft, trx]) + + assert len(concat) == 2 * len(trx) + trx.close() + concat.close() From 8568045532b7e9ec8302f099f29f5a668e6aef07 Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Thu, 4 Jan 2024 15:24:32 -0500 Subject: [PATCH 2/3] [NF] add concatenate workflows --- .gitignore | 1 + dipy/data/fetcher.py | 13 +++++ dipy/io/streamline.py | 13 +++-- dipy/workflows/cli.py | 2 + dipy/workflows/docstring_parser.py | 4 +- dipy/workflows/io.py | 94 ++++++++++++++++++++++++++++-- dipy/workflows/tests/test_io.py | 25 +++++++- pyproject.toml | 1 + 8 files changed, 138 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index 813a86c973..f9df95100e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ *.pkl *.orig *.trk +*.trx *.pam5 build *~ diff --git a/dipy/data/fetcher.py b/dipy/data/fetcher.py index 7d27d2e49a..1c2cf7f121 100644 --- a/dipy/data/fetcher.py +++ b/dipy/data/fetcher.py @@ -932,6 +932,19 @@ def get_fnames(name='small_64D'): seed_coords_name = pjoin(folder, 'ptt_seed_coords.txt') seed_image_name = pjoin(folder, 'ptt_seed_image.nii') return fod_name, seed_coords_name, seed_image_name + if name == "gold_standard_tracks": + filepath_dix = {} + files, folder = fetch_gold_standard_io() + for filename in files: + filepath_dix[filename] = os.path.join(folder, filename) + + with open(filepath_dix['points_data.json']) as json_file: + points_data = dict(json.load(json_file)) + + with open(filepath_dix['streamlines_data.json']) as json_file: + streamlines_data = dict(json.load(json_file)) + + return filepath_dix, points_data, streamlines_data def read_qtdMRI_test_retest_2subjects(): diff --git a/dipy/io/streamline.py b/dipy/io/streamline.py index a9d32e51b6..7abcd371a9 100644 --- a/dipy/io/streamline.py +++ b/dipy/io/streamline.py @@ -36,7 +36,8 @@ def save_tractogram(sft, filename, bbox_valid_check=True): """ _, extension = os.path.splitext(filename) - if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', '.dpy']: + if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', + '.dpy']: raise TypeError('Output filename is not one of the supported format.') if bbox_valid_check and not sft.is_bbox_in_vox_valid(): @@ -121,7 +122,8 @@ def load_tractogram(filename, reference, to_space=Space.RASMM, The tractogram to load (must have been saved properly) """ _, extension = os.path.splitext(filename) - if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', '.dpy']: + if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', + '.dpy']: logging.error('Output filename is not one of the supported format.') return False @@ -389,9 +391,9 @@ def concatenate_tractogram(tractogram_list, *, delete_dpv=False, all_groups_len[group_key] += len(trx_1.groups[group_key]) else: all_groups_len[group_key] = len(trx_1.groups[group_key]) - if (group_key in all_groups_dtype and \ - trx_1.groups[group_key].dtype != all_groups_dtype[group_key] - ): + if group_key in all_groups_dtype and \ + trx_1.groups[group_key].dtype != \ + all_groups_dtype[group_key]: raise ValueError("Shared group key, has different dtype.") else: all_groups_dtype[group_key] = trx_1.groups[group_key].dtype @@ -454,4 +456,3 @@ def concatenate_tractogram(tractogram_list, *, delete_dpv=False, curr_trx, strs_start=strs_end, pts_start=pts_end ) return new_trx - diff --git a/dipy/workflows/cli.py b/dipy/workflows/cli.py index f2f233f359..183fde5b57 100644 --- a/dipy/workflows/cli.py +++ b/dipy/workflows/cli.py @@ -31,6 +31,8 @@ "dipy_gibbs_ringing": ("dipy.workflows.denoise", "GibbsRingingFlow"), "dipy_horizon": ("dipy.workflows.viz", "HorizonFlow"), "dipy_info": ("dipy.workflows.io", "IoInfoFlow"), + "dipy_concatenate_tracks": ("dipy.workflows.io", + "ConcatenateTractogramFlow"), "dipy_labelsbundles": ("dipy.workflows.segment", "LabelsBundlesFlow"), "dipy_median_otsu": ("dipy.workflows.segment", "MedianOtsuFlow"), "dipy_recobundles": ("dipy.workflows.segment", "RecoBundlesFlow"), diff --git a/dipy/workflows/docstring_parser.py b/dipy/workflows/docstring_parser.py index 22480e7f27..326b557c87 100644 --- a/dipy/workflows/docstring_parser.py +++ b/dipy/workflows/docstring_parser.py @@ -305,7 +305,7 @@ def _parse_summary(self): while True: summary = self._doc.read_to_next_empty_line() summary_str = " ".join([s.strip() for s in summary]).strip() - if re.compile('^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str): + if re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str): self['Signature'] = summary_str if not self._is_at_section(): continue @@ -349,7 +349,7 @@ def _str_indent(self, doc, indent=4): def _str_signature(self): if self['Signature']: - return [self['Signature'].replace('*', '\*')] + [''] + return [self['Signature'].replace('*', r'\*')] + [''] else: return [''] diff --git a/dipy/workflows/io.py b/dipy/workflows/io.py index 3add53d45f..7f717900ee 100644 --- a/dipy/workflows/io.py +++ b/dipy/workflows/io.py @@ -1,11 +1,15 @@ +import importlib import os import sys - -import numpy as np import logging -import importlib from inspect import getmembers, isfunction + +import trx.trx_file_memmap as tmm +import numpy as np + from dipy.io.image import load_nifti, save_nifti +from dipy.io.streamline import (concatenate_tractogram, load_tractogram, + save_tractogram) from dipy.workflows.workflow import Workflow @@ -219,9 +223,9 @@ def run(self, data_names, out_dir=''): else: os.environ.pop('DIPY_HOME', None) - # We load the module again so that if we run another one of these in - # the same process, we don't have the env variable pointing to the - # wrong place + # We load the module again so that if we run another one of these + # in the same process, we don't have the env variable pointing + # to the wrong place self.load_module('dipy.data.fetcher') @@ -257,3 +261,81 @@ def run(self, input_files, vol_idx=0, out_dir='', save_nifti(osplit, split_vol, affine, image.header) logging.info('Split volume saved as {0}'.format(osplit)) + + +class ConcatenateTractogramFlow(Workflow): + @classmethod + def get_short_name(cls): + return 'concatracks' + + def run(self, tractogram_files, reference=None, delete_dpv=False, + delete_dps=False, delete_groups=False, check_space_attributes=True, + preallocation=False, out_dir='', + out_extension='trx', + out_tractogram='concatenated_tractogram'): + """Concatenate multiple tractograms into one. + + Parameters + ---------- + tractogram_list : variable string + The stateful tractogram filenames to concatenate + reference : string, optional + Reference anatomy for tck/vtk/fib/dpy file. + support (.nii or .nii.gz). + delete_dpv : bool, optional + Delete dpv keys that do not exist in all the provided TrxFiles + delete_dps : bool, optional + Delete dps keys that do not exist in all the provided TrxFile + delete_groups : bool, optional + Delete all the groups that currently exist in the TrxFiles + check_space_attributes : bool, optional + Verify that dimensions and size of data are similar between all the + TrxFiles + preallocation : bool, optional + Preallocated TrxFile has already been generated and is the first + element in trx_list (Note: delete_groups must be set to True as + well) + out_dir : string, optional + Output directory. (default current directory) + out_extension : string, optional + Extension of the resulting tractogram + out_tractogram : string, optional + Name of the resulting tractogram + + """ + io_it = self.get_io_iterator() + + trx_list = [] + has_group = False + for fpath, oext, otracks in io_it: + + if fpath.lower().endswith('.trx') or \ + fpath.lower().endswith('.trk'): + reference = 'same' + + if not reference: + raise ValueError("No reference provided. It is needed for tck," + "fib, dpy or vtk files") + + tractogram_obj = load_tractogram(fpath, reference, + bbox_valid_check=False) + + if not isinstance(tractogram_obj, tmm.TrxFile): + tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj) + elif len(tractogram_obj.groups): + has_group = True + trx_list.append(tractogram_obj) + + trx = concatenate_tractogram( + trx_list, delete_dpv=delete_dpv, delete_dps=delete_dps, + delete_groups=delete_groups or not has_group, + check_space_attributes=check_space_attributes, + preallocation=preallocation) + + valid_extensions = ['trk', 'trx', "tck,", "fib", "dpy", "vtk"] + if out_extension.lower() not in valid_extensions: + raise ValueError("Invalid extension. Valid extensions are: " + "{0}".format(valid_extensions)) + + out_fpath = os.path.join(out_dir, f"{out_tractogram}.{out_extension}") + save_tractogram(trx.to_sft(), out_fpath, bbox_valid_check=False) diff --git a/dipy/workflows/tests/test_io.py b/dipy/workflows/tests/test_io.py index 41d48a3be9..6216fbced4 100644 --- a/dipy/workflows/tests/test_io.py +++ b/dipy/workflows/tests/test_io.py @@ -6,9 +6,11 @@ from dipy.data import get_fnames from dipy.io.image import load_nifti +from dipy.io.streamline import load_tractogram from dipy.testing import assert_true from dipy.data.fetcher import dipy_home -from dipy.workflows.io import IoInfoFlow, FetchFlow, SplitFlow +from dipy.workflows.io import (IoInfoFlow, FetchFlow, SplitFlow, + ConcatenateTractogramFlow) fname_log = mkstemp()[1] @@ -91,3 +93,24 @@ def test_split_flow(): split_data, split_affine = load_nifti(split_path) npt.assert_equal(split_data.shape, volume[..., 0].shape) npt.assert_array_almost_equal(split_affine, affine) + + +def test_concatenate_flow(): + with TemporaryDirectory() as out_dir: + concatenate_flow = ConcatenateTractogramFlow() + data_path, _, _ = get_fnames('gold_standard_tracks') + input_files = [v for k, v in data_path.items() + if k in ['gs.trk', 'gs.tck', 'gs.trx', 'gs.fib'] + ] + concatenate_flow.run(*input_files, out_dir=out_dir) + assert_true( + concatenate_flow.last_generated_outputs['out_extension'].endswith( + 'trx')) + assert_true(os.path.isfile( + concatenate_flow.last_generated_outputs['out_tractogram'] + + ".trx")) + + trk = load_tractogram( + concatenate_flow.last_generated_outputs['out_tractogram'] + ".trx", + 'same') + npt.assert_equal(len(trk), 13) diff --git a/pyproject.toml b/pyproject.toml index d21f0a9783..b0b0517e48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ dipy_split = "dipy.workflows.cli:run" dipy_track = "dipy.workflows.cli:run" dipy_track_pft = "dipy.workflows.cli:run" dipy_slr = "dipy.workflows.cli:run" +dipy_concatenate_tracks = "dipy.workflows.cli:run" [project.optional-dependencies] all = ["dipy[dev,doc,style,test, viz, ml, extra]"] From 43c97d6d5290fa19561e5e54a723171fa3ffd7f2 Mon Sep 17 00:00:00 2001 From: Serge Koudoro Date: Mon, 8 Jan 2024 12:27:19 -0500 Subject: [PATCH 3/3] [RF] rename workflow --- dipy/workflows/cli.py | 4 ++-- dipy/workflows/io.py | 2 +- pyproject.toml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dipy/workflows/cli.py b/dipy/workflows/cli.py index 183fde5b57..b40a4dbd03 100644 --- a/dipy/workflows/cli.py +++ b/dipy/workflows/cli.py @@ -31,8 +31,8 @@ "dipy_gibbs_ringing": ("dipy.workflows.denoise", "GibbsRingingFlow"), "dipy_horizon": ("dipy.workflows.viz", "HorizonFlow"), "dipy_info": ("dipy.workflows.io", "IoInfoFlow"), - "dipy_concatenate_tracks": ("dipy.workflows.io", - "ConcatenateTractogramFlow"), + "dipy_concatenate_tractograms": ("dipy.workflows.io", + "ConcatenateTractogramFlow"), "dipy_labelsbundles": ("dipy.workflows.segment", "LabelsBundlesFlow"), "dipy_median_otsu": ("dipy.workflows.segment", "MedianOtsuFlow"), "dipy_recobundles": ("dipy.workflows.segment", "RecoBundlesFlow"), diff --git a/dipy/workflows/io.py b/dipy/workflows/io.py index 7f717900ee..1325056a79 100644 --- a/dipy/workflows/io.py +++ b/dipy/workflows/io.py @@ -332,7 +332,7 @@ def run(self, tractogram_files, reference=None, delete_dpv=False, check_space_attributes=check_space_attributes, preallocation=preallocation) - valid_extensions = ['trk', 'trx', "tck,", "fib", "dpy", "vtk"] + valid_extensions = ['trk', 'trx', "tck", "fib", "dpy", "vtk"] if out_extension.lower() not in valid_extensions: raise ValueError("Invalid extension. Valid extensions are: " "{0}".format(valid_extensions)) diff --git a/pyproject.toml b/pyproject.toml index b0b0517e48..71f3be8996 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,7 +81,7 @@ dipy_split = "dipy.workflows.cli:run" dipy_track = "dipy.workflows.cli:run" dipy_track_pft = "dipy.workflows.cli:run" dipy_slr = "dipy.workflows.cli:run" -dipy_concatenate_tracks = "dipy.workflows.cli:run" +dipy_concatenate_tractograms = "dipy.workflows.cli:run" [project.optional-dependencies] all = ["dipy[dev,doc,style,test, viz, ml, extra]"]