Skip to content

Commit

Permalink
fix: shape and order of resampled array
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Nov 16, 2023
1 parent 2c36d08 commit 9772710
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ def __init__(self, transforms, reference=None):
)
self._inverse = np.linalg.inv(self._matrix)

def __iter__(self):
"""Enable iterating over the series of transforms."""
for _m in self.matrix:
yield Affine(_m, reference=self._reference)

def __getitem__(self, i):
"""Enable indexed access to the series of matrices."""
return Affine(self.matrix[i, ...], reference=self._reference)
Expand Down Expand Up @@ -458,42 +463,37 @@ def apply(
# Invert target's (moving) affine once
ras2vox = ~Affine(spatialimage.affine)

if spatialimage.ndim == 4:
if len(self) != spatialimage.shape[-1]:
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), spatialimage.shape[-1])
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.T.shape[0], ) + spatialimage.shape[-1:], dtype=output_dtype, order="F"
if spatialimage.ndim == 4 and (len(self) != spatialimage.shape[-1]):
raise ValueError(
"Attempting to apply %d transforms on a file with "
"%d timepoints" % (len(self), spatialimage.shape[-1])
)

for t in range(spatialimage.shape[-1]):
# Map the input coordinates on to timepoint t of the target (moving)
ycoords = Affine(self.matrix[t]).map(xcoords.T)[..., : _ref.ndim]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

# Interpolate
resampled[..., t] = ndi.map_coordinates(
spatialimage.dataobj[..., t].astype(input_dtype, copy=False),
yvoxels.T,
output=output_dtype,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
elif spatialimage.ndim in (2, 3):
ycoords = self.map(xcoords.T)[..., : _ref.ndim]
# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(xcoords.T.shape[0], len(self)), dtype=output_dtype, order="F"
)

dataobj = (
np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
if spatialimage.ndim in (2, 3)
else None
)

for t, xfm_t in enumerate(self):
# Map the input coordinates on to timepoint t of the target (moving)
ycoords = xfm_t.map(xcoords.T)[..., : _ref.ndim]

# Calculate corresponding voxel coordinates
yvoxels = ras2vox.map(ycoords)[..., : _ref.ndim]

resampled = ndi.map_coordinates(
spatialimage.dataobj.astype(input_dtype, copy=False),
# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
dataobj if dataobj is not None
else np.asanyarray(spatialimage.dataobj[..., t], dtype=input_dtype)
),
yvoxels.T,
output=output_dtype,
order=order,
Expand All @@ -503,9 +503,9 @@ def apply(
)

if isinstance(_ref, ImageGrid): # If reference is grid, reshape
newdata = resampled.reshape((len(self), *_ref.shape))
newdata = resampled.reshape(_ref.shape + (len(self), ))
moved = spatialimage.__class__(
np.moveaxis(newdata, 0, -1), _ref.affine, spatialimage.header
newdata, _ref.affine, spatialimage.header
)
moved.header.set_data_dtype(output_dtype)
return moved
Expand Down

0 comments on commit 9772710

Please sign in to comment.