-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add DWI shell detection utils. Add the corresponding tests.
- Loading branch information
1 parent
356431a
commit f971213
Showing
3 changed files
with
394 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- | ||
# vi: set ft=python sts=4 ts=4 sw=4 et: | ||
# | ||
# Copyright 2024 The NiPreps Developers <[email protected]> | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY kIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# We support and encourage derived works from this project, please read | ||
# about our expectations at | ||
# | ||
# https://www.nipreps.org/community/licensing/ | ||
# | ||
import numpy as np | ||
from dipy.core.gradients import get_bval_indices | ||
from sklearn.cluster import KMeans | ||
|
||
B0_THRESHOLD = 50 # from dmriprep | ||
SHELL_DIFF_THRES = 20 # 150 in dmriprep | ||
|
||
|
||
def extract_dwi_shell(dwi, bvals, bvecs, bvals_to_extract, tol=SHELL_DIFF_THRES): | ||
"""Extract the DWI volumes that are on the given b-value shells. Multiple | ||
shells can be extracted at once by specifying multiple b-values. The | ||
extracted volumes will be in the same order as in the original file. | ||
Parameters | ||
---------- | ||
dwi : nib.Nifti1Image | ||
Original DWI multi-shell volume. | ||
bvals : ndarray | ||
b-values in FSL format. | ||
bvecs : ndarray | ||
b-vectors in FSL format. | ||
bvals_to_extract : list of int | ||
List of b-values to extract. | ||
tol : int, optional | ||
Tolerance between the b-values to extract and the actual b-values. | ||
Returns | ||
------- | ||
indices : ndarray | ||
Indices of the volumes corresponding to the given ``bvals``. | ||
shell_data : ndarray | ||
Volumes corresponding to the given ``bvals``. | ||
output_bvals : ndarray | ||
Selected b-values (as extracted from ``bvals``). | ||
output_bvecs : ndarray | ||
Selected b-vectors. | ||
""" | ||
|
||
indices = [get_bval_indices(bvals, shell, tol=tol) for shell in bvals_to_extract] | ||
indices = np.unique(np.sort(np.hstack(indices))) | ||
|
||
if len(indices) == 0: | ||
raise ValueError( | ||
f"No DWI volumes found corresponding to the given b-values: {bvals_to_extract}" | ||
) | ||
|
||
shell_data = dwi.get_fdata()[..., indices] | ||
output_bvals = bvals[indices].astype(int) | ||
output_bvecs = bvecs[indices, :] | ||
|
||
return indices, shell_data, output_bvals, output_bvecs | ||
|
||
|
||
# ToDo | ||
# Long term: use DiffusionGradientTable from dmriprep, normalizing gradients, | ||
# etc. ? | ||
def find_shelling_scheme(bvals, tol=SHELL_DIFF_THRES): | ||
"""Find the shelling scheme on the given b-values: extract the b-value | ||
shells as the b-values centroids using k-means clustering. | ||
Parameters | ||
---------- | ||
bvals : :obj:`ndarray` | ||
b-values in FSL format. | ||
tol : :obj:`int`, optional | ||
Tolerance between the b-values and the centroids in the average squared | ||
distance sense. | ||
Returns | ||
------- | ||
shells : :obj:`ndarray` | ||
b-value shells. | ||
bval_centroids : :obj:`ndarray` | ||
Shell value corresponding to each value in ``bvals``. | ||
""" | ||
|
||
# Use kmeans to find the shelling scheme | ||
for k in range(1, len(np.unique(bvals)) + 1): | ||
kmeans_res = KMeans(n_clusters=k).fit(bvals.reshape(-1, 1)) | ||
# ToDo | ||
# The tolerance is not a very intuitive value, as it has to do with the | ||
# sum of squared distances across all samples to the centroids | ||
# (_inertia) | ||
# Alternatives: | ||
# - We could accept the number of clusters as a parameter and do | ||
# kmeans_res = KMeans(n_clusters=n_clusters) | ||
# Setting that to 3 in the last testing case, where tol = 60 is not | ||
# intuitive would give the expected 6, 1000, 2000 clusters. | ||
# Passes all tests. But maybe not tested corner cases | ||
# We could have both k and tol as optional parameters, set to None by | ||
# default to force the user set one | ||
# - Use get_bval_indices to get the cluster centroids and then | ||
# substitute the values in bvals with the corresponding values | ||
# indices = [get_bval_indices(bvals, shell, tol=tol) for shell in bvals_to_extract] | ||
# result = np.zeros_like(bvals) | ||
# for i, idx in enumerate(indices): | ||
# result[idx] = bvals_to_extract[i] | ||
|
||
if kmeans_res.inertia_ / len(bvals) < tol: | ||
break | ||
else: | ||
raise ValueError(f"bvals parsing failed: no shells found more than {tol} apart") | ||
|
||
# Convert the kclust labels to an array | ||
shells = kmeans_res.cluster_centers_ | ||
bval_centroids = np.zeros(bvals.shape) | ||
for i in range(shells.size): | ||
bval_centroids[kmeans_res.labels_ == i] = shells[i][0] | ||
|
||
return np.sort(np.squeeze(shells, axis=-1)), bval_centroids |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,262 @@ | ||
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- | ||
# vi: set ft=python sts=4 ts=4 sw=4 et: | ||
# | ||
# Copyright 2024 The NiPreps Developers <[email protected]> | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# We support and encourage derived works from this project, please read | ||
# about our expectations at | ||
# | ||
# https://www.nipreps.org/community/licensing/ | ||
# | ||
import nibabel as nib | ||
import numpy as np | ||
|
||
from eddymotion.model.gradient_utils import ( | ||
extract_dwi_shell, | ||
find_shelling_scheme, | ||
) | ||
|
||
|
||
def test_extract_dwi_shell(): | ||
# dMRI volume with 5 gradients | ||
bvals = np.asarray([0, 1980, 12, 990, 2000]) | ||
bval_count = len(bvals) | ||
vols_size = (10, 15, 20) | ||
dwi = np.ones((*vols_size, bval_count)) | ||
bvecs = np.ones((bval_count, 3)) | ||
# Set all i-th gradient dMRI volume data and bvecs values to i | ||
for i in range(bval_count): | ||
dwi[..., i] = i | ||
bvecs[i, :] = i | ||
dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4)) | ||
|
||
bvals_to_extract = [0, 2000] | ||
tol = 15 | ||
|
||
expected_indices = np.asarray([0, 2, 4]) | ||
expected_shell_data = np.stack([i * np.ones(vols_size) for i in expected_indices], axis=-1) | ||
expected_shell_bvals = np.asarray([0, 12, 2000]) | ||
expected_shell_bvecs = np.asarray([[i] * 3 for i in expected_indices]) | ||
|
||
(obtained_indices, obtained_shell_data, obtained_shell_bvals, obtained_shell_bvecs) = ( | ||
extract_dwi_shell(dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol) | ||
) | ||
|
||
assert np.array_equal(obtained_indices, expected_indices) | ||
assert np.array_equal(obtained_shell_data, expected_shell_data) | ||
assert np.array_equal(obtained_shell_bvals, expected_shell_bvals) | ||
assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs) | ||
|
||
bvals = np.asarray([0, 1010, 12, 990, 2000]) | ||
bval_count = len(bvals) | ||
vols_size = (10, 15, 20) | ||
dwi = np.ones((*vols_size, bval_count)) | ||
bvecs = np.ones((bval_count, 3)) | ||
# Set all i-th gradient dMRI volume data and bvecs values to i | ||
for i in range(bval_count): | ||
dwi[..., i] = i | ||
bvecs[i, :] = i | ||
dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4)) | ||
|
||
bvals_to_extract = [0, 1000] | ||
tol = 20 | ||
|
||
expected_indices = np.asarray([0, 1, 2, 3]) | ||
expected_shell_data = np.stack([i * np.ones(vols_size) for i in expected_indices], axis=-1) | ||
expected_shell_bvals = np.asarray([0, 1010, 12, 990]) | ||
expected_shell_bvecs = np.asarray([[i] * 3 for i in expected_indices]) | ||
|
||
(obtained_indices, obtained_shell_data, obtained_shell_bvals, obtained_shell_bvecs) = ( | ||
extract_dwi_shell(dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol) | ||
) | ||
|
||
assert np.array_equal(obtained_indices, expected_indices) | ||
assert np.array_equal(obtained_shell_data, expected_shell_data) | ||
assert np.array_equal(obtained_shell_bvals, expected_shell_bvals) | ||
assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs) | ||
|
||
|
||
def test_find_shelling_scheme(): | ||
tol = 20 | ||
bvals = np.asarray([0, 0]) | ||
expected_shells = np.asarray([0]) | ||
expected_bval_centroids = np.asarray([0, 0]) | ||
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol) | ||
|
||
assert np.array_equal(obtained_shells, expected_shells) | ||
assert np.array_equal(obtained_bval_centroids, expected_bval_centroids) | ||
|
||
bvals = np.asarray( | ||
[ | ||
5, | ||
300, | ||
300, | ||
300, | ||
300, | ||
300, | ||
305, | ||
1005, | ||
995, | ||
1000, | ||
1000, | ||
1005, | ||
1000, | ||
1000, | ||
1005, | ||
995, | ||
1000, | ||
1005, | ||
5, | ||
995, | ||
1000, | ||
1000, | ||
995, | ||
1005, | ||
995, | ||
1000, | ||
995, | ||
995, | ||
2005, | ||
2000, | ||
2005, | ||
2005, | ||
1995, | ||
2000, | ||
2005, | ||
2000, | ||
1995, | ||
2005, | ||
5, | ||
1995, | ||
2005, | ||
1995, | ||
1995, | ||
2005, | ||
2005, | ||
1995, | ||
2000, | ||
2000, | ||
2000, | ||
1995, | ||
2000, | ||
2000, | ||
2005, | ||
2005, | ||
1995, | ||
2005, | ||
2005, | ||
1990, | ||
1995, | ||
1995, | ||
1995, | ||
2005, | ||
2000, | ||
1990, | ||
2010, | ||
5, | ||
] | ||
) | ||
expected_shells = np.asarray([5.0, 300.83333333, 999.5, 2000.0]) | ||
expected_bval_centroids = [ | ||
5.0, | ||
300.83333333, | ||
300.83333333, | ||
300.83333333, | ||
300.83333333, | ||
300.83333333, | ||
300.83333333, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
5.0, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
999.5, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
5.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
2000.0, | ||
5.0, | ||
] | ||
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol) | ||
|
||
# ToDo | ||
# Giving a tolerance of 15 this fails because it finds 5 clusters | ||
assert np.allclose(obtained_shells, expected_shells) | ||
assert np.allclose(obtained_bval_centroids, expected_bval_centroids) | ||
|
||
bvals = np.asarray([0, 1980, 12, 990, 2000]) | ||
expected_shells = np.asarray([6, 990, 1980, 2000]) | ||
expected_bval_centroids = np.asarray([6, 1980, 6, 990, 2000]) | ||
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol) | ||
|
||
assert np.allclose(obtained_shells, expected_shells) | ||
assert np.allclose(obtained_bval_centroids, expected_bval_centroids) | ||
|
||
bvals = np.asarray([0, 1010, 12, 990, 2000]) | ||
tol = 60 | ||
expected_shells = np.asarray([6, 1000, 2000]) | ||
expected_bval_centroids = np.asarray([6, 1000, 6, 1000, 2000]) | ||
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol) | ||
|
||
assert np.allclose(obtained_shells, expected_shells) | ||
assert np.allclose(obtained_bval_centroids, expected_bval_centroids) |