diff --git a/src/eddymotion/model/dmri_utils.py b/src/eddymotion/model/dmri_utils.py new file mode 100644 index 00000000..e3f2941b --- /dev/null +++ b/src/eddymotion/model/dmri_utils.py @@ -0,0 +1,132 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2024 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY kIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +import numpy as np +from dipy.core.gradients import get_bval_indices +from sklearn.cluster import KMeans + +B0_THRESHOLD = 50 # from dmriprep +SHELL_DIFF_THRES = 20 # 150 in dmriprep + + +def extract_dwi_shell(dwi, bvals, bvecs, bvals_to_extract, tol=SHELL_DIFF_THRES): + """Extract the DWI volumes that are on the given b-value shells. Multiple + shells can be extracted at once by specifying multiple b-values. The + extracted volumes will be in the same order as in the original file. + + Parameters + ---------- + dwi : nib.Nifti1Image + Original DWI multi-shell volume. + bvals : ndarray + b-values in FSL format. + bvecs : ndarray + b-vectors in FSL format. + bvals_to_extract : list of int + List of b-values to extract. + tol : int, optional + Tolerance between the b-values to extract and the actual b-values. + + Returns + ------- + indices : ndarray + Indices of the volumes corresponding to the given ``bvals``. + shell_data : ndarray + Volumes corresponding to the given ``bvals``. + output_bvals : ndarray + Selected b-values (as extracted from ``bvals``). + output_bvecs : ndarray + Selected b-vectors. + """ + + indices = [get_bval_indices(bvals, shell, tol=tol) for shell in bvals_to_extract] + indices = np.unique(np.sort(np.hstack(indices))) + + if len(indices) == 0: + raise ValueError( + f"No DWI volumes found corresponding to the given b-values: {bvals_to_extract}" + ) + + shell_data = dwi.get_fdata()[..., indices] + output_bvals = bvals[indices].astype(int) + output_bvecs = bvecs[indices, :] + + return indices, shell_data, output_bvals, output_bvecs + + +# ToDo +# Long term: use DiffusionGradientTable from dmriprep, normalizing gradients, +# etc. ? +def find_shelling_scheme(bvals, tol=SHELL_DIFF_THRES): + """Find the shelling scheme on the given b-values: extract the b-value + shells as the b-values centroids using k-means clustering. + + Parameters + ---------- + bvals : :obj:`ndarray` + b-values in FSL format. + tol : :obj:`int`, optional + Tolerance between the b-values and the centroids in the average squared + distance sense. + + Returns + ------- + shells : :obj:`ndarray` + b-value shells. + bval_centroids : :obj:`ndarray` + Shell value corresponding to each value in ``bvals``. + """ + + # Use kmeans to find the shelling scheme + for k in range(1, len(np.unique(bvals)) + 1): + kmeans_res = KMeans(n_clusters=k).fit(bvals.reshape(-1, 1)) + # ToDo + # The tolerance is not a very intuitive value, as it has to do with the + # sum of squared distances across all samples to the centroids + # (_inertia) + # Alternatives: + # - We could accept the number of clusters as a parameter and do + # kmeans_res = KMeans(n_clusters=n_clusters) + # Setting that to 3 in the last testing case, where tol = 60 is not + # intuitive would give the expected 6, 1000, 2000 clusters. + # Passes all tests. But maybe not tested corner cases + # We could have both k and tol as optional parameters, set to None by + # default to force the user set one + # - Use get_bval_indices to get the cluster centroids and then + # substitute the values in bvals with the corresponding values + # indices = [get_bval_indices(bvals, shell, tol=tol) for shell in bvals_to_extract] + # result = np.zeros_like(bvals) + # for i, idx in enumerate(indices): + # result[idx] = bvals_to_extract[i] + + if kmeans_res.inertia_ / len(bvals) < tol: + break + else: + raise ValueError(f"bvals parsing failed: no shells found more than {tol} apart") + + # Convert the kclust labels to an array + shells = kmeans_res.cluster_centers_ + bval_centroids = np.zeros(bvals.shape) + for i in range(shells.size): + bval_centroids[kmeans_res.labels_ == i] = shells[i][0] + + return np.sort(np.squeeze(shells, axis=-1)), bval_centroids diff --git a/src/eddymotion/model/tests/__init__.py b/src/eddymotion/model/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/eddymotion/model/tests/test_dmri_utils.py b/src/eddymotion/model/tests/test_dmri_utils.py new file mode 100644 index 00000000..24b37487 --- /dev/null +++ b/src/eddymotion/model/tests/test_dmri_utils.py @@ -0,0 +1,262 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: +# +# Copyright 2024 The NiPreps Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# We support and encourage derived works from this project, please read +# about our expectations at +# +# https://www.nipreps.org/community/licensing/ +# +import nibabel as nib +import numpy as np + +from eddymotion.model.gradient_utils import ( + extract_dwi_shell, + find_shelling_scheme, +) + + +def test_extract_dwi_shell(): + # dMRI volume with 5 gradients + bvals = np.asarray([0, 1980, 12, 990, 2000]) + bval_count = len(bvals) + vols_size = (10, 15, 20) + dwi = np.ones((*vols_size, bval_count)) + bvecs = np.ones((bval_count, 3)) + # Set all i-th gradient dMRI volume data and bvecs values to i + for i in range(bval_count): + dwi[..., i] = i + bvecs[i, :] = i + dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4)) + + bvals_to_extract = [0, 2000] + tol = 15 + + expected_indices = np.asarray([0, 2, 4]) + expected_shell_data = np.stack([i * np.ones(vols_size) for i in expected_indices], axis=-1) + expected_shell_bvals = np.asarray([0, 12, 2000]) + expected_shell_bvecs = np.asarray([[i] * 3 for i in expected_indices]) + + (obtained_indices, obtained_shell_data, obtained_shell_bvals, obtained_shell_bvecs) = ( + extract_dwi_shell(dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol) + ) + + assert np.array_equal(obtained_indices, expected_indices) + assert np.array_equal(obtained_shell_data, expected_shell_data) + assert np.array_equal(obtained_shell_bvals, expected_shell_bvals) + assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs) + + bvals = np.asarray([0, 1010, 12, 990, 2000]) + bval_count = len(bvals) + vols_size = (10, 15, 20) + dwi = np.ones((*vols_size, bval_count)) + bvecs = np.ones((bval_count, 3)) + # Set all i-th gradient dMRI volume data and bvecs values to i + for i in range(bval_count): + dwi[..., i] = i + bvecs[i, :] = i + dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4)) + + bvals_to_extract = [0, 1000] + tol = 20 + + expected_indices = np.asarray([0, 1, 2, 3]) + expected_shell_data = np.stack([i * np.ones(vols_size) for i in expected_indices], axis=-1) + expected_shell_bvals = np.asarray([0, 1010, 12, 990]) + expected_shell_bvecs = np.asarray([[i] * 3 for i in expected_indices]) + + (obtained_indices, obtained_shell_data, obtained_shell_bvals, obtained_shell_bvecs) = ( + extract_dwi_shell(dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol) + ) + + assert np.array_equal(obtained_indices, expected_indices) + assert np.array_equal(obtained_shell_data, expected_shell_data) + assert np.array_equal(obtained_shell_bvals, expected_shell_bvals) + assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs) + + +def test_find_shelling_scheme(): + tol = 20 + bvals = np.asarray([0, 0]) + expected_shells = np.asarray([0]) + expected_bval_centroids = np.asarray([0, 0]) + obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol) + + assert np.array_equal(obtained_shells, expected_shells) + assert np.array_equal(obtained_bval_centroids, expected_bval_centroids) + + bvals = np.asarray( + [ + 5, + 300, + 300, + 300, + 300, + 300, + 305, + 1005, + 995, + 1000, + 1000, + 1005, + 1000, + 1000, + 1005, + 995, + 1000, + 1005, + 5, + 995, + 1000, + 1000, + 995, + 1005, + 995, + 1000, + 995, + 995, + 2005, + 2000, + 2005, + 2005, + 1995, + 2000, + 2005, + 2000, + 1995, + 2005, + 5, + 1995, + 2005, + 1995, + 1995, + 2005, + 2005, + 1995, + 2000, + 2000, + 2000, + 1995, + 2000, + 2000, + 2005, + 2005, + 1995, + 2005, + 2005, + 1990, + 1995, + 1995, + 1995, + 2005, + 2000, + 1990, + 2010, + 5, + ] + ) + expected_shells = np.asarray([5.0, 300.83333333, 999.5, 2000.0]) + expected_bval_centroids = [ + 5.0, + 300.83333333, + 300.83333333, + 300.83333333, + 300.83333333, + 300.83333333, + 300.83333333, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 5.0, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 999.5, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 5.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 2000.0, + 5.0, + ] + obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol) + + # ToDo + # Giving a tolerance of 15 this fails because it finds 5 clusters + assert np.allclose(obtained_shells, expected_shells) + assert np.allclose(obtained_bval_centroids, expected_bval_centroids) + + bvals = np.asarray([0, 1980, 12, 990, 2000]) + expected_shells = np.asarray([6, 990, 1980, 2000]) + expected_bval_centroids = np.asarray([6, 1980, 6, 990, 2000]) + obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol) + + assert np.allclose(obtained_shells, expected_shells) + assert np.allclose(obtained_bval_centroids, expected_bval_centroids) + + bvals = np.asarray([0, 1010, 12, 990, 2000]) + tol = 60 + expected_shells = np.asarray([6, 1000, 2000]) + expected_bval_centroids = np.asarray([6, 1000, 6, 1000, 2000]) + obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol) + + assert np.allclose(obtained_shells, expected_shells) + assert np.allclose(obtained_bval_centroids, expected_bval_centroids)