Skip to content

Commit

Permalink
RF: Simplify spline fit/apply by keeping in (L, K) shape
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Sep 24, 2023
1 parent fa9121e commit a278c90
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 23 deletions.
10 changes: 5 additions & 5 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack
from scipy.sparse import hstack as sparse_hstack

# Output name baseline
out_name = fname_presuffix(
Expand Down Expand Up @@ -197,9 +197,9 @@ def _run_interface(self, runtime):
data -= center

# Calculate collocation matrix from (possibly resized) image and knot grids
colmat = sparse_vstack(
colmat = sparse_hstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()
).tocsr()

bs_grids_str = ["x".join(str(s) for s in grid.shape) for grid in bs_grids]
bs_grids_str[-1] = f"and {bs_grids_str[-1]}"
Expand Down Expand Up @@ -254,9 +254,9 @@ def _run_interface(self, runtime):
mask = np.asanyarray(masknii.dataobj) > 1e-4
else:
mask = np.ones_like(fmapnii.dataobj, dtype=bool)
colmat = sparse_vstack(
colmat = sparse_hstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()
).tocsr()

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
Expand Down
6 changes: 3 additions & 3 deletions sdcflows/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,11 @@ def test_grid_bspline_weights():
nb.Nifti1Image(np.zeros(target_shape), target_aff),
nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff),
).tocsr()
assert weights.shape == (64, 1000)
assert weights.shape == (1000, 64)
# Empirically determined numbers intended to indicate that something
# significant has changed. If it turns out we've been doing this wrong,
# these numbers will probably change.
assert np.isclose(weights[0, 0], 0.00089725334)
assert np.isclose(weights[-1, -1], 0.18919244)
assert np.isclose(weights.sum(axis=1).max(), 129.3907)
assert np.isclose(weights.sum(axis=1).min(), 0.0052327816)
assert np.isclose(weights.sum(axis=0).max(), 129.3907)
assert np.isclose(weights.sum(axis=0).min(), 0.0052327816)
27 changes: 12 additions & 15 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from warnings import warn
from scipy import ndimage as ndi
from scipy.interpolate import BSpline
from scipy.sparse import vstack as sparse_vstack, kron, lil_array
from scipy.sparse import hstack as sparse_hstack, kron, lil_array

import nibabel as nb
import nitransforms as nt
Expand Down Expand Up @@ -320,18 +320,15 @@ def fit(
)

# Generate tensor-product B-Spline weights
weights = []
coeffs_data = []
for level in coeffs:
wmat = grid_bspline_weights(target_reference, level)
weights.append(wmat)
coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1))
colmat = sparse_hstack(
[grid_bspline_weights(target_reference, level) for level in coeffs]
)
coefficients = np.hstack(
[level.get_fdata(dtype="float32").reshape(-1) for level in coeffs]
)

# Reconstruct the fieldmap (in Hz) from coefficients
fmap = np.zeros(projected_reference.shape[:3], dtype="float32")
fmap = (np.squeeze(np.hstack(coeffs_data).T) @ sparse_vstack(weights)).reshape(
fmap.shape
)
fmap = np.reshape(colmat @ coefficients, projected_reference.shape[:3])

# Generate a NIfTI object
hdr = target_reference.header.copy()
Expand Down Expand Up @@ -703,7 +700,7 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
Returns
-------
weights : :obj:`numpy.ndarray` (:math:`K \times N`)
weights : :obj:`numpy.ndarray` (:math:`N \times K`)
A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
for the *N* voxels of the target EPI, for each of the total *K* knots.
This sparse matrix can be directly used as design matrix for the fitting
Expand Down Expand Up @@ -747,11 +744,11 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
colloc_ax = lil_array(distance.shape, dtype=dtype)
colloc_ax[within_support] = bspl(locs)[:, 1:-1][within_support]

# Transpose to (K, L) and convert to CSR for efficient multiplication
wd.append(colloc_ax.T.tocsr())
# Convert to CSR for efficient multiplication
wd.append(colloc_ax.tocsr())

# Calculate the tensor product of the three design matrices
return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)
return kron(kron(wd[0], wd[1]), wd[2])


def _move_coeff(in_coeff, fmap_ref, transform, fmap_target=None):
Expand Down

0 comments on commit a278c90

Please sign in to comment.