diff --git a/sdcflows/tests/test_transform.py b/sdcflows/tests/test_transform.py index 865d3c4cad..72851876be 100644 --- a/sdcflows/tests/test_transform.py +++ b/sdcflows/tests/test_transform.py @@ -348,7 +348,16 @@ def test_grid_bspline_weights(): nb.Nifti1Image(np.zeros(target_shape), target_aff), nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff), ).tocsr() - assert weights.shape == (64, 1000) + assert weights.shape == (np.prod(np.array(ctrl_shape) + 2), np.prod(target_shape)) + + # Calculate the legacy mask to index the weights below + legacy_mask = np.pad( + np.ones(ctrl_shape, dtype=bool), + ((1, 1), (1, 1), (1, 1)), + ).reshape(-1) + + # Drop scipy's padding + weights = weights[legacy_mask] # Empirically determined numbers intended to indicate that something # significant has changed. If it turns out we've been doing this wrong, # these numbers will probably change. diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 7247b97c2f..a56e5cfdad 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -56,8 +56,8 @@ import numpy as np from warnings import warn from scipy import ndimage as ndi -from scipy.signal import cubic -from scipy.sparse import vstack as sparse_vstack, kron, lil_array +from scipy.interpolate import BSpline +from scipy.sparse import vstack as sparse_vstack, kron import nibabel as nb import nitransforms as nt @@ -325,7 +325,12 @@ def fit( for level in coeffs: wmat = grid_bspline_weights(target_reference, level) weights.append(wmat) - coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1)) + + # Coefficients must be zero-padded because Scipy's BSpline pads the knots. + coeffs_data.append(np.pad( + level.get_fdata(dtype="float32"), + ((1, 1), (1, 1), (1, 1)), + ).reshape(-1)) # Reconstruct the fieldmap (in Hz) from coefficients fmap = np.zeros(projected_reference.shape[:3], dtype="float32") @@ -732,18 +737,16 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"): coords[axis] = np.arange(sample_shape[axis], dtype=dtype) # Calculate the index component of samples w.r.t. B-Spline knots along current axis + # Size of locations is L locs = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis] - knots = np.arange(knots_shape[axis], dtype=dtype) - - distance = np.abs(locs[np.newaxis, ...] - knots[..., np.newaxis]) - within_support = distance < 2.0 - d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True) - bs_w = cubic(d_vals) - colloc_ax = lil_array((knots_shape[axis], sample_shape[axis]), dtype=dtype) - colloc_ax[within_support] = bs_w[d_idxs] + # Size of knots is K + 6 so that all locations are fully covered by basis + knots = np.arange(-3, knots_shape[axis] + 3, dtype=dtype) - wd.append(colloc_ax) + # Our original design matrices had size (K, L) + # However, BSpline.design_matrix() generates a size of (L, K + 2), + # hence the trasposition (and zero-padding of 1 at every face when using these) + wd.append(BSpline.design_matrix(locs, knots, 3).T) # Calculate the tensor product of the three design matrices return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)