Skip to content

Commit

Permalink
Merge pull request scilus#918 from EmmaRenauld/unit_tests_dwi
Browse files Browse the repository at this point in the history
Unit tests dwi
  • Loading branch information
arnaudbore authored Feb 27, 2024
2 parents f79ffa3 + 73646bd commit 4c2848c
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 67 deletions.
162 changes: 117 additions & 45 deletions scilpy/dwi/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@
import numpy as np

from scilpy.gradients.bvec_bval_tools import identify_shells, \
round_bvals_to_shell, DEFAULT_B0_THRESHOLD
round_bvals_to_shell, DEFAULT_B0_THRESHOLD, is_normalized_bvecs, \
normalize_bvecs


def apply_bias_field(dwi_data, bias_field_data, mask_data):
"""
ToDo: Explain formula why applying field = dividing?
+ why we need to rescale after?
To apply a bias field (computed beforehands), we need to
1) Divide the dwi by the bias field. This is the correction itself.
See the following references:
https://simpleitk.readthedocs.io/en/master/link_N4BiasFieldCorrection_docs.html
https://mrtrix.readthedocs.io/en/dev/reference/commands/dwibiascorrect.html
2) Rescale the dwi, to ensure that the initial min-max range is kept.
Parameters
----------
dwi_data: np.ndarray
The 4D dwi data.
bias_field_data: np.ndarray
The 3D bias field.
The 3D bias field. Typically comes from ANTS'S N4BiasFieldCorrection.
mask_data: np.ndarray
The mask where to apply the bias field.
Expand All @@ -28,8 +33,7 @@ def apply_bias_field(dwi_data, bias_field_data, mask_data):
The modified 4D dwi_data.
"""
nuc_dwi_data = np.divide(
dwi_data[mask_data],
bias_field_data[mask_data].reshape((len(mask_data[0]), 1)))
dwi_data[mask_data, :], bias_field_data[mask_data][:, None])

rescaled_nuc_data = _rescale_dwi(dwi_data[mask_data], nuc_dwi_data)
dwi_data[mask_data] = rescaled_nuc_data
Expand All @@ -47,7 +51,7 @@ def _rescale_intensity(val, slope, in_max, bc_max):
----------
val: float
Value to be scaled
scale: float
slope: float
Scaling factor to be applied
in_max: float
Max possible value
Expand Down Expand Up @@ -94,12 +98,12 @@ def _rescale_dwi(in_data, bc_data):

chunk = np.arange(0, len(in_data), 100000)
chunk = np.append(chunk, len(in_data))
for i in range(len(chunk)-1):
nz_bc_data = bc_data[chunk[i]:chunk[i+1]]
for i in range(len(chunk) - 1):
nz_bc_data = bc_data[chunk[i]:chunk[i + 1]]
rescale_func = np.vectorize(_rescale_intensity, otypes=[np.float32])

rescaled_data = rescale_func(nz_bc_data, slope, in_max, bc_max)
bc_data[chunk[i]:chunk[i+1]] = rescaled_data
bc_data[chunk[i]:chunk[i + 1]] = rescaled_data

return bc_data

Expand All @@ -119,104 +123,172 @@ def compute_dwi_attenuation(dwi_weights: np.ndarray, b0: np.ndarray):
dwi_attenuation : np.ndarray
Signal attenuation (Diffusion weights normalized by the B0).
"""
# Avoid division by 0. Remember coordinates where the b0 was 0. We will set
# those voxels to 0 in the final result.
zeros_mask = b0 == 0

b0 = b0[..., None] # Easier to work if it is a 4D array.

# Make sure that, in every voxels, weights are lower in the b0. Should
# always be the case, but with the noise we never know!
# Make sure that, in every voxel, weights are lower in the dwi than in the
# b0. Should always be the case, but with the noise we never know!
erroneous_voxels = np.any(dwi_weights > b0, axis=3)
nb_erroneous_voxels = np.sum(erroneous_voxels)
if nb_erroneous_voxels != 0:
logging.info("# of voxels where `dwi_signal > b0` in any direction: "
"{}".format(nb_erroneous_voxels))
"{}. They were set to the b0 value to allow computing "
"signal attenuation."
.format(nb_erroneous_voxels))
dwi_weights = np.minimum(dwi_weights, b0)

# Compute attenuation
b0[zeros_mask] = 1e-10
dwi_attenuation = dwi_weights / b0

# Make sure we didn't divide by 0.
dwi_attenuation[np.logical_not(np.isfinite(dwi_attenuation))] = 0.
dwi_attenuation *= ~zeros_mask[:, :, :, None]

return dwi_attenuation


def detect_volume_outliers(data, bvecs, bvals, std_scale,
def detect_volume_outliers(data, bvals, bvecs, std_scale,
b0_thr=DEFAULT_B0_THRESHOLD):
"""
Detects outliers. Finds the 3 closest angular neighbors of each direction
(per shell) and computes the voxel-wise correlation.
If the angles or correlations to neighbors are below the shell average (by
std_scale x STD) it will flag the volume as a potential outlier.
Parameters
----------
data: np.ndarray
The dwi data.
bvecs: np.ndarray
The bvecs
bvals: np.array
The b-values vector.
4D Input diffusion volume with shape (X, Y, Z, N)
bvals : ndarray
1D bvals array with shape (N,)
bvecs : ndarray
2D bvecs array with shape (N, 3)
std_scale: float
How many deviation from the mean are required to be considered an
outlier.
b0_thr: float
Value below which b-values are considered as b0.
Returns
-------
results_dict: dict
The resulting statistics.
One key per shell (its b-value). For each key, the associated entry is
an array of shape [nb_points, 3] where columns are:
- point_idx: int, the index of the bvector in the input bvecs.
- mean_angle: float, the mean angles of the 3 closest bvecs, in
degree
- mean_correlation: float, the mean correlation of the 3D data
associated to the 3 closest bvecs.
outliers_dict: dict
The resulting outliers.
One key per shell (its b-value). For each key, the associated entry is
a dict {'outliers_angle': list[int],
'outliers_corr': list[int]}
The indices of outliers (indices in the original bvecs).
"""
if not is_normalized_bvecs(bvecs):
logging.warning("Your b-vectors do not seem normalized... Normalizing")
bvecs = normalize_bvecs(bvecs)

results_dict = {}
shells_to_extract = identify_shells(bvals, b0_thr, sort=True)[0]
bvals = round_bvals_to_shell(bvals, shells_to_extract)
for bval in shells_to_extract[shells_to_extract > b0_thr]:
shell_idx = np.where(bvals == bval)[0]
shell = bvecs[shell_idx]

# Requires at least 3 values per shell to find 3 closest values!
# Requires at least 5 values to use argpartition, below.
if len(shell_idx) < 5:
raise NotImplementedError(
"This outlier detection method is only available with at "
"least 5 points per shell. Got {} on shell {}."
.format(len(shell_idx), bval))

shell = bvecs[shell_idx, :] # All bvecs on that shell
results_dict[bval] = np.ones((len(shell), 3)) * -1
for i, vec in enumerate(shell):
if np.linalg.norm(vec) < 0.001:
continue

# Supposing that vectors are normalized, cos(angle) = dot
dot_product = np.clip(np.tensordot(shell, vec, axes=1), -1, 1)
angle = np.arccos(dot_product) * 180 / math.pi
angle[np.isnan(angle)] = 0
idx = np.argpartition(angle, 4).tolist()
angles = np.rad2deg(np.arccos(dot_product))
angles[np.isnan(angles)] = 0

# Managing the symmetry between b-vectors:
# if angle is > 90, it becomes 180 - x
big_angles = angles > 90
angles[big_angles] = 180 - angles[big_angles]

# Using argpartition rather than sort; faster. With kth=4, the 4th
# element is correctly positioned, and smaller elements are
# placed before. Considering that we will then remove the b-vec
# itself (angle 0), we are left with the 3 closest angles in
# idx[0:3] (not necessarily sorted, but ok).
idx = np.argpartition(angles, 4).tolist()
idx.remove(i)

avg_angle = np.average(angle[idx[:3]])
avg_angle = np.average(angles[idx[:3]])

corr = np.corrcoef([data[..., shell_idx[i]].ravel(),
data[..., shell_idx[idx[0]]].ravel(),
data[..., shell_idx[idx[1]]].ravel(),
data[..., shell_idx[idx[2]]].ravel()])
# Corr is a triangular matrix. The interesting line is the first:
# current data vs the 3 others. First value is with itself = 1.
results_dict[bval][i] = [shell_idx[i], avg_angle,
np.average(corr[0, 1:])]

# Computation done. Now verifying if above scale.
# Loop on shells:
logging.info("Analysing, for each bvec, the mean angle of the 3 closest "
"bvecs, and the mean correlation of their associated data.")
outliers_dict = {}
for key in results_dict.keys():
avg_angle = np.round(np.average(results_dict[key][:, 1]), 4)
std_angle = np.round(np.std(results_dict[key][:, 1]), 4)
# Column #1 = The mean_angle for all bvecs
avg_angle = np.average(results_dict[key][:, 1])
std_angle = np.std(results_dict[key][:, 1])

avg_corr = np.round(np.average(results_dict[key][:, 2]), 4)
std_corr = np.round(np.std(results_dict[key][:, 2]), 4)
# Column #2 = The mean_corr for all bvecs
avg_corr = np.average(results_dict[key][:, 2])
std_corr = np.std(results_dict[key][:, 2])

# Only looking if some data are *below* the average - n*std.
outliers_angle = np.argwhere(
results_dict[key][:, 1] < avg_angle - (std_scale * std_angle))
outliers_corr = np.argwhere(
results_dict[key][:, 2] < avg_corr - (std_scale * std_corr))

logging.info('Results for shell {} with {} directions:'
.format(key, len(results_dict[key])))
logging.info('AVG and STD of angles: {} +/- {}'
logging.info('AVG and STD of angles: {:.2f} +/- {:.2f}'
.format(avg_angle, std_angle))
logging.info('AVG and STD of correlations: {} +/- {}'
logging.info('AVG and STD of correlations: {:.4f} +/- {:.4f}'
.format(avg_corr, std_corr))

if len(outliers_angle) or len(outliers_corr):
logging.info('Possible outliers ({} STD below or above average):'
.format(std_scale))
logging.info('Outliers based on angle [position (4D), value]')
for i in outliers_angle:
logging.info(results_dict[key][i, :][0][0:2])
logging.info('Outliers based on correlation [position (4D), ' +
'value]')
for i in outliers_corr:
logging.info(results_dict[key][i, :][0][0::2])
logging.info('Possible outliers ({} STD below average):'
.format(std_scale))
if len(outliers_angle):
logging.info('Outliers based on angle [position (4D), value]')
for i in outliers_angle:
logging.info(" {}".format(results_dict[key][i, 0::2]))
if len(outliers_corr):
logging.info('Outliers based on correlation [position (4D), '
'value]')
for i in outliers_corr:
logging.info(" {}".format(results_dict[key][i, 0::2]))
else:
logging.info('No outliers detected.')

outliers_dict[key] = {
'outliers_angle': results_dict[key][outliers_angle, 0],
'outliers_corr': results_dict[key][outliers_corr, 0]}
logging.debug('Shell with b-value {}'.format(key))
logging.debug("\n" + pprint.pformat(results_dict[key]))
print()

return results_dict, outliers_dict


def compute_residuals(predicted_data, real_data, b0s_mask=None, mask=None):
"""
Expand Down
80 changes: 77 additions & 3 deletions scilpy/dwi/tests/test_operations.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,93 @@
# -*- coding: utf-8 -*-
import numpy as np

from scilpy.dwi.operations import compute_dwi_attenuation, \
detect_volume_outliers, apply_bias_field


def test_apply_bias_field():
pass

# DWI is 1 everywhere, one voxel at 0.
dwi = np.ones((10, 10, 10, 5))
dwi[0, 0, 0, :] = 0
mask = np.ones((10, 10, 10), dtype=bool)

# bias field is 2 everywhere
bias_field = 2 * np.ones((10, 10, 10))

# result should be 1/2 everywhere, one voxel at 0. Rescaled to 0-1.
out_dwi = apply_bias_field(dwi, bias_field, mask)
assert np.max(out_dwi) == 1
assert out_dwi[0, 0, 0, 0] == 0


def test_compute_dwi_attenuation():
pass
fake_b0 = np.ones((10, 10, 10))
fake_dwi = np.ones((10, 10, 10, 4)) * 0.5

# Test 1: attenuation of 0.5 / 1 = 0.5 everywhere
res = compute_dwi_attenuation(fake_dwi, fake_b0)
expected = np.ones((10, 10, 10, 4)) * 0.5
assert np.array_equal(res, expected)

# Test 2: noisy data: one voxel is not attenuated, and has a value > b0 for
# one gradient. Should give attenuation=1.
fake_dwi[2, 2, 2, 2] = 2
expected[2, 2, 2, 2] = 1

# + Test 3: a 0 in the b0. Can divide correctly?
fake_b0[4, 4, 4] = 0
expected[4, 4, 4, :] = 0

res = compute_dwi_attenuation(fake_dwi, fake_b0)
assert np.array_equal(res, expected)


def test_detect_volume_outliers():
pass
# For this test: all 90 or 180 degrees on one shell.
bvals = 1000 * np.ones(5)
bvecs = np.asarray([[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[-1, 0, 0], # inverse of first
[0, -1, 0]]) # inverse of second

# DWI associated with the last bvec is very different. Others are highly
# correlated (but not equal, or the correlation is NaN: one voxel
# different). One voxel different for the first 4 gradients. Random for
# the last.
dwi = np.ones((10, 10, 10, 5))
dwi[0, 0, 0, 0:4] = np.random.rand(4)
dwi[..., -1] = np.random.rand(10, 10, 10)

res, outliers = detect_volume_outliers(dwi, bvals, bvecs, std_scale=1)

# Should get one shell
keys = list(res.keys())
assert len(keys) == 1
assert keys[0] == 1000
res = res[1000]
outliers = outliers[1000]

# Should get a table 5x3.
assert np.array_equal(res.shape, [5, 3])

# First column: index of the bvecs. They should all be managed.
assert np.array_equal(np.sort(res[:, 0]), np.arange(5))

# Second column = Mean angle. The most different should be the 3rd (#2)
# But not an outlier.
assert np.argmax(res[:, 1]) == 2
assert len(outliers['outliers_angle']) == 0

# Thirst column = corr. The most uncorrelated should be the 5th (#4)
# Should also be an outlier with STD 1
assert np.argmin(res[:, 2]) == 4
assert outliers['outliers_corr'][0] == 4


def test_compute_residuals():
# Quite simple. Not testing.
pass


Expand Down
Loading

0 comments on commit 4c2848c

Please sign in to comment.