Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF: Use scipy.interpolate.BSpline to construct spline basis #393

Merged
merged 5 commits into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions sdcflows/interfaces/bspline.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class BSplineApprox(SimpleInterface):

def _run_interface(self, runtime):
from sklearn import linear_model as lm
from scipy.sparse import vstack as sparse_vstack
from scipy.sparse import hstack as sparse_hstack

# Output name baseline
out_name = fname_presuffix(
Expand Down Expand Up @@ -197,9 +197,9 @@ def _run_interface(self, runtime):
data -= center

# Calculate collocation matrix from (possibly resized) image and knot grids
colmat = sparse_vstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()
colmat = sparse_hstack(
[grid_bspline_weights(fmapnii, grid) for grid in bs_grids]
).tocsr()

bs_grids_str = ["x".join(str(s) for s in grid.shape) for grid in bs_grids]
bs_grids_str[-1] = f"and {bs_grids_str[-1]}"
Expand Down Expand Up @@ -254,9 +254,9 @@ def _run_interface(self, runtime):
mask = np.asanyarray(masknii.dataobj) > 1e-4
else:
mask = np.ones_like(fmapnii.dataobj, dtype=bool)
colmat = sparse_vstack(
grid_bspline_weights(fmapnii, grid) for grid in bs_grids
).T.tocsr()
colmat = sparse_hstack(
[grid_bspline_weights(fmapnii, grid) for grid in bs_grids]
).tocsr()

regressors = colmat[mask.reshape(-1), :]
interp_data = np.zeros_like(data)
Expand Down
6 changes: 3 additions & 3 deletions sdcflows/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,11 @@ def test_grid_bspline_weights():
nb.Nifti1Image(np.zeros(target_shape), target_aff),
nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff),
).tocsr()
assert weights.shape == (64, 1000)
assert weights.shape == (1000, 64)
# Empirically determined numbers intended to indicate that something
# significant has changed. If it turns out we've been doing this wrong,
# these numbers will probably change.
assert np.isclose(weights[0, 0], 0.00089725334)
assert np.isclose(weights[-1, -1], 0.18919244)
assert np.isclose(weights.sum(axis=1).max(), 129.3907)
assert np.isclose(weights.sum(axis=1).min(), 0.0052327816)
assert np.isclose(weights.sum(axis=0).max(), 129.3907)
assert np.isclose(weights.sum(axis=0).min(), 0.0052327816)
44 changes: 23 additions & 21 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
import numpy as np
from warnings import warn
from scipy import ndimage as ndi
from scipy.signal import cubic
from scipy.sparse import vstack as sparse_vstack, kron, lil_array
from scipy.interpolate import BSpline
from scipy.sparse import hstack as sparse_hstack, kron, lil_array

import nibabel as nb
import nitransforms as nt
Expand Down Expand Up @@ -309,7 +309,6 @@ def fit(
atol=1e-3,
)

weights = []
if approx:
from sdcflows.utils.tools import deoblique_and_zooms

Expand All @@ -321,17 +320,15 @@ def fit(
)

# Generate tensor-product B-Spline weights
coeffs_data = []
for level in coeffs:
wmat = grid_bspline_weights(target_reference, level)
weights.append(wmat)
coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1))
colmat = sparse_hstack(
[grid_bspline_weights(projected_reference, level) for level in coeffs]
).tocsr()
coefficients = np.hstack(
[level.get_fdata(dtype="float32").reshape(-1) for level in coeffs]
)

# Reconstruct the fieldmap (in Hz) from coefficients
fmap = np.zeros(projected_reference.shape[:3], dtype="float32")
fmap = (np.squeeze(np.hstack(coeffs_data).T) @ sparse_vstack(weights)).reshape(
fmap.shape
)
fmap = np.reshape(colmat @ coefficients, projected_reference.shape[:3])

# Generate a NIfTI object
hdr = target_reference.header.copy()
Expand Down Expand Up @@ -703,7 +700,7 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):

Returns
-------
weights : :obj:`numpy.ndarray` (:math:`K \times N`)
weights : :obj:`numpy.ndarray` (:math:`N \times K`)
A sparse matrix of interpolating weights :math:`\Psi^3(\mathbf{k}, \mathbf{s})`
for the *N* voxels of the target EPI, for each of the total *K* knots.
This sparse matrix can be directly used as design matrix for the fitting
Expand Down Expand Up @@ -732,21 +729,26 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
coords[axis] = np.arange(sample_shape[axis], dtype=dtype)

# Calculate the index component of samples w.r.t. B-Spline knots along current axis
# Size of locations is L
locs = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis]
knots = np.arange(knots_shape[axis], dtype=dtype)

distance = np.abs(locs[np.newaxis, ...] - knots[..., np.newaxis])
# Size of knots is K + 6 so that all locations are fully covered by basis
knots = np.arange(-3, knots_shape[axis] + 3, dtype=dtype)

bspl = BSpline(knots, np.eye(len(knots) - 3 - 1), 3)

# Construct a sparse design matrix (L, K)
distance = np.abs(locs[..., np.newaxis] - knots[np.newaxis, 3:-3])
within_support = distance < 2.0
d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True)
bs_w = cubic(d_vals)

colloc_ax = lil_array((knots_shape[axis], sample_shape[axis]), dtype=dtype)
colloc_ax[within_support] = bs_w[d_idxs]
colloc_ax = lil_array(distance.shape, dtype=dtype)
colloc_ax[within_support] = bspl(locs)[:, 1:-1][within_support]

wd.append(colloc_ax)
# Convert to CSR for efficient multiplication
wd.append(colloc_ax.tocsr())

# Calculate the tensor product of the three design matrices
return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)
return kron(kron(wd[0], wd[1]), wd[2])


def _move_coeff(in_coeff, fmap_ref, transform, fmap_target=None):
Expand Down
Loading