Skip to content

Commit

Permalink
Merge pull request #59 from oesteban/enh/transform-map
Browse files Browse the repository at this point in the history
ENH: Support for transforms mappings (e.g., head-motion correction)
  • Loading branch information
oesteban authored Feb 25, 2020
2 parents 391f28d + da44997 commit 955cd38
Show file tree
Hide file tree
Showing 12 changed files with 504 additions and 61 deletions.
33 changes: 19 additions & 14 deletions nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@
import numpy as np
import h5py
import warnings
from nibabel.loadsave import load
from nibabel.loadsave import load as _nbload
from nibabel import funcs as _nbfuncs
from nibabel.nifti1 import intent_codes as INTENT_CODES
from nibabel.cifti2 import Cifti2Image
from scipy import ndimage as ndi

EQUALITY_TOL = 1e-5


class TransformError(ValueError):
class TransformError(TypeError):
"""A custom exception for transforms."""


Expand Down Expand Up @@ -51,7 +52,7 @@ def __init__(self, dataset):
return

if isinstance(dataset, (str, Path)):
dataset = load(str(dataset))
dataset = _nbload(str(dataset))

if hasattr(dataset, 'numDA'): # Looks like a Gifti file
_das = dataset.get_arrays_from_intent(INTENT_CODES['pointset'])
Expand Down Expand Up @@ -96,14 +97,18 @@ class ImageGrid(SampledSpatialData):
def __init__(self, image):
"""Create a gridded sampling reference."""
if isinstance(image, (str, Path)):
image = load(str(image))
image = _nbfuncs.squeeze_image(_nbload(str(image)))

self._affine = image.affine
self._shape = image.shape

self._ndim = getattr(image, 'ndim', len(image.shape))
if self._ndim == 4:
self._shape = image.shape[:3]
self._ndim = 3

self._npoints = getattr(image, 'npoints',
np.prod(image.shape))
np.prod(self._shape))
self._ndindex = None
self._coords = None
self._inverse = getattr(image, 'inverse',
Expand Down Expand Up @@ -168,13 +173,15 @@ class TransformBase(object):

__slots__ = ['_reference']

def __init__(self):
def __init__(self, reference=None):
"""Instantiate a transform."""
self._reference = None
if reference:
self.reference = reference

def __call__(self, x, inverse=False, index=0):
def __call__(self, x, inverse=False):
"""Apply y = f(x)."""
return self.map(x, inverse=inverse, index=index)
return self.map(x, inverse=inverse)

def __add__(self, b):
"""
Expand Down Expand Up @@ -246,13 +253,13 @@ def apply(self, spatialimage, reference=None,
"""
if reference is not None and isinstance(reference, (str, Path)):
reference = load(str(reference))
reference = _nbload(str(reference))

_ref = self.reference if reference is None \
else SpatialReference.factory(reference)

if isinstance(spatialimage, (str, Path)):
spatialimage = load(str(spatialimage))
spatialimage = _nbload(str(spatialimage))

data = np.asanyarray(spatialimage.dataobj)
output_dtype = output_dtype or data.dtype
Expand All @@ -279,7 +286,7 @@ def apply(self, spatialimage, reference=None,

return resampled

def map(self, x, inverse=False, index=0):
def map(self, x, inverse=False):
r"""
Apply :math:`y = f(x)`.
Expand All @@ -291,8 +298,6 @@ def map(self, x, inverse=False, index=0):
Input RAS+ coordinates (i.e., physical coordinates).
inverse : bool
If ``True``, apply the inverse transform :math:`x = f^{-1}(y)`.
index : int, optional
Transformation index
Returns
-------
Expand Down Expand Up @@ -407,7 +412,7 @@ def insert(self, i, x):
"""
self.transforms = self.transforms[:i] + _as_chain(x) + self.transforms[i:]

def map(self, x, inverse=False, index=0):
def map(self, x, inverse=False):
"""
Apply a succession of transforms, e.g., :math:`y = f_3(f_2(f_1(f_0(x))))`.
Expand Down
11 changes: 5 additions & 6 deletions nitransforms/io/fsl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Read/write FSL's transforms."""
import os
import numpy as np
from pathlib import Path
from nibabel.affines import voxel_sizes

from .base import BaseLinearTransformList, LinearParameters, TransformFileError
Expand Down Expand Up @@ -63,13 +64,11 @@ class FSLLinearTransformArray(BaseLinearTransformList):

def to_filename(self, filename):
"""Store this transform to a file with the appropriate format."""
if len(self.xforms) == 1:
self.xforms[0].to_filename(filename)
return

output_dir = Path(filename).parent
output_dir.mkdir(exist_ok=True, parents=True)
for i, xfm in enumerate(self.xforms):
with open('%s.%03d' % (filename, i), 'w') as f:
f.write(xfm.to_string())
(output_dir / '.'.join((str(filename), '%03d' % i))).write_text(
xfm.to_string())

def to_ras(self, moving=None, reference=None):
"""Return a nitransforms' internal RAS matrix."""
Expand Down
Loading

0 comments on commit 955cd38

Please sign in to comment.