Skip to content

Commit

Permalink
MRG: concatenate_raws(), concatenate_epochs(), write_evokeds() gain o…
Browse files Browse the repository at this point in the history
…n_mismatch param and raise by default (#9438)

* mne.concatenate_raws() gains on_mismatch param and raises by default

Fixes #9436

* Fix doc build [skip github][skip azp]

* Repurpose existing info comparison code

* Phrasing

* Apply suggestions from code review

Co-authored-by: Eric Larson <[email protected]>

* Address review comments

* Style

* Final reviewer comment

* Use docdict; add on_mismatch to other methods too

* Changelog & style

* Fix doc build

* Fix test

* Use verbose decorator

* Update mne/utils/docs.py [skip githib] [skip azp] [ci skip]

Co-authored-by: Eric Larson <[email protected]>

* Update changelog [skip github][skip azp]

* Fix a typo

Co-authored-by: Eric Larson <[email protected]>
  • Loading branch information
hoechenberger and larsoner authored Jun 3, 2021
1 parent e009e5c commit 3b85ebb
Show file tree
Hide file tree
Showing 8 changed files with 121 additions and 53 deletions.
4 changes: 4 additions & 0 deletions doc/changes/latest.inc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ Enhancements

- Add support for interpolating oxy and deoxyhaemoglobin data types (:gh:`9431` by `Robert Luke`_)

- :func:`mne.concatenate_raws`, :func:`mne.concatenate_epochs`, and func:`mne.write_evokeds` gained a new parameter ``on_mismatch``, which controls behavior in case not all of the supplied instances share the same device-to-head transformation (:gh:`9438` by `Richard Höchenberger`_)

Bugs
~~~~
- Fix bug with :meth:`mne.Epochs.crop` and :meth:`mne.Evoked.crop` when ``include_tmax=False``, where the last sample was always cut off, even when ``tmax > epo.times[-1]`` (:gh:`9378` **by new contributor** |Jan Sosulski|_)
Expand All @@ -61,6 +63,8 @@ Bugs

- Fix bug when computing rank from info for SSS data with only gradiometers or magnetometers (:gh:`9435` by `Alex Gramfort`_)

- :func:`mne.concatenate_raws` now raises an exception if ``raw.info['dev_head_t']`` differs between files. This behavior can be controlled using the new ``on_mismatch`` parameter (:gh:`9438` by `Richard Höchenberger`_)

API changes
~~~~~~~~~~~
- Nothing yet
55 changes: 16 additions & 39 deletions mne/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
write_double_matrix, write_complex_float_matrix,
write_complex_double_matrix, write_id, write_string,
_get_split_size, _NEXT_FILE_BUFFER, INT32_MAX)
from .io.meas_info import read_meas_info, write_meas_info, _merge_info
from .io.meas_info import (read_meas_info, write_meas_info, _merge_info,
_ensure_infos_match)
from .io.open import fiff_open, _get_next_fname
from .io.tree import dir_tree_find
from .io.tag import read_tag, read_tag_info
Expand All @@ -34,7 +35,7 @@
from .io.pick import (channel_indices_by_type, channel_type,
pick_channels, pick_info, _pick_data_channels,
_DATA_CH_TYPES_SPLIT, _picks_to_idx)
from .io.proj import setup_proj, ProjMixin, _proj_equal
from .io.proj import setup_proj, ProjMixin
from .io.base import BaseRaw, TimeMixin
from .bem import _check_origin
from .evoked import EvokedArray, _check_decim
Expand Down Expand Up @@ -3295,39 +3296,6 @@ def add_channels_epochs(epochs_list, verbose=None):
return epochs


def _compare_epochs_infos(info1, info2, name):
"""Compare infos."""
if not isinstance(name, str): # passed epochs index
name = f'epochs[{name:d}]'
info1._check_consistency()
info2._check_consistency()
if info1['nchan'] != info2['nchan']:
raise ValueError(f'{name}.info[\'nchan\'] must match')
if set(info1['bads']) != set(info2['bads']):
raise ValueError(f'{name}.info[\'bads\'] must match')
if info1['sfreq'] != info2['sfreq']:
raise ValueError(f'{name}.info[\'sfreq\'] must match')
if set(info1['ch_names']) != set(info2['ch_names']):
raise ValueError(f'{name}.info[\'ch_names\'] must match')
if len(info2['projs']) != len(info1['projs']):
raise ValueError(f'SSP projectors in {name} must be the same')
if any(not _proj_equal(p1, p2) for p1, p2 in
zip(info2['projs'], info1['projs'])):
raise ValueError(f'SSP projectors in {name} must be the same')
if (info1['dev_head_t'] is None) != (info2['dev_head_t'] is None) or \
(info1['dev_head_t'] is not None and not
np.allclose(info1['dev_head_t']['trans'],
info2['dev_head_t']['trans'], rtol=1e-6)):
raise ValueError(f'{name}.info[\'dev_head_t\'] must match. The '
'instances probably come from different runs, and '
'are therefore associated with different head '
'positions. Manually change info[\'dev_head_t\'] to '
'avoid this message but beware that this means the '
'MEG sensors will not be properly spatially aligned. '
'See mne.preprocessing.maxwell_filter to realign the '
'runs to a common head position.')


def _update_offset(offset, events, shift):
if offset == 0:
return offset
Expand All @@ -3340,7 +3308,8 @@ def _update_offset(offset, events, shift):
return offset


def _concatenate_epochs(epochs_list, with_data=True, add_offset=True):
def _concatenate_epochs(epochs_list, with_data=True, add_offset=True, *,
on_mismatch='raise'):
"""Auxiliary function for concatenating epochs."""
if not isinstance(epochs_list, (list, tuple)):
raise TypeError('epochs_list must be a list or tuple, got %s'
Expand All @@ -3366,7 +3335,8 @@ def _concatenate_epochs(epochs_list, with_data=True, add_offset=True):
shift = int((10 + tmax) * out.info['sfreq'])
events_offset = _update_offset(None, out.events, shift)
for ii, epochs in enumerate(epochs_list[1:], 1):
_compare_epochs_infos(epochs.info, info, ii)
_ensure_infos_match(epochs.info, info, f'epochs[{ii}]',
on_mismatch=on_mismatch)
if not np.allclose(epochs.times, epochs_list[0].times):
raise ValueError('Epochs must have same times')

Expand Down Expand Up @@ -3443,7 +3413,9 @@ def _finish_concat(info, data, events, event_id, tmin, tmax, metadata,
return out


def concatenate_epochs(epochs_list, add_offset=True):
@verbose
def concatenate_epochs(epochs_list, add_offset=True, *, on_mismatch='raise',
verbose=None):
"""Concatenate a list of epochs into one epochs object.
Parameters
Expand All @@ -3455,6 +3427,10 @@ def concatenate_epochs(epochs_list, add_offset=True):
Epochs sets, such that they are easy to distinguish after the
concatenation.
If False, the event times are unaltered during the concatenation.
%(on_info_mismatch)s
%(verbose)s
.. versionadded:: 0.24
Returns
-------
Expand All @@ -3466,7 +3442,8 @@ def concatenate_epochs(epochs_list, add_offset=True):
.. versionadded:: 0.9.0
"""
return _finish_concat(*_concatenate_epochs(epochs_list,
add_offset=add_offset))
add_offset=add_offset,
on_mismatch=on_mismatch))


@verbose
Expand Down
19 changes: 13 additions & 6 deletions mne/evoked.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
from .io.tree import dir_tree_find
from .io.pick import pick_types, _picks_to_idx, _FNIRS_CH_TYPES_SPLIT
from .io.meas_info import (read_meas_info, write_meas_info,
_read_extended_ch_info, _rename_list)
_read_extended_ch_info, _rename_list,
_ensure_infos_match)
from .io.proj import ProjMixin
from .io.write import (start_file, start_block, end_file, end_block,
write_int, write_string, write_float_matrix,
Expand Down Expand Up @@ -1326,7 +1327,8 @@ def _read_evoked(fname, condition=None, kind='average', allow_maxshield=False):
return info, nave, aspect_kind, comment, times, data, baseline


def write_evokeds(fname, evoked):
@verbose
def write_evokeds(fname, evoked, *, on_mismatch='raise', verbose=None):
"""Write an evoked dataset to a file.
Parameters
Expand All @@ -1337,6 +1339,10 @@ def write_evokeds(fname, evoked):
The evoked dataset, or list of evoked datasets, to save in one file.
Note that the measurement info from the first evoked instance is used,
so be sure that information matches.
%(on_info_mismatch)s
%(verbose)s
.. versionadded:: 0.24
See Also
--------
Expand All @@ -1349,12 +1355,11 @@ def write_evokeds(fname, evoked):
`~mne.Evoked` object, and will be restored when reading the data again
via `mne.read_evokeds`.
"""
_write_evokeds(fname, evoked)
_write_evokeds(fname, evoked, on_mismatch=on_mismatch)


def _write_evokeds(fname, evoked, check=True):
def _write_evokeds(fname, evoked, check=True, *, on_mismatch='raise'):
"""Write evoked data."""
from .epochs import _compare_epochs_infos
from .dipole import DipoleFixed # avoid circular import

if check:
Expand All @@ -1380,7 +1385,9 @@ def _write_evokeds(fname, evoked, check=True):
start_block(fid, FIFF.FIFFB_PROCESSED_DATA)
for ei, e in enumerate(evoked):
if ei:
_compare_epochs_infos(evoked[0].info, e.info, f'evoked[{ei}]')
_ensure_infos_match(info1=evoked[0].info, info2=e.info,
name=f'evoked[{ei}]',
on_mismatch=on_mismatch)
start_block(fid, FIFF.FIFFB_EVOKED)

# Comment is optional
Expand Down
14 changes: 10 additions & 4 deletions mne/io/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .utils import _construct_bids_filename, _check_orig_units
from .pick import (pick_types, pick_channels, pick_info, _picks_to_idx,
channel_type)
from .meas_info import write_meas_info
from .meas_info import write_meas_info, _ensure_infos_match
from .proj import setup_proj, activate_proj, _proj_equal, ProjMixin
from ..channels.channels import (ContainsMixin, UpdateChannelsMixin,
SetChannelsMixin, InterpolationMixin,
Expand Down Expand Up @@ -2397,7 +2397,8 @@ def _check_raw_compatibility(raw):


@verbose
def concatenate_raws(raws, preload=None, events_list=None, verbose=None):
def concatenate_raws(raws, preload=None, events_list=None, *,
on_mismatch='raise', verbose=None):
"""Concatenate raw instances as if they were continuous.
.. note:: ``raws[0]`` is modified in-place to achieve the concatenation.
Expand All @@ -2409,10 +2410,11 @@ def concatenate_raws(raws, preload=None, events_list=None, verbose=None):
Parameters
----------
raws : list
List of Raw instances to concatenate (in order).
List of `~mne.io.Raw` instances to concatenate (in order).
%(preload_concatenate)s
events_list : None | list
The events to concatenate. Defaults to None.
The events to concatenate. Defaults to ``None``.
%(on_info_mismatch)s
%(verbose)s
Returns
Expand All @@ -2422,6 +2424,10 @@ def concatenate_raws(raws, preload=None, events_list=None, verbose=None):
events : ndarray of int, shape (n_events, 3)
The events. Only returned if ``event_list`` is not None.
"""
for idx, raw in enumerate(raws[1:], start=1):
_ensure_infos_match(info1=raws[0].info, info2=raw.info,
name=f'raws[{idx}]', on_mismatch=on_mismatch)

if events_list is not None:
if len(events_list) != len(raws):
raise ValueError('`raws` and `event_list` are required '
Expand Down
19 changes: 19 additions & 0 deletions mne/io/fiff/tests/test_raw_fiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,25 @@ def test_multiple_files(tmpdir):
assert len(raw) == raw.last_samp - raw.first_samp + 1


@testing.requires_testing_data
@pytest.mark.parametrize('on_mismatch', ('ignore', 'warn', 'raise'))
def test_concatenate_raws(on_mismatch):
"""Test error handling during raw concatenation."""
raw = read_raw_fif(fif_fname).crop(0, 10)
raws = [raw, raw.copy()]
raws[1].info['dev_head_t']['trans'] += 0.1
kws = dict(raws=raws, on_mismatch=on_mismatch)

if on_mismatch == 'ignore':
concatenate_raws(**kws)
elif on_mismatch == 'warn':
with pytest.warns(RuntimeWarning, match='different head positions'):
concatenate_raws(**kws)
elif on_mismatch == 'raise':
with pytest.raises(ValueError, match='different head positions'):
concatenate_raws(**kws)


@testing.requires_testing_data
@pytest.mark.parametrize('mod', (
'meg',
Expand Down
52 changes: 50 additions & 2 deletions mne/io/meas_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .tag import (read_tag, find_tag, _ch_coord_dict, _update_ch_info_named,
_rename_list)
from .proj import (_read_proj, _write_proj, _uniquify_projs, _normalize_proj,
Projection)
_proj_equal, Projection)
from .ctf_comp import _read_ctf_comp, write_ctf_comp
from .write import (start_file, end_file, start_block, end_block,
write_string, write_dig_points, write_float, write_int,
Expand All @@ -34,7 +34,7 @@
from ..transforms import invert_transform, Transform, _coord_frame_name
from ..utils import (logger, verbose, warn, object_diff, _validate_type,
_stamp_to_dt, _dt_to_stamp, _pl, _is_numeric,
_check_option)
_check_option, _on_missing, _check_on_missing)
from ._digitization import (_format_dig_points, _dig_kind_proper, DigPoint,
_dig_kind_rev, _dig_kind_ints, _read_dig_fif)
from ._digitization import write_dig as _dig_write_dig
Expand Down Expand Up @@ -2478,3 +2478,51 @@ def _write_ch_infos(fid, chs, reset_range, ch_names_mapping):
for (key, (const, _, write)) in _CH_INFO_MAP.items():
write(fid, const, ch[key])
end_block(fid, FIFF.FIFFB_CH_INFO)


def _ensure_infos_match(info1, info2, name, *, on_mismatch='raise'):
"""Check if infos match.
Parameters
----------
info1, info2 : instance of Info
The infos to compare.
name : str
The name of the object appearing in the error message of the comparison
fails.
on_mismatch : 'raise' | 'warn' | 'ignore'
What to do in case of a mismatch of ``dev_head_t`` between ``info1``
and ``info2``.
"""
_check_on_missing(on_missing=on_mismatch, name='on_mismatch')

info1._check_consistency()
info2._check_consistency()

if info1['nchan'] != info2['nchan']:
raise ValueError(f'{name}.info[\'nchan\'] must match')
if set(info1['bads']) != set(info2['bads']):
raise ValueError(f'{name}.info[\'bads\'] must match')
if info1['sfreq'] != info2['sfreq']:
raise ValueError(f'{name}.info[\'sfreq\'] must match')
if set(info1['ch_names']) != set(info2['ch_names']):
raise ValueError(f'{name}.info[\'ch_names\'] must match')
if len(info2['projs']) != len(info1['projs']):
raise ValueError(f'SSP projectors in {name} must be the same')
if any(not _proj_equal(p1, p2) for p1, p2 in
zip(info2['projs'], info1['projs'])):
raise ValueError(f'SSP projectors in {name} must be the same')
if (info1['dev_head_t'] is None) != (info2['dev_head_t'] is None) or \
(info1['dev_head_t'] is not None and not
np.allclose(info1['dev_head_t']['trans'],
info2['dev_head_t']['trans'], rtol=1e-6)):
msg = (f"{name}.info['dev_head_t'] differs. The "
f"instances probably come from different runs, and "
f"are therefore associated with different head "
f"positions. Manually change info['dev_head_t'] to "
f"avoid this message but beware that this means the "
f"MEG sensors will not be properly spatially aligned. "
f"See mne.preprocessing.maxwell_filter to realign the "
f"runs to a common head position.")
_on_missing(on_missing=on_mismatch, msg=msg,
name='on_mismatch')
4 changes: 2 additions & 2 deletions mne/tests/test_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2653,14 +2653,14 @@ def test_concatenate_epochs():
epochs2 = epochs.copy()
concatenate_epochs([epochs, epochs2]) # should work
epochs2.info['dev_head_t']['trans'][:3, 3] += 0.0001
with pytest.raises(ValueError, match='dev_head_t.*must match'):
with pytest.raises(ValueError, match=r"info\['dev_head_t'\] differs"):
concatenate_epochs([epochs, epochs2])
with pytest.raises(TypeError, match='must be a list or tuple'):
concatenate_epochs('foo')
with pytest.raises(TypeError, match='must be an instance of Epochs'):
concatenate_epochs([epochs, 'foo'])
epochs2.info['dev_head_t'] = None
with pytest.raises(ValueError, match='dev_head_t.*must match'):
with pytest.raises(ValueError, match=r"info\['dev_head_t'\] differs"):
concatenate_epochs([epochs, epochs2])
epochs.info['dev_head_t'] = None
concatenate_epochs([epochs, epochs2]) # should work
Expand Down
7 changes: 7 additions & 0 deletions mne/utils/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@
.. versionadded:: 0.22
""" % (_on_missing_base,)
docdict['on_info_mismatch'] = f"""
on_mismatch : 'raise' | 'warn' | 'ignore'
{_on_missing_base} the device-to-head transformation differs between
instances.
.. versionadded:: 0.24
"""
docdict['saturated'] = """\
saturated : str
Replace saturated segments of data with NaNs, can be:
Expand Down

0 comments on commit 3b85ebb

Please sign in to comment.