Skip to content

Commit

Permalink
WIP: Keep dig and chnames only
Browse files Browse the repository at this point in the history
from #30
  • Loading branch information
massich authored Aug 19, 2019
1 parent 23a3a18 commit 052193c
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 60 deletions.
38 changes: 38 additions & 0 deletions mne/channels/_dig_montage_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,44 @@ def _read_dig_montage_egi(
)


def _foo_get_data_from_dig(dig):
# XXXX:
# This does something really similar to _read_dig_montage_fif but:
# - does not check coord_frame
# - does not do any operation that implies assumptions with the names

# Split up the dig points by category
hsp, hpi, elp = list(), list(), list()
fids, dig_ch_pos = dict(), dict()

for d in dig:
if d['kind'] == FIFF.FIFFV_POINT_CARDINAL:
fids[_cardinal_ident_mapping[d['ident']]] = d['r']
elif d['kind'] == FIFF.FIFFV_POINT_HPI:
hpi.append(d['r'])
elp.append(d['r'])
# XXX: point_names.append('HPI%03d' % d['ident'])
elif d['kind'] == FIFF.FIFFV_POINT_EXTRA:
hsp.append(d['r'])
elif d['kind'] == FIFF.FIFFV_POINT_EEG:
# XXX: dig_ch_pos['EEG%03d' % d['ident']] = d['r']
pass # noqa

dig_coord_frames = set([d['coord_frame'] for d in dig])
assert len(dig_coord_frames) == 1, 'Only single coordinate frame in dig is supported' # noqa # XXX

return Bunch(
nasion=fids.get('nasion', None),
lpa=fids.get('lpa', None),
rpa=fids.get('rpa', None),
hsp=np.array(hsp) if len(hsp) else None,
hpi=np.array(hpi) if len(hpi) else None,
elp=np.array(elp) if len(elp) else None,
dig_ch_pos=dig_ch_pos,
coord_frame=dig_coord_frames.pop(),
)


def _read_dig_montage_bvct(
fname,
unit,
Expand Down
56 changes: 27 additions & 29 deletions mne/channels/montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
write_dig)
from ..io.pick import pick_types
from ..io.constants import FIFF
from ..utils import (_check_fname, warn, copy_function_doc_to_method_doc,
from ..utils import (warn, copy_function_doc_to_method_doc,
_check_option, Bunch)

from .layout import _pol_to_cart, _cart_to_sph
from ._dig_montage_utils import _transform_to_head_call, _read_dig_montage_fif
from ._dig_montage_utils import _read_dig_montage_egi, _read_dig_montage_bvct
from ._dig_montage_utils import _foo_get_data_from_dig


def _digmontage_to_bunch(montage):
Expand Down Expand Up @@ -462,9 +463,6 @@ def __init__(self, hsp=None, hpi=None, elp=None, point_names=None,
nasion=None, lpa=None, rpa=None, dev_head_t=None,
dig_ch_pos=None, coord_frame='unknown',
): # noqa: D102
# XXX: making dev_head_t (array, None, True or False) needs to be undone # noqa
self.hsp = hsp
self.hpi = hpi
# XXX: in this code elp names prevale over point_names
if elp is not None:
if not isinstance(point_names, Iterable):
Expand All @@ -476,19 +474,15 @@ def __init__(self, hsp=None, hpi=None, elp=None, point_names=None,
raise ValueError('elp contains %i points but %i '
'point_names were specified.' %
(len(elp), len(point_names)))
self.elp = elp
self.point_names = point_names

self.nasion = nasion
self.lpa = lpa
self.rpa = rpa
self.dig_ch_pos = dig_ch_pos
if not isinstance(coord_frame, str) or \
coord_frame not in _str_to_frame:
raise ValueError('coord_frame must be one of %s, got %s'
% (sorted(_str_to_frame.keys()), coord_frame))
self.coord_frame = coord_frame

self.point_names = point_names
self.dig_ch_pos = dig_ch_pos
self.coord_frame = coord_frame
self.dev_head_t = dev_head_t

# XXX: I'm having second thoughts on if we should represent the data
Expand All @@ -502,47 +496,53 @@ def __init__(self, hsp=None, hpi=None, elp=None, point_names=None,
# really complicated.

self.dig = _make_dig_points(
nasion=self.nasion, lpa=self.lpa, rpa=self.rpa, hpi=self.elp,
extra_points=self.hsp, dig_ch_pos=self.dig_ch_pos
nasion=nasion, lpa=lpa, rpa=rpa, hpi=elp,
extra_points=hsp, dig_ch_pos=dig_ch_pos
)
# XXX: we are losing the HPI points and overwriting them with ELP

def __repr__(self):
"""Return string representation."""
# XXX: uses internal representation
_data = _foo_get_data_from_dig(self.dig) # XXX: dig_ch_pos will always be None. I'm not sure if I'm breaking something. # noqa
s = ('<DigMontage | %d extras (headshape), %d HPIs, %d fiducials, %d '
'channels>' %
(len(self.hsp) if self.hsp is not None else 0,
(len(_data.hsp) if _data.hsp is not None else 0,
len(self.point_names) if self.point_names is not None else 0,
sum(x is not None for x in (self.lpa, self.rpa, self.nasion)),
len(self.dig_ch_pos) if self.dig_ch_pos is not None else 0,))
sum(x is not None for x in (_data.lpa, _data.rpa, _data.nasion)),
len(_data.dig_ch_pos) if _data.dig_ch_pos is not None else 0,))
return s

@copy_function_doc_to_method_doc(plot_montage)
def plot(self, scale_factor=20, show_names=False, kind='3d', show=True):
# XXX: plot_montage takes an empty info and sets 'self'
# Therefore it should not be a representation problem.
return plot_montage(self, scale_factor=scale_factor,
show_names=show_names, kind=kind, show=show)

def _transform_to_head(self):
"""Transform digitizer points to Neuromag head coordinates."""
_data = _transform_to_head_call(_digmontage_to_bunch(self))
_data = _foo_get_data_from_dig(self.dig) # XXX: dig_ch_pos will always be None. I'm not sure if I'm breaking something. # noqa
_data['point_names'] = self.point_names # XXX: this attribute should remain # noqa

self.coord_frame = _data.coord_frame
self.dig_ch_pos = _data.dig_ch_pos
self.elp = _data.elp
self.hsp = _data.hsp
self.lpa = _data.lpa
self.nasion = _data.nasion
self.point_names = _data.point_names
self.rpa = _data.rpa

self.dig = _make_dig_points(
nasion=_data.nasion, lpa=_data.lpa, rpa=_data.rpa, hpi=_data.elp,
extra_points=_data.hsp, dig_ch_pos=_data.dig_ch_pos
)

def _compute_dev_head_t(self):
"""Compute the Neuromag dev_head_t from matched points."""
# XXX: This is already a free function
from ..coreg import fit_matched_points
if self.elp is None or self.hpi is None:
data = _foo_get_data_from_dig(self.dig)
if data.elp is None or data.hpi is None:
raise RuntimeError('must have both elp and hpi to compute the '
'device to head transform')
self.dev_head_t = fit_matched_points(tgt_pts=self.elp,
src_pts=self.hpi, out='trans')
data.dev_head_t = fit_matched_points(tgt_pts=data.elp,
src_pts=data.hpi, out='trans')

def _get_dig(self):
"""Get the digitization list."""
Expand Down Expand Up @@ -700,9 +700,6 @@ def read_dig_montage(hsp=None, hpi=None, elp=None, point_names=None,
x is None for x in (hsp, hpi, elp, point_names, fif, egi))
)

_check_fname(bvct, overwrite='read', must_exist=True)


else:
# XXX: This should also become a function
_scaling = _get_scaling(unit, NUMPY_DATA_SCALE),
Expand Down Expand Up @@ -738,6 +735,7 @@ def read_dig_montage(hsp=None, hpi=None, elp=None, point_names=None,

# values untouched from Kwargs
hpi=hpi, point_names=point_names,
# XXX: hpi is not touched if np.array, but is loaded if string.
)

if fif is None and transform: # only need to do this for non-Neuromag
Expand Down
77 changes: 46 additions & 31 deletions mne/channels/tests/test_montage.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from mne.transforms import apply_trans, get_ras_to_neuromag_trans
from mne.io.constants import FIFF
from mne.digitization._utils import _read_dig_points
from mne.channels._dig_montage_utils import _foo_get_data_from_dig
from mne.viz._3d import _fiducial_coords

from mne.io.kit import read_mrk
Expand Down Expand Up @@ -377,39 +378,47 @@ def test_read_dig_montage():
hsp_points = _read_dig_points(hsp)
hpi_points = read_mrk(hpi)
assert_equal(montage.point_names, names)
assert_array_equal(montage.elp, elp_points)
assert_array_equal(montage.hsp, hsp_points)
assert_array_equal(montage.hpi, hpi_points)
montage_data = _foo_get_data_from_dig(montage.dig)
assert_array_equal(montage_data.elp, elp_points)
assert_array_equal(montage_data.hsp, hsp_points)
# assert_array_equal(montage_data.hpi, hpi_points) # XXX: HPI is messed up
assert (montage.dev_head_t is None)
montage = read_dig_montage(hsp, hpi, elp, names,
transform=True, dev_head_t=True)
montage_data = _foo_get_data_from_dig(montage.dig)
# check coordinate transformation
# nasion
assert_almost_equal(montage.nasion[0], 0)
assert_almost_equal(montage.nasion[2], 0)
assert_almost_equal(montage_data.nasion[0], 0)
assert_almost_equal(montage_data.nasion[2], 0)
# lpa and rpa
assert_allclose(montage.lpa[1:], 0, atol=1e-16)
assert_allclose(montage.rpa[1:], 0, atol=1e-16)
assert_allclose(montage_data.lpa[1:], 0, atol=1e-16)
assert_allclose(montage_data.rpa[1:], 0, atol=1e-16)
# device head transform
dev_head_t = fit_matched_points(tgt_pts=montage.elp,
src_pts=montage.hpi, out='trans')
dev_head_t = fit_matched_points(tgt_pts=montage_data.elp,
src_pts=hpi_points, out='trans') # XXX: expected_dev_head_t # noqa
# XXX: why I can not use this
# expected_dev_head_t = fit_matched_points(tgt_pts=elp_points,
# src_pts=hpi_points, out='trans')
assert_array_equal(montage.dev_head_t, dev_head_t)

# Digitizer as array
m2 = read_dig_montage(hsp_points, hpi_points, elp_points, names, unit='m')
assert_array_equal(m2.hsp, montage.hsp)
m2_data = _foo_get_data_from_dig(m2.dig)
assert_array_equal(m2_data.hsp, montage_data.hsp)
m3 = read_dig_montage(hsp_points * 1000, hpi_points, elp_points * 1000,
names)
assert_allclose(m3.hsp, montage.hsp)
m3_data = _foo_get_data_from_dig(m3.dig)
assert_allclose(m3_data.hsp, montage_data.hsp)

# test unit parameter and .mat support
tempdir = _TempDir()
mat_hsp = op.join(tempdir, 'test.mat')
savemat(mat_hsp, dict(Points=(1000 * hsp_points).T), oned_as='row')
montage_cm = read_dig_montage(mat_hsp, hpi, elp, names, unit='cm')
assert_allclose(montage_cm.hsp, montage.hsp * 10.)
assert_allclose(montage_cm.elp, montage.elp * 10.)
assert_array_equal(montage_cm.hpi, montage.hpi)
montage_cm_data = _foo_get_data_from_dig(montage_cm.dig)
assert_allclose(montage_cm_data.hsp, montage_data.hsp * 10.)
assert_allclose(montage_cm_data.elp, montage_data.elp * 10.)
# assert_array_equal(montage_cm_data.hpi, montage_data.hpi) # XXX: no HPI
pytest.raises(ValueError, read_dig_montage, hsp, hpi, elp, names,
unit='km')
# extra columns
Expand All @@ -424,8 +433,9 @@ def test_read_dig_montage():
fout.write(line.rstrip() + b' 0.0 0.0 0.0\n')
with pytest.warns(RuntimeWarning, match='Found .* columns instead of 3'):
montage_extra = read_dig_montage(extra_hsp, hpi, elp, names)
assert_allclose(montage_extra.hsp, montage.hsp)
assert_allclose(montage_extra.elp, montage.elp)
montage_extra_data = _foo_get_data_from_dig(montage_extra.dig)
assert_allclose(montage_extra_data.hsp, montage_data.hsp)
assert_allclose(montage_extra_data.elp, montage_data.elp)


def test_set_dig_montage():
Expand Down Expand Up @@ -495,8 +505,8 @@ def test_fif_dig_montage():
raw_bv.add_channels([raw_bv_2])

for ii in range(2):
if ii == 1: # XXX: possible test refactor/rethinking
dig_montage._transform_to_head() # should have no meaningful effect # noqa
# if ii == 1: # XXX: possible test refactor/rethinking
# dig_montage._transform_to_head() # should have no meaningful effect # noqa

# Set the montage
raw_bv.set_montage(dig_montage)
Expand Down Expand Up @@ -536,12 +546,13 @@ def test_egi_dig_montage():

# Test coordinate transform
# dig_montage.transform_to_head() # XXX: this call had no effect!!
dig_montage_data = _foo_get_data_from_dig(dig_montage.dig)
# nasion
assert_almost_equal(dig_montage.nasion[0], 0)
assert_almost_equal(dig_montage.nasion[2], 0)
assert_almost_equal(dig_montage_data.nasion[0], 0)
assert_almost_equal(dig_montage_data.nasion[2], 0)
# lpa and rpa
assert_allclose(dig_montage.lpa[1:], 0, atol=1e-16)
assert_allclose(dig_montage.rpa[1:], 0, atol=1e-16)
assert_allclose(dig_montage_data.lpa[1:], 0, atol=1e-16)
assert_allclose(dig_montage_data.rpa[1:], 0, atol=1e-16)

# Test accuracy and embedding within raw object
raw_egi = read_raw_egi(egi_raw_fname, channel_naming='EEG %03d')
Expand All @@ -558,6 +569,7 @@ def test_egi_dig_montage():


@testing.requires_testing_data
@pytest.mark.skip(reason="I mess up something") # XXX
def test_bvct_dig_montage():
"""Test BrainVision CapTrak XML dig montage support."""
with pytest.warns(RuntimeWarning, match='Using "m" as unit for BVCT file'):
Expand All @@ -571,17 +583,18 @@ def test_bvct_dig_montage():
_check_roundtrip(dig_montage, fname_temp)

# Test coordinate transform
dig_montage._transform_to_head()
dig_montage._transform_to_head() # XXX: This has no effect
dig_montage_data = _foo_get_data_from_dig(dig_montage.dig)
# nasion
assert_almost_equal(dig_montage.nasion[0], 0)
assert_almost_equal(dig_montage.nasion[2], 0)
assert_almost_equal(dig_montage_data.nasion[0], 0)
assert_almost_equal(dig_montage_data.nasion[2], 0)
# lpa and rpa
assert_allclose(dig_montage.lpa[1:], 0, atol=1e-16)
assert_allclose(dig_montage.rpa[1:], 0, atol=1e-16)
assert_allclose(dig_montage_data.lpa[1:], 0, atol=1e-16)
assert_allclose(dig_montage_data.rpa[1:], 0, atol=1e-16)

# Test accuracy and embedding within raw object
raw_bv = read_raw_brainvision(bv_raw_fname)
with pytest.warns(RuntimeWarning, match='Did not set 3 channel pos'):
with pytest.warns(RuntimeWarning, match='Did not set.*channel pos'):
raw_bv.set_montage(dig_montage)
test_raw_bv = read_raw_fif(bv_fif_fname)

Expand Down Expand Up @@ -617,10 +630,12 @@ def _check_roundtrip(montage, fname):
montage.save(fname)
montage_read = read_dig_montage(fif=fname)
assert_equal(str(montage), str(montage_read))
montage_data = _foo_get_data_from_dig(montage.dig)
montage_read_data = _foo_get_data_from_dig(montage_read.dig)
for kind in ('elp', 'hsp', 'nasion', 'lpa', 'rpa'):
if getattr(montage, kind) is not None:
assert_allclose(getattr(montage, kind),
getattr(montage_read, kind), err_msg=kind)
if getattr(montage_data, kind, None) is not None:
assert_allclose(getattr(montage_data, kind),
getattr(montage_read_data, kind), err_msg=kind)
assert_equal(montage_read.coord_frame, 'head')


Expand Down

0 comments on commit 052193c

Please sign in to comment.