From a278c90633ebf02fba2ac37e6d171350e7d4c4d2 Mon Sep 17 00:00:00 2001 From: Chris Markiewicz Date: Sun, 24 Sep 2023 10:58:49 -0400 Subject: [PATCH] RF: Simplify spline fit/apply by keeping in (L, K) shape --- sdcflows/interfaces/bspline.py | 10 +++++----- sdcflows/tests/test_transform.py | 6 +++--- sdcflows/transform.py | 27 ++++++++++++--------------- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/sdcflows/interfaces/bspline.py b/sdcflows/interfaces/bspline.py index 767e055927..5674bdd486 100644 --- a/sdcflows/interfaces/bspline.py +++ b/sdcflows/interfaces/bspline.py @@ -130,7 +130,7 @@ class BSplineApprox(SimpleInterface): def _run_interface(self, runtime): from sklearn import linear_model as lm - from scipy.sparse import vstack as sparse_vstack + from scipy.sparse import hstack as sparse_hstack # Output name baseline out_name = fname_presuffix( @@ -197,9 +197,9 @@ def _run_interface(self, runtime): data -= center # Calculate collocation matrix from (possibly resized) image and knot grids - colmat = sparse_vstack( + colmat = sparse_hstack( grid_bspline_weights(fmapnii, grid) for grid in bs_grids - ).T.tocsr() + ).tocsr() bs_grids_str = ["x".join(str(s) for s in grid.shape) for grid in bs_grids] bs_grids_str[-1] = f"and {bs_grids_str[-1]}" @@ -254,9 +254,9 @@ def _run_interface(self, runtime): mask = np.asanyarray(masknii.dataobj) > 1e-4 else: mask = np.ones_like(fmapnii.dataobj, dtype=bool) - colmat = sparse_vstack( + colmat = sparse_hstack( grid_bspline_weights(fmapnii, grid) for grid in bs_grids - ).T.tocsr() + ).tocsr() regressors = colmat[mask.reshape(-1), :] interp_data = np.zeros_like(data) diff --git a/sdcflows/tests/test_transform.py b/sdcflows/tests/test_transform.py index 865d3c4cad..eb8bd09bb2 100644 --- a/sdcflows/tests/test_transform.py +++ b/sdcflows/tests/test_transform.py @@ -348,11 +348,11 @@ def test_grid_bspline_weights(): nb.Nifti1Image(np.zeros(target_shape), target_aff), nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff), ).tocsr() - assert weights.shape == (64, 1000) + assert weights.shape == (1000, 64) # Empirically determined numbers intended to indicate that something # significant has changed. If it turns out we've been doing this wrong, # these numbers will probably change. assert np.isclose(weights[0, 0], 0.00089725334) assert np.isclose(weights[-1, -1], 0.18919244) - assert np.isclose(weights.sum(axis=1).max(), 129.3907) - assert np.isclose(weights.sum(axis=1).min(), 0.0052327816) + assert np.isclose(weights.sum(axis=0).max(), 129.3907) + assert np.isclose(weights.sum(axis=0).min(), 0.0052327816) diff --git a/sdcflows/transform.py b/sdcflows/transform.py index 583ef0d0a3..a0cfcfbb9e 100644 --- a/sdcflows/transform.py +++ b/sdcflows/transform.py @@ -57,7 +57,7 @@ from warnings import warn from scipy import ndimage as ndi from scipy.interpolate import BSpline -from scipy.sparse import vstack as sparse_vstack, kron, lil_array +from scipy.sparse import hstack as sparse_hstack, kron, lil_array import nibabel as nb import nitransforms as nt @@ -320,18 +320,15 @@ def fit( ) # Generate tensor-product B-Spline weights - weights = [] - coeffs_data = [] - for level in coeffs: - wmat = grid_bspline_weights(target_reference, level) - weights.append(wmat) - coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1)) + colmat = sparse_hstack( + [grid_bspline_weights(target_reference, level) for level in coeffs] + ) + coefficients = np.hstack( + [level.get_fdata(dtype="float32").reshape(-1) for level in coeffs] + ) # Reconstruct the fieldmap (in Hz) from coefficients - fmap = np.zeros(projected_reference.shape[:3], dtype="float32") - fmap = (np.squeeze(np.hstack(coeffs_data).T) @ sparse_vstack(weights)).reshape( - fmap.shape - ) + fmap = np.reshape(colmat @ coefficients, projected_reference.shape[:3]) # Generate a NIfTI object hdr = target_reference.header.copy() @@ -703,7 +700,7 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"): Returns ------- - weights : :obj:`numpy.ndarray` (:math:`K \times N`) + weights : :obj:`numpy.ndarray` (:math:`N \times K`) A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})` for the *N* voxels of the target EPI, for each of the total *K* knots. This sparse matrix can be directly used as design matrix for the fitting @@ -747,11 +744,11 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"): colloc_ax = lil_array(distance.shape, dtype=dtype) colloc_ax[within_support] = bspl(locs)[:, 1:-1][within_support] - # Transpose to (K, L) and convert to CSR for efficient multiplication - wd.append(colloc_ax.T.tocsr()) + # Convert to CSR for efficient multiplication + wd.append(colloc_ax.tocsr()) # Calculate the tensor product of the three design matrices - return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype) + return kron(kron(wd[0], wd[1]), wd[2]) def _move_coeff(in_coeff, fmap_ref, transform, fmap_target=None):