diff --git a/.gitignore b/.gitignore index 813a86c973..f9df95100e 100644 --- a/.gitignore +++ b/.gitignore @@ -14,6 +14,7 @@ *.pkl *.orig *.trk +*.trx *.pam5 build *~ diff --git a/dipy/data/fetcher.py b/dipy/data/fetcher.py index 7d27d2e49a..1c2cf7f121 100644 --- a/dipy/data/fetcher.py +++ b/dipy/data/fetcher.py @@ -932,6 +932,19 @@ def get_fnames(name='small_64D'): seed_coords_name = pjoin(folder, 'ptt_seed_coords.txt') seed_image_name = pjoin(folder, 'ptt_seed_image.nii') return fod_name, seed_coords_name, seed_image_name + if name == "gold_standard_tracks": + filepath_dix = {} + files, folder = fetch_gold_standard_io() + for filename in files: + filepath_dix[filename] = os.path.join(folder, filename) + + with open(filepath_dix['points_data.json']) as json_file: + points_data = dict(json.load(json_file)) + + with open(filepath_dix['streamlines_data.json']) as json_file: + streamlines_data = dict(json.load(json_file)) + + return filepath_dix, points_data, streamlines_data def read_qtdMRI_test_retest_2subjects(): diff --git a/dipy/io/streamline.py b/dipy/io/streamline.py index b4d284da29..f3919d241b 100644 --- a/dipy/io/streamline.py +++ b/dipy/io/streamline.py @@ -36,7 +36,8 @@ def save_tractogram(sft, filename, bbox_valid_check=True): """ _, extension = os.path.splitext(filename) - if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', '.dpy']: + if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', + '.dpy']: raise TypeError('Output filename is not one of the supported format.') if bbox_valid_check and not sft.is_bbox_in_vox_valid(): @@ -121,7 +122,8 @@ def load_tractogram(filename, reference, to_space=Space.RASMM, The tractogram to load (must have been saved properly) """ _, extension = os.path.splitext(filename) - if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', '.dpy']: + if extension not in ['.trk', '.tck', '.trx', '.vtk', '.vtp', '.fib', + '.dpy']: logging.error('Output filename is not one of the supported format.') return False diff --git a/dipy/utils/meson.build b/dipy/utils/meson.build index dfc1e21dc9..8b7800fd59 100644 --- a/dipy/utils/meson.build +++ b/dipy/utils/meson.build @@ -30,6 +30,7 @@ python_sources = [ 'multiproc.py', 'optpkg.py', 'parallel.py', + 'tractogram.py', 'tripwire.py', 'volume.py', ] diff --git a/dipy/utils/tests/meson.build b/dipy/utils/tests/meson.build index 3662f2d553..fef2aabb87 100644 --- a/dipy/utils/tests/meson.build +++ b/dipy/utils/tests/meson.build @@ -7,6 +7,7 @@ python_sources = [ 'test_omp.py', 'test_optpkg.py', 'test_parallel.py', + 'test_tractogram.py', 'test_tripwire.py', 'test_volume.py', ] diff --git a/dipy/utils/tests/test_tractogram.py b/dipy/utils/tests/test_tractogram.py new file mode 100644 index 0000000000..05f91aea91 --- /dev/null +++ b/dipy/utils/tests/test_tractogram.py @@ -0,0 +1,17 @@ + +import trx.trx_file_memmap as tmm + +from dipy.data import get_fnames +from dipy.io.streamline import load_tractogram +from dipy.utils.tractogram import concatenate_tractogram + + +def test_concatenate(): + filepath_dix, _, _ = get_fnames('gold_standard_tracks') + sft = load_tractogram(filepath_dix['gs.trk'], filepath_dix['gs.nii']) + trx = tmm.load(filepath_dix['gs.trx']) + concat = concatenate_tractogram([sft, trx]) + + assert len(concat) == 2 * len(trx) + trx.close() + concat.close() diff --git a/dipy/utils/tractogram.py b/dipy/utils/tractogram.py new file mode 100644 index 0000000000..8159d3ff65 --- /dev/null +++ b/dipy/utils/tractogram.py @@ -0,0 +1,201 @@ + +"""This module is dedicated to the handling of tractograms.""" +import os +import logging + +import numpy as np +import trx.trx_file_memmap as tmm + + +def concatenate_tractogram(tractogram_list, *, delete_dpv=False, + delete_dps=False, delete_groups=False, + check_space_attributes=True, preallocation=False): + """Concatenate multiple tractograms into one. + + If the data_per_point or data_per_streamline is not the same for all + tractograms, the data must be deleted first. + + Parameters + ---------- + tractogram_list : List[StatefulTractogram or TrxFile] + The stateful tractogram to concatenate + delete_dpv: bool, optional + Delete dpv keys that do not exist in all the provided TrxFiles + delete_dps: bool, optional + Delete dps keys that do not exist in all the provided TrxFile + delete_groups: bool, optional + Delete all the groups that currently exist in the TrxFiles + check_space_attributes: bool, optional + Verify that dimensions and size of data are similar between all the + TrxFiles + preallocation: bool, optional + Preallocated TrxFile has already been generated and is the first + element in trx_list (Note: delete_groups must be set to True as well) + + Returns + ------- + new_trx: TrxFile + TrxFile representing the concatenated data + + """ + trx_list = [] + for sft in tractogram_list: + if not isinstance(sft, tmm.TrxFile): + sft = tmm.TrxFile.from_sft(sft) + elif len(sft.groups): + delete_groups = True + trx_list.append(sft) + + trx_list = [curr_trx for curr_trx in trx_list + if curr_trx.header["NB_STREAMLINES"] > 0] + + if not trx_list: + logging.warning("Inputs of concatenation were empty.") + return tmm.TrxFile() + + if len(trx_list) == 1: + if len(tractogram_list) > 1: + logging.warning("Only 1 valid tractogram returned.") + return trx_list[0] + + ref_trx = trx_list[0] + all_dps = [] + all_dpv = [] + for curr_trx in trx_list: + all_dps.extend(list(curr_trx.data_per_streamline.keys())) + all_dpv.extend(list(curr_trx.data_per_vertex.keys())) + all_dps, all_dpv = set(all_dps), set(all_dpv) + + if check_space_attributes: + for curr_trx in trx_list[1:]: + if not np.allclose(ref_trx.header["VOXEL_TO_RASMM"], + curr_trx.header["VOXEL_TO_RASMM"]) or \ + not np.array_equal( + ref_trx.header["DIMENSIONS"], + curr_trx.header["DIMENSIONS"] + ): + raise ValueError("Wrong space attributes.") + + if preallocation and not delete_groups: + raise ValueError( + "Groups are variables, cannot be handled with " "preallocation" + ) + + # Verifying the validity of fixed-size arrays, coherence between inputs + for curr_trx in trx_list: + for key in all_dpv: + if key not in ref_trx.data_per_vertex.keys() \ + or key not in curr_trx.data_per_vertex.keys(): + if not delete_dpv: + logging.debug( + "{} dpv key does not exist in all TrxFile.".format(key) + ) + raise ValueError( + "TrxFile must be sharing identical dpv " "keys.") + elif ( + ref_trx.data_per_vertex[key]._data.dtype + != curr_trx.data_per_vertex[key]._data.dtype + ): + logging.debug( + "{} dpv key is not declared with the same dtype " + "in all TrxFile.".format(key) + ) + raise ValueError("Shared dpv key, has different dtype.") + + for curr_trx in trx_list: + for key in all_dps: + if key not in ref_trx.data_per_streamline.keys() \ + or key not in curr_trx.data_per_streamline.keys(): + if not delete_dps: + logging.debug( + "{} dps key does not exist in all " "TrxFile.".format( + key) + ) + raise ValueError( + "TrxFile must be sharing identical dps " "keys.") + elif ( + ref_trx.data_per_streamline[key].dtype + != curr_trx.data_per_streamline[key].dtype + ): + logging.debug( + "{} dps key is not declared with the same dtype " + "in all TrxFile.".format(key) + ) + raise ValueError("Shared dps key, has different dtype.") + + all_groups_len = {} + all_groups_dtype = {} + # Variable-size arrays do not have to exist in all TrxFile + if not delete_groups: + for trx_1 in trx_list: + for group_key in trx_1.groups.keys(): + # Concatenating groups together + if group_key in all_groups_len: + all_groups_len[group_key] += len(trx_1.groups[group_key]) + else: + all_groups_len[group_key] = len(trx_1.groups[group_key]) + if group_key in all_groups_dtype and \ + trx_1.groups[group_key].dtype != \ + all_groups_dtype[group_key]: + raise ValueError("Shared group key, has different dtype.") + else: + all_groups_dtype[group_key] = trx_1.groups[group_key].dtype + + # Once the checks are done, actually concatenate + to_concat_list = trx_list[1:] if preallocation else trx_list + if not preallocation: + nb_vertices = 0 + nb_streamlines = 0 + for curr_trx in to_concat_list: + curr_strs_len, curr_pts_len = curr_trx._get_real_len() + nb_streamlines += curr_strs_len + nb_vertices += curr_pts_len + + new_trx = tmm.TrxFile( + nb_vertices=nb_vertices, nb_streamlines=nb_streamlines, + init_as=ref_trx + ) + if delete_dps: + new_trx.data_per_streamline = {} + if delete_dpv: + new_trx.data_per_vertex = {} + if delete_groups: + new_trx.groups = {} + + tmp_dir = new_trx._uncompressed_folder_handle.name + + # When memory is allocated on the spot, groups and data_per_group can + # be concatenated together + for group_key in all_groups_len.keys(): + if not os.path.isdir(os.path.join(tmp_dir, "groups/")): + os.mkdir(os.path.join(tmp_dir, "groups/")) + dtype = all_groups_dtype[group_key] + group_filename = os.path.join( + tmp_dir, "groups/" "{}.{}".format(group_key, dtype.name) + ) + group_len = all_groups_len[group_key] + new_trx.groups[group_key] = tmm._create_memmap( + group_filename, mode="w+", shape=(group_len,), dtype=dtype + ) + if delete_groups: + continue + pos = 0 + count = 0 + for curr_trx in trx_list: + curr_len = len(curr_trx.groups[group_key]) + new_trx.groups[group_key][pos: pos + curr_len] = \ + curr_trx.groups[group_key] + count + pos += curr_len + count += curr_trx.header["NB_STREAMLINES"] + + strs_end, pts_end = 0, 0 + else: + new_trx = ref_trx + strs_end, pts_end = new_trx._get_real_len() + + for curr_trx in to_concat_list: + # Copy the TrxFile fixed-size info (the right chunk) + strs_end, pts_end = new_trx._copy_fixed_arrays_from( + curr_trx, strs_start=strs_end, pts_start=pts_end + ) + return new_trx diff --git a/dipy/workflows/cli.py b/dipy/workflows/cli.py index 79587598f5..83dec894d7 100644 --- a/dipy/workflows/cli.py +++ b/dipy/workflows/cli.py @@ -31,6 +31,8 @@ "dipy_gibbs_ringing": ("dipy.workflows.denoise", "GibbsRingingFlow"), "dipy_horizon": ("dipy.workflows.viz", "HorizonFlow"), "dipy_info": ("dipy.workflows.io", "IoInfoFlow"), + "dipy_concatenate_tractograms": ("dipy.workflows.io", + "ConcatenateTractogramFlow"), "dipy_labelsbundles": ("dipy.workflows.segment", "LabelsBundlesFlow"), "dipy_median_otsu": ("dipy.workflows.segment", "MedianOtsuFlow"), "dipy_recobundles": ("dipy.workflows.segment", "RecoBundlesFlow"), diff --git a/dipy/workflows/docstring_parser.py b/dipy/workflows/docstring_parser.py index 22480e7f27..326b557c87 100644 --- a/dipy/workflows/docstring_parser.py +++ b/dipy/workflows/docstring_parser.py @@ -305,7 +305,7 @@ def _parse_summary(self): while True: summary = self._doc.read_to_next_empty_line() summary_str = " ".join([s.strip() for s in summary]).strip() - if re.compile('^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str): + if re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$').match(summary_str): self['Signature'] = summary_str if not self._is_at_section(): continue @@ -349,7 +349,7 @@ def _str_indent(self, doc, indent=4): def _str_signature(self): if self['Signature']: - return [self['Signature'].replace('*', '\*')] + [''] + return [self['Signature'].replace('*', r'\*')] + [''] else: return [''] diff --git a/dipy/workflows/io.py b/dipy/workflows/io.py index bbfb79e195..ca8427264b 100644 --- a/dipy/workflows/io.py +++ b/dipy/workflows/io.py @@ -1,12 +1,16 @@ +import importlib import os import sys - -import numpy as np import logging -import importlib from inspect import getmembers, isfunction + +import trx.trx_file_memmap as tmm +import numpy as np + from dipy.io.image import load_nifti, save_nifti +from dipy.io.streamline import load_tractogram, save_tractogram from dipy.reconst.shm import convert_sh_descoteaux_tournier +from dipy.utils.tractogram import concatenate_tractogram from dipy.workflows.workflow import Workflow @@ -220,9 +224,9 @@ def run(self, data_names, out_dir=''): else: os.environ.pop('DIPY_HOME', None) - # We load the module again so that if we run another one of these in - # the same process, we don't have the env variable pointing to the - # wrong place + # We load the module again so that if we run another one of these + # in the same process, we don't have the env variable pointing + # to the wrong place self.load_module('dipy.data.fetcher') @@ -260,6 +264,84 @@ def run(self, input_files, vol_idx=0, out_dir='', logging.info('Split volume saved as {0}'.format(osplit)) +class ConcatenateTractogramFlow(Workflow): + @classmethod + def get_short_name(cls): + return 'concatracks' + + def run(self, tractogram_files, reference=None, delete_dpv=False, + delete_dps=False, delete_groups=False, check_space_attributes=True, + preallocation=False, out_dir='', + out_extension='trx', + out_tractogram='concatenated_tractogram'): + """Concatenate multiple tractograms into one. + + Parameters + ---------- + tractogram_list : variable string + The stateful tractogram filenames to concatenate + reference : string, optional + Reference anatomy for tck/vtk/fib/dpy file. + support (.nii or .nii.gz). + delete_dpv : bool, optional + Delete dpv keys that do not exist in all the provided TrxFiles + delete_dps : bool, optional + Delete dps keys that do not exist in all the provided TrxFile + delete_groups : bool, optional + Delete all the groups that currently exist in the TrxFiles + check_space_attributes : bool, optional + Verify that dimensions and size of data are similar between all the + TrxFiles + preallocation : bool, optional + Preallocated TrxFile has already been generated and is the first + element in trx_list (Note: delete_groups must be set to True as + well) + out_dir : string, optional + Output directory. (default current directory) + out_extension : string, optional + Extension of the resulting tractogram + out_tractogram : string, optional + Name of the resulting tractogram + + """ + io_it = self.get_io_iterator() + + trx_list = [] + has_group = False + for fpath, oext, otracks in io_it: + + if fpath.lower().endswith('.trx') or \ + fpath.lower().endswith('.trk'): + reference = 'same' + + if not reference: + raise ValueError("No reference provided. It is needed for tck," + "fib, dpy or vtk files") + + tractogram_obj = load_tractogram(fpath, reference, + bbox_valid_check=False) + + if not isinstance(tractogram_obj, tmm.TrxFile): + tractogram_obj = tmm.TrxFile.from_sft(tractogram_obj) + elif len(tractogram_obj.groups): + has_group = True + trx_list.append(tractogram_obj) + + trx = concatenate_tractogram( + trx_list, delete_dpv=delete_dpv, delete_dps=delete_dps, + delete_groups=delete_groups or not has_group, + check_space_attributes=check_space_attributes, + preallocation=preallocation) + + valid_extensions = ['trk', 'trx', "tck", "fib", "dpy", "vtk"] + if out_extension.lower() not in valid_extensions: + raise ValueError("Invalid extension. Valid extensions are: " + "{0}".format(valid_extensions)) + + out_fpath = os.path.join(out_dir, f"{out_tractogram}.{out_extension}") + save_tractogram(trx.to_sft(), out_fpath, bbox_valid_check=False) + + class ConvertSHFlow(Workflow): @classmethod def get_short_name(cls): diff --git a/dipy/workflows/tests/test_io.py b/dipy/workflows/tests/test_io.py index 12d9a40658..c3c870cbe7 100644 --- a/dipy/workflows/tests/test_io.py +++ b/dipy/workflows/tests/test_io.py @@ -6,11 +6,14 @@ import numpy.testing as npt from dipy.data import get_fnames +from dipy.data.fetcher import dipy_home from dipy.io.image import load_nifti, save_nifti +from dipy.io.streamline import load_tractogram from dipy.testing import assert_true -from dipy.data.fetcher import dipy_home from dipy.reconst.shm import convert_sh_descoteaux_tournier -from dipy.workflows.io import IoInfoFlow, FetchFlow, SplitFlow, ConvertSHFlow +from dipy.workflows.io import (IoInfoFlow, FetchFlow, SplitFlow, + ConcatenateTractogramFlow, ConvertSHFlow) + fname_log = mkstemp()[1] @@ -95,6 +98,27 @@ def test_split_flow(): npt.assert_array_almost_equal(split_affine, affine) +def test_concatenate_flow(): + with TemporaryDirectory() as out_dir: + concatenate_flow = ConcatenateTractogramFlow() + data_path, _, _ = get_fnames('gold_standard_tracks') + input_files = [v for k, v in data_path.items() + if k in ['gs.trk', 'gs.tck', 'gs.trx', 'gs.fib'] + ] + concatenate_flow.run(*input_files, out_dir=out_dir) + assert_true( + concatenate_flow.last_generated_outputs['out_extension'].endswith( + 'trx')) + assert_true(os.path.isfile( + concatenate_flow.last_generated_outputs['out_tractogram'] + + ".trx")) + + trk = load_tractogram( + concatenate_flow.last_generated_outputs['out_tractogram'] + ".trx", + 'same') + npt.assert_equal(len(trk), 13) + + def test_convert_sh_flow(): with TemporaryDirectory() as out_dir: filepath_in = os.path.join(out_dir, 'sh_coeff_img.nii.gz') diff --git a/pyproject.toml b/pyproject.toml index 6a44a0e39a..95fd7849b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ dipy_split = "dipy.workflows.cli:run" dipy_track = "dipy.workflows.cli:run" dipy_track_pft = "dipy.workflows.cli:run" dipy_slr = "dipy.workflows.cli:run" +dipy_concatenate_tractograms = "dipy.workflows.cli:run" [project.optional-dependencies] all = ["dipy[dev,doc,style,test, viz, ml, extra]"]