Skip to content

Commit

Permalink
Merge pull request #734 from frheault/synb0_integration
Browse files Browse the repository at this point in the history
Working synb0 wrapper in scilpy
  • Loading branch information
arnaudbore authored Feb 22, 2024
2 parents 89ac8ce + bc117fb commit 594fbb7
Show file tree
Hide file tree
Showing 8 changed files with 267 additions and 12 deletions.
Binary file not shown.
48 changes: 43 additions & 5 deletions scilpy/image/volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,17 +194,55 @@ def transform_dwi(reg_obj, static, dwi, interpolation='linear'):


def register_image(static, static_grid2world, moving, moving_grid2world,
transformation_type='affine', dwi=None):
transformation_type='affine', dwi=None, fine=False):
"""
Register a moving image to a static image using either rigid or affine
transformations. If a DWI (4D) is provided, it applies the transformation
to each volume.
Parameters
----------
static : ndarray
The static image volume to which the moving image will be registered.
static_grid2world : ndarray
The grid-to-world (vox2ras) transformation associated with the static
image.
moving : ndarray
The moving image volume that needs to be registered to the static image.
moving_grid2world : ndarray
The grid-to-world (vox2ras) transformation associated with the moving
image.
transformation_type : str, optional
The type of transformation ('rigid' or 'affine'). Default is 'affine'.
dwi : ndarray, optional
Diffusion-weighted imaging data (if applicable). Default is None.
fine : bool, optional
Whether to use fine or coarse settings for the registration.
Default is False.
Raises
------
ValueError
If the transformation_type is neither 'rigid' nor 'affine'.
Returns
-------
ndarray or tuple
If `dwi` is None, returns transformed moving image and transformation
matrix.
If `dwi` is not None, returns transformed DWI and transformation matrix.
"""

if transformation_type not in ['rigid', 'affine']:
raise ValueError('Transformation type not available in Dipy')

# Set all parameters for registration
nbins = 32
nbins = 64 if fine else 32
params0 = None
sampling_prop = None
level_iters = [50, 25, 5]
sigmas = [8.0, 4.0, 2.0]
factors = [8, 4, 2]
level_iters = [250, 100, 50, 25] if fine else [50, 25, 5]
sigmas = [8.0, 4.0, 2.0, 1.0] if fine else [8.0, 4.0, 2.0]
factors = [8, 4, 2, 1.0] if fine else [8, 4, 2]
metric = MutualInformationMetric(nbins, sampling_prop)
reg_obj = AffineRegistration(metric=metric, level_iters=level_iters,
sigmas=sigmas, factors=factors, verbosity=0)
Expand Down
18 changes: 18 additions & 0 deletions scilpy/io/fetcher.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

import inspect
import logging
import hashlib
import os
Expand Down Expand Up @@ -131,3 +132,20 @@ def fetch_data(files_dict, keys=None):
else:
# toDo. Verify that data on disk is the right one.
logging.warning("Not fetching data; already on disk.")


def get_synb0_template_path():
"""
Return MNI 2.5mm template in scilpy repository
Returns
-------
path: str
Template path
"""
import scilpy # ToDo. Is this the only way?
module_path = inspect.getfile(scilpy)
module_path = os.path.dirname(os.path.dirname(module_path))

path = os.path.join(module_path, 'data/',
'mni_icbm152_t1_tal_nlin_asym_09c_masked_2_5.nii.gz')
return path
12 changes: 10 additions & 2 deletions scilpy/preprocessing/distortion_correction.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# -*- coding: utf-8 -*-

import logging

import numpy as np


def create_acqparams(readout, encoding_direction, nb_b0s=1, nb_rev_b0s=1):
def create_acqparams(readout, encoding_direction, synb0=False,
nb_b0s=1, nb_rev_b0s=1):
"""
Create acqparams for Topup and Eddy
Expand All @@ -23,13 +26,18 @@ def create_acqparams(readout, encoding_direction, nb_b0s=1, nb_rev_b0s=1):
acqparams: np.array
acqparams
"""
if synb0:
logging.warning('Using SyNb0, untested feature. Be careful.')

acqparams = np.zeros((nb_b0s + nb_rev_b0s, 4))
acqparams[:, 3] = readout

enum_direction = {'x': 0, 'y': 1, 'z': 2}
acqparams[0:nb_b0s, enum_direction[encoding_direction]] = 1
if nb_rev_b0s > 0:
acqparams[nb_b0s:, enum_direction[encoding_direction]] = -1
val = -1 if not synb0 else 1
acqparams[nb_b0s:, enum_direction[encoding_direction]] = val
acqparams[nb_b0s:, 3] = readout if not synb0 else 0

return acqparams

Expand Down
5 changes: 2 additions & 3 deletions scilpy/tractanalysis/afd_along_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def afd_map_along_streamlines(sft, fodf, fodf_basis, length_weighting):
rdAFD map (weighted if length_weighting)
"""


afd_sum, rd_sum, weights = \
afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis,
length_weighting)
Expand Down Expand Up @@ -112,7 +111,8 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis,

normalization_weights = np.ones_like(seg_lengths)
if length_weighting:
normalization_weights = seg_lengths / np.linalg.norm(fodf.header.get_zooms()[:3])
normalization_weights = seg_lengths / \
np.linalg.norm(fodf.header.get_zooms()[:3])

for vox_idx, closest_vertex_index, norm_weight in zip(vox_indices,
closest_vertex_indices,
Expand All @@ -130,5 +130,4 @@ def afd_and_rd_sums_along_streamlines(sft, fodf, fodf_basis,
weight_map[vox_idx] += norm_weight

rd_sum_map[rd_sum_map < 0.] = 0.

return afd_sum_map, rd_sum_map, weight_map
7 changes: 5 additions & 2 deletions scripts/scil_dwi_prepare_topup_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def _build_arg_parser():
p.add_argument('--config', default='b02b0.cnf',
help='Topup config file [%(default)s].')

p.add_argument('--synb0', action='store_true',
help='If set, will use SyNb0 custom acqparams file.')

p.add_argument('--encoding_direction', default='y',
choices=['x', 'y', 'z'],
help='Acquisition direction of the forward b0 '
Expand Down Expand Up @@ -119,8 +122,8 @@ def main():
fused_b0s_path = os.path.join(args.out_directory, args.out_b0s)
nib.save(nib.Nifti1Image(fused_b0s, b0_img.affine), fused_b0s_path)

acqparams = create_acqparams(
args.readout, args.encoding_direction, b0.shape[-1], rev_b0.shape[-1])
acqparams = create_acqparams(args.readout, args.encoding_direction,
args.synb0, b0.shape[-1], rev_b0.shape[-1])

if not os.path.exists(args.out_directory):
os.makedirs(args.out_directory)
Expand Down
142 changes: 142 additions & 0 deletions scripts/scil_volume_b0_synthesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Wrapper for SyNb0 available in Dipy, to run it on a single subject.
Requires Skull-Strip b0 and t1w images as input, the script will normalize the
t1w's WM to 110, co-register both images, then register it to the appropriate
template, run SyNb0 and then transform the result back to the original space.
This script must be used carefully, as it is not meant to be used in an
environment with the following dependencies already installed (not default
in Scilpy):
- tensorflow-addons
- tensorrt
- tensorflow
"""


import argparse
import logging
import os
import sys
import warnings

# Disable tensorflow warnings
with warnings.catch_warnings():
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
warnings.simplefilter("ignore")
from dipy.nn.synb0 import Synb0

from dipy.align.imaffine import AffineMap
from dipy.segment.tissue import TissueClassifierHMRF
import nibabel as nib
import numpy as np
from scipy.ndimage import gaussian_filter

from scilpy.io.fetcher import get_synb0_template_path
from scilpy.io.utils import (add_overwrite_arg,
add_verbose_arg,
assert_inputs_exist,
assert_outputs_exist)
from scilpy.image.volume_operations import register_image


def _build_arg_parser():
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawTextHelpFormatter)
p.add_argument('in_b0',
help='Input b0 image.')
p.add_argument('in_b0_mask',
help='Input b0 mask.')
p.add_argument('in_t1',
help='Input t1w image.')
p.add_argument('in_t1_mask',
help='Input t1w mask.')
p.add_argument('out_b0',
help='Output b0 image without distortion.')

add_verbose_arg(p)
add_overwrite_arg(p)

return p


def main():
parser = _build_arg_parser()
args = parser.parse_args()
assert_inputs_exist(parser, [args.in_b0, args.in_t1])
assert_outputs_exist(parser, args, args.out_b0)

logging.getLogger().setLevel(logging.getLevelName(args.verbose))
logging.info('The usage of synthetic b0 is not fully tested.'
'Be careful when using it.')

template_img = nib.load(get_synb0_template_path())
template_data = template_img.get_fdata()

b0_img = nib.load(args.in_b0)
b0_skull_data = b0_img.get_fdata()
b0_mask_img = nib.load(args.in_b0_mask)
b0_mask_data = b0_mask_img.get_fdata()

t1_img = nib.load(args.in_t1)
t1_skull_data = t1_img.get_fdata()
t1_mask_img = nib.load(args.in_t1_mask)
t1_mask_data = t1_mask_img.get_fdata()

b0_bet_data = np.zeros(b0_skull_data.shape)
b0_bet_data[b0_mask_data > 0] = b0_skull_data[b0_mask_data > 0]
t1_bet_data = np.zeros(t1_skull_data.shape)
t1_bet_data[t1_mask_data > 0] = t1_skull_data[t1_mask_data > 0]

# Crude estimation of the WM mean intensity and normalization
logging.info('Estimating WM mean intensity')
hmrf = TissueClassifierHMRF()
t1_bet_data = gaussian_filter(t1_bet_data, 2)
_, final_segmentation, _ = hmrf.classify(t1_bet_data, 3, 0.25,
tolerance=1e-4, max_iter=5)
avg_wm = np.mean(t1_skull_data[final_segmentation == 3])
t1_skull_data /= avg_wm
t1_skull_data *= 110

# SyNB0 works only in a standard space, so we need to register the images
logging.info('Registering images')
# Use the BET image for registration
t1_bet_to_b0, t1_bet_to_b0_transform = register_image(b0_bet_data,
b0_img.affine,
t1_bet_data,
t1_img.affine,
fine=True)
affine_map = AffineMap(t1_bet_to_b0_transform,
b0_skull_data.shape, b0_img.affine,
t1_skull_data.shape, t1_img.affine)
t1_skull_to_b0 = affine_map.transform(t1_skull_data.astype(np.float64))

# Then register to MNI (using the BET again)
_, t1_bet_to_b0_to_mni_transform = register_image(template_data,
template_img.affine,
t1_bet_to_b0,
b0_img.affine,
fine=True)
affine_map = AffineMap(t1_bet_to_b0_to_mni_transform,
template_data.shape, template_img.affine,
b0_skull_data.shape, b0_img.affine)

# But for prediction, we want the skull
b0_skull_to_mni = affine_map.transform(b0_skull_data.astype(np.float64))
t1_skull_to_mni = affine_map.transform(t1_skull_to_b0.astype(np.float64))

logging.info('Running SyN-B0')
SyNb0 = Synb0(args.verbose)
rev_b0 = SyNb0.predict(b0_skull_to_mni, t1_skull_to_mni)
rev_b0 = affine_map.transform_inverse(rev_b0.astype(np.float64))

dtype = b0_img.get_data_dtype()
nib.save(nib.Nifti1Image(rev_b0.astype(dtype), b0_img.affine),
args.out_b0)


if __name__ == "__main__":
main()
47 changes: 47 additions & 0 deletions scripts/tests/test_volume_b0_synthesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from scilpy.io.fetcher import fetch_data, get_home, get_testing_files_dict
import os
import tempfile

import pytest
import nibabel as nib
import numpy as np
tensorflow = pytest.importorskip("tensorflow")


# If they already exist, this only takes 5 seconds (check md5sum)
fetch_data(get_testing_files_dict(), keys=['others.zip', 'processing.zip'])
tmp_dir = tempfile.TemporaryDirectory()


def test_help_option(script_runner):
ret = script_runner.run('scil_volume_b0_synthesis.py', '--help')
assert ret.success


@pytest.mark.skipif(tensorflow is None, reason="Tensorflow not installed")
def test_synthesis(script_runner):
os.chdir(os.path.expanduser(tmp_dir.name))
in_t1 = os.path.join(get_home(), 'others',
't1.nii.gz')
in_b0 = os.path.join(get_home(), 'processing',
'b0_mean.nii.gz')

t1_img = nib.load(in_t1)
b0_img = nib.load(in_b0)
t1_data = t1_img.get_fdata()
b0_data = b0_img.get_fdata()
t1_data[t1_data > 0] = 1
b0_data[b0_data > 0] = 1
nib.save(nib.Nifti1Image(t1_data.astype(np.uint8), t1_img.affine),
't1_mask.nii.gz')
nib.save(nib.Nifti1Image(b0_data.astype(np.uint8), b0_img.affine),
'b0_mask.nii.gz')

ret = script_runner.run('scil_volume_b0_synthesis.py',
in_t1, 't1_mask.nii.gz',
in_b0, 'b0_mask.nii.gz',
'b0_synthesized.nii.gz', '-v')
assert ret.success

0 comments on commit 594fbb7

Please sign in to comment.