Skip to content

Commit

Permalink
ENH: Add digitization class (#6369)
Browse files Browse the repository at this point in the history
Digitization is a list-like object of Digpoints with comparison functionalities.
  • Loading branch information
massich authored Jun 4, 2019
1 parent 4348125 commit dbda584
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 6 deletions.
1 change: 1 addition & 0 deletions doc/python_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ File I/O
io.read_info
io.show_fiff
digitization.DigPoint
digitization.Digitization

Base class:

Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ Current
Changelog
~~~~~~~~~

- Add :class:`mne.digitization.Digitization` class to simplify montage by `Joan Massich`_

- Add support for showing head surface (to visualize digitization fit) while showing a single-layer BEM to :func:`mne.viz.plot_alignment` by `Eric Larson`_

Bug
Expand Down
2 changes: 1 addition & 1 deletion mne/digitization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base import DigPoint
from .base import DigPoint, Digitization

__all__ = [
'DigPoint',
Expand Down
32 changes: 31 additions & 1 deletion mne/digitization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#
# License: BSD (3-clause)
import numpy as np
from copy import deepcopy

from ..transforms import _coord_frame_name
from ..io.constants import FIFF

Expand All @@ -26,7 +28,8 @@

def _format_dig_points(dig):
"""Format the dig points nicely."""
return [DigPoint(d) for d in dig] if dig is not None else dig
dig_points = [DigPoint(d) for d in dig] if dig is not None else dig
return Digitization(dig_points)


class DigPoint(dict):
Expand Down Expand Up @@ -73,3 +76,30 @@ def __eq__(self, other): # noqa: D105
return False
else:
return np.allclose(self['r'], other['r'])


class Digitization(list):
"""Represent a list of DigPoint objects.
Parameters
----------
elements : list | None
A list of DigPoint objects.
"""

def __init__(self, elements=None):

elements = list() if elements is None else elements

if not all([isinstance(_, DigPoint) for _ in elements]):
_msg = 'Digitization expected a iterable of DigPoint objects.'
raise ValueError(_msg)
else:
super(Digitization, self).__init__(deepcopy(elements))

def __eq__(self, other): # noqa: D105
if not isinstance(other, (Digitization, list)) or \
len(self) != len(other):
return False
else:
return all([ss == oo for ss, oo in zip(self, other)])
3 changes: 3 additions & 0 deletions mne/io/ctf/ctf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from .info import _compose_meas_info, _read_bad_chans, _annotate_bad_segments
from .constants import CTF

from ...digitization.base import _format_dig_points


@fill_doc
def read_raw_ctf(directory, system_clock='truncate', preload=False,
Expand Down Expand Up @@ -116,6 +118,7 @@ def __init__(self, directory, system_clock='truncate', preload=False,
# Compose a structure which makes fiff writing a piece of cake
info = _compose_meas_info(res4, coils, coord_trans, eeg)
info['dig'] += digs
info['dig'] = _format_dig_points(info['dig'])
info['bads'] += _read_bad_chans(directory, info)

# Determine how our data is distributed across files
Expand Down
3 changes: 2 additions & 1 deletion mne/io/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .proc_history import _read_proc_history, _write_proc_history
from ..transforms import invert_transform
from ..utils import logger, verbose, warn, object_diff, _validate_type
from ..digitization.base import _format_dig_points
from .compensator import get_current_comp

# XXX: most probably the functions needing this, should go somewhere else
Expand Down Expand Up @@ -1129,7 +1130,7 @@ def read_meas_info(fid, tree, clean_bads=False, verbose=None):
info['dev_ctf_t'] = Transform('meg', 'ctf_head', dev_ctf_trans)

# All kinds of auxliary stuff
info['dig'] = dig
info['dig'] = _format_dig_points(dig)
info['bads'] = bads
info._update_redundant()
if clean_bads:
Expand Down
3 changes: 3 additions & 0 deletions mne/io/tests/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from mne.utils import _TempDir, catch_logging, _raw_annot
from mne.io.meas_info import _get_valid_units

from mne.digitization import Digitization


def test_orig_units():
"""Test the error handling for original units."""
Expand Down Expand Up @@ -87,6 +89,7 @@ def _test_raw_reader(reader, test_preloading=True, test_kwargs=True, **kwargs):
full_data = raw._data
assert raw.__class__.__name__ in repr(raw) # to test repr
assert raw.info.__class__.__name__ in repr(raw.info)
assert isinstance(raw.info['dig'], (type(None), Digitization))

# gh-5604
assert _handle_meas_date(raw.info['meas_date']) >= 0
Expand Down
40 changes: 40 additions & 0 deletions mne/tests/test_digitization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# -*- coding: utf-8 -*-
# Authors: Alexandre Gramfort <[email protected]>
# Joan Massich <[email protected]>
#
# License: BSD (3-clause)
import pytest
import numpy as np
from mne.digitization import Digitization
from mne.digitization.base import _format_dig_points

dig_dict_list = [
dict(kind=_, ident=_, r=np.empty((3,)), coord_frame=_)
for _ in [1, 2, 42]
]

digpoints_list = _format_dig_points(dig_dict_list)


@pytest.mark.parametrize('data', [
pytest.param(digpoints_list, id='list of digpoints'),
pytest.param(dig_dict_list, id='list of digpoint dicts',
marks=pytest.mark.xfail(raises=ValueError)),
pytest.param(['foo', 'bar'], id='list of strings',
marks=pytest.mark.xfail(raises=ValueError)),
])
def test_digitization_constructor(data):
"""Test Digitization constructor."""
dig = Digitization(data)
assert dig == data

dig[0]['kind'] = data[0]['kind'] - 1 # modify something in dig
assert dig != data


def test_delete_elements():
"""Test deleting some Digitization elements."""
dig = Digitization(digpoints_list)
original_length = len(dig)
del dig[0]
assert len(dig) == original_length - 1
5 changes: 3 additions & 2 deletions mne/utils/_bunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Bunch-related classes."""
# Authors: Alexandre Gramfort <[email protected]>
# Eric Larson <[email protected]>
# Joan Massich <[email protected]>
#
# License: BSD (3-clause)

Expand Down Expand Up @@ -94,10 +95,10 @@ def _named_subclass(klass):
class NamedInt(_Named, int):
"""Int with a name in __repr__."""

pass
pass # noqa


class NamedFloat(_Named, float):
"""Float with a name in __repr__."""

pass
pass # noqa
2 changes: 1 addition & 1 deletion mne/utils/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,7 @@ def object_diff(a, b, pre=''):
Currently supported: dict, list, tuple, ndarray, int, str, bytes,
float, StringIO, BytesIO.
b : object
Must be same type as x1.
Must be same type as ``a``.
pre : str
String to prepend to each line.
Expand Down

0 comments on commit dbda584

Please sign in to comment.