Skip to content

Commit

Permalink
FIX: Add mask to fieldmap application
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Sep 26, 2023
1 parent 58ecbe7 commit 42fa25b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 4 deletions.
4 changes: 3 additions & 1 deletion sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec):
mandatory=True,
desc="input coefficients as calculated in the estimation stage",
)
fmap_mask = File(exist=True, desc="Mask used to calculate coefficients")
fmap2data_xfm = InputMultiObject(
File(exists=True),
desc="the transform by which the target EPI can be resampled on the fieldmap's grid.",
Expand Down Expand Up @@ -393,7 +394,8 @@ def _run_interface(self, runtime):

# Pre-cached interpolator object
unwarp = B0FieldTransform(
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff]
coeffs=[nb.load(cname) for cname in self.inputs.in_coeff],
mask=self.inputs.fmap_mask or None,
)

# We can now write out the fieldmap
Expand Down
15 changes: 12 additions & 3 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ class B0FieldTransform:

coeffs = attr.ib(default=None)
"""B-Spline coefficients (one value per control point)."""
mask = attr.ib(default=None)
"""B-Spline coefficients (one value per control point)."""
mapped = attr.ib(default=None, init=False)
"""
A cache of the interpolated field in Hz (i.e., the fieldmap *mapped* on to the
Expand Down Expand Up @@ -279,6 +281,8 @@ def fit(
``False`` if cache was valid and will be reused.
"""
from nitransforms.linear import Affine

# Calculate the physical coordinates of target grid
if isinstance(target_reference, (str, bytes, Path)):
target_reference = nb.load(target_reference)
Expand Down Expand Up @@ -319,6 +323,12 @@ def fit(
target_reference,
)

if self.mask is None:
mask = np.ones(projected_reference.shape, dtype=bool)
else:
mask_img = nb.load(self.mask)
mask = np.bool_(Affine(reference=projected_reference).apply(mask_img).dataobj)

# Generate tensor-product B-Spline weights
colmat = sparse_hstack(
[grid_bspline_weights(target_reference, level) for level in coeffs]
Expand All @@ -328,7 +338,8 @@ def fit(
)

# Reconstruct the fieldmap (in Hz) from coefficients
fmap = np.reshape(colmat @ coefficients, projected_reference.shape[:3])
fmap = np.zeros(projected_reference.shape[:3], dtype='float32')
fmap[mask] = colmat[mask.reshape(-1)] @ coefficients

# Generate a NIfTI object
hdr = target_reference.header.copy()
Expand All @@ -341,8 +352,6 @@ def fit(
self.mapped = nb.Nifti1Image(fmap, projected_reference.affine, hdr)

if approx:
from nitransforms.linear import Affine

_tmp_reference = nb.Nifti1Image(
np.zeros(
target_reference.shape[:3], dtype=target_reference.get_data_dtype()
Expand Down
2 changes: 2 additions & 0 deletions sdcflows/workflows/apply/correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def init_unwarp_wf(*, free_mem=None, omp_nthreads=1, debug=False, name="unwarp_w
"distorted",
"metadata",
"fmap_coeff",
"fmap_mask",
"fmap2data_xfm",
"data2fmap_xfm",
"hmc_xforms",
Expand Down Expand Up @@ -139,6 +140,7 @@ def init_unwarp_wf(*, free_mem=None, omp_nthreads=1, debug=False, name="unwarp_w
("metadata", "metadata")]),
(inputnode, resample, [("distorted", "in_data"),
("fmap_coeff", "in_coeff"),
("fmap_mask", "fmap_mask"),
("fmap2data_xfm", "fmap2data_xfm"),
("data2fmap_xfm", "data2fmap_xfm"),
("hmc_xforms", "in_xfms")]),
Expand Down

0 comments on commit 42fa25b

Please sign in to comment.