Skip to content

Commit

Permalink
ENH: Add DWI shell detection utils
Browse files Browse the repository at this point in the history
Add DWI shell detection utils.

Add the corresponding tests.
  • Loading branch information
jhlegarreta committed May 28, 2024
1 parent 356431a commit f971213
Show file tree
Hide file tree
Showing 3 changed files with 394 additions and 0 deletions.
132 changes: 132 additions & 0 deletions src/eddymotion/model/dmri_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2024 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY kIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
import numpy as np
from dipy.core.gradients import get_bval_indices
from sklearn.cluster import KMeans

B0_THRESHOLD = 50 # from dmriprep
SHELL_DIFF_THRES = 20 # 150 in dmriprep


def extract_dwi_shell(dwi, bvals, bvecs, bvals_to_extract, tol=SHELL_DIFF_THRES):
"""Extract the DWI volumes that are on the given b-value shells. Multiple
shells can be extracted at once by specifying multiple b-values. The
extracted volumes will be in the same order as in the original file.
Parameters
----------
dwi : nib.Nifti1Image
Original DWI multi-shell volume.
bvals : ndarray
b-values in FSL format.
bvecs : ndarray
b-vectors in FSL format.
bvals_to_extract : list of int
List of b-values to extract.
tol : int, optional
Tolerance between the b-values to extract and the actual b-values.
Returns
-------
indices : ndarray
Indices of the volumes corresponding to the given ``bvals``.
shell_data : ndarray
Volumes corresponding to the given ``bvals``.
output_bvals : ndarray
Selected b-values (as extracted from ``bvals``).
output_bvecs : ndarray
Selected b-vectors.
"""

indices = [get_bval_indices(bvals, shell, tol=tol) for shell in bvals_to_extract]
indices = np.unique(np.sort(np.hstack(indices)))

if len(indices) == 0:
raise ValueError(
f"No DWI volumes found corresponding to the given b-values: {bvals_to_extract}"
)

shell_data = dwi.get_fdata()[..., indices]
output_bvals = bvals[indices].astype(int)
output_bvecs = bvecs[indices, :]

return indices, shell_data, output_bvals, output_bvecs


# ToDo
# Long term: use DiffusionGradientTable from dmriprep, normalizing gradients,
# etc. ?
def find_shelling_scheme(bvals, tol=SHELL_DIFF_THRES):
"""Find the shelling scheme on the given b-values: extract the b-value
shells as the b-values centroids using k-means clustering.
Parameters
----------
bvals : :obj:`ndarray`
b-values in FSL format.
tol : :obj:`int`, optional
Tolerance between the b-values and the centroids in the average squared
distance sense.
Returns
-------
shells : :obj:`ndarray`
b-value shells.
bval_centroids : :obj:`ndarray`
Shell value corresponding to each value in ``bvals``.
"""

# Use kmeans to find the shelling scheme
for k in range(1, len(np.unique(bvals)) + 1):
kmeans_res = KMeans(n_clusters=k).fit(bvals.reshape(-1, 1))
# ToDo
# The tolerance is not a very intuitive value, as it has to do with the
# sum of squared distances across all samples to the centroids
# (_inertia)
# Alternatives:
# - We could accept the number of clusters as a parameter and do
# kmeans_res = KMeans(n_clusters=n_clusters)
# Setting that to 3 in the last testing case, where tol = 60 is not
# intuitive would give the expected 6, 1000, 2000 clusters.
# Passes all tests. But maybe not tested corner cases
# We could have both k and tol as optional parameters, set to None by
# default to force the user set one
# - Use get_bval_indices to get the cluster centroids and then
# substitute the values in bvals with the corresponding values
# indices = [get_bval_indices(bvals, shell, tol=tol) for shell in bvals_to_extract]
# result = np.zeros_like(bvals)
# for i, idx in enumerate(indices):
# result[idx] = bvals_to_extract[i]

if kmeans_res.inertia_ / len(bvals) < tol:
break
else:
raise ValueError(f"bvals parsing failed: no shells found more than {tol} apart")

# Convert the kclust labels to an array
shells = kmeans_res.cluster_centers_
bval_centroids = np.zeros(bvals.shape)
for i in range(shells.size):
bval_centroids[kmeans_res.labels_ == i] = shells[i][0]

return np.sort(np.squeeze(shells, axis=-1)), bval_centroids
Empty file.
262 changes: 262 additions & 0 deletions src/eddymotion/model/tests/test_dmri_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
#
# Copyright 2024 The NiPreps Developers <[email protected]>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# We support and encourage derived works from this project, please read
# about our expectations at
#
# https://www.nipreps.org/community/licensing/
#
import nibabel as nib
import numpy as np

from eddymotion.model.gradient_utils import (
extract_dwi_shell,
find_shelling_scheme,
)


def test_extract_dwi_shell():
# dMRI volume with 5 gradients
bvals = np.asarray([0, 1980, 12, 990, 2000])
bval_count = len(bvals)
vols_size = (10, 15, 20)
dwi = np.ones((*vols_size, bval_count))
bvecs = np.ones((bval_count, 3))
# Set all i-th gradient dMRI volume data and bvecs values to i
for i in range(bval_count):
dwi[..., i] = i
bvecs[i, :] = i
dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4))

bvals_to_extract = [0, 2000]
tol = 15

expected_indices = np.asarray([0, 2, 4])
expected_shell_data = np.stack([i * np.ones(vols_size) for i in expected_indices], axis=-1)
expected_shell_bvals = np.asarray([0, 12, 2000])
expected_shell_bvecs = np.asarray([[i] * 3 for i in expected_indices])

(obtained_indices, obtained_shell_data, obtained_shell_bvals, obtained_shell_bvecs) = (
extract_dwi_shell(dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol)
)

assert np.array_equal(obtained_indices, expected_indices)
assert np.array_equal(obtained_shell_data, expected_shell_data)
assert np.array_equal(obtained_shell_bvals, expected_shell_bvals)
assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs)

bvals = np.asarray([0, 1010, 12, 990, 2000])
bval_count = len(bvals)
vols_size = (10, 15, 20)
dwi = np.ones((*vols_size, bval_count))
bvecs = np.ones((bval_count, 3))
# Set all i-th gradient dMRI volume data and bvecs values to i
for i in range(bval_count):
dwi[..., i] = i
bvecs[i, :] = i
dwi_img = nib.Nifti1Image(dwi, affine=np.eye(4))

bvals_to_extract = [0, 1000]
tol = 20

expected_indices = np.asarray([0, 1, 2, 3])
expected_shell_data = np.stack([i * np.ones(vols_size) for i in expected_indices], axis=-1)
expected_shell_bvals = np.asarray([0, 1010, 12, 990])
expected_shell_bvecs = np.asarray([[i] * 3 for i in expected_indices])

(obtained_indices, obtained_shell_data, obtained_shell_bvals, obtained_shell_bvecs) = (
extract_dwi_shell(dwi_img, bvals, bvecs, bvals_to_extract=bvals_to_extract, tol=tol)
)

assert np.array_equal(obtained_indices, expected_indices)
assert np.array_equal(obtained_shell_data, expected_shell_data)
assert np.array_equal(obtained_shell_bvals, expected_shell_bvals)
assert np.array_equal(obtained_shell_bvecs, expected_shell_bvecs)


def test_find_shelling_scheme():
tol = 20
bvals = np.asarray([0, 0])
expected_shells = np.asarray([0])
expected_bval_centroids = np.asarray([0, 0])
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol)

assert np.array_equal(obtained_shells, expected_shells)
assert np.array_equal(obtained_bval_centroids, expected_bval_centroids)

bvals = np.asarray(
[
5,
300,
300,
300,
300,
300,
305,
1005,
995,
1000,
1000,
1005,
1000,
1000,
1005,
995,
1000,
1005,
5,
995,
1000,
1000,
995,
1005,
995,
1000,
995,
995,
2005,
2000,
2005,
2005,
1995,
2000,
2005,
2000,
1995,
2005,
5,
1995,
2005,
1995,
1995,
2005,
2005,
1995,
2000,
2000,
2000,
1995,
2000,
2000,
2005,
2005,
1995,
2005,
2005,
1990,
1995,
1995,
1995,
2005,
2000,
1990,
2010,
5,
]
)
expected_shells = np.asarray([5.0, 300.83333333, 999.5, 2000.0])
expected_bval_centroids = [
5.0,
300.83333333,
300.83333333,
300.83333333,
300.83333333,
300.83333333,
300.83333333,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
5.0,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
999.5,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
5.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
2000.0,
5.0,
]
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol)

# ToDo
# Giving a tolerance of 15 this fails because it finds 5 clusters
assert np.allclose(obtained_shells, expected_shells)
assert np.allclose(obtained_bval_centroids, expected_bval_centroids)

bvals = np.asarray([0, 1980, 12, 990, 2000])
expected_shells = np.asarray([6, 990, 1980, 2000])
expected_bval_centroids = np.asarray([6, 1980, 6, 990, 2000])
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol=tol)

assert np.allclose(obtained_shells, expected_shells)
assert np.allclose(obtained_bval_centroids, expected_bval_centroids)

bvals = np.asarray([0, 1010, 12, 990, 2000])
tol = 60
expected_shells = np.asarray([6, 1000, 2000])
expected_bval_centroids = np.asarray([6, 1000, 6, 1000, 2000])
obtained_shells, obtained_bval_centroids = find_shelling_scheme(bvals, tol)

assert np.allclose(obtained_shells, expected_shells)
assert np.allclose(obtained_bval_centroids, expected_bval_centroids)

0 comments on commit f971213

Please sign in to comment.