diff --git a/sdcflows/tests/test_transform.py b/sdcflows/tests/test_transform.py index 72851876be..865d3c4cad 100644 --- a/sdcflows/tests/test_transform.py +++ b/sdcflows/tests/test_transform.py @@ -348,16 +348,7 @@ def test_grid_bspline_weights(): nb.Nifti1Image(np.zeros(target_shape), target_aff), nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff), ).tocsr() - assert weights.shape == (np.prod(np.array(ctrl_shape) + 2), np.prod(target_shape)) - - # Calculate the legacy mask to index the weights below - legacy_mask = np.pad( - np.ones(ctrl_shape, dtype=bool), - ((1, 1), (1, 1), (1, 1)), - ).reshape(-1) - - # Drop scipy's padding - weights = weights[legacy_mask] + assert weights.shape == (64, 1000) # Empirically determined numbers intended to indicate that something # significant has changed. If it turns out we've been doing this wrong, # these numbers will probably change. diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 92394fc206..0f5fb1d19c 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -57,7 +57,7 @@ from warnings import warn from scipy import ndimage as ndi from scipy.interpolate import BSpline -from scipy.sparse import vstack as sparse_vstack, kron +from scipy.sparse import vstack as sparse_vstack, kron, lil_array import nibabel as nb import nitransforms as nt @@ -309,7 +309,6 @@ def fit( atol=1e-3, ) - weights = [] if approx: from sdcflows.utils.tools import deoblique_and_zooms @@ -321,16 +320,12 @@ def fit( ) # Generate tensor-product B-Spline weights + weights = [] coeffs_data = [] for level in coeffs: wmat = grid_bspline_weights(target_reference, level) weights.append(wmat) - - # Coefficients must be zero-padded because Scipy's BSpline pads the knots. - coeffs_data.append(np.pad( - level.get_fdata(dtype="float32"), - ((1, 1), (1, 1), (1, 1)), - ).reshape(-1)) + coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1)) # Reconstruct the fieldmap (in Hz) from coefficients fmap = np.zeros(projected_reference.shape[:3], dtype="float32") @@ -743,10 +738,17 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"): # Size of knots is K + 6 so that all locations are fully covered by basis knots = np.arange(-3, knots_shape[axis] + 3, dtype=dtype) - # Our original design matrices had size (K, L) - # However, BSpline.design_matrix() generates a size of (L, K + 2), - # hence the transposition (and zero-padding of 1 at every face when using these) - wd.append(BSpline.design_matrix(locs, knots, 3).T) + bspl = BSpline(knots, np.eye(len(knots) - 3 - 1), 3) + + # Construct a sparse design matrix (L, K) + distance = np.abs(locs[..., np.newaxis] - knots[np.newaxis, 3:-3]) + within_support = distance < 2.0 + + matrix = lil_array(distance.shape, dtype=dtype) + matrix[within_support] = bspl(locs)[:, 1:-1][within_support] + + # Transpose to (K, L) and convert to CSR for efficient multiplication + wd.append(matrix.T.tocsr()) # Calculate the tensor product of the three design matrices return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)