Skip to content

Commit

Permalink
RF: Sample spline basis less densely
Browse files Browse the repository at this point in the history
Use a construction that will potentially allow us to calculate
derivatives more easily.
  • Loading branch information
effigies committed Sep 23, 2023
1 parent 2bf307f commit 253ad91
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 22 deletions.
11 changes: 1 addition & 10 deletions sdcflows/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,16 +348,7 @@ def test_grid_bspline_weights():
nb.Nifti1Image(np.zeros(target_shape), target_aff),
nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff),
).tocsr()
assert weights.shape == (np.prod(np.array(ctrl_shape) + 2), np.prod(target_shape))

# Calculate the legacy mask to index the weights below
legacy_mask = np.pad(
np.ones(ctrl_shape, dtype=bool),
((1, 1), (1, 1), (1, 1)),
).reshape(-1)

# Drop scipy's padding
weights = weights[legacy_mask]
assert weights.shape == (64, 1000)
# Empirically determined numbers intended to indicate that something
# significant has changed. If it turns out we've been doing this wrong,
# these numbers will probably change.
Expand Down
26 changes: 14 additions & 12 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
from warnings import warn
from scipy import ndimage as ndi
from scipy.interpolate import BSpline
from scipy.sparse import vstack as sparse_vstack, kron
from scipy.sparse import vstack as sparse_vstack, kron, lil_array

import nibabel as nb
import nitransforms as nt
Expand Down Expand Up @@ -309,7 +309,6 @@ def fit(
atol=1e-3,
)

weights = []
if approx:
from sdcflows.utils.tools import deoblique_and_zooms

Expand All @@ -321,16 +320,12 @@ def fit(
)

# Generate tensor-product B-Spline weights
weights = []
coeffs_data = []
for level in coeffs:
wmat = grid_bspline_weights(target_reference, level)
weights.append(wmat)

# Coefficients must be zero-padded because Scipy's BSpline pads the knots.
coeffs_data.append(np.pad(
level.get_fdata(dtype="float32"),
((1, 1), (1, 1), (1, 1)),
).reshape(-1))
coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1))

# Reconstruct the fieldmap (in Hz) from coefficients
fmap = np.zeros(projected_reference.shape[:3], dtype="float32")
Expand Down Expand Up @@ -743,10 +738,17 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
# Size of knots is K + 6 so that all locations are fully covered by basis
knots = np.arange(-3, knots_shape[axis] + 3, dtype=dtype)

# Our original design matrices had size (K, L)
# However, BSpline.design_matrix() generates a size of (L, K + 2),
# hence the transposition (and zero-padding of 1 at every face when using these)
wd.append(BSpline.design_matrix(locs, knots, 3).T)
bspl = BSpline(knots, np.eye(len(knots) - 3 - 1), 3)

# Construct a sparse design matrix (L, K)
distance = np.abs(locs[..., np.newaxis] - knots[np.newaxis, 3:-3])
within_support = distance < 2.0

matrix = lil_array(distance.shape, dtype=dtype)
matrix[within_support] = bspl(locs)[:, 1:-1][within_support]

# Transpose to (K, L) and convert to CSR for efficient multiplication
wd.append(matrix.T.tocsr())

# Calculate the tensor product of the three design matrices
return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)
Expand Down

0 comments on commit 253ad91

Please sign in to comment.