Skip to content

Commit

Permalink
Merge pull request #982 from EmmaRenauld/easy_tests_reconst
Browse files Browse the repository at this point in the history
Easy tests reconst
  • Loading branch information
arnaudbore authored Apr 26, 2024
2 parents 9b79ea4 + 3c87ad3 commit 35b40e1
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 77 deletions.
60 changes: 35 additions & 25 deletions scilpy/reconst/fodf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,23 @@
cvx, have_cvxpy, _ = optional_package("cvxpy")


def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
fa_threshold, md_threshold, is_legacy=True):
def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis,
fa_threshold, md_threshold,
small_dims=False, is_legacy=True):
"""
Compute mean maximal fodf value in ventricules. Given heuristics thresholds
on FA and MD values, finds the voxels of the ventricules or CSF and
computes a mean fODF value. This is described in
Dell'Acqua et al. HBM 2013.
Ventricles are searched in a window in the middle of the data to increase
speed. No need to scan the whole image.
Parameters
----------
data: ndarray (x, y, z, ncoeffs)
Input fODF file in spherical harmonics coefficients.
Input fODF file in spherical harmonics coefficients. Uses sphere
'repulsion100' to convert to SF values.
fa: ndarray (x, y, z)
FA (Fractional Anisotropy) volume from DTI
md: ndarray (x, y, z)
Expand All @@ -38,15 +43,16 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
Either 'tournier07' or 'descoteaux07'
small_dims: bool
If set, takes the full range of data to search the max fodf amplitude
in ventricles. Useful when the data has small dimensions.
in ventricles, rather than a window center in the data. Useful when the
data has small dimensions.
fa_threshold: float
Maximal threshold of FA (voxels under that threshold are considered
for evaluation).
for evaluation). Suggested value: 0.1.
md_threshold: float
Minimal threshold of MD in mm2/s (voxels above that threshold are
considered for evaluation).
considered for evaluation). Suggested value: 0.003.
is_legacy : bool, optional
Whether or not the SH basis is in its legacy form.
Whether the SH basis is in its legacy form.
Returns
-------
Expand All @@ -57,26 +63,17 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
order = find_order_from_nb_coeff(data)
sphere = get_sphere('repulsion100')
b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=is_legacy)
sum_of_max = 0
count = 0

mask = np.zeros(data.shape[:-1])

if np.min(data.shape[:-1]) > 40:
step = 20
else:
if np.min(data.shape[:-1]) > 20:
step = 10
else:
step = 5

# 1000 works well at 2x2x2 = 8 mm3
# Hence, we multiply by the volume of a voxel
vol = (zoom[0] * zoom[1] * zoom[2])
if vol != 0:
max_number_of_voxels = 1000 * 8 // vol
else:
max_number_of_voxels = 1000
logging.debug("Searching for ventricle voxels, up to a maximum of {} "
"voxels.".format(max_number_of_voxels))

# In the case of 2D-like data (3D data with one dimension size of 1), or
# a small 3D dataset, the full range of data is scanned.
Expand All @@ -85,14 +82,27 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, small_dims,
all_j = list(range(0, data.shape[1]))
all_k = list(range(0, data.shape[2]))
# In the case of a normal 3D dataset, a window is created in the middle of
# the image to capture the ventricules. No need to scan the whole image.
# the image to capture the ventricles. No need to scan the whole image.
# (Automatic definition of window's radius based on the shape of the data.)
else:
all_i = list(range(int(data.shape[0]/2) - step,
int(data.shape[0]/2) + step))
all_j = list(range(int(data.shape[1]/2) - step,
int(data.shape[1]/2) + step))
all_k = list(range(int(data.shape[2]/2) - step,
int(data.shape[2]/2) + step))
if np.min(data.shape[:-1]) > 40:
radius = 20
else:
if np.min(data.shape[:-1]) > 20:
radius = 10
else:
radius = 5

all_i = list(range(int(data.shape[0]/2) - radius,
int(data.shape[0]/2) + radius))
all_j = list(range(int(data.shape[1]/2) - radius,
int(data.shape[1]/2) + radius))
all_k = list(range(int(data.shape[2]/2) - radius,
int(data.shape[2]/2) + radius))

# Ok. Now find ventricle voxels.
sum_of_max = 0
count = 0
for i in all_i:
for j in all_j:
for k in all_k:
Expand Down
52 changes: 34 additions & 18 deletions scilpy/reconst/frf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
def compute_ssst_frf(data, bvals, bvecs, b0_threshold=DEFAULT_B0_THRESHOLD,
mask=None, mask_wm=None, fa_thresh=0.7, min_fa_thresh=0.5,
min_nvox=300, roi_radii=10, roi_center=None):
"""Compute a single-shell (under b=1500), single-tissue single Fiber
Response Function from a DWI volume.
A DTI fit is made, and voxels containing a single fiber population are
found using a threshold on the FA.
"""
Computes a single-shell (under b=1500), single-tissue single Fiber
Response Function from a DWI volume. A DTI fit is made, and voxels
containing a single fiber population are found using either a threshold on
the FA, inside a white matter mask.
Parameters
----------
Expand All @@ -43,7 +44,7 @@ def compute_ssst_frf(data, bvals, bvecs, b0_threshold=DEFAULT_B0_THRESHOLD,
3D mask with shape (X,Y,Z)
Binary white matter mask. Only the data inside this mask and above the
threshold defined by fa_thresh will be used to estimate the fiber
response function.
response function. If not given, all voxels inside `mask` will be used.
fa_thresh : float, optional
Use this threshold as the initial threshold to select single fiber
voxels. Defaults to 0.7
Expand All @@ -63,7 +64,7 @@ def compute_ssst_frf(data, bvals, bvecs, b0_threshold=DEFAULT_B0_THRESHOLD,
Returns
-------
full_reponse : ndarray
full_response : ndarray
Fiber Response Function, with shape (4,)
Raises
Expand Down Expand Up @@ -139,10 +140,11 @@ def compute_msmt_frf(data, bvals, bvecs, btens=None, data_dti=None,
fa_thr_wm=0.7, fa_thr_gm=0.2, fa_thr_csf=0.1,
md_thr_gm=0.0007, md_thr_csf=0.003, min_nvox=300,
roi_radii=10, roi_center=None, tol=20):
"""Compute a multi-shell, multi-tissue single Fiber
Response Function from a DWI volume.
A DTI fit is made, and voxels containing a single fiber population are
found using a threshold on the FA and MD.
"""
Computes a multi-shell, multi-tissue single Fiber Response Function from a
DWI volume. A DTI fit is made, and voxels containing a single fiber
population are found using a threshold on the FA and MD, inside a mask of
each tissue type.
Parameters
----------
Expand Down Expand Up @@ -304,33 +306,47 @@ def compute_msmt_frf(data, bvals, bvecs, btens=None, data_dti=None,
return responses, frf_masks


def replace_frf(old_frf, new_frf, no_factor):
def replace_frf(old_frf, new_frf, no_factor=False):
"""
Replace old_frf with new_frf
Replaces the 3 first values of old_frf with new_frf. Formats the new_frf
from a string value and verifies that the number of shells corresponds.
Parameters
----------
old_frf: np.ndarray
A loaded frf file, of shape (n, 4).
new_frf: tuple
The new frf, to be interpreted with a 10**-4 factor. Ex: (15,4,4)
A loaded frf file, of shape (N, 4), where N is the number of shells.
new_frf: str
The new frf, to be interpreted with a 10**-4 factor. Ex: 15,4,4. With
multishell: all values, concatenated into one string.
Ex: 15,4,4,13,5,5,12,5,5.
no_factor: bool
If true, the fiber response function is evaluated without the
10**-4 factor.
Returns
-------
response: np.ndarray
Formatted new frf, of shape (n, 4)
"""
old_frf = old_frf.T
new_frf = np.array(literal_eval(new_frf), dtype=np.float64)
if len(old_frf.shape) == 1: # When loading from one shell, we get (4, )
old_frf = old_frf[None, :]
old_nb_shells = old_frf.shape[0]
b0_mean = old_frf[:, 3]

new_frf = np.array(literal_eval(new_frf), dtype=np.float64)
if not no_factor:
new_frf *= 10 ** -4
b0_mean = old_frf[3]

if new_frf.shape[0] % 3 != 0:
raise ValueError('Inputed new frf is not valid. There should be '
'three values per shell, and thus the total number '
'of values should be a multiple of three.')

nb_shells = int(new_frf.shape[0] / 3)
if nb_shells != old_nb_shells:
raise ValueError("The old frf contained {} shell(s). Cannot replace "
"with {} shell(s).".format(old_nb_shells, nb_shells))

new_frf = new_frf.reshape((nb_shells, 3))

response = np.empty((nb_shells, 4))
Expand Down
36 changes: 34 additions & 2 deletions scilpy/reconst/tests/test_fodf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
# -*- coding: utf-8 -*-
import numpy as np
from dipy.data import get_sphere
from dipy.reconst.shm import sh_to_sf_matrix

from scilpy.reconst.fodf import get_ventricles_max_fodf
from scilpy.reconst.utils import find_order_from_nb_coeff
from scilpy.tests.arrays import fodf_3x3_order8_descoteaux07


def test_get_ventricles_max_fodf():
# toDO
pass
fake_fa = np.ones((3, 3, 1)) # High FA
fake_fa[1:3, 0:2, 0] = 0 # Low in ventricles
fake_md = np.zeros((3, 3, 1)) # Low MD
fake_md[0:2, 0:2, 0] = 1 # High in ventricles
zoom = [1, 1, 1]
fa_threshold = 0.5
md_threshold = 0.5
sh_basis = 'descoteaux07'

# Should find that the only 2 ventricle voxels are at [1, 0:2, 0]
mean, mask = get_ventricles_max_fodf(
fodf_3x3_order8_descoteaux07, fake_fa, fake_md, zoom, sh_basis,
fa_threshold, md_threshold, small_dims=True)

expected_mask = np.logical_and(~fake_fa.astype(bool), fake_md)
assert np.count_nonzero(mask) == 2
assert np.array_equal(mask.astype(bool), expected_mask)

# Reconstruct SF values same as in method.
order = find_order_from_nb_coeff(fodf_3x3_order8_descoteaux07)
sphere = get_sphere('repulsion100')
b_matrix, _ = sh_to_sf_matrix(sphere, order, sh_basis, legacy=True)

sf1 = np.dot(fodf_3x3_order8_descoteaux07[1, 0, 0], b_matrix)
sf2 = np.dot(fodf_3x3_order8_descoteaux07[1, 1, 0], b_matrix)

assert mean == np.mean([np.max(sf1), np.max(sf2)])


def test_fit_from_model():
Expand Down
75 changes: 69 additions & 6 deletions scilpy/reconst/tests/test_frf.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,79 @@
# -*- coding: utf-8 -*-
import os
import tempfile

import nibabel as nib
import numpy as np
from dipy.io import read_bvals_bvecs

from scilpy import SCILPY_HOME
from scilpy.io.fetcher import fetch_data, get_testing_files_dict
from scilpy.reconst.frf import compute_ssst_frf, compute_msmt_frf, replace_frf

# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['processing.zip'])
tmp_dir = tempfile.TemporaryDirectory()
in_dwi = os.path.join(SCILPY_HOME, 'processing', 'dwi_crop.nii.gz')
in_bval = os.path.join(SCILPY_HOME, 'processing', 'dwi.bval')
in_bvec = os.path.join(SCILPY_HOME, 'processing', 'dwi.bvec')


def test_compute_ssst_frf():
# toDO
pass
# Uses data from our test data.
# To use a smaller subset, we need to ensure that it has at least one
# voxel with FA higher than 0.7. Quite fast as is, so, ok.
dwi = nib.load(in_dwi).get_fdata() # Shape: 57, 67, 56, 64
bvals, bvecs = read_bvals_bvecs(in_bval, in_bvec)

result = compute_ssst_frf(dwi, bvals, bvecs)

# Value with current data at the date of test creation:
expected_result = [1.03068237e-03, 2.44994949e-04,
2.44994949e-04, 3.26903486e+03]
assert np.allclose(result, expected_result)


def test_compute_msmt_frf():
# toDO
pass
# Uses data from our test data.
# To use a smaller subset, we need to ensure that it has at least one
# voxel with each tissue type.
dwi = nib.load(in_dwi).get_fdata() # Shape: 57, 67, 56, 64
bvals, bvecs = read_bvals_bvecs(in_bval, in_bvec)

responses, masks = compute_msmt_frf(dwi, bvals, bvecs)

# Value with current data at the date of test creation:
expected_result_wm = [[1.56925332e-03, 4.68706503e-04,
4.68706503e-04, 3.26903486e+03],
[1.15181122e-03, 3.75303294e-04,
3.75303294e-04, 3.26903486e+03],
[8.61299793e-04, 3.14541494e-04,
3.14541494e-04, 3.26903486e+03]]
expected_result_gm = [[9.74471606e-04, 8.34628732e-04,
8.34628732e-04, 3.42007686e+03],
[7.76991313e-04, 6.89550835e-04,
6.89550835e-04, 3.42007686e+03],
[6.26617550e-04, 5.73389066e-04,
5.73389066e-04, 3.42007686e+03]]
expected_result_csf = [[9.33140592e-04, 8.31445917e-04,
8.31445917e-04, 3.62805637e+03],
[7.69894406e-04, 7.07255607e-04,
7.07255607e-04, 3.62805637e+03],
[6.34735398e-04, 5.96451860e-04,
5.96451860e-04, 3.62805637e+03]]
assert np.allclose(responses[0], expected_result_wm)
assert np.allclose(responses[1], expected_result_gm)
assert np.allclose(responses[2], expected_result_csf)

assert np.count_nonzero(masks[0]) == 845 # wm
assert np.count_nonzero(masks[1]) == 1779 # gm
assert np.count_nonzero(masks[2]) == 449 # csf


def test_replace_frf():
# toDo
pass
old_frf = np.random.rand(4)
new_frf = "15,4,4"
result = replace_frf(old_frf, new_frf, no_factor=True)

# Rounds to float64
assert np.allclose(result, [15, 4, 4, old_frf[-1]])
4 changes: 2 additions & 2 deletions scripts/scil_fodf_max_in_ventricles.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ def main():
sh_basis, is_legacy = parse_sh_basis_arg(args)

value, mask = get_ventricles_max_fodf(fodf, fa, md, zoom, sh_basis,
args.small_dims, args.fa_threshold,
args.md_threshold,
args.fa_threshold, args.md_threshold,
small_dims=args.small_dims,
is_legacy=is_legacy)

if args.mask_output:
Expand Down
9 changes: 3 additions & 6 deletions scripts/tests/test_frf_set_diffusivities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,23 @@ def test_help_option(script_runner):

def test_execution_processing_ssst(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(SCILPY_HOME, 'processing',
'frf.txt')
in_frf = os.path.join(SCILPY_HOME, 'processing', 'frf.txt')
ret = script_runner.run('scil_frf_set_diffusivities.py', in_frf,
'15,4,4', 'new_frf.txt', '-f')
assert ret.success


def test_execution_processing_msmt(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(SCILPY_HOME, 'commit_amico',
'wm_frf.txt')
in_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt')
ret = script_runner.run('scil_frf_set_diffusivities.py', in_frf,
'15,4,4,13,4,4,12,5,5', 'new_frf.txt', '-f')
assert ret.success


def test_execution_processing__wrong_input(script_runner, monkeypatch):
monkeypatch.chdir(os.path.expanduser(tmp_dir.name))
in_frf = os.path.join(SCILPY_HOME, 'commit_amico',
'wm_frf.txt')
in_frf = os.path.join(SCILPY_HOME, 'commit_amico', 'wm_frf.txt')
ret = script_runner.run('scil_frf_set_diffusivities.py', in_frf,
'15,4,4,13,4,4', 'new_frf.txt', '-f')
assert not ret.success
Loading

0 comments on commit 35b40e1

Please sign in to comment.