diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 9219ddf7fd..4125d5fe9b 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -298,6 +298,7 @@ class _ApplyCoeffsFieldInputSpec(BaseInterfaceInputSpec): mandatory=True, desc="input coefficients as calculated in the estimation stage", ) + fmap_mask = File(exist=True, desc="Mask used to calculate coefficients") fmap2data_xfm = InputMultiObject( File(exists=True), desc="the transform by which the target EPI can be resampled on the fieldmap's grid.", @@ -393,7 +394,8 @@ def _run_interface(self, runtime): # Pre-cached interpolator object unwarp = B0FieldTransform( - coeffs=[nb.load(cname) for cname in self.inputs.in_coeff] + coeffs=[nb.load(cname) for cname in self.inputs.in_coeff], + mask=self.inputs.fmap_mask or None, ) # We can now write out the fieldmap diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 2be476ae21..4e866acfdc 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -233,6 +233,8 @@ class B0FieldTransform: coeffs = attr.ib(default=None) """B-Spline coefficients (one value per control point).""" + mask = attr.ib(default=None) + """B-Spline coefficients (one value per control point).""" mapped = attr.ib(default=None, init=False) """ A cache of the interpolated field in Hz (i.e., the fieldmap *mapped* on to the @@ -279,6 +281,8 @@ def fit( ``False`` if cache was valid and will be reused. """ + from nitransforms.linear import Affine + # Calculate the physical coordinates of target grid if isinstance(target_reference, (str, bytes, Path)): target_reference = nb.load(target_reference) @@ -319,6 +323,12 @@ def fit( target_reference, ) + if self.mask is None: + mask = np.ones(projected_reference.shape, dtype=bool) + else: + mask_img = nb.load(self.mask) + mask = np.bool_(Affine(reference=projected_reference).apply(mask_img).dataobj) + # Generate tensor-product B-Spline weights colmat = sparse_hstack( [grid_bspline_weights(target_reference, level) for level in coeffs] @@ -328,7 +338,8 @@ def fit( ) # Reconstruct the fieldmap (in Hz) from coefficients - fmap = np.reshape(colmat @ coefficients, projected_reference.shape[:3]) + fmap = np.zeros(projected_reference.shape[:3], dtype='float32') + fmap[mask] = colmat[mask.reshape(-1)] @ coefficients # Generate a NIfTI object hdr = target_reference.header.copy() @@ -341,8 +352,6 @@ def fit( self.mapped = nb.Nifti1Image(fmap, projected_reference.affine, hdr) if approx: - from nitransforms.linear import Affine - _tmp_reference = nb.Nifti1Image( np.zeros( target_reference.shape[:3], dtype=target_reference.get_data_dtype() diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 1c5b53db1f..1186d22830 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -88,6 +88,7 @@ def init_unwarp_wf(*, free_mem=None, omp_nthreads=1, debug=False, name="unwarp_w "distorted", "metadata", "fmap_coeff", + "fmap_mask", "fmap2data_xfm", "data2fmap_xfm", "hmc_xforms", @@ -139,6 +140,7 @@ def init_unwarp_wf(*, free_mem=None, omp_nthreads=1, debug=False, name="unwarp_w ("metadata", "metadata")]), (inputnode, resample, [("distorted", "in_data"), ("fmap_coeff", "in_coeff"), + ("fmap_mask", "fmap_mask"), ("fmap2data_xfm", "fmap2data_xfm"), ("data2fmap_xfm", "data2fmap_xfm"), ("hmc_xforms", "in_xfms")]),