Skip to content

Commit

Permalink
Merge pull request #885 from EmmaRenauld/test_gradients
Browse files Browse the repository at this point in the history
Unit tests: gradients.bvec_bval_tools
  • Loading branch information
arnaudbore authored Feb 14, 2024
2 parents 1e9f814 + f0bc75a commit 6872d0d
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 131 deletions.
130 changes: 32 additions & 98 deletions scilpy/gradients/bvec_bval_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from dipy.core.gradients import get_bval_indices
import numpy as np

from scilpy.io.gradients import (save_gradient_sampling_fsl,
save_gradient_sampling_mrtrix)

DEFAULT_B0_THRESHOLD = 20


Expand All @@ -31,37 +28,30 @@ def is_normalized_bvecs(bvecs):
-------
True/False
"""

bvecs_norm = np.linalg.norm(bvecs, axis=1)
return np.all(np.logical_or(np.abs(bvecs_norm - 1) < 1e-3,
bvecs_norm == 0))


def normalize_bvecs(bvecs, filename=None):
def normalize_bvecs(bvecs):
"""
Normalize b-vectors
Parameters
----------
bvecs : (N, 3) array
input b-vectors (N, 3) array
filename : string
output filename where to save the normalized bvecs
Returns
-------
bvecs : (N, 3)
normalized b-vectors
"""

bvecs = bvecs.copy() # Avoid in-place modification.
bvecs_norm = np.linalg.norm(bvecs, axis=1)
idx = bvecs_norm != 0
bvecs[idx] /= bvecs_norm[idx, None]

if filename is not None:
logging.info('Saving new bvecs: {}'.format(filename))
np.savetxt(filename, np.transpose(bvecs), "%.8f")

return bvecs


Expand Down Expand Up @@ -120,76 +110,7 @@ def check_b0_threshold(
return b0_thr


def fsl2mrtrix(fsl_bval_filename, fsl_bvec_filename, mrtrix_filename):
"""
Convert a fsl dir_grad.bvec/.bval files to mrtrix encoding.b file.
Parameters
----------
fsl_bval_filename: str
path to input fsl bval file.
fsl_bvec_filename: str
path to input fsl bvec file.
mrtrix_filename : str
path to output mrtrix encoding.b file.
Returns
-------
"""

shells = np.loadtxt(fsl_bval_filename)
points = np.loadtxt(fsl_bvec_filename)
bvals = np.unique(shells).tolist()

# Remove .bval and .bvec if present
mrtrix_filename = mrtrix_filename.replace('.b', '')

if not points.shape[0] == 3:
points = points.transpose()
logging.warning('WARNING: Your bvecs seem transposed. ' +
'Transposing them.')

shell_idx = [int(np.where(bval == bvals)[0]) for bval in shells]
save_gradient_sampling_mrtrix(points,
shell_idx,
bvals,
mrtrix_filename + '.b')


def mrtrix2fsl(mrtrix_filename, fsl_filename):
"""
Convert a mrtrix encoding.b file to fsl dir_grad.bvec/.bval files.
Parameters
----------
mrtrix_filename : str
path to mrtrix encoding.b file.
fsl_bval_filename: str
path to the output fsl files. Files will be named
fsl_bval_filename.bval and fsl_bval_filename.bvec.
"""
# Remove .bval and .bvec if present
fsl_filename = fsl_filename.replace('.bval', '')
fsl_filename = fsl_filename.replace('.bvec', '')

mrtrix_b = np.loadtxt(mrtrix_filename)
if not len(mrtrix_b.shape) == 2 or not mrtrix_b.shape[1] == 4:
raise ValueError('mrtrix file must have 4 columns')

points = np.array([mrtrix_b[:, 0], mrtrix_b[:, 1], mrtrix_b[:, 2]])
shells = np.array(mrtrix_b[:, 3])

bvals = np.unique(shells).tolist()
shell_idx = [int(np.where(bval == bvals)[0]) for bval in shells]

save_gradient_sampling_fsl(points,
shell_idx,
bvals,
filename_bval=fsl_filename + '.bval',
filename_bvec=fsl_filename + '.bvec')


def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False):
def identify_shells(bvals, tol=40.0, round_centroids=False, sort=False):
"""
Guessing the shells from the b-values. Returns the list of shells and, for
each b-value, the associated shell.
Expand All @@ -206,10 +127,11 @@ def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False):
----------
bvals: array (N,)
Array of bvals
threshold: float
Limit value to consider that a b-value is on an existing shell. Above
this limit, the b-value is placed on a new shell.
roundCentroids: bool
tol: float
Limit difference to centroid to consider that a b-value is on an
existing shell. On or above this limit, the b-value is placed on a new
shell.
round_centroids: bool
If true will round shell values to the nearest 10.
sort: bool
Sort centroids and shell_indices associated.
Expand All @@ -229,7 +151,7 @@ def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False):
bval_centroids = [bvals[0]]
for bval in bvals[1:]:
diffs = np.abs(np.asarray(bval_centroids, dtype=float) - bval)
if not len(np.where(diffs < threshold)[0]):
if not len(np.where(diffs < tol)[0]):
# Found no bval in bval centroids close enough to the current one.
# Create new centroid (i.e. new shell)
bval_centroids.append(bval)
Expand All @@ -241,15 +163,23 @@ def identify_shells(bvals, threshold=40.0, roundCentroids=False, sort=False):

shell_indices = np.argmin(np.abs(bvals_for_diffs - centroids), axis=1)

if roundCentroids:
if round_centroids:
centroids = np.round(centroids, decimals=-1)

# Ex: with bvals [0, 5], threshold 5, we get centroids 0, 5.
# Rounded, we get centroids 0, 0.
if len(np.unique(centroids)) != len(centroids):
logging.warning("With option to round the centroids to the "
"nearest 10, with tolerance {}, we get unclear "
"division of the shells. Use this data carefully."
.format(tol))

if sort:
sort_index = np.argsort(centroids)
sorted_centroids = np.zeros(centroids.shape)
sorted_indices = np.zeros(shell_indices.shape)
sorted_centroids = centroids[sort_index]

sorted_indices = np.zeros(shell_indices.shape, dtype=int)
for i in range(len(centroids)):
sorted_centroids[i] = centroids[sort_index[i]]
sorted_indices[shell_indices == i] = sort_index[i]
return sorted_centroids, sorted_indices

Expand Down Expand Up @@ -286,19 +216,22 @@ def flip_gradient_sampling(bvecs, axes, sampling_type):
Parameters
----------
bvecs: np.ndarray
Loaded bvecs. In the case 'mrtrix' the bvecs actually also contain the
bvals.
bvecs loaded directly, not re-formatted. Careful! Must respect the
format (not verified here).
axes: list of int
List of axes to flip (e.g. [0, 1])
List of axes to flip (e.g. [0, 1]). See str_to_axis_index.
sampling_type: str
Either 'mrtrix' or 'fsl'.
Either 'mrtrix': bvecs are of shape (N, 4) or
'fsl': bvecs are of shape (3, N)
Returns
-------
bvecs: np.array
The final bvecs.
"""
assert sampling_type in ['mrtrix', 'fsl']

bvecs = bvecs.copy() # Avoid in-place modification.
if sampling_type == 'mrtrix':
for axis in axes:
bvecs[:, axis] *= -1
Expand All @@ -315,12 +248,13 @@ def swap_gradient_axis(bvecs, final_order, sampling_type):
Parameters
----------
bvecs: np.array
Loaded bvecs. In the case 'mrtrix' the bvecs actually also contain the
bvals.
bvecs loaded directly, not re-formatted. Careful! Must respect the
format (not verified here).
final_order: new order
Final order (ex, 2 1 0)
sampling_type: str
Either 'mrtrix' or 'fsl'.
Either 'mrtrix': bvecs are of shape (N, 4) or
'fsl': bvecs are of shape (3, N)
Returns
-------
Expand Down
87 changes: 64 additions & 23 deletions scilpy/gradients/tests/test_bvec_bval_tools.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,93 @@
# -*- coding: utf-8 -*-
import numpy as np

from scilpy.gradients.bvec_bval_tools import round_bvals_to_shell
from scilpy.gradients.bvec_bval_tools import (
identify_shells, is_normalized_bvecs, flip_gradient_sampling,
normalize_bvecs, round_bvals_to_shell, str_to_axis_index,
swap_gradient_axis)

bvecs = np.asarray([[1.0, 1.0, 1.0],
[1.0, 0.0, 1.0],
[0.0, 1.0, 0.0],
[8.0, 1.0, 1.0]])


def test_is_normalized_bvecs():
# toDO
pass
assert not is_normalized_bvecs(bvecs)
assert is_normalized_bvecs(
bvecs / np.linalg.norm(bvecs, axis=1, keepdims=True))


def test_normalize_bvecs():
# toDo
pass
assert is_normalized_bvecs(normalize_bvecs(bvecs))


def test_check_b0_threshold():
# toDo
pass


def test_fsl2mrtrix():
# toDo
pass


def test_mrtrix2fsl():
# toDo
# toDo To be modified (see PR#867).
pass


def test_identify_shells():
# toDo
pass
def _subtest_identify_shells(bvals, threshold,
expected_raw_centroids, expected_raw_shells,
expected_round_sorted_centroids,
expected_round_sorted_shells):
bvals = np.asarray(bvals)

# 1) Not rounded, not sorted
c, s = identify_shells(bvals, threshold)
assert np.array_equal(c, expected_raw_centroids)
assert np.array_equal(s, expected_raw_shells)

# 2) Rounded, sorted
c, s = identify_shells(bvals, threshold, round_centroids=True,
sort=True)
assert np.array_equal(c, expected_round_sorted_centroids)
assert np.array_equal(s, expected_round_sorted_shells)

# Test 1. All easy. Over the limit for 0, 5, 15. Clear difference for
# 100, 2000.
_subtest_identify_shells(bvals=[0, 0, 5, 15, 2000, 100], threshold=50,
expected_raw_centroids=[0, 2000, 100],
expected_raw_shells=[0, 0, 0, 0, 1, 2],
expected_round_sorted_centroids=[0, 100, 2000],
expected_round_sorted_shells=[0, 0, 0, 0, 2, 1])

# Test 2. Threshold on the limit.
# Additional difficulty with option rounded: two shells with the same
# value, but a warning is printed. Should it raise an error?
_subtest_identify_shells(bvals=[0, 0, 5, 2000, 100], threshold=5,
expected_raw_centroids=[0, 5, 2000, 100],
expected_raw_shells=[0, 0, 1, 2, 3],
expected_round_sorted_centroids=[0, 0, 100, 2000],
expected_round_sorted_shells=[0, 0, 1, 3, 2])


def test_str_to_axis_index():
# Very simple, nothing to do
pass
assert str_to_axis_index('x') == 0
assert str_to_axis_index('y') == 1
assert str_to_axis_index('z') == 2
assert str_to_axis_index('v') is None


def test_flip_gradient_sampling():
# toDo
pass
fsl_bvecs = bvecs.T
b = flip_gradient_sampling(fsl_bvecs, axes=[0], sampling_type='fsl')
assert np.array_equal(b, np.asarray([[-1.0, 1.0, 1.0],
[-1.0, 0.0, 1.0],
[-0.0, 1.0, 0.0],
[-8.0, 1.0, 1.0]]).T)


def test_swap_gradient_axis():
# toDo
pass
fsl_bvecs = bvecs.T
final_order = [1, 0, 2]
b = swap_gradient_axis(fsl_bvecs, final_order, sampling_type='fsl')
assert np.array_equal(b, np.asarray([[1.0, 1.0, 1.0],
[0.0, 1.0, 1.0],
[1.0, 0.0, 0.0],
[1.0, 8.0, 1.0]]).T)


def test_round_bvals_to_shell():
Expand Down
4 changes: 2 additions & 2 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ def compute_snr(dwi, bval, bvec, b0_thr, mask,
mask = get_data_as_mask(mask, dtype=bool)

if split_shells:
centroids, shell_indices = identify_shells(bval, threshold=40.0,
roundCentroids=False,
centroids, shell_indices = identify_shells(bval, tol=40.0,
round_centroids=False,
sort=False)
bval = centroids[shell_indices]

Expand Down
Loading

0 comments on commit 6872d0d

Please sign in to comment.