Skip to content

Commit

Permalink
fix: clean up scipy.signal.cubic deprecation
Browse files Browse the repository at this point in the history
Resolves: #370.
  • Loading branch information
oesteban authored and effigies committed Sep 26, 2023
1 parent c90c4ed commit 9c474f1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
11 changes: 10 additions & 1 deletion sdcflows/tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,16 @@ def test_grid_bspline_weights():
nb.Nifti1Image(np.zeros(target_shape), target_aff),
nb.Nifti1Image(np.zeros(ctrl_shape), ctrl_aff),
).tocsr()
assert weights.shape == (64, 1000)
assert weights.shape == (np.prod(np.array(ctrl_shape) + 2), np.prod(target_shape))

# Calculate the legacy mask to index the weights below
legacy_mask = np.pad(
np.ones(ctrl_shape, dtype=bool),
((1, 1), (1, 1), (1, 1)),
).reshape(-1)

# Drop scipy's padding
weights = weights[legacy_mask]
# Empirically determined numbers intended to indicate that something
# significant has changed. If it turns out we've been doing this wrong,
# these numbers will probably change.
Expand Down
27 changes: 15 additions & 12 deletions sdcflows/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@
import numpy as np
from warnings import warn
from scipy import ndimage as ndi
from scipy.signal import cubic
from scipy.sparse import vstack as sparse_vstack, kron, lil_array
from scipy.interpolate import BSpline
from scipy.sparse import vstack as sparse_vstack, kron

import nibabel as nb
import nitransforms as nt
Expand Down Expand Up @@ -325,7 +325,12 @@ def fit(
for level in coeffs:
wmat = grid_bspline_weights(target_reference, level)
weights.append(wmat)
coeffs_data.append(level.get_fdata(dtype="float32").reshape(-1))

# Coefficients must be zero-padded because Scipy's BSpline pads the knots.
coeffs_data.append(np.pad(
level.get_fdata(dtype="float32"),
((1, 1), (1, 1), (1, 1)),
).reshape(-1))

# Reconstruct the fieldmap (in Hz) from coefficients
fmap = np.zeros(projected_reference.shape[:3], dtype="float32")
Expand Down Expand Up @@ -732,18 +737,16 @@ def grid_bspline_weights(target_nii, ctrl_nii, dtype="float32"):
coords[axis] = np.arange(sample_shape[axis], dtype=dtype)

# Calculate the index component of samples w.r.t. B-Spline knots along current axis
# Size of locations is L
locs = nb.affines.apply_affine(target_to_grid, coords.T)[:, axis]
knots = np.arange(knots_shape[axis], dtype=dtype)

distance = np.abs(locs[np.newaxis, ...] - knots[..., np.newaxis])
within_support = distance < 2.0
d_vals, d_idxs = np.unique(distance[within_support], return_inverse=True)
bs_w = cubic(d_vals)

colloc_ax = lil_array((knots_shape[axis], sample_shape[axis]), dtype=dtype)
colloc_ax[within_support] = bs_w[d_idxs]
# Size of knots is K + 6 so that all locations are fully covered by basis
knots = np.arange(-3, knots_shape[axis] + 3, dtype=dtype)

wd.append(colloc_ax)
# Our original design matrices had size (K, L)
# However, BSpline.design_matrix() generates a size of (L, K + 2),
# hence the trasposition (and zero-padding of 1 at every face when using these)
wd.append(BSpline.design_matrix(locs, knots, 3).T)

# Calculate the tensor product of the three design matrices
return kron(kron(wd[0], wd[1]), wd[2]).astype(dtype)
Expand Down

0 comments on commit 9c474f1

Please sign in to comment.