Skip to content

Commit

Permalink
tst: ramping up coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
oesteban committed Feb 20, 2020
1 parent 0431194 commit 9cecb8e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 24 deletions.
4 changes: 3 additions & 1 deletion nitransforms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ class TransformBase(object):

__slots__ = ['_reference']

def __init__(self):
def __init__(self, reference=None):
"""Instantiate a transform."""
self._reference = None
if reference:
self.reference = reference

def __call__(self, x, inverse=False):
"""Apply y = f(x)."""
Expand Down
42 changes: 20 additions & 22 deletions nitransforms/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,16 @@ def __init__(self, matrix=None, reference=None):
[0, 0, 0, 1]])
"""
super().__init__()
if matrix is None:
matrix = np.eye(4)

self._matrix = np.array(matrix)
if self._matrix.ndim != 2:
raise TypeError('Affine should be 2D.')

if self._matrix.shape[0] != self._matrix.shape[1]:
raise TypeError('Matrix is not square.')
super().__init__(reference=reference)
self._matrix = np.eye(4)

if reference:
self.reference = reference
if matrix is not None:
matrix = np.array(matrix)
if matrix.ndim != 2:
raise TypeError('Affine should be 2D.')
elif matrix.shape[0] != matrix.shape[1]:
raise TypeError('Matrix is not square.')
self._matrix = matrix

def __eq__(self, other):
"""
Expand Down Expand Up @@ -162,7 +159,7 @@ def to_filename(self, filename, fmt='X5', moving=None):
# xform info
lt = io.LinearTransform()
lt['sigma'] = 1.
lt['m_L'] = [self.matrix]
lt['m_L'] = self.matrix
# Just for reference, nitransforms does not write VOX2VOX
lt['src'] = io.VolumeGeometry.from_image(moving)
lt['dst'] = io.VolumeGeometry.from_image(self.reference)
Expand Down Expand Up @@ -299,7 +296,7 @@ def map(self, x, inverse=False):

def to_filename(self, filename, fmt='X5', moving=None):
"""Store the transform in BIDS-Transforms HDF5 file format (.x5)."""
if fmt.lower() in ['itk', 'ants', 'elastix']:
if fmt.lower() in ('itk', 'ants', 'elastix'):
itkobj = io.itk.ITKLinearTransformArray.from_ras(self.matrix)
itkobj.to_filename(filename)
return filename
Expand All @@ -323,18 +320,19 @@ def to_filename(self, filename, fmt='X5', moving=None):
fslobj.to_filename(filename)
return filename

if fmt.lower() == 'fs':
if fmt.lower() in ('fs', 'lta'):
# xform info
lt = io.LinearTransform()
lt['sigma'] = 1.
lt['m_L'] = self.matrix
# Just for reference, nitransforms does not write VOX2VOX
lt['src'] = io.VolumeGeometry.from_image(moving)
lt['dst'] = io.VolumeGeometry.from_image(self.reference)
# to make LTA file format
lta = io.LinearTransformArray()
lta['type'] = 1 # RAS2RAS
lta['xforms'].append(lt)
for m in self.matrix:
lt = io.LinearTransform()
lt['sigma'] = 1.
lt['m_L'] = m
# Just for reference, nitransforms does not write VOX2VOX
lt['src'] = io.VolumeGeometry.from_image(moving)
lt['dst'] = io.VolumeGeometry.from_image(self.reference)
lta['xforms'].append(lt)

with open(filename, 'w') as f:
f.write(lta.to_string())
Expand Down
17 changes: 16 additions & 1 deletion nitransforms/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,30 @@ def test_linear_typeerrors(data_path):
ntl.Affine.from_filename(data_path / 'itktflist.tfm', fmt='itk')


def test_loadsave(tmp_path, data_path):
def test_loadsave_itk(tmp_path, data_path):
"""Test idempotency."""
ref_file = data_path / 'someones_anatomy.nii.gz'
xfm = ntl.load(data_path / 'itktflist2.tfm', fmt='itk')
xfm.reference = ref_file
xfm.to_filename(tmp_path / 'writtenout.tfm', fmt='itk')

assert (data_path / 'itktflist2.tfm').read_text() \
== (tmp_path / 'writtenout.tfm').read_text()


@pytest.mark.xfail(reason="Not fully implemented")
@pytest.mark.parametrize('fmt', ['itk', 'fsl', 'afni', 'lta'])
def test_loadsave(tmp_path, data_path, fmt):
"""Test idempotency."""
ref_file = data_path / 'someones_anatomy.nii.gz'
xfm = ntl.load(data_path / 'itktflist2.tfm', fmt='itk')
xfm.reference = ref_file

fname = tmp_path / '.'.join(('wrttenout', fmt))
xfm.to_filename(fname, fmt=fmt)
xfm == ntl.load(fname, fmt=fmt, reference=ref_file)


@pytest.mark.xfail(reason="Not fully implemented")
@pytest.mark.parametrize('image_orientation', ['RAS', 'LAS', 'LPS', 'oblique'])
@pytest.mark.parametrize('sw_tool', ['itk', 'fsl', 'afni', 'fs'])
Expand Down

0 comments on commit 9cecb8e

Please sign in to comment.